#! /usr/bin/env python
#
# This is taken from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
import matplotlib.pyplot as plt
import numpy as np
from NetForStatistics import *
# (1) Init
(trainloader,testloader) = getDataset()
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = Net("cpu")
# (4) Training
net.eval()
tloss = evalmodel(net,trainloader)
trainloss = [ tloss ]
trainloss2 = [ tloss ]
testloss = [evalmodel(net,testloader)]
for epoch in range(3): # loop over the dataset multiple times
tloss = trainmodel(net,trainloader)
trainloss.append(tloss)
trainloss2.append(evalmodel(net,trainloader))
testloss.append(evalmodel(net,testloader))
print('Finished Training')
print("loss during training:", trainloss)
print("loss on training set:", trainloss2)
print("loss on test set:", testloss)
x = list(range(len(testloss)))
plt.plot( x, trainloss, "b", x, trainloss2, "k", x, testloss, "r" )
plt.savefig( "plot.svg" )