1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| import os import sys import json import torch import torch.nn as nn from torchvision import transforms, datasets, utils import matplotlib.pyplot as plt import numpy as np import torch.optim as optim from tqdm import tqdm from model import AlexNet def main():
device = torch.device("cuda:0" if torch.cuda.is\_available() else "cpu") print("using {} device.".format(device))
data\_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} data\_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) image\_path = os.path.join(data\_root, "data\_set", "flower\_data") assert os.path.exists(image\_path), "{} path does not exist.".format(image\_path)
train\_dataset = datasets.ImageFolder(root=os.path.join(image\_path, "train"), transform=data\_transform["train"]) train\_num = len(train\_dataset)
flower\_list = train\_dataset.class\_to\_idx cla\_dict = dict((val, key) for key, val in flower\_list.items())
json\_str = json.dumps(cla\_dict, indent=4) with open('class\_indices.json', 'w') as json\_file: json\_file.write(json\_str)
batch\_size = 32
nw = min([os.cpu\_count(), batch\_size if batch\_size > 1 else 0, 8]) print('Using {} dataloader workers every process'.format(nw))
train\_loader = torch.utils.data.DataLoader(train\_dataset, batch\_size=batch\_size, shuffle=True, num\_workers=nw)
validate\_dataset = datasets.ImageFolder(root=os.path.join(image\_path, "val"), transform=data\_transform["val"]) val\_num = len(validate\_dataset)
validate\_loader = torch.utils.data.DataLoader(validate\_dataset, batch\_size=batch\_size, shuffle=True, num\_workers=nw) print("using {} images for training, {} images for validation.".format(train\_num, val\_num))
net = AlexNet(num\_classes=5, init\_weights=True) net.to(device) loss\_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0002) epochs = 10 save\_path = './AlexNet.pth' best\_acc = 0.0 train\_steps = len(train\_loader) for epoch in range(epochs):
net.train() running\_loss = 0.0 train\_bar = tqdm(train\_loader, file=sys.stdout) for step, data in enumerate(train\_bar): images, labels = data optimizer.zero\_grad() outputs = net(images.to(device)) loss = loss\_function(outputs, labels.to(device)) loss.backward() optimizer.step()
running\_loss += loss.item() train\_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
net.eval() acc = 0.0 with torch.no\_grad(): val\_bar = tqdm(validate\_loader, file=sys.stdout) for val\_data in val\_bar: val\_images, val\_labels = val\_data outputs = net(val\_images.to(device)) predict\_y = torch.max(outputs, dim=1)[1] acc += torch.eq(predict\_y, val\_labels.to(device)).sum().item() val\_accurate = acc / val\_num print('[epoch %d] train\_loss: %.3f val\_accuracy: %.3f' % (epoch + 1, running\_loss / train\_steps, val\_accurate))
if val\_accurate > best\_acc: best\_acc = val\_accurate torch.save(net.state\_dict(), save\_path) print('Finished Training') if \_\_name\_\_ == '\_\_main\_\_': main()
|