本文为稀土技术社区首发签约文章,14天内制止转载,14天后未获授权制止转载,侵权必究!

作者简介:秃头小苏,致力于用最浅显的语言描述问题

往期回顾:对立生成网络GAN系列——GAN原理及手写数字生成小事例   对立生成网络GAN系列——DCGAN简介及人脸图画生成事例   对立生成网络GAN系列——AnoGAN原理及缺点检测实战   对立生成网络GAN系列——EGBAD原理及缺点检测实战   对立生成网络GAN系列——WGAN原理及实战演练 对立生成网络GAN系列——GANomaly原理及源码解析

近期目标:写好专栏的每一篇文章

支撑小苏:点赞、收藏⭐、留言

对立生成网络GAN系列——f-AnoGAN原理及缺点检测实战

写在前面

​  在前面我已经介绍了好几种用于缺点检测的GAN网络了,感兴趣的能够重视一下我的专栏:深度学习网络原理与实战 。现在专栏主要更新了GAN系列文章和Transformer系列文章,都有理论详解和代码实战,文中的解说都比较浅显易懂,假如你希望丰厚这方面的常识,主张你阅览试试,相信你会有蛮不错的收成。

​  在阅览本篇教程之前,我觉得你有必要读读下面三篇文章:

  • [1]对立生成网络GAN系列——AnoGAN原理及缺点检测实战
  • [2]对立生成网络GAN系列——GANomaly原理及源码解析
  • [3]对立生成网络GAN系列——WGAN原理及实战演练

​  [1]是运用GAN网络完成缺点检测的开山之作,也算是这篇文章的基础,所以这是你必须要读且要了解透彻的。[2]算是[1]比较经典的改善,和本篇文章也有必定的相似之处,了解它会对你看透此篇文章有很大帮助。[3]提出了一种使原始GAN练习更加稳定的方法,本篇文章在练习GA时运用了,主张你也要有所了解。

​  假如你预备好了的话,就让咱们一起来看看f-AnoGAN吧!!!

f-AnoGAN原理详解✨✨✨

​  咱们先来看看f-AnoGAN的全称吧——f-AnoGAN: Fast unsupervised anomaly detection with generative adversarial networks。点击☞☞☞下载论文了解概况。

​  假如你对我上文提到的三篇文章都有所了解的话,再来看这篇文章,你会发现它是真滴简略。这就带咱们一起来看看f-AnoGAN的网路架构。首要,咱们先来看看f-AnoGAN的练习进程,练习主要分两步进行,第一步是练习一个生成对立网络,第二步运用第一步生成对立网络的权重,练习一个encoder编码器。咱们直接来看下图好了:

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

​  在进程①中,咱们练习的是一个WGAN,关于WGAN的细节能够看[3]这篇博客,在后文的代码实战中我也会谈谈这部分的内容。假如你对WGAN不了解的话,也不必太担心,这里你完全能够练习一个原始GAN,仅仅效果或许没有WGAN好罢了,但是对于了解f-AnoGAN的进程是完全没用影响的。当WGAN练习完毕后,生成器G和判别器D的权重就会冻住,进程②的G和D的权重不会发生变化。在进程②中,咱们意图是练习一个编码器E。论文中给出了三种练习E的结构,分别为ziz结构,izi结构和izif结构,咱们一个个来看一下:

  • ziz结构

    ​ 咱们直接来看下图吧:

    对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

    ​  z表明的是潜在变量,i表明的是图片(image),ziz结构即表明潜在变量z经过固定的生成器生成image,然后再经过编码器编码成潜在变量z。此部分的丢失函数用LzizL_{ziz}表明,其表达式为:

    ​          Lziz(z)=1d∣∣z−E(G(z))∣∣2L_{ziz}(z)=\frac{1}{d}||z-E(G(z))||^2

    ​ 其间d表明z的维度,其实上述公式便是核算z和E(G(z))E(G(z))的MSE丢失啦。

  • izi结构

    ​ 相同的,咱们直接看图:

    对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

    ​  izi结构表明real image先经过编码器E将图片映射到潜在空间,然后再经过生成器G生成图片。此部分的丢失函数用LiziL_{izi}表明,其表达式为:

    ​            Lizi(x)=1n∣∣x−G(E(x))∣∣2L_{izi}(x)=\frac{1}{n}||x-G(E(x))||^2

    ​  其间,n表明输入image的像素点数量,这个公式相同表明x和G(E(x))G(E(x))的MSE丢失。

  • izif结构

    ​   izif结构相比izi结构在后面加了一个判别器D,如下图所示:【论文终究选择了这个结构练习编码器E】

    对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

    ​  不知道咱们发现没有,这个结构和GANomaly是十分像的。izif的丢失函数由两部分构成,一部分为LiziL_{izi},另一部分为LDL_D。izif丢失函数表达式如下:

    ​    Lizif(x)=1n∣∣x−G(E(x))∣∣2+knd⋅∣∣f(x)−f(G(E(x)))∣∣2L_{izif}(x)=\frac{1}{n}||x-G(E(x))||^2 \ + \ \frac{k}{n_d} \cdot ||f(x)-f(G(E(x)))||^2

    ​  其间,k为两个丢失函数的权重参数,代码中k=1。f(*)表明判别器中心层的输出,ndn_d表明判别器中心输出层的维度。



​  f-AnoGAN的练习进程就为咱们介绍到这里了,是不是很简略呢。【假如你觉得有难度的话主张你看看我写在前面中提到的三篇博文,或许结合我下文的代码了解了解】练习完毕后,咱们保存生成器G、判别器D和编码器E的权重,然后将它们用于缺点检测中。缺点检测就更加简略啦,反常得分函数便是咱们上文所说的izif结构的丢失函数,如下图所示:

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

f-AnoGAN代码实战✨✨✨

代码目录结构剖析

​  这部分我在paperswithcode上看到了一个用pytorch完成的f-AnoGAN的代码:f-AnoGAN源码地址。这个代码的逻辑十分清晰,所以我就以这个代码来为咱们介绍f-AnoGAN的完成了。

​  首要咱们来看一下整个代码的结构,如下图所示:

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

​ 咱们需求留意一下,mnistmvtec_adyour_own_dataset是针对不同数据集进行试验的。考虑到咱们对mnist数据集相对熟悉,故本文以mnist数据集为例为咱们介绍。【也便是说mvtec_adyour_own_dataset文件夹下的文件都不会运用到,这里咱们留意一下就好】

数据集加载

​  这部分定义在mnist文件夹下的tools.py中,首要咱们获取MNIST数据集,经过torchvision下的datasets包直接下载即可,如下:

train = datasets.MNIST(path, train=True, download=download)
test = datasets.MNIST(path, train=False, download=download)

​  咱们知道,minst数据集train中有60000条数据,test中有10000条数据。这些数据的targets为0-9,首要咱们获取train中targets为0的数据,代码如下:

_x_train = train.data[train.targets == training_label]      #传入的training_label为0

​  经过调试能够发现,_x_train的维度为(5923,28,28),即targets=0的数据一共有5923条。

​  接着咱们将_x_train依照8:2的比列划分为练习集和测验集的一部分,代码如下:

x_train, x_test_normal = _x_train.split((int(len(_x_train) * split_rate)), dim=0)   #传入的split_rate为0.8

​  运行后x_train有4738条数据,x_test_normal有1185条数据。

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

​  上文说到x_test_normal仅仅测验集的一部分,完整的测验数据集包含x_test_normal、train数据集中除去targets=0以外的其它数据和test中的一切数据,代码如下:

x_test = torch.cat([x_test_normal,
                        train.data[train.targets != training_label],
                        test.data], dim=0)

​  这样终究测验集的数据共有65262条。


​  上文咱们获得了练习集和测验集的数据,咱们还需求获取练习集和测验集的标签,代码如下:

_y_train = train.targets[train.targets == training_label]
y_train, y_test_normal = _y_train.split((int(len(_y_train) * split_rate)), dim=0)
y_test = torch.cat([y_test_normal,
                        train.targets[train.targets != training_label],
                        test.targets], dim=0)                                       

​  相同,练习集的标签y_train有4738个,测验集的标签y_test有65262个。

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战


​  有了数据后,咱们对数据做一些预处理,然后用DataLoader加载数据集,代码如下:

train_mnist = SimpleDataset(x_train, y_train,
                                transform=transforms.Compose(
                                    [transforms.ToPILImage(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5], [0.5])])
                                )
train_dataloader = DataLoader(train_mnist, batch_size=opt.batch_size,
                                  shuffle=True)

模型建立

class Generator(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.img_shape = (opt.channels, opt.img_size, opt.img_size)
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
            )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *self.img_shape)
        return img
class Discriminator(nn.Module):
    def __init__(self, opt):
        super().__init__()
        img_shape = (opt.channels, opt.img_size, opt.img_size)
        self.features = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True)
            )
        self.last_layer = nn.Sequential(
            nn.Linear(256, 1)
            )
    def forward(self, img):
        features = self.forward_features(img)
        validity = self.last_layer(features)
        return validity
    def forward_features(self, img):
        img_flat = img.view(img.shape[0], -1)
        features = self.features(img_flat)
        return features
class Encoder(nn.Module):
    def __init__(self, opt):
        super().__init__()
        img_shape = (opt.channels, opt.img_size, opt.img_size)
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, opt.latent_dim),
            nn.Tanh()
        )
    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

​  由所以教学,所以建立的模型很简略,乃至都没有卷积,都是全连接层,咱们肯定一看就能了解。

练习WGAN

​  阅览这部分之前主张阅览WGAN的相关常识喔,点击☞☞☞了解概况。WGAN属所以开山之作,而f-AnoGAN用的是WGAN-GP,其是一种WGAN的改善。关于WGAN-GP我还没做相关介绍,咱们自行补充常识,引荐链接:WGAN-GP。假如觉得学起来有困难的话欢迎评论区留言讨论,要是人多的话后期或许会出一起WGAN-GP的教程。【其实WGAN-GP和WGAN的思维是一样的,仅仅在于完成lipschitz条件的方法不同】

注:在咱们了解WGAN-GP时或许会遇到直线段的另一种定义方法,是凸优化中的相关内容,不清楚的能够参阅我此前一篇关于凸优化介绍的文章:凸优化理论基础1–仿射集

​  咱们来看看练习WGAN的代码吧:

def train_wgangp(opt, generator, discriminator,
                 dataloader, device, lambda_gp=10):
    generator.to(device)
    discriminator.to(device)
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr, betas=(opt.b1, opt.b2))
    os.makedirs("results/images", exist_ok=True)
    padding_epoch = len(str(opt.n_epochs))
    padding_i = len(str(len(dataloader)))
    batches_done = 0
    for epoch in range(opt.n_epochs):
        for i, (imgs, _)in enumerate(dataloader):
            # Configure input
            real_imgs = imgs.to(device)
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            # Sample noise as generator input
            z = torch.randn(imgs.shape[0], opt.latent_dim, device=device)
            # Generate a batch of images
            fake_imgs = generator(z)
            # Real images
            real_validity = discriminator(real_imgs)
            # Fake images
            fake_validity = discriminator(fake_imgs.detach())   #运用.detach()方法能够不更新generator的值
            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator,
                                                        real_imgs.data,
                                                        fake_imgs.data,
                                                        device)
            # Adversarial loss
            d_loss = (-torch.mean(real_validity) + torch.mean(fake_validity)
                      + lambda_gp * gradient_penalty)
            d_loss.backward()
            optimizer_D.step()
            optimizer_G.zero_grad()
            # Train the generator and output log every n_critic steps
            if i % opt.n_critic == 0:
                # -----------------
                #  Train Generator
                # -----------------
                # Generate a batch of images
                fake_imgs = generator(z)
                # Loss measures generator's ability to fool the discriminator
                # Train on fake images
                fake_validity = discriminator(fake_imgs)
                g_loss = -torch.mean(fake_validity)
                g_loss.backward()
                optimizer_G.step()
                print(f"[Epoch {epoch:{padding_epoch}}/{opt.n_epochs}] "
                      f"[Batch {i:{padding_i}}/{len(dataloader)}] "
                      f"[D loss: {d_loss.item():3f}] "
                      f"[G loss: {g_loss.item():3f}]")
                if batches_done % opt.sample_interval == 0:
                    save_image(fake_imgs.data[:25],
                               f"results/images/{batches_done:06}.png",
                               nrow=5, normalize=True)
                batches_done += opt.n_critic
    torch.save(generator.state_dict(), "results/generator")
    torch.save(discriminator.state_dict(), "results/discriminator")

​  上述代码的中心是compute_gradient_penalty函数,是用来核算梯度赏罚的,这也是WGAN-GP最中心的地方,代码如下:

def compute_gradient_penalty(D, real_samples, fake_samples, device):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(*real_samples.shape[:2], 1, 1, device=device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples)
    # 能够直接对变量进行操作,现在pytorch已经舍弃autograd.Variable
    interpolates.requires_grad_(requires_grad=True)
    # interpolates = autograd.Variable(interpolates, requires_grad=True)
    d_interpolates = D(interpolates)
    fake = torch.ones(*d_interpolates.shape, device=device)
    # Get gradient w.r.t. interpolates
    # https://zhuanlan.zhihu.com/p/83172023
    gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates,
                              grad_outputs=fake, create_graph=True,
                              retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.shape[0], -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

注:想要了解这部分代码,需求了解pytorch中的auograd包。引荐咱们去阅览此篇博文:Pytorch autograd,backward详解

  练习完毕后,咱们保存了生成器和判别器的权重,一起保存了一些生成图片成果,部分展现如下,是不是效果还不错呢。

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

【由于咱们练习集图片都是0,一切咱们生成的图片都是0喔!!!】

练习编码器E

​ 话不多说,让咱们直接上代码吧!!!

def train_encoder_izif(opt, generator, discriminator, encoder,
                       dataloader, device, kappa=1.0):
    generator.load_state_dict(torch.load("results/generator"))
    discriminator.load_state_dict(torch.load("results/discriminator"))
    generator.to(device).eval()
    discriminator.to(device).eval()
    encoder.to(device)
    criterion = nn.MSELoss()
    optimizer_E = torch.optim.Adam(encoder.parameters(),
                                   lr=opt.lr, betas=(opt.b1, opt.b2))
    os.makedirs("results/images_e", exist_ok=True)
    padding_epoch = len(str(opt.n_epochs))
    padding_i = len(str(len(dataloader)))
    batches_done = 0
    for epoch in range(opt.n_epochs):
        for i, (imgs, _) in enumerate(dataloader):
            # Configure input
            real_imgs = imgs.to(device)
            # ----------------
            #  Train Encoder
            # ----------------
            optimizer_E.zero_grad()
            # Generate a batch of latent variables
            z = encoder(real_imgs)
            # Generate a batch of images
            fake_imgs = generator(z)
            # Real features
            real_features = discriminator.forward_features(real_imgs)
            # Fake features
            fake_features = discriminator.forward_features(fake_imgs)
            # izif architecture
            loss_imgs = criterion(fake_imgs, real_imgs)
            loss_features = criterion(fake_features, real_features)
            e_loss = loss_imgs + kappa * loss_features
            e_loss.backward()
            optimizer_E.step()
            # Output training log every n_critic steps
            if i % opt.n_critic == 0:
                print(f"[Epoch {epoch:{padding_epoch}}/{opt.n_epochs}] "
                      f"[Batch {i:{padding_i}}/{len(dataloader)}] "
                      f"[E loss: {e_loss.item():3f}]")
                if batches_done % opt.sample_interval == 0:
                    fake_z = encoder(fake_imgs)
                    reconfiguration_imgs = generator(fake_z)
                    save_image(reconfiguration_imgs.data[:25],
                               f"results/images_e/{batches_done:06}.png",
                               nrow=5, normalize=True)
                batches_done += opt.n_critic
    torch.save(encoder.state_dict(), "results/encoder")

你会发现这些代码真滴很简略。练习完毕后咱们会保存编码器E的权重和重构后的一些图片。重构后图片效果也还是蛮好的。

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

测验反常得分

咱们将检测的反常得分保存在score.csv文件中,保存四项参数,分别为label、img_distance、anomaly_score和z_distance。

def test_anomaly_detection(opt, generator, discriminator, encoder,
                           dataloader, device, kappa=1.0):
    generator.load_state_dict(torch.load("results/generator"))
    discriminator.load_state_dict(torch.load("results/discriminator"))
    encoder.load_state_dict(torch.load("results/encoder"))
    generator.to(device).eval()
    discriminator.to(device).eval()
    encoder.to(device).eval()
    criterion = nn.MSELoss()
    with open("results/score.csv", "w") as f:
        f.write("label,img_distance,anomaly_score,z_distance\n")
    for (img, label) in tqdm(dataloader):
        real_img = img.to(device)
        real_z = encoder(real_img)
        fake_img = generator(real_z)
        fake_z = encoder(fake_img)
        real_feature = discriminator.forward_features(real_img)
        fake_feature = discriminator.forward_features(fake_img)
        # Scores for anomaly detection
        img_distance = criterion(fake_img, real_img)
        loss_feature = criterion(fake_feature, real_feature)
        anomaly_score = img_distance + kappa * loss_feature
        z_distance = criterion(fake_z, real_z)
        with open("results/score.csv", "a") as f:
            f.write(f"{label.item()},{img_distance},"
                    f"{anomaly_score},{z_distance}\n")

在得到score.csv文件后,咱们能够来读取文件内容制作精度曲线。首要导入一些必要的包:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, precision_recall_curve, auc

然后读取刚刚得到的score.csv文件:

df = pd.read_csv("./results/score.csv")

df的内容如下:能够看到一共有65262行数据,这和咱们数据读取时测验集数据巨细是一致的。

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

接着咱们读取各列的数据,并把标签为0的标签设置为0,其它的设置为1.

trainig_label = 0
labels = np.where(df["label"].values == trainig_label, 0, 1)
anomaly_score = df["anomaly_score"].values
img_distance = df["img_distance"].values
z_distance = df["z_distance"].values

然后能够根据上面的值得到一些画图所需值:

fpr, tpr, _ = roc_curve(labels, img_distance)
precision, recall, _ = precision_recall_curve(labels, img_distance)
roc_auc = auc(fpr, tpr)
pr_auc =  auc(recall, precision)

接下来就能够画图了:

plt.plot(fpr, tpr, label=f"AUC = {roc_auc:3f}")
plt.plot([0, 1], [0, 1], linestyle="--")
plt.title("ROC-AUC")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.show()

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

plt.plot(recall, precision, label=f"PR = {pr_auc:3f}")
plt.title("PR-AUC")
plt.xlabel("Recall")
plt.ylabel("Pecision")
plt.legend()
plt.show()

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

plt.hist([anomaly_score[labels == 0], anomaly_score[labels == 1]],
          bins=100, density=True, stacked=True,
          label=["Normal", "Abnormal"])
plt.title("Discrete distributions of anomaly scores")
plt.xlabel("Anomaly scores A(x)")
plt.ylabel("h")
plt.legend()
plt.show()

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

保存差异图画

​  代码中还定义了保存原图和生成图差异的图画,即将真实图画和生成图画做差,看看它们的差异,代码很简略,咱们来看看:

def save_compared_images(opt, generator, encoder, dataloader, device):
    generator.load_state_dict(torch.load("results/generator"))
    encoder.load_state_dict(torch.load("results/encoder"))
    generator.to(device).eval()
    encoder.to(device).eval()
    os.makedirs("results/images_diff", exist_ok=True)
    for i, (img, label) in enumerate(dataloader):
        real_img = img.to(device)
        real_z = encoder(real_img)
        fake_img = generator(real_z)
        compared_images = torch.empty(real_img.shape[0] * 3,
                                      *real_img.shape[1:])
        compared_images[0::3] = real_img
        compared_images[1::3] = fake_img
        compared_images[2::3] = real_img - fake_img
        save_image(compared_images.data,
                   f"results/images_diff/{opt.n_grid_lines*(i+1):06}.png",
                   nrow=3, normalize=True)
        if opt.n_iters is not None and opt.n_iters == i:
            break

​ 我也抽取一张保存的图画来给咱们看看成果:

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

​  经过上图能够发现,无论原始输入即原图是什么,生成图都会将其生成0,原图和生成图做差后得到的图片因而也会发生不同的差异。

总结

​  f-AnoGAN就为咱们介绍到这里了,其实你细细探索下来会觉得十分简略。代码部分咱们要勤着手,多调试,这样你会有不一样的收成。

如若文章对你有所帮助,那就

        

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战