一、网络结构

image-1677074247320

网络模型计算方法

image-1677486778494

from torch import nn

class mnist_net(nn.Module):
    def __init__(self):
        super(mnist_net, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, 5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

二、读取数据

import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Resize((32,32))])
train_set=torchvision.datasets.MNIST(root="./data", train = True, transform=transform, download=True)
test_set=torchvision.datasets.MNIST(root="./data", train = False, transform=transform, download=True)

train_set_size=len(train_set)
test_set_size=len(test_set)

train_dataloader = DataLoader(dataset=train_set, batch_size=64)
test_dataloader = DataLoader(dataset=test_set, batch_size=64)

显示数据

print("训练集的长度{}".format(train_set_size))
print("测试集的长度{}".format(test_set_size))
images,target=next(iter(train_dataloader))
print(images.shape)
to_pil_image = torchvision.transforms.ToPILImage()
image = to_pil_image(images[0])
image.show()
print(target[0])

训练过程

net = mnist_net()
if torch.cuda.is_available():
    net = net.cuda()
epoch = 100
loss_fn=nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn=loss_fn.cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
loss_list = []
acc_list = []
net.train()
for i in range(epoch):
    print("------第{}轮训练开始-----".format(i))
    train_step = 0;
    total_loss = 0
    for data in train_dataloader:
        image, label = data
        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()
        output = net(image)
        loss = loss_fn(output,label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_step += 1
        total_loss += loss.item()
        if train_step%100 == 0:
            print("训练次数:{},loss:{}".format(train_step, loss))
    total_acc = 0
    if (i+1)%10 == 0:
        torch.save(net, "model_mnist_{}.pth".format(i+1))
    with torch.no_grad():
        for data in test_dataloader:
            img, true_y = data
            if torch.cuda.is_available():
                img = img.cuda()
                true_y = true_y.cuda()
            output = net(img)
            acc = (output.argmax(1) == true_y).sum()
            total_acc += acc
        print("正确率:{}".format(total_acc/test_set_size))
    loss_list.append(total_loss)
    acc_list.append((total_acc/test_set_size).detach().cpu().numpy())

image-1677074439435

结果

import matplotlib.pyplot as plt
plt.figure()
plt.subplot(1, 2, 1)
plt.plot([i for i in range(epoch)], loss_list, label='loss')
plt.xlabel("epochs")
plt.ylabel("total_loss")
plt.subplot(1, 2, 2)
plt.plot([i for i in range(epoch)], acc_list, label='acc')
plt.xlabel("epochs")
plt.ylabel("acc")
plt.savefig("resout.png")
plt.show()

image-1677074665661

测试

import torchvision.transforms
from torch import nn
from  mnist_net import mnist_net
from PIL import  Image
import  torch
path = "./eight.png"
img = Image.open(path)
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Resize((32,32)),
                                           torchvision.transforms.Grayscale(num_output_channels=1)])
img = transform(img)

net = torch.load("./model/model_mnist_90.pth")
img = torch.reshape(img,(1,1,32,32))
net.eval()
with torch.no_grad():
    pred=net(img)
img = torch.reshape(img,(1,32,32))
img = to_pil_image(img)
img.show()
print("预测结果:{}".format(pred.argmax(1)))