-
动机:传统的扩散模型(如 DDPM 或 SDE 模型)生成的是无条件图像(即随机生成的图像)。类别条件扩散模型的目的是引入 外部知识(如类别标签
$y$ 、文本序列或其他模态)来 引导(steer) 图像生成过程,以生成特定类别的图像。
单纯的 SDE 模型(作为 DDPM 和 SMLD 的统一框架)训练的是一个无条件的分数函数或噪声预测网络 $\mathbf{s}{\theta}(\mathbf{x}, t)$ 或 $\epsilon{\theta}(\mathbf{x}_t, t)$。
-
核心转变:将模型从 无条件生成
$\mathbf{x} \sim p(\mathbf{x})$ 扩展到 条件生成$\mathbf{x} \sim p(\mathbf{x}|y)$ 。 -
网络功能扩展:这要求模型的预测网络必须能够接收和处理类别信息
$y$ 。 -
数学形式:
- 无条件模型预测:$\epsilon_{\theta}(\mathbf{x}_t, t)$
- 类别条件模型预测:$\epsilon_{\theta}(\mathbf{x}_t, t, y)$
通过在训练和推理过程中始终将类别
类别条件扩散模型的训练过程与无条件 DDPM 或 SDE 模型的训练过程基本相同,但增加了对类别条件
-
数据选取:从训练数据集中选取原始图像
$\mathbf{x}_0$ 时,必须同时获取其对应的类别标签$y$ 。 -
噪声与时间步:随机选择时间步
$t$ ,并生成相应的噪声$\epsilon$ 来创建$\mathbf{x}_t$ 。 -
网络输入:将
$\mathbf{x}_t$ 、时间步$t$ 和 类别$y$ 三者同时输入到修改后的 U-Net 模型$\epsilon_{\theta}(\mathbf{x}_t, t, y)$ 中。 -
损失函数:继续使用原始 DDPM 的简化 L2 损失进行优化:
$$\nabla_{\theta} | \epsilon - \epsilon_{\theta}(\mathbf{x}t, t, y) |^2$$
(即训练 $\epsilon{\theta}$ 去预测真实的采样噪声
$\epsilon$ )。 - 稳定性:训练过程与原始 DDPM 教程中描述的完全相同,这表明了扩散模型在加入条件约束后的训练稳定性。
推理过程遵循 DDPM 的逆向去噪链,但每一步都受到目标类别的约束。
-
初始噪声:选取白噪声
$\mathbf{x}_T \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ 。 -
设定条件:需要人为设定一个目标类别标签
$y$ 作为输入。 -
去噪循环:循环
$t=T, \ldots, 1$ 执行去噪步骤:- 模型预测条件噪声
$\epsilon_{\theta}(\mathbf{x}_t, t, y)$ 。 - 使用预测的噪声和设定的类别
$y$ 来计算下一时间步step的去噪状态$\mathbf{x}_{t-1}$ 。
- 模型预测条件噪声
为了实现类别条件化,DDPM 的核心 U-Net 结构必须进行修改,以接受和处理类别信息
-
基础架构:仍使用 U-Net 模型作为噪声预测网络
$\epsilon_{\theta}$ 。 -
条件输入:与时间步
$t$ 类似,类别信息$y$ 必须被 嵌入 (embedding),然后与中间激活张量融合。
1. Adaptive Group Normalization (AdaGN):这是资料中推荐和实现的融合机制。
- 动机:AdaGN 被认为比简单的相加(Addition + GroupNorm)效果更好。
-
方法:类别嵌入
$y$ 和时间步嵌入$t$ 通过线性投影生成用于动态调整 Group Normalization 的缩放和平移参数:$$\displaystyle y_s \verb|GroupNorm|(h) + y_b$$ -
$h$ :残差块(第一个卷积之后)的中间激活。 -
$y_s$ :来自时间步嵌入的线性投影(缩放因子)。 -
$y_b$ :来自类别嵌入的线性投影(平移因子)。
-
- 应用位置:AdaGN 应用于 U-Net 内部的每个残差块(Residual Block)中。
2. 替代方法(简单堆叠):
- 方法:可以简单地将类别嵌入 堆叠(stacking) 到输入张量的通道中。
- 优点:这种方法无需修改现有的 U-Net 架构,只需要调整输入通道的数量。
-
扩散模型的灵活性:SDE 框架统一了 DDPM 和 SMLD,使我们能够将复杂的生成过程视为连续的 SDE。类别条件化进一步展示了扩散模型的强大灵活性:通过简单地将条件
$y$ 纳入到噪声预测网络$\epsilon_{\theta}(\mathbf{x}_t, t, y)$ 中,模型就能实现对连续时间去噪过程的精确控制。 -
融合机制的关键性:在扩散模型中,有效且精细地将条件信息(如
$y$ 或$t$ )注入到 U-Net 的中间层至关重要。使用 AdaGN 等技术,通过动态调整归一化参数,实现了比简单堆叠更强大的性能。 -
时间步和类别条件的耦合:在 AdaGN 中,$y_s$ 来自时间步嵌入,$y_b$ 来自类别嵌入。这意味着模型在特定时间步
$t$ 调整激活时,也同时考虑了目标类别$y$ 的特征,从而实现细粒度的条件控制。 -
低成本扩展:通过在训练中加入
$y$ ,并且保持与原始 DDPM 相同的 L2 损失,类别条件扩散模型在不引入新的复杂损失函数或额外计算开销的情况下,实现了从无条件到条件生成的关键跨越。