前语
于 11 月底正式开课的分散模型课程正在火热进行中,在中国社区成员们的协助下,咱们组织了「抱抱脸中文本地化志愿者小组」并完成了分散模型课程的中文翻译,感谢 @darcula1993、@XhrLeokk、@hoi2022、@SuSung-boy 对课程的翻译!
假如你还没有开端课程的学习,咱们主张你从 榜首单元:分散模型简介 开端。
分散模型从零到一
这个 Notebook 咱们将展现相同的进程(向数据添加噪声、创建模型、练习和采样),并尽或许简略地在 PyTorch 中从头开端完成。然后,咱们将这个「玩具示例」与 diffusers 版别进行比较,并重视两者的差异以及改善之处。这儿的方针是了解不同的组件和其间的规划决策,以便在检查新的完成时能够快速确定要害思维。
让咱们开端吧!
有时,只考虑一些业务最简略的状况会有助于更好地理解其作业原理。咱们将在本笔记本中测验这一点,从“玩具”分散模型开端,看看不同的部分是怎样作业的,然后再检查它们与更杂乱的完成有何不同。
你将跟随本文的 Notebook 学习到
- 损坏进程(向数据添加噪声)
- 什么是 UNet,以及怎样从零开端完成一个极小的 UNet
- 分散模型练习
- 抽样理论
然后,咱们将比较咱们的版别与 diffusers 库中的 DDPM 完成的差异
- 对小型 UNet 的改善
- DDPM 噪声方案
- 练习方针的差异
- timestep 调理
- 抽样办法
这个笔记本适当深化,假如你对从零开端的深化研讨不感爱好,能够放心地跳过!
还值得注意的是,这儿的大多数代码都是出于说明的意图,我不主张直接将其用于您自己的作业(除非您只是为了学习意图而测验改善这儿展现的示例)。
准备环境与导入:
!pipinstall-qdiffusers
importtorch
importtorchvision
fromtorchimportnn
fromtorch.nnimportfunctionalasF
fromtorch.utils.dataimportDataLoader
fromdiffusersimportDDPMScheduler,UNet2DModel
frommatplotlibimportpyplotasplt
device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")
print(f'Usingdevice:{device}')
数据
在这儿,咱们将运用一个十分小的经典数据集 mnist 来进行测试。假如您想在不改动任何其他内容的状况下给模型一个略微困难一点的应战,请运用 torchvision.dataset
,FashionMNIST 应作为替代品。
dataset=torchvision.datasets.MNIST(root="mnist/",train=True,download=True,transform=torchvision.transforms.ToTensor())
train_dataloader=DataLoader(dataset,batch_size=8,shuffle=True)
x,y=next(iter(train_dataloader))
print('Inputshape:',x.shape)
print('Labels:',y)
plt.imshow(torchvision.utils.make_grid(x)[0],cmap='Greys');
该数据会集的每张图都是一个数字的 28×28 像素的灰度图,像素值的规模是从 0 到 1。
损坏进程
假设你没有读过任何分散模型的论文,但你知道这个进程会添加噪声。你会怎样做?
咱们或许想要一个简略的办法来操控损坏的程度。那么,假如咱们要引入一个参数来操控输入的“噪声量”,那么咱们会这么做:
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
假如 amount = 0,则返回输入而不做任何更改。假如 amount = 1,咱们将得到一个纯粹的噪声。经过这种办法将输入与噪声混合,咱们将输出坚持在相同的规模(0 to 1)。
咱们能够很容易地完成这一点(可是要注意 tensor 的 shape,以防被广播 (broadcasting) 机制不正确的影响到):
defcorrupt(x,amount):
"""Corrupttheinput`x`bymixingitwithnoiseaccordingto`amount`"""
noise=torch.rand_like(x)
amount=amount.view(-1,1,1,1)#Sortshapesobroadcastingworks
returnx*(1-amount)+noise*amount
让咱们来可视化一下输出的成果,以了解是否符合咱们的预期:
#Plottingtheinputdata
fig,axs=plt.subplots(2,1,figsize=(12,5))
axs[0].set_title('Inputdata')
axs[0].imshow(torchvision.utils.make_grid(x)[0],cmap='Greys')
#Addingnoise
amount=torch.linspace(0,1,x.shape[0])#Lefttoright->morecorruption
noised_x=corrupt(x,amount)
#Plottinfthenoisedversion
axs[1].set_title('Corrupteddata(--amountincreases-->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0],cmap='Greys');
当噪声量接近 1 时,咱们的数据开端看起来像纯随机噪声。但关于大多数的噪声状况下,您仍是能够很好地识别出数字。你认为这是最佳的吗?
模型
咱们想要一个模型,它能够接纳 28px 的噪声图画,并输出相同形状的猜测。一个比较流行的挑选是一个叫做 UNet 的架构。开端被发明用于医学图画中的分割任务,UNet 由一个“紧缩途径”和一个“扩展途径”组成。“紧缩途径”会使经过该途径的数据被紧缩,而经过“扩展途径”会将数据扩展回原始维度(类似于自动编码器)。模型中的残差衔接也答应信息和梯度在不同层级之间流动。
一些 UNet 的规划在每个阶段都有杂乱的 blocks,但关于这个玩具 demo,咱们只会构建一个最简略的示例,它接纳一个单通道图画,并经过下行途径上的三个卷积层(图和代码中的 down_layers)和上行途径上的 3 个卷积层,在下行和上行层之间具有残差衔接。咱们将运用 max pooling 进行下采样和 nn.Upsample
用于上采样。某些比较杂乱的 UNets 的规划会运用带有可学习参数的上采样和下采样 layer。下面的结构图大致展现了每个 layer 的输出通道数:
代码完成如下:
classBasicUNet(nn.Module):
"""AminimalUNetimplementation."""
def__init__(self,in_channels=1,out_channels=1):
super().__init__()
self.down_layers=torch.nn.ModuleList([
nn.Conv2d(in_channels,32,kernel_size=5,padding=2),
nn.Conv2d(32,64,kernel_size=5,padding=2),
nn.Conv2d(64,64,kernel_size=5,padding=2),
])
self.up_layers=torch.nn.ModuleList([
nn.Conv2d(64,64,kernel_size=5,padding=2),
nn.Conv2d(64,32,kernel_size=5,padding=2),
nn.Conv2d(32,out_channels,kernel_size=5,padding=2),
])
self.act=nn.SiLU()#Theactivationfunction
self.downscale=nn.MaxPool2d(2)
self.upscale=nn.Upsample(scale_factor=2)
defforward(self,x):
h=[]
fori,linenumerate(self.down_layers):
x=self.act(l(x))#Throughthelayerntheactivationfunction
ifi<2:#Forallbutthethird(final)downlayer:
h.append(x)#Storingoutputforskipconnection
x=self.downscale(x)#Downscalereadyforthenextlayer
fori,linenumerate(self.up_layers):
ifi>0:#Forallexceptthefirstuplayer
x=self.upscale(x)#Upscale
x+=h.pop()#Fetchingstoredoutput(skipconnection)
x=self.act(l(x))#Throughthelayerntheactivationfunction
returnx
咱们能够验证输出 shape 是否如咱们期望的那样与输入相同:
net=BasicUNet()
x=torch.rand(8,1,28,28)
net(x).shape
torch.Size([8, 1, 28, 28])
该网络有 30 多万个参数:
sum([p.numel()forpinnet.parameters()])
309057
您能够测验更改每个 layer 中的通道数或测验不同的结构规划。
练习模型
那么,模型到底应该做什么呢?同样,对这个问题有各种不同的观点,但关于这个演示,让咱们挑选一个简略的框架:给定一个损坏的输入 noisy_x
,模型应该输出它对本来 x
的最佳猜想。咱们将经过均方误差将猜测与真实值进行比较。
咱们现在能够测验练习网络了。
- 获取一批数据
- 添加随机噪声
- 将数据输入模型
- 将模型猜测与洁净图画进行比较,以核算 loss
- 更新模型的参数
你能够自在进行修改来测验取得更好的成果!
#Dataloader(youcanmesswithbatchsize)
batch_size=128
train_dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
#Howmanyrunsthroughthedatashouldwedo?
n_epochs=3
#Createthenetwork
net=BasicUNet()
net.to(device)
#Ourlossfinction
loss_fn=nn.MSELoss()
#Theoptimizer
opt=torch.optim.Adam(net.parameters(),lr=1e-3)
#Keepingarecordofthelossesforlaterviewing
losses=[]
#Thetrainingloop
forepochinrange(n_epochs):
forx,yintrain_dataloader:
#Getsomedataandpreparethecorruptedversion
x=x.to(device)#DataontheGPU
noise_amount=torch.rand(x.shape[0]).to(device)#Pickrandomnoiseamounts
noisy_x=corrupt(x,noise_amount)#Createournoisyx
#Getthemodelprediction
pred=net(noisy_x)
#Calculatetheloss
loss=loss_fn(pred,x)#Howcloseistheoutputtothetrue'clean'x?
#Backpropandupdatetheparams:
opt.zero_grad()
loss.backward()
opt.step()
#Storethelossforlater
losses.append(loss.item())
#Printourtheaverageofthelossvaluesforthisepoch:
avg_loss=sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finishedepoch{epoch}.Averagelossforthisepoch:{avg_loss:05f}')
#Viewthelosscurve
plt.plot(losses)
plt.ylim(0,0.1);
Finished epoch 0. Average loss for this epoch: 0.026736
Finished epoch 1. Average loss for this epoch: 0.020692
Finished epoch 2. Average loss for this epoch: 0.018887
咱们能够测验经过抓取一批数据,以不同的数量损坏数据,然后喂进模型取得猜测来观察成果:
#@markdownVisualizingmodelpredictionsonnoisyinputs:
#Fetchsomedata
x,y=next(iter(train_dataloader))
x=x[:8]#Onlyusingthefirst8foreasyplotting
#Corruptwitharangeofamounts
amount=torch.linspace(0,1,x.shape[0])#Lefttoright->morecorruption
noised_x=corrupt(x,amount)
#Getthemodelpredictions
withtorch.no_grad():
preds=net(noised_x.to(device)).detach().cpu()
#Plot
fig,axs=plt.subplots(3,1,figsize=(12,7))
axs[0].set_title('Inputdata')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0,1),cmap='Greys')
axs[1].set_title('Corrupteddata')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0,1),cmap='Greys')
axs[2].set_title('NetworkPredictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0,1),cmap='Greys');
你能够看到,关于较低的噪声水平数量,猜测的成果适当不错!可是,当噪声水平十分高时,模型能够取得的信息就开端逐步削减。而当咱们到达 amount = 1 时,模型会输出一个模糊的猜测,该猜测会很接近数据集的平均值。模型经过这样的办法来猜想原始输入。
取样(采样)
假如咱们在高噪声水平下的猜测不是很好,咱们怎样才干生成图画呢?
假如咱们从彻底随机的噪声开端,检查一下模型猜测的成果,然后只朝着猜测方向移动一小部分,比如说 20%。现在咱们有一个噪声许多的图画,其间或许隐藏了一些关于输入数据的结构的提示,咱们能够将其输入到模型中以取得新的猜测。期望这个新的猜测比榜首个略微好一点(因为咱们这一次的输入略微削减了一点噪声),所以咱们能够用这个新的更好的猜测再往前迈出一小步。
假如一切顺利的话,以上进程重复几回今后咱们就会得到一个新的图画!以下图例是迭代了五次今后的成果,左边是每个阶段的模型输入的可视化,右侧则是猜测的去噪图画。请注意,即使模型在第 1 步就猜测了去噪图画,咱们也只是将输入向去噪图画变换了一小部分。重复几回今后,图画的结构开端逐步呈现并得到改善 , 直到取得咱们的最终成果停止。
#@markdownSamplingstrategy:Breaktheprocessinto5stepsandmove1/5'thofthewaythereeachtime:
n_steps=5
x=torch.rand(8,1,28,28).to(device)#Startfromrandom
step_history=[x.detach().cpu()]
pred_output_history=[]
foriinrange(n_steps):
withtorch.no_grad():#Noneedtotrackgradientsduringinference
pred=net(x)#Predictthedenoisedx0
pred_output_history.append(pred.detach().cpu())#Storemodeloutputforplotting
mix_factor=1/(n_steps-i)#Howmuchwemovetowardstheprediction
x=x*(1-mix_factor)+pred*mix_factor#Movepartofthewaythere
step_history.append(x.detach().cpu())#Storestepforplotting
fig,axs=plt.subplots(n_steps,2,figsize=(9,4),sharex=True)
axs[0,0].set_title('x(modelinput)')
axs[0,1].set_title('modelprediction')
foriinrange(n_steps):
axs[i,0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0,1),cmap='Greys')
axs[i,1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0,1),cmap='Greys')
咱们能够将流程分成更多进程,并期望经过这种办法取得更好的图画:
#@markdownShowingmoreresults,using40samplingsteps
n_steps=40
x=torch.rand(64,1,28,28).to(device)
foriinrange(n_steps):
noise_amount=torch.ones((x.shape[0],)).to(device)*(1-(i/n_steps))#Startinghighgoinglow
withtorch.no_grad():
pred=net(x)
mix_factor=1/(n_steps-i)
x=x*(1-mix_factor)+pred*mix_factor
fig,ax=plt.subplots(1,1,figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0,1),cmap='Greys')
<matplotlib.image.AxesImage at 0x7f27567d8210>
成果并不是十分好,可是现已呈现了一些能够被认出来的数字!您能够测验练习更长时刻(例如,10 或 20 个 epoch),并调整模型装备、学习率、优化器等。此外,假如您想测验略微困难一点的数据集,您能够测验一下 fashionMNIST,只需要一行代码的替换就能够了。
与 DDPM 做比较
在本节中,咱们将看看咱们的“玩具”完成与其他笔记本中运用的根据 DDPM 论文的办法有何不同: 分散器简介 Notebook。
咱们将会看到的
- 模型的体现受限于随迭代周期 (timesteps) 改动的操控条件,在前向传导中时刻步 (t) 是作为一个参数被传入的
- 有许多不同的取样策略可挑选,或许会比咱们上面所运用的最简略的版别更好
- diffusers
UNet2DModel
比咱们的 BasicUNet 更先进 - 损坏进程的处理办法不同
- 练习方针不同,包含猜测噪声而不是去噪图画
- 该模型经过调理 timestep 来调理噪声水平 , 其间 t 作为一个附加参数传入前向进程中。
- 有许多不同的采样策略可供挑选,它们应该比咱们上面简略的版别更有效。
自 DDPM 论文发表以来,现已有人提出了许多改善主张,但这个比如关于不同的可用规划决策具有指导意义。读完这篇文章后,你或许会想要深化了解这篇论文《Elucidating the Design Space of Diffusion-Based Generative Models》,它对所有这些组件进行了详细的讨论,并就怎样取得最佳功能提出了新的主张。
假如你觉得这些内容对你来说过分艰深了,请不要忧虑!你能够随意跳过本笔记本的其余部分或将其保存以备不时之需。
UNet
diffusers 中的 UNet2DModel 模型比上述根本 UNet 模型有许多改善:
- GroupNorm 层对每个 blocks 的输入进行了组标准化(group normalization)
- Dropout 层能使练习更滑润
- 每个块有多个 resnet 层(假如 layers_per_block 未设置为 1)
- 注意机制(一般仅用于输入分辨率较低的 blocks)
- timestep 的调理。
- 具有可学习参数的下采样和上采样块
让咱们来创建并细心研讨一下 UNet2DModel:
model=UNet2DModel(
sample_size=28,#thetargetimageresolution
in_channels=1,#thenumberofinputchannels,3forRGBimages
out_channels=1,#thenumberofoutputchannels
layers_per_block=2,#howmanyResNetlayerstouseperUNetblock
block_out_channels=(32,64,64),#Roughlymatchingourbasicunetexample
down_block_types=(
"DownBlock2D",#aregularResNetdownsamplingblock
"AttnDownBlock2D",#aResNetdownsamplingblockw/spatialself-attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",#aResNetupsamplingblockwithspatialself-attention
"UpBlock2D",#aregularResNetupsamplingblock
),
)
print(model)
UNet2DModel(
(conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding(
(linear_1): Linear(in_features=32, out_features=128, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=128, out_features=128, bias=True)
)
(down_blocks): ModuleList(
(0): DownBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(1): AttnDownBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(2): AttnDownBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
)
(up_blocks): ModuleList(
(0): AttnUpBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(2): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(1): AttnUpBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(2): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
(conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(2): UpBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
(conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
(conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
正如你所看到的,还有更多!它比咱们的 BasicUNet 有多得多的参数量:
sum([p.numel()forpinmodel.parameters()])#1.7Mvsthe~309kparametersoftheBasicUNet
1707009
咱们能够用这个模型替代本来的模型来重复一遍上面展现的练习进程。咱们需要将 x 和 timestep 传递给模型(这儿我会传递 t = 0,以表明它在没有 timestep 条件的状况下作业,并坚持采样代码简略,但您也能够测验输入 (amount*1000)
,使 timestep 与噪声水平适当)。假如要检查代码,更改的行将显现为“#<<<
。
#@markdownTryingUNet2DModelinsteadofBasicUNet:
#Dataloader(youcanmesswithbatchsize)
batch_size=128
train_dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
#Howmanyrunsthroughthedatashouldwedo?
n_epochs=3
#Createthenetwork
net=UNet2DModel(
sample_size=28,#thetargetimageresolution
in_channels=1,#thenumberofinputchannels,3forRGBimages
out_channels=1,#thenumberofoutputchannels
layers_per_block=2,#howmanyResNetlayerstouseperUNetblock
block_out_channels=(32,64,64),#Roughlymatchingourbasicunetexample
down_block_types=(
"DownBlock2D",#aregularResNetdownsamplingblock
"AttnDownBlock2D",#aResNetdownsamplingblockwithspatialself-attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",#aResNetupsamplingblockwithspatialself-attention
"UpBlock2D",#aregularResNetupsamplingblock
),
)#<<<
net.to(device)
#Ourlossfinction
loss_fn=nn.MSELoss()
#Theoptimizer
opt=torch.optim.Adam(net.parameters(),lr=1e-3)
#Keepingarecordofthelossesforlaterviewing
losses=[]
#Thetrainingloop
forepochinrange(n_epochs):
forx,yintrain_dataloader:
#Getsomedataandpreparethecorruptedversion
x=x.to(device)#DataontheGPU
noise_amount=torch.rand(x.shape[0]).to(device)#Pickrandomnoiseamounts
noisy_x=corrupt(x,noise_amount)#Createournoisyx
#Getthemodelprediction
pred=net(noisy_x,0).sample#<<<Usingtimestep0always,adding.sample
#Calculatetheloss
loss=loss_fn(pred,x)#Howcloseistheoutputtothetrue'clean'x?
#Backpropandupdatetheparams:
opt.zero_grad()
loss.backward()
opt.step()
#Storethelossforlater
losses.append(loss.item())
#Printourtheaverageofthelossvaluesforthisepoch:
avg_loss=sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finishedepoch{epoch}.Averagelossforthisepoch:{avg_loss:05f}')
#Plotlossesandsomesamples
fig,axs=plt.subplots(1,2,figsize=(12,5))
#Losses
axs[0].plot(losses)
axs[0].set_ylim(0,0.1)
axs[0].set_title('Lossovertime')
#Samples
n_steps=40
x=torch.rand(64,1,28,28).to(device)
foriinrange(n_steps):
noise_amount=torch.ones((x.shape[0],)).to(device)*(1-(i/n_steps))#Startinghighgoinglow
withtorch.no_grad():
pred=net(x,0).sample
mix_factor=1/(n_steps-i)
x=x*(1-mix_factor)+pred*mix_factor
axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0,1),cmap='Greys')
axs[1].set_title('GeneratedSamples');
Finished epoch 0. Average loss for this epoch: 0.018925
Finished epoch 1. Average loss for this epoch: 0.012785
Finished epoch 2. Average loss for this epoch: 0.011694
这看起来比咱们的榜首组成果好多了!您能够测验调整 UNet 装备或更长时刻的练习,以取得更好的功能。
损坏进程
DDPM 论文描绘了一个为每个“timestep”添加少量噪声的损坏进程。为某些 timestep 给定 , 咱们能够得到一个噪声稍稍添加的 :
这就是说,咱们取 , 给他一个 的系数,然后加上带有 系数的噪声。这儿 是依据一些管理器来为每一个 t 设定的,来决定每一个迭代周期中添加多少噪声。现在,咱们不想把这个推演进行 500 次来得到 ,所以咱们用另一个公式来依据给出的 核算得到任意 t 时刻的 :
数学符号看起来总是很吓人!幸运的是,调度器为咱们处理了所有这些(撤销下一个单元格的注释以检查代码)。咱们能够画出 (标记为 sqrt_alpha_prod
) 和 (标记为 sqrt_one_minus_alpha_prod
) 来看一下输入 (x) 与噪声是怎样在不同迭代周期中量化和叠加的 :
#??noise_scheduler.add_noise
noise_scheduler=DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu()**0.5,label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1-noise_scheduler.alphas_cumprod.cpu())**0.5,label=r"$\sqrt{(1-\bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");
一开端 , 噪声 x 里绝大部分都是 x 自身的值 (sqrt_alpha_prod ~= 1),可是随着时刻的推移,x 的成分逐步下降而噪声的成分逐步添加。与咱们依据 amount
对 x 和噪声进行线性混合不同,这个噪声的添加相对较快。咱们能够在一些数据上看到这一点:
#@markdownvisualizetheDDPMnoisingprocessfordifferenttimesteps:
#Noiseabatchofimagestoviewtheeffect
fig,axs=plt.subplots(3,1,figsize=(16,10))
xb,yb=next(iter(train_dataloader))
xb=xb.to(device)[:8]
xb=xb*2.-1.#Mapto(-1,1)
print('Xshape',xb.shape)
#Showcleaninputs
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(),cmap='Greys')
axs[0].set_title('CleanX')
#Addnoisewithscheduler
timesteps=torch.linspace(0,999,8).long().to(device)
noise=torch.randn_like(xb)#<<NB:randnnotrand
noisy_xb=noise_scheduler.add_noise(xb,noise,timesteps)
print('NoisyXshape',noisy_xb.shape)
#Shownoisyversion(withandwithoutclipping)
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1,1),cmap='Greys')
axs[1].set_title('NoisyX(clippedto(-1,1)')
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(),cmap='Greys')
axs[2].set_title('NoisyX');
X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])
在运转中的另一个改动:在 DDPM 版别中,参加的噪声是取自一个高斯分布(来自均值 0 方差 1 的 torch.randn),而不是在咱们原始 corrupt
函数中运用的 0-1 之间的均匀分布(torch.rand),当然对练习数据做正则化也能够理解。在另一篇笔记中,你会看到 Normalize(0.5, 0.5)
函数在改动列表中,它把图片数据从 (0, 1) 区间映射到 (-1, 1),对咱们的方针来说也‘满足用了’。咱们在此篇笔记中没运用这个办法,但在上面的可视化中为了更好的展现添加了这种做法。
练习方针
在咱们的玩具示例中,咱们让模型测验猜测去噪图画。在 DDPM 和许多其他分散模型完成中,模型则会猜测损坏进程中运用的噪声(在缩放之前,因此是单位方差噪声)。在代码中,它看起来像是这样:
noise=torch.randn_like(xb)#<<NB:randnnotrand
noisy_x=noise_scheduler.add_noise(x,noise,timesteps)
model_prediction=model(noisy_x,timesteps).sample
loss=mse_loss(model_prediction,noise)#noiseasthetarget
你或许认为猜测噪声(咱们能够从中得出去噪图画的样子)等同于直接猜测去噪图画。那么,为什么要这么做呢?这只是是为了数学上的方便吗?
这儿其实还有另一些精妙之处。咱们在练习进程中,会核算不同(随机挑选)timestep 的 loss。这些不同的方针将导致这些 loss 的不同的“隐含权重”,其间猜测噪声会将更多的权重放在较低的噪声水平上。你能够挑选更杂乱的方针来改动这种“隐性损失权重”。或许,您挑选的噪声管理器将在较高的噪声水平下产生更多的示例。或许你让模型规划成猜测 “velocity” v,咱们将其界说为由噪声水平影响的图画和噪声组合(请参阅“分散模型快速采样的渐进蒸馏”- ‘PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS’)。或许你将模型规划成猜测噪声,然后根据某些因子来对 loss 进行缩放:比如有些理论指出能够参阅噪声水平(参见“分散模型的感知优先练习”-‘Perception Prioritized Training of Diffusion Models’),或许根据一些探究模型最佳噪声水平的实验(参见“根据分散的生成模型的规划空间说明”-‘Elucidating the Design Space of Diffusion-Based Generative Models’)。
一句话解说:挑选方针对模型功能有影响,现在有许多研讨者正在探究“最佳”选项是什么。目前,猜测噪声(epsilon 或 eps)是最流行的办法,但随着时刻的推移,咱们很或许会看到库中支撑的其他方针,并在不同的状况下运用。
迭代周期(Timestep)调理
UNet2DModel 以 x 和 timestep 为输入。后者被转化为一个嵌入(embedding),并在多个当地被输入到模型中。
这背面的理论支撑是这样的:经过向模型供给有关噪声水平的信息,它能够更好地执行任务。虽然在没有这种 timestep 条件的状况下也能够练习模型,但在某些状况下,它好像的确有助于功能,目前来说绝大多数的模型完成都包含了这一输入。
取样(采样)
有一个模型能够用来猜测在带噪样本中的噪声(或许说能猜测其去噪版别),咱们怎样用它来生成图画呢?
咱们能够给入纯噪声,然后就期望模型能一步就输出一个不带噪声的好图画。可是,就咱们上面所见到的来看,这一般行不通。所以,咱们在模型猜测的基础上运用满足多的小步,迭代着来每次去除一点点噪声。
具体咱们怎样走这些小步,取决于运用上面取样办法。咱们不会去深化讨论太多的理论细节,可是一些顶层想法是这样:
- 每一步你想走多大?也就是说,你遵从什么样的“噪声方案(噪声管理)”?
- 你只运用模型当前步的猜测成果来指导下一步的更新方向吗(像 DDPM,DDIM 或是其他的什么那样)?你是否要运用模型来多猜测几回来估计一个更高阶的梯度来更新一步更大更准确的成果(更高阶的办法和一些离散 ODE 处理器)?或许保存前史猜测值来测验更好的指导当前步的更新(线性多步或遗传取样器)?
- 你是否会在取样进程中额外再加一些随机噪声,或你彻底已知的(deterministic)来添加噪声?许多取样器经过参数(如 DDIM 中的 ‘eta’)来供用户挑选。
关于分散模型取样器的研讨演进的很快,随之开发出了越来越多能够运用更少步就找到好成果的办法。英勇和有好奇心的人或许会在浏览 diffusers library 中不同部署办法时感到十分有意思,能够检查 Schedulers 代码 或看看 Schedulers 文档,这儿常常有一些相关的论文。
结语
期望这能够从一些不同的角度来审视分散模型供给一些协助。这篇笔记是 Jonathan Whitaker 为 Hugging Face 课程所写的,假如你对从噪声和束缚分类来生成样本的比如感爱好。问题与 bug 能够经过 GitHub issues 或 Discord 来沟通。
称谢榜首单元第二部分社区奉献者
感谢社区成员们对本课程的奉献:
@darcula1993、@XhrLeokk:魔都强人工智能孵化者,二里街调参记载坚持人,一切爱好使然的 AIGC 色图创作家的保护者,图灵神在五角场的仅有指定路上行走。
感谢茶叶蛋蛋对本文奉献规划素材!
欢迎经过链接参加咱们的本地化小组与大家共同沟通:
bit.ly/3G40j6U