本文正在参与「金石方案 . 分割6万现金大奖」
作者简介:秃头小苏,致力于用最浅显的语言描述问题
往期回忆:对立生成网络GAN系列——GAN原理及手写数字生成小事例 对立生成网络GAN系列——DCGAN简介及人脸图像生成事例
近期目标:写好专栏的每一篇文章
支撑小苏:点赞、收藏⭐、留言
pytorch保存与加载模型详解篇
写在前面
最近,看到不少小伙伴问pytorch怎么保存和加载模型,其实这部分pytorch官网介绍的也是很清楚的,感兴趣的点击☞☞☞了解概况
可是必定有很多人是不愿意看官网的,所以我仍是花一篇文章来为咱们介绍介绍。当然了,在介绍中我会参加自己的一些了解,让咱们有一个更深的知道。假如预备好了的话,就让咱们开端吧。⏳⏳⏳
模型保存与加载
pytorch中介绍了几种不同的模型保存和加载办法,我会在下文逐个为咱们介绍。首要先让咱们来随意界说一个模型,如下:【用的是pytorch官网的例子】
# 模型界说
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
界说好模型结构后,咱们能够实例化这个模型:
#模型初始化
model = TheModelClass()
模型初始化过后,咱们就一起来看看模型保存和加载的办法吧。
办法1
办法1是官方引荐的一种办法,咱们直接来看代码好了,如下:
# 保存模型
torch.save(model.state_dict(), './model/model_state_dict.pth')
该办法后边的参数'./model/model_state_dict.pth'
为模型的保存途径,模型后缀名官方引荐运用.pth
和.pt
,当然了,你取别的后缀名也是完全可行的。☘☘☘
介绍了模型的保存,下面就来看看办法1是怎么加载模型的。【这儿我说明一点,模型保存往往是在练习中进行的,而模型加载大都用在模型推理中,它们存在两个文件中,故咱们在推理过程中要先实列化模型】
# 加载模型
model_test1 = TheModelClass() # 加载模型时应先实例化模型
# load_state_dict()函数接收一个字典,所以不能直接将'./model/model_state_dict.pth'传入,而是先运用load函数将保存的模型参数反序列化
model_test1.load_state_dict(torch.load('./model/model_state_dict.pth'))
model_test1.eval() # 模型推理时设置
在上述的代码注释中我有写到,咱们运用load_state_dict()
加载模型时先需求运用load办法将保存的模型参数==反序列化==,load后的结果是一个字典,这时就能够经过load_state_dict()
办法来加载了。
这儿我来简略说一下我了解的反序列化,其和序列化是相对应的一个概念。序列化便是把内存中的数据保存到磁盘中,像咱们运用torch.save()
办法保存模型便是序列化;而反序列化则是将硬盘中的数据加载到内存当中,显然咱们加载模型的过程便是反序列化过程。【大致的意思如下图所示,偶尔在水群的时分看到一个画图软件,是不是还挺好看的】
办法2
办法2十分简略,直接上代码:
# 保存模型
torch.save(model, './model/model.pt') #这儿咱们保存模型的后缀名取.pt
# 加载模型
model_test2 = torch.load('./model/model.pt')
model_test2.eval() # 模型推理时设置
可是这种办法是不引荐运用的,因为你运用这种办法保存模型,然后再加载时会遇到各式各样的过错。为了加深咱们了解,咱们来看这样的一个例子。文件的结构如下图所示:
models.py
文件中存储的是模型的界说,其坐落文件夹models下。save_model.py
文件中写的是保存模型的代码,如下:
from models.models import TheModelClass
from torch import optim
import torch
#模型初始化
model = TheModelClass()
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# ## 保存加载办法2——save/load
# # 保存模型
# torch.save(models, './models/models.pt')
履行此文件后,会生成models.pt
文件,咱们在履行load_mode.py
文件即可完成加载,load_mode.py
内容如下:
from models.models import TheModelClass
import torch
## 加载办法2
# 加载模型
model_test2 = TheModelClass()
model_test2 = torch.load('./models/models.pt')
model_test2.eval() # 模型推理时设置
print(model_test2)
此刻咱们能够正常加载。但假如咱们将models文件夹修改为model,如下:
此刻咱们在运用如下代码加载模型的话就会呈现过错:
from models.models import TheModelClass
import torch
## 加载办法2
# 加载模型
model_test2 = TheModelClass()
model_test2 = torch.load('./model/models.pt') #这儿需求修改一下文件途径
model_test2.eval() # 模型推理时设置
print(model_test2)
呈现这种过错的原因是运用办法2进行模型保存的时分会把模型结构界说文件途径记载下来,加载的时分就会依据途径解析它然后装载参数;当把模型界说文件途径修改以后,运用torch.load(path)就会报错。
其实运用办法2进行模型的保存和加载还会存在各种问题,感兴趣的能够看看这篇博文。总归,在咱们往后的运用中,尽量不要用办法2来加载模型。
办法3
pytorch还为咱们提供了一种模型保存与加载的办法——checkpoint。这种办法保存的是一个字典,假如咱们程序在运转中因为某种原因异常中止,那么这种办法能够很方便的让咱们接着上次练习,正因为这样,我十分引荐咱们运用这种办法进行模型的保存与加载。下面就让咱们一起来看看办法3是怎么运用的吧!!!
首要,咱们同样运用torch.save
来保存模型,可是这儿保存的是一个字典,里面能够填入你需求保存的参数,如下:
# 保存checkpoint
torch.save({
'epoch':epoch,
'model_state_dict':model.state_dict(),
'optimizer_state_dict':optimizer.state_dict(),
'loss':loss
}, './model/model_checkpoint.tar' #这儿的后缀名官方引荐运用.tar
)
接着咱们来看看怎么加载checkpoint,代码如下:
# 加载checkpoint
model_checkpoint = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('./model/model_checkpoint.tar') # 先反序列化模型
model_checkpoint.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
看了我上文的介绍,咱们是否知道怎么运用checkpoint
了呢,我想咱们都会觉得这个不是很难,但要自己写可能仍是不好掌握,那么第一次就让我来带领咱们看看怎么在代码中运用checkpoint
吧!!!
这节我选用cifar10数据集完成物体分类的例子,我的这篇博文对其进行了详细介绍,那么这儿介绍checkpoint
我将利用这个demo来为咱们讲解。首要咱们直接来看模型保存的完好代码,如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、预备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、建立神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, input):
input = self.model1(input)
return input
#4、创立网络模型
net = Net()
#5、设置丢失函数、优化器
#丢失函数
loss_fun = nn.CrossEntropyLoss() #交叉熵
loss_fun = loss_fun.to(device)
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate) #SGD:梯度下降算法
#6、设置网络练习中的一些参数
total_train_step = 0 #记载总计练习次数
total_test_step = 0 #记载总计测验次数
Max_epoch = 10 #规划练习轮数
#7、开端进行练习
for epoch in range(Max_epoch):
print("---第{}轮练习开端---".format(epoch))
net.train() #开端练习,不是有必要的,在网络中有BN,dropout时需求
#因为练习集数据较多,这儿我没用练习集练习,而是选用测验集(test_dataset_loader)当练习集,但思维是共同的
for data in test_dataset_loader:
imgs, targets = data
targets = targets.to(device)
outputs = net(imgs)
#比较输出与真实值,计算Loss
loss = loss_fun(outputs, targets)
#反向传达,调整参数
optimizer.zero_grad() #每次让梯度重置
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 50 == 0:
print("---第{}次练习完毕, Loss:{})".format(total_train_step, loss.item()))
if (epoch+1) % 2 == 0:
# 保存checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, './model/model_checkpoint_epoch_{}.tar'.format(epoch) # 这儿的后缀名官方引荐运用.tar
)
if epoch > 5:
print("---意外中断---")
break
整个流程和这篇文章根本共同,不清楚的主张先花几分钟阅览一下哈。主要区别便是在最后保存模型的时分我运用了checkpoint
进行保存,且两个epoch保存一次。当epoch=6时,我设置了一个break模拟程序意外中断,中断后能够来看一下终端的输出信息,如下图所示:
咱们能够看到在进行第6轮循环时,程序中断了,此刻最新的保存的模型是第五次练习结果,如下:
一起注意到第5次练习完毕的loss在2.0左右,假如咱们下次接着练习,丢失应该是在2.0附近。
好了,上面因为一些糟糕的原因导致程序中断了,现在我想接着上次练习的结果继续练习,我该怎么办呢?代码如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、预备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、建立神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, input):
input = self.model1(input)
return input
#4、创立网络模型
net = Net()
#5、设置丢失函数、优化器
#丢失函数
loss_fun = nn.CrossEntropyLoss() #交叉熵
loss_fun = loss_fun.to(device)
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate) #SGD:梯度下降算法
#6、设置网络练习中的一些参数
total_train_step = 0 #记载总计练习次数
total_test_step = 0 #记载总计测验次数
Max_epoch = 10 #规划练习轮数
##########################################################################################
# 加载checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar') # 先反序列化模型
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################
#7、开端进行练习
for epoch in range(start_epoch+1, Max_epoch):
print("---第{}轮练习开端---".format(epoch))
net.train() #开端练习,不是有必要的,在网络中有BN,dropout时需求
for data in test_dataset_loader:
imgs, targets = data
targets = targets.to(device)
outputs = net(imgs)
#比较输出与真实值,计算Loss
loss = loss_fun(outputs, targets)
#反向传达,调整参数
optimizer.zero_grad() #每次让梯度重置
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 50 == 0:
print("---第{}次练习完毕, Loss:{})".format(total_train_step, loss.item()))
if (epoch+1) % 2 == 0:
# 保存checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, './model/model_checkpoint_epoch_{}.tar'.format(epoch) # 这儿的后缀名官方引荐运用.tar
)
这儿的代码相较之前的多了一个加载checkpoint
的过程,我将其截取出来,如下图所示:
经过加载checkpoint
咱们就保存了之前练习的参数,从而完成断点续练习,咱们直接来看履行此代码的结果,如下图所示:
从上图能够看出咱们的练习是从第6轮开端的,并且初始的loss为1.99,和2.0挨近。这就说明了咱们现已完成了中断后恢复练习的操作。
这儿我简略的说两句,上文介绍checkpoint
的用法时,练习中断和练习恢复我是放在两个文件中的进行的,可是在实践中咱们必定是在一个文件中运转,那这该怎么办呢?其实办法很简略啦,咱们只需求设置一个if条件将加载checkpoint
的部分放在练习文件中,然后设置一个参数来控制if条件的履行即可。具体细节我就不给咱们介绍了,假如有不明白的谈论区见吧!!!
总结
这部分仍是蛮简略的,但一些细节仍是需求咱们自行考量,我就为咱们介绍到这儿啦,希望咱们都能够有所收成吧。
如若文章对你有所帮助,那就