本文为稀土技能社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
beginning
上期给咱们介绍了常识蒸馏的中心原理,时刻有点久不知道咱们还记不记得。一句话概括就是——将教师模型的常识经过soft targets传递给轻量化的学生模型,然后进步学生模型功能,减少核算需求。还没看过或者忘记的小伙伴赶紧来看看叭➡从教师到学生:神奇的“常识蒸馏”之旅——原理详解篇。明白了原理之后,咱们今天就来实战一下,看看教师模型、学生模型是怎样用代码构建的,学习如何用常识蒸馏来进步学生模型的功能,让咱们对蒸馏有一个更加直观的感受。除此之外,上期还有一些额外的常识蒸馏常识点没讲完,这次也一口气介绍完叭。废话不多说啦,假如你也对此感兴趣,想着手完成常识蒸馏看看作用,让咱们一同愉快的学习叭
1.常识蒸馏代码实战
在介绍代码之前呢,给咱们共享两个好用的常识蒸馏代码库:
- 模型压缩工具箱MMRazor开源库
- 12个SOTA常识蒸馏算法的pytorch复现
第一个开源库包含剪枝、蒸馏、神经架构搜索和量化;第二个是大神发表的RepDistiller,里边有12种用pytorch完成的流行常识蒸馏算法。都对常识蒸馏的学习很有协助滴
1.1不同温度下softmax可视化
经过上期的学习,咱们知道了蒸馏温度T越高,soft targets就越soft,所以温度是至关重要滴,那首先咱们就来学着画一下不同温度关于softmax的影响叭
- 导入工具包:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
- 输入各类别的logits:
logits = np.array([-5,2,7,9])
4个类别的logit,你能够了解成是神经网络最终一层的线性分类层输出的4个类别的logit,它们有正有负有大有小。
- 一般softmax( T=1 ):
softmax_1=np.exp(logits) / sum(np.exp(logits))
softmax_1
plt.plot(softmax_1,label='softmax_1')
plt.legend()
plt.show()
一般的softmax是蒸馏温度等于1,softmax_1=np.exp(logits) / sum(np.exp(logits))代表着把e−5+e2+e7+e9e^{-5}+e^2+e^7+e^9作为分母,e−5、e2、e7、e9e^{-5}、e^2、e^7、e^9别离作为分子算出来的各个数值,其代表了每一个softmax的后验概率。此刻画出的图如下
- 常识蒸馏softmax( T=3 ):
plt.plot(softmax_1,label='T=1')
T=3
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3,label='T=3')
T=5
softmax_5 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_5,label='T=5')
T=10
softmax_10 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_10,label='T=10')
T=100
softmax_100 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_100,label='T=100')
plt.xticks(np.arange(4), ['Cat', 'Dog','Donkey','Horse'])
plt.legend()
plt.show()
别离尝试温度T=3,T=5,T=10,T=100,画出它们的图如下所示,能够发现T越大,soft targets越soft,贫富差距就越小;T越小,两极分化就越大。所以T的选取很重要,若是过小的话就和没有蒸馏是相同的,过大又会陷入平均主义。
1.2载入数据集
下面就以MNIST数据集为例,利用pytorch从头练习教师网络、从头练习学生网络,并用常识蒸馏练习学生网络比较功能。
- 导入工具包:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transform
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
#设置随机种子,便于复现
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True
- 载入MNIST数据集:
#载入数据集
train_dataset = torchvision.datasets.MNIST(
root="dataset/",
train=True,
transform=transforms.ToTensor(),
download=True
)
#载入测验集
test_dataset = torchvision.datasets.MNIST(
root="dataset/",
train=False,
transform=transforms.ToTensor(),
download=True
)
#生成dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
先导入工具包,一般的代码是要放在云gpu上的(最好是);然后载入练习集和测验集,生成练习集的DataLoader和测验集的DataLoader。
1.3构建并练习教师模型
构建教师模型:
class TeacherModel(nn.Module):
def __init__(self, in_channels=1,num_classes=10):
super(TeacherModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784,1200)
self.fc2 = nn.Linear(1200,1200)
self.fc3 = nn.Linear(1200,num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1,784)
x = self.fc1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
结构一个教师网络,这个教师网络有三层隐含层,每一层都加了dropout,避免过拟合。第一层是把输入的MNIST中784个像素映射到1200个神经元,第二层是把1200个神经元映射成1200个神经元,第三层是把1200个神经元映射成10个类别。
从头练习教师模型:
model = TeacherModel()
model = model.to(device)
summary(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 6
for epoch in range(epochs):
model.train()
#练习集上练习模型权重
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
#前向猜测
preds = model(data)
loss = criterion(preds, targets)
#反向传达,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
#测验集上评价模型功能
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))
teacher_model = model
首先指定一个穿插熵分类丢失函数CrossEntropyLoss,指定优化器和学习率,开端练习6轮,每一次练习都是从前向再反向,每一轮之后再在测验集上评价模型的功能。运行之后看到如下所示的成果,准确率为0.9762(PS:其实这些代码都是很简单的基础常识,在前面的学习中也详解讲过啦,这儿就不再细说辽)
1.4构建并练习学生模型
构建学生模型:
class StudentModel(nn.Module):
def __init__(self, in_channels=1,num_classes=10):
super(StudentModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784,20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, num_classes)
def forward(self, x):
x = x.view(-1,784)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
构建的学生模型就要小得多啦,它的每一层只要20个神经元,构建办法和上面的相同。
从头练习学生模型:
model = StudentModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
model.train()
#练习集上练习模型权重
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
#前向猜测
preds = model(data)
loss = criterion(preds, targets)
#反向传达,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
#测验集上评价模型功能
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))
student_model_scratch = model
从头练习学生模型和上面练习教师模型也是相同的,最终运行得到的成果如下,准确率只要0.8986,所以咱们要用常识蒸馏来练习学生模型,进步它的功能。
1.5常识蒸馏练习学生模型
常识蒸馏练习学生模型:
#预备预练习好的教师模型
teacher_model.eval()
#预备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()
#蒸馏温度
temp = 7
#hard_loss
hard_loss = nn.CrossEntropyLoss()
#hard_loss 权重
alpha = 0.3
# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(model.paramaters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
#练习集上练习模型权重
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
#教师模型猜测
with torch.no_grad():
teacher_preds = teacher_model(data)
#学生模型猜测
student_preds = model(data)
#核算hard_loss
student_loss = hard_loss(student_preds, targets)
#核算蒸馏后的猜测成果及soft_loss
ditillation_loss = soft_loss(
F.softmax(student_preds / temp, dim=1),
F.softmax(teacher_preds / temp, dim=1)
)
#将hard_loss和soft_loss加权求和
loss = alpha * student_loss + (1-alpha) * ditillation_loss
#反向传达,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
#测验集上评价模型功能
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))
蒸馏温度选为7,hard_loss是一个一般的分类穿插熵丢失函数,而soft_loss是一个KL散度(差不多也是穿插熵丢失函数)。练习时也是从前向后反向,前向时先获取教师网络的猜测成果,对教师网络猜测成果进行蒸馏和softmax,然后把学生网络温度为temp和教师网络温度为temp时别离算出来softmax,一同作为soft_loss,算出来一个总的丢失函数loss = alpha * student_loss + (1-alpha) * ditillation_loss。其他反向传达、评价功能等步骤和刚刚是相同的。运行得到的成果如下,能够看到准确率相比没蒸馏前有了进步(虽然进步不大,但这仅仅一个小demo,详细的还要进行调参优化)
其实,咱们并不能用最终的分数来衡量它常识蒸馏是好是坏,由于常识蒸馏并不仅仅能涨点,并不仅仅能压缩模型进步功能,它还有许多潜在的好处,比如咱们能够用海量的无监督的大数据集,能够避免过拟合,能够完成常识从大模型到小模型的搬迁,这才是关于常识蒸馏咱们要把握的点✨✨✨
2.常识蒸馏的补充常识
常识蒸馏为什么work???\color{blue}{常识蒸馏为什么work???}
学完了常识蒸馏的原理与代码完成之后,小伙伴们有没有仔细想一想常识蒸馏为什么会有用呢比较让人服气的一个机了解释是:
如上图,绿色是教师网络的求解空间,由于教师网络比较大嘛,所以它的表达能力和拟合能力比较强;学生网络是比较小的蓝色区域,它的表达能力比较差,求解空间比较小。练习教师网络之后,假如教师网络收敛到了红圈里边,假如咱们独自练习学生网络(不蒸馏,直接用本来的数据集和标签),那么学生网络会收敛到黄色区域。毫无疑问,此刻的学生网络和教师网络是有一定间隔的,假如单纯的用hard label来练习学生网络,它是无法到达教师网络的水平滴;但咱们加上常识蒸馏(橙色区域)之后,教师网络就会引导这个黄圈,告知它怎么去收敛,那么它最终会收敛到这个橙圈里,而橙圈是教师网络的一个子集,它离原生的学生网络的收敛空间更挨近,离教师网络越近,作用就越好
常识蒸馏与搬迁学习???\color{blue}{常识蒸馏与搬迁学习???}
学完常识蒸馏后,小伙伴们有没有感觉和搬迁学习很像,究竟常识蒸馏是从教师模型搬迁到学生模型上的,那它俩到底是一个什么关系腻其实,常识蒸馏和搬迁学习是没关系滴,它俩的概念是正交的,搬迁学习指的是把一个范畴练习的模型,让其泛化到另一个范畴,比如说用X胸片的数据集去练习一个原本识别猫狗的模型,然后猫狗模型就慢慢学会去分辩x光胸片的各种病,这种把猫狗域搬迁到了医疗域属于搬迁学习(侧重于范畴的搬迁);而常识蒸馏是把一个模型的常识搬迁到另一个模型上,通常是大模型搬迁到小模型(侧重于模型的搬迁)。所以这俩是能够穿插的,能够用常识蒸馏完成搬迁学习……也能够完全没有任何关系
ending
看到这儿相信盆友们都对如何用代码完成常识蒸馏有了一个全面深化的了解啦,小伙伴们学废了没呀很高兴能把学到的常识以文章的形式共享给咱们假如你也觉得我的共享对你有所协助,please一键三连嗷!!!下期见