#! /usr/bin/env python # # This is taken from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html import matplotlib.pyplot as plt import numpy as np from NetForStatistics import * # (1) Init (trainloader,testloader) = getDataset() classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') net = Net("cpu") # (4) Training net.eval() tloss = evalmodel(net,trainloader) trainloss = [ tloss ] trainloss2 = [ tloss ] testloss = [evalmodel(net,testloader)] for epoch in range(3): # loop over the dataset multiple times tloss = trainmodel(net,trainloader) trainloss.append(tloss) trainloss2.append(evalmodel(net,trainloader)) testloss.append(evalmodel(net,testloader)) print('Finished Training') print("loss during training:", trainloss) print("loss on training set:", trainloss2) print("loss on test set:", testloss) x = list(range(len(testloss))) plt.plot( x, trainloss, "b", x, trainloss2, "k", x, testloss, "r" ) plt.savefig( "plot.svg" )