本文记录异常检测23年性能最佳的工作 DDAD 的原理以及官方源码解析。
DDAD 是 2024 年以前 MVTec AD 数据集上性能最好的异常检测模型,本文解读相关论文并对源码进行解读
项目 | 内容 | 备注 |
---|---|---|
方法名称 | DDAD | |
论文题目 | Anomaly Detection with Conditioned Denoising Diffusion Models^25 | |
论文连接 | | |
开源代码 | | |
发表时间 | 2023.12.03 | |
方法类别 | 深度学习 -> 基于重构 -> 扩散模型 | |
Detection AU-ROC | 99.8% | |
Segmentation AU-ROC | 98.1% | |
Segmentation AU-PRO | 92.3% | |
核心思想 |
| |
输入图像 X , 经过扩散过程得到随机的 X_{T’} , 之后需要通过 X_{T’} 经过反扩散过程重构图像生成 x_0 , 目标是 y . 因为重构目标是 y , 所以假设重构输出与目标接近, 即:x_0 \approx y , 假设从 X_{T’} 到 X_0 每一步添加的噪声为 \epsilon_\theta^{(t)} , 得到过程中的 X_{T’},X_{T’-1},X_{T’-2}, …,X_2,X_1,X_0 , 反过来, 我们向 y 逐步加入 \epsilon_\theta^{(t)} , 得到 y, y_1,y_2, …,y_{T’-2},y_{T’-1},y_{T’} , 那么根据假设可以推断 y_t\approx x_t , 那么就可以用 y_t 指导每一步去噪产生 x_t 的训练过程, 也就是带条件的扩散模型.
该步骤训练完成后会得到可以重构出和目标图像类似的扩散模型, 训练过程中仅使用 OK 数据进行训练, 这样扩散模型仅学会了重构 OK 数据的能力.
在异常检测推断流程中, 重构的目标图像会被设置为输入图像 x , 目的是基于 x 生成一幅没有缺陷的重构图 x_0 , 之后比对 x_0 和 x 之间的差异判断是否存在异常.
条件扩散模型将 AU-ROC 从 85.7% 提高到 92.4%
现在已经得到了 x_0 和 x , 如何对比二者得到异常分数图效果比较好呢. 最直接的想法是将二者直接在像素空间上作差, 结果用 D_p 表示, 该方法确实直接有效, 但是无法抵抗一些重构过程中产生的噪声, 因此论文使用预训练的骨干网络提取特征作为额外的分数判定依据.
选择一个 ImageNet 预训练的骨干网络, 提取 x_0 和 x 的特征 (主要用下采样 2x 和 4x 的特征), , 计算二者特征的余弦距离作为特征度量差异距离 D_f .
最后将二者归一化加权叠加在一起得到异常分数:
其中
按照算法的完备性至此已经可以完成异常检测工作了, 但是文章还试图解决 ImageNet 对当前数据适应性不是最优的问题, 尝试用项目数据对预训练模型进行微调, 使其适应当前的数据以获得更好的特征提取能力.
核心思想仍然基于之前的假设 x_0 \approx y , 那么我们就希望网络对重构产生的误差不那么敏感, 也就是让网络觉得 x_0 和 y 的特征相近, 依此可以进行模型微调. 但是仅用这一个 loss 容易使得模型坍缩退化, 为了使得模型在保持原本的泛化能力的同时适应我们的需求, 作者在刚刚的损失函数基础上增加了当前模型对原始模型的特征蒸馏损失
$$ \begin{gathered} \mathcal{L}_{DA} =\mathcal{L}_{Similarity}(\mathbf{x_0},\mathbf{y})+\lambda_{DL}\mathcal{L}_{DL}(\mathbf{x_0},\mathbf{y}) \ =\sum_{j\in J}\left(1-\cos(\phi_j(\mathbf{x}_0),\phi_j(\mathbf{y}))\right) \ +\lambda_{DL}\sum_{j\in J}\left(1-\cos(\phi_j(\mathbf{y}),\overline{\phi}_j(\mathbf{y}))\right) \ +\lambda_{DL}\sum_{j\in J}\left(1-\cos(\phi_{j}(\mathbf{x}_{0}),\overline{\phi}_{j}(\mathbf{x}_{0}))\right), \end{gathered} $$
如此完成模型的微调.
其中 j\in {1,2,3}
域适应性将 AU-ROC 从 92.4% 提高到99.8%
在 MVTec
数据集得到 99.8% 的图像 AU-ROC 和 97.2% 的分割 AU-ROC.
测试数据使用 MVTec AD 数据集,下载链接
开源仓库:https://github.com/arimousa/DDAD
当前 Commit ID: e4e11f1b4ff5cf0a2762c4d8a5dfdfb6bfa64303
将数据集放在仓库根目录 datasets/MVTec
文件夹中:
环境 | 版本 | 备注 |
---|---|---|
Python | 3.8.+ | |
kornia | 0.6.12 | |
matplotlib | 3.7.1 | |
numpy | 1.24.3 | |
omegaconf | 2.1.2 | |
opencv-python-headless | 4.5.5.64 | |
pandas | 2.0.1 | |
Pillow | 9.5.0 | |
scikit-image | 0.19.2 | |
scikit-learn | 1.2.2 | |
scipy | 1.10.1 | |
torch | 2.0.1 | 2.0.1+cu118 |
torchvision | 0.15.2 | 0.15.2+cu118 |
torchmetrics | 0.11.4 | |
sklearn | 0.0.post5 | 没有成功安装,未发现对程序运行的影响 |
核心代码都在根目录中:
12345678910111213141516171819202122232425262728 | ├── anomaly_map.py // 生成异常得分图├── checkpoints // 模型保存路径│ └── MVTec│ └── screw│ ├── 100│ └── 200├── config.yaml // 核心配置文件├── dataset.py // 数据集控制├── datasets // 数据保存文件夹│ └── MVTec├── ddad.py // DDAD 网络框架├── feature_extractor.py // FineTune 核心代码├── images // 论文示意图│ ├── DDAD_Framework.png│ └── Qualitative.png├── LICENSE├── loss.py // 损失函数├── main.py // 入口代码,支持训练、微调、测试├── metrics.py // 评价指标├── __pycache__├── README.md├── reconstruction.py // 重建模块├── requirements.txt // 依赖环境├── resnet.py // 基础特征提取 Backbone, ResNet├── train.py // 训练代码├── unet.py // Unet 网络结构└── visualize.py // 可视化代码 |
---|
我们的 3080 显卡 10g 显存,可以使用如下配置训练 DDAD 模型:
配置文件为
config.yaml
12345678910111213141516171819202122232425262728293031323334353637383940414243444546 | data : name: MVTec #MVTec #MTD #VisA data_dir: datasets/MVTec #MVTec #VisA #MTD category: screw #'carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood' # 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1', 'pcb2' ,'pcb3', 'pcb4', 'pipe_fryum' image_size: 256 batch_size: 12 # 32 for DDAD and 16 for DDADS DA_batch_size: 16 #16 for MVTec and macaroni2, pcb1 in VisA, and 32 for other categories in VisA test_batch_size: 16 #16 for MVTec, 32 for VisA mask : True input_channel : 3model: DDADS: True checkpoint_dir: checkpoints/MVTec #MTD #MVTec #VisA checkpoint_name: weights exp_name: default feature_extractor: resnet50 #wide_resnet101_2 # wide_resnet50_2 #resnet50 learning_rate: 3e-4 weight_decay: 0.05 epochs: 3000 load_chp : 2000 # From this epoch checkpoint will be loaded. Every 250 epochs a checkpoint is saved. Try to load 750 or 1000 epochs for Visa and 1000-1500-2000 for MVTec. DA_epochs: 4 # Number of epochs for Domain adaptation. DA_chp: 4 v : 1 #7 # 1 for MVTec and cashew in VisA, and 7 for VisA (1.5 for cashew). Control parameter for pixel-wise and feature-wise comparison. v * D_p + D_f w : 2 # Conditionig parameter. The higher the value, the more the model is conditioned on the target image. "Fine tuninig this parameter results in better performance". w_DA : 3 #3 # Conditionig parameter for domain adaptation. The higher the value, the more the model is conditioned on the target image. DLlambda : 0.1 # 0.1 for MVTec and 0.01 for VisA trajectory_steps: 1000 test_trajectoy_steps: 250 # Starting point for denoining trajectory. test_trajectoy_steps_DA: 250 # Starting point for denoining trajectory for domain adaptation. skip : 25 # Number of steps to skip for denoising trajectory. skip_DA : 25 eta : 1 # Stochasticity parameter for denoising process. beta_start : 0.0001 beta_end : 0.02 device: 'cuda' #<"cpu", "gpu", "tpu", "ipu"> save_model: True num_workers : 2 seed : 42metrics: auroc: True pro: True misclassifications: False visualisation: True |
---|
DDAD 实现异常检测需要分两阶段训练
推断时需要加载训练好的 Unet 和特征提取器
构建 Unet 模型的函数为 main.py -> build_model
,通过实例化 unet.py -> UNetModel
类实现。
核心函数在 dataset.py -> Dataset_maker
类中,根据文件夹名称构建所需数据集。
入口函数在 main.py -> train
,核心代码在 train.py -> trainer
函数中
此处训练的是 去噪Unet
网络,期望网络可以将叠加在图像上的噪声恢复出来
向 Unet 输入带噪的图像,输出张量与噪声的二范数距离作为损失
123 | x = at.sqrt() * x_0 + (1- at).sqrt() * e output = model(x, t.float())return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) |
---|
经过训练,可以使得 Unet 网络较好地预测添加到数据中的噪声
入口函数 main.py -> finetuning
,核心代码在 reconstruction.py
中
微调特征提取器,这里使用的是 Resnet,由于算力有限这里采用 Resnet50
输入一个 batch 的数据一半为输入 一半为目标,训练重构器
同时兼顾原始模型的蒸馏损失
训练完成后保留特征提取器
入口函数 main.py -> detection
,核心代码在 ddad.py
的 DDAD
类中
过程中可以在配置文件配置可视化参数为 True
结果保存可视化结果
测试结果
12 | AUROC: (92.5,97.6)PRO: 90.9 |
---|
结果被 center crop 到 224*224,可以实现一定的检测能力
这个结果远没有达到论文描述的水准,仓库中作者解释说这个去噪网络训练很不稳定,建议下载他们训练好的 Unet 模型,我下载了同规格的模型后性能得到一定提升
12 | AUROC: (95.4,98.6)PRO: 94.0 |
---|
可视化的图像的效果也要好一些
使用官网更大的模型在这组数据集上可以取得更好的结果
12 | AUROC: (97.3,99.4)PRO: 97.0 |
---|
但是本机的 3080 显存不足以支撑该训练,而且 WideResnet101 模型较大
graph TD
E(Unet)
B--> C
D(config)
G(noise)
H(trainer)
D -->H
E --> H
I(train)
H -->I
C--提取-->I
G--加入数据-->I
J(loss)
I --预测噪声-->J
J --更新参数--> E
A--组合-->B
A(DATA)
B(dataset)
C(dataloader)
graph TD
A(DATA)
B(dataset)
C(dataloader)
A--> B
B--> C
D(config)
E(Unet)
F(FeatureExtractor)
G(image)
H(noise)
C -->G
I(+)
G-->I
H-->I
J(noised image)
I-->J
J --输入--> E
D --> E
K(target image)
L(predicted noise)
E --> L
M(reconstructed image)
L -->M
K-->M
M --> F
N(image feature)
F-->N