持续创造,加快成长!这是我参与「日新方案 10 月更文挑战」的第32天,点击查看活动详情

前语

本文将介绍如安在 PyTorch 中构建一个简单的卷积神经网络,并练习它运用 MNIST 数据集辨认手写数字,这将能够被看做是图像辨认的 “Hello, World!”;

在 【项目实战】MNIST 手写数字辨认(上) 中,我现已介绍过了怎么配置环境,准备数据集以及运用数据集,接下来将要进行构建网络、练习模型、评价模型、优化模型等;

构建网络

现在让咱们持续构建咱们的网络。咱们将运用两个二维卷积层,然后是两个全连接(或线性)层。作为激活函数,咱们将挑选校正线性单元(简称 ReLU),作为正则化的手段,咱们将运用两个 dropout 层。在 PyTorch 中,构建网络的一种好方法是为咱们希望构建的网络创建一个新类。让咱们在这儿导入一些子模块以获得更易读的代码。

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

从广义上讲,咱们能够以为 torch.nn 层包含可练习的参数,而 torch.nn.functional 是纯函数式的。 forward() 传递界说了咱们运用给定层和函数核算输出的方法。在前向传递中的某处打印张量以便利调试是非常好的。这在尝试更杂乱的模型时会派上用场。请注意,前向传递能够利用例如一个成员变量甚至数据自身来确认执行路径——它也能够运用多个参数!

现在让咱们初始化网络和优化器。

network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

注意:假如咱们运用 GPU 进行练习,咱们还应该将网络参数发送到 GPU,例如 network.cuda()。在将网络参数传递给优化器之前,将它们传输到适当的设备非常重要,不然优化器将无法以正确的方法盯梢它们。

练习模型

是时分树立咱们的练习循环了。

首要,咱们要确保咱们的网络处于练习模式。

然后咱们每个 epoch 对一切练习数据进行一次迭代。

再由 DataLoader 加载单个批次。咱们需求运用 optimizer.zero_grad() 手动将梯度设置为零,因为 PyTorch 默认会累积梯度。然后发生网络的输出(前向传递)并核算输出和地面实况标签之间的负对数似然丢失。

现在,backward() 调用收集了一组新的梯度,咱们运用 optimizer.step() 将其传播回每个网络参数。

咱们还将经过一些打印输出盯梢进展。为了稍后创建一个美丽的练习曲线,咱们还创建了两个列表来保存练习和测验丢失。在 x 轴上,咱们希望显现网络在练习期间看到的练习示例的数量。

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

咱们将在开端练习之前运转一次测验循环,看看咱们仅运用随机初始化的网络参数实现了什么样的准确度丢失。你能猜出咱们在这种情况下的准确性怎么吗?

def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
                (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
            torch.save(network.state_dict(), f'{BASEPATH}/results/model.pth')
            torch.save(optimizer.state_dict(), f'{BASEPATH}/results/optimizer.pth')

神经网络模块和优化器能够运用 .state_dict() 保存和加载它们的内部状况。有了这个,咱们能够经过调用 .load_state_dict(state_dict),持续从曾经保存的状况字典中练习。

现在为咱们的测验循环。在这儿,咱们总结了测验丢失并盯梢正确分类的数字以核算网络的准确性。

def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

运用上下文管理器 no_grad() 咱们能够避免在核算图中存储生成网络输出的核算。

是时分进行培训了,在循环 n_epochs 之前,咱们将手动添加一个 test() 调用,以运用随机初始化的参数评价咱们的模型。

test()
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

【项目实战】MNIST 手写数字识别(下)

评价模型

仅经过 3 个 epoch 的练习,咱们就现已成功地在测验集上达到了 97% 的准确率!咱们从随机初始化的参数开端,正如预期的那样,在开端练习之前,测验集的准确率只有大约 10%。

制作一下咱们的练习曲线:

fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')

【项目实战】MNIST 手写数字识别(下)

看起来咱们甚至能够持续练习几个 epoch!

但在此之前,让咱们再看几个例子,就像咱们之前所做的一样,并比较模型的输出。

with torch.no_grad():
    output = network(example_data)
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Prediction: {}".format(
        output.data.max(1, keepdim=True)[1][i].item()))
    plt.xticks([])
    plt.yticks([])

【项目实战】MNIST 手写数字识别(下)

咱们模型的预测好像与这些示例相符!

持续练习

现在让咱们持续练习网络,或许更切当地说,看看咱们怎么从咱们在第一次练习运转期间保存的 state_dicts 持续练习。咱们将初始化一组新的网络和优化器。

continued_network = Net()
continued_optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                                momentum=momentum)

运用 .load_state_dict() 能够加载前次保存时网络和优化器的内部状况。

network_state_dict = torch.load(f'{BASEPATH}/results/model.pth')
continued_network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(f'{BASEPATH}/results/optimizer.pth')
continued_optimizer.load_state_dict(optimizer_state_dict)

再次运转一个练习循环,应该从咱们离开的地方持续练习。

要查看这一点,让咱们简单地运用与曾经相同的列表来盯梢丢失值

因为咱们为看到的练习示例数量构建测验计数器,因而咱们必须在此处手动追加。

for i in range(4,9):
    test_counter.append(i*len(train_loader.dataset))
    train(i)
    test()

【项目实战】MNIST 手写数字识别(下)

咱们再次看到测验集准确性从一个时期到另一个时期的添加(慢得多)。让咱们将其可视化以进一步查看练习进展。

fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')

【项目实战】MNIST 手写数字识别(下)

这看起来依然是一条相当平滑的学习曲线,就像咱们最初会练习 8 个 epoch 一样!请记住,咱们仅仅从第 5 个红点开端将值附加到相同的列表中。

由此咱们能够得出两个定论:

  1. 从查看点内部状况持续按预期工作。
  2. 咱们好像依然没有遇到过拟合问题!看起来咱们的 dropout 层在规范化模型方面做得很好。

跋文

MNIST 手写数字辨认的内容到这儿就完毕了;

PyTorch 和 TorchVision 构建了一个新环境,用它来分类 MNIST 数据会集的手写数字,并希望运用 PyTorch 开宣布良好的直觉。

上篇精讲:【项目实战】MNIST 手写数字辨认(上)

我是,期待你的关注;

创造不易,请多多支持;

系列专栏:项目实战 AI