#!/usr/bin/env python3

import torch
import torch.nn as nn, torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from import DataLoader, TensorDataset
import matplotlib.pyplot as plt

import numpy as np
import sklearn.metrics as skl


def plotImages(train_images,train_labels,fn="test.png"):
    Plot the six first images in the given dataset.
    for i in range(0, 6):
        plt.subplot(160 + (i+1))
        plt.imshow(train_images[i].reshape(28,28), cmap=plt.get_cmap('gray'))
    if fn:

def getTensorDataSet(images,labels,debug=DEBUG):
    Transform the dataset given as images and labels into the
    tensor format used by PyTorch.

    From Listing 6-3 (Ketkar 2021).
    images_tensor = torch.tensor(images)/255.0
    images_tensor = images_tensor.view(-1,1,28,28)
    labels_tensor = torch.tensor(labels)
    if debug:
        print("Labels Shape:",labels_tensor.shape)
        print("Images Shape:",images_tensor.shape)
    return TensorDataset(images_tensor, labels_tensor)

# Listing 6-4. Defining the CNN and the Helper Functions

class ConvNet(nn.Module):
    Simple convolutional network for demonstration.
    def __init__(self, num_classes=10):
        "The constructor defines the elements of the network."
        super(ConvNet, self).__init__()

        self.conv_unit_1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.conv_unit_2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),

        self.fc1 = nn.Linear(7*7*32, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        "The forward method defines how data flows through the network."
        out = self.conv_unit_1(x)
        out = self.conv_unit_2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = F.log_softmax(out,dim=1)
        return out

class ML:
    "Class to set up a CNN system with all the necessary components."
    def __init__(self):
        "The constructor defines the machine learning system. "
        self.device = None # torch.device('cpu')
        self.model = ConvNet(10)
        if self.device: self.model =
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) 
    def addData(self,train_tensor,val_tensor):
        """Load Train and Validation TensorDatasets into the data generator 
        for Training"""
        self.train_loader = DataLoader(train_tensor, batch_size=64,
                 num_workers=2, shuffle=True)
        self.val_loader = DataLoader(val_tensor, batch_size=64,
                 num_workers=2, shuffle=True)
    def trainModel(self,num_epochs=5):
        """Train the ML system.

        Return loss and accuracy for every epoch, both for the training
        and validation sets.
        total_step = len(self.train_loader)
        tr = self.evaluate(self.train_loader)
        vr = self.evaluate(self.val_loader)
        tl = [tr]
        vl = [vr]
        print(f'Average Training Loss: {tr[0]:.4f} '
                  f'Accuracy: {100.0*tr[1]:.3f}%)')
        print(f'Average Validation Loss: {vr[0]:.4f} '
                    f'Accuracy: {100.0*vr[1]:.3f}%)')

        for epoch in range(num_epochs):
            for i, (images, labels) in enumerate(self.train_loader):
                if self.device:
                    images =
                    labels =

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)


            print ('Epoch [{}/{}], Loss: {:.4f}' .format(epoch+1,
                   num_epochs, loss.item()))
            tr = self.evaluate(self.train_loader)
            vr = self.evaluate(self.val_loader)
            print(f'Average Training Loss: {tr[0]:.4f} '
                  f'Accuracy: {100.0*tr[1]:.3f}%)')
            print(f'Average Validation Loss: {vr[0]:.4f} '
                    f'Accuracy: {100.0*vr[1]:.3f}%)')
        return ( vl, tl )

    def evaluate(self,data_loader):
        Evaluate the model on the given data set.
        Return the average loss (according to the loss function and
        the accuracy.
        loss = 0
        correct = 0
        for data, target in data_loader:
            if self.device:
                data =
                target =
            output = self.model(data)
            loss += F.cross_entropy(output, target, size_average=False).data.item()
            predicted =, keepdim=True)[1]
            correct += (target.reshape(-1,1) == predicted.
        loss /= len(data_loader.dataset)
        return ( loss, correct / len(data_loader.dataset) )
        # print(f'Average Loss: {loss:.4f}, Accuracy: {correct}/{len(data_loader.dataset)} ({100. * correct / len(data_loader.dataset):.3f}%)')
    def training_predictions(self):
        Get predictions on the the training set, returning a pair
        of actual classes and predicted class.
        return self.make_predictions(self.train_loader)
    def make_predictions(self,data_loader=None):
        Get predictions on the the validation set, returning a pair
        of actual classes and predicted class.

        The `data_loader` argument can be given to make predictions on 
        a different dataset.
        if data_loader == None: data_loader = self.val_loader
        test_preds = torch.LongTensor()
        actual = torch.LongTensor()
        for data, target in data_loader:
            if self.device: data =
            output = self.model(data)
            # Predict  output/Take the index of the output with max value
            preds = output.cpu().data.max(1, keepdim=True)[1]

            # Combine tensors from each batch
            test_preds =, preds), dim=0)
            actual =,target),dim=0)
        return actual,test_preds
    def getData( self, fn="kaggle/train.csv", testsize=0.2 ):
        Load a CSV file and split it into training and validation sets.
        The CSV contains a flat file of images, i.e. each 28*28 image is
        flattened into a row of 784 colums (1 column represents a pixel value).
        For CNN, we would need to reshape this to our desired shape.
        Sample data can be found at 
        From Listing 6-3 (Ketkar 2021).
        train_df = pd.read_csv(fn)
        train_labels = train_df['label'].values
        train_images = (train_df.iloc[:,1:].values).astype('float32')
        self.train_images, self.val_images, \
                self.train_labels, self.val_labels = train_test_split(
                  train_images, train_labels, random_state=2020, test_size=testsize)
        # Reshape the flat row into [#images,#Channels,#Width, Height]
        # grayscale -> just 1 channel
        self.train_images = self.train_images.reshape(
                self.train_images.shape[0],1,28, 28)
        self.val_images = self.val_images.reshape(
                self.val_images.shape[0],1,28, 28)
        self.addData( getTensorDataSet(self.train_images, self.train_labels),
                      getTensorDataSet(self.val_images,self.val_labels) )

    def plotImages(self):
        """Plot the six first images of the training set"""
        return plotImages(self.train_images,self.train_labels)

if __name__ == "__main__":

    ml = ML()

# Normally, we would use a large training set; taking only 20% of
# the data for testing is a typical rule of thumb.  However, it
# is interesting to see how the network behaves with very small
# training sets.  Here we have taken 99.84% of the data for testing,
# leaving 67 images for training.  You should play with the numbers.

    ml.getData( fn="kaggle/train.csv", testsize=0.9984 )

# For large test sets, 250 epochs is a lot, but with small test sets,
# the results may be quite erratic, and we use a lot of epochs to 
# see if it stabilises.

    (vl, tl) = ml.trainModel(num_epochs=250)

# The training method returns loss and accuracy per epoch, both
# for the training set (tl) and the validation set (vl).
# We plot the accuracies below.

    vl = [ x[1].item() for x in vl ]
    tl = [ x[1].item() for x in tl ]
    xrange = range(len(vl))

    plt.plot( xrange, vl, "+r", label="Validation" )
    plt.plot( xrange, tl, "xb", label="Training" )
    plt.savefig( "accuracy.png" )

# The confusion matrix is a useful tool to assess the performance.
# This is calculated for us by the SciKit library.

    actual, predicted = ml.make_predictions()
    actual = np.array(actual).reshape(-1,1)
    predicted = np.array(predicted).reshape(-1,1)
    print("Validation Accuracy-",round(skl.accuracy_score(actual, predicted),4)*100)
    print("\nConfusion Matrix\n",skl.confusion_matrix(actual,predicted))