lenet/train.py
#!/usr/bin/env python
from lenet import LeNet5
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import onnx
def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad()
output = net(images)
loss = criterion(output, labels)
loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1)
if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
loss.backward()
optimizer.step()
def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum()
avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
data_train = MNIST('./data',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3)
cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
}
for epoch in range(1, 16):
train(epoch)
test()
dummy_input = torch.randn(1, 1, 32, 32, requires_grad=True)
torch.onnx.export(net, dummy_input, 'lenet.onnx')
onnx_model = onnx.load('lenet.onnx')
onnx.checker.check_model(onnx_model)