一、网络结构
网络模型计算方法
from torch import nn
class mnist_net(nn.Module):
def __init__(self):
super(mnist_net, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 32, 5, stride=1, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, stride=1, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, stride=1, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4, 64),
nn.Linear(64, 10)
)
def forward(self,x):
x = self.model(x)
return x
二、读取数据
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((32,32))])
train_set=torchvision.datasets.MNIST(root="./data", train = True, transform=transform, download=True)
test_set=torchvision.datasets.MNIST(root="./data", train = False, transform=transform, download=True)
train_set_size=len(train_set)
test_set_size=len(test_set)
train_dataloader = DataLoader(dataset=train_set, batch_size=64)
test_dataloader = DataLoader(dataset=test_set, batch_size=64)
显示数据
print("训练集的长度{}".format(train_set_size))
print("测试集的长度{}".format(test_set_size))
images,target=next(iter(train_dataloader))
print(images.shape)
to_pil_image = torchvision.transforms.ToPILImage()
image = to_pil_image(images[0])
image.show()
print(target[0])
训练过程
net = mnist_net()
if torch.cuda.is_available():
net = net.cuda()
epoch = 100
loss_fn=nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss_fn=loss_fn.cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
loss_list = []
acc_list = []
net.train()
for i in range(epoch):
print("------第{}轮训练开始-----".format(i))
train_step = 0;
total_loss = 0
for data in train_dataloader:
image, label = data
if torch.cuda.is_available():
image = image.cuda()
label = label.cuda()
output = net(image)
loss = loss_fn(output,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_step += 1
total_loss += loss.item()
if train_step%100 == 0:
print("训练次数:{},loss:{}".format(train_step, loss))
total_acc = 0
if (i+1)%10 == 0:
torch.save(net, "model_mnist_{}.pth".format(i+1))
with torch.no_grad():
for data in test_dataloader:
img, true_y = data
if torch.cuda.is_available():
img = img.cuda()
true_y = true_y.cuda()
output = net(img)
acc = (output.argmax(1) == true_y).sum()
total_acc += acc
print("正确率:{}".format(total_acc/test_set_size))
loss_list.append(total_loss)
acc_list.append((total_acc/test_set_size).detach().cpu().numpy())
结果
import matplotlib.pyplot as plt
plt.figure()
plt.subplot(1, 2, 1)
plt.plot([i for i in range(epoch)], loss_list, label='loss')
plt.xlabel("epochs")
plt.ylabel("total_loss")
plt.subplot(1, 2, 2)
plt.plot([i for i in range(epoch)], acc_list, label='acc')
plt.xlabel("epochs")
plt.ylabel("acc")
plt.savefig("resout.png")
plt.show()
测试
import torchvision.transforms
from torch import nn
from mnist_net import mnist_net
from PIL import Image
import torch
path = "./eight.png"
img = Image.open(path)
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((32,32)),
torchvision.transforms.Grayscale(num_output_channels=1)])
img = transform(img)
net = torch.load("./model/model_mnist_90.pth")
img = torch.reshape(img,(1,1,32,32))
net.eval()
with torch.no_grad():
pred=net(img)
img = torch.reshape(img,(1,32,32))
img = to_pil_image(img)
img.show()
print("预测结果:{}".format(pred.argmax(1)))