#! /usr/bin/env python
#
# This is taken from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
import matplotlib.pyplot as plt
import numpy as np
from NetForStatistics import *
# (1) Init
(trainloader,testloader) = getDataset()
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = Net()
# net = Net("cpu")
# (4) Training
net.eval()
tloss = evalmodel(net,trainloader)
trainloss = [ tloss ]
trainloss2 = [ tloss ]
testloss = [evalmodel(net,testloader)]
for epoch in range(12): # loop over the dataset multiple times
tloss = trainmodel(net,trainloader)
trainloss.append(tloss)
# trainloss2.append(evalmodel(net,trainloader))
# testloss.append(evalmodel(net,testloader))
def errorrate(net,dataloader):
tcount = 0
terror = 0
net.eval()
for i, data in enumerate(dataloader, 0):
inputs, labels = (data[0].to(net.device), data[1].to(net.device))
outputs = net(inputs)
_, pred = torch.max(outputs.data,1)
error = (labels != pred).sum()
print(labels,pred)
tcount += len(labels)
terror += error
return (terror/tcount,tcount)
rE,N = errorrate(net,trainloader)
rE = rE.item()
hatsigma = np.sqrt( rE*(1-rE)/N ).item()
print( f"rE = {rE}; hatsigma = {hatsigma}" )