Denoising Diffusion GANs代码解析:从扩散系数到生成器设计的关键细节
【免费下载链接】denoising-diffusion-ganTackling the Generative Learning Trilemma with Denoising Diffusion GANs https://arxiv.org/abs/2112.07804项目地址: https://gitcode.com/gh_mirrors/de/denoising-diffusion-gan
欢迎来到Denoising Diffusion GANs(去噪扩散生成对抗网络)的深度解析指南!🚀 本文将带你深入了解这个创新性的生成模型,它巧妙地将扩散模型的高质量生成能力与GANs的高效性相结合,实现了生成学习三难问题的突破。无论你是深度学习新手还是有一定经验的开发者,这篇文章都将为你揭开Denoising Diffusion GANs的神秘面纱。
📊 什么是Denoising Diffusion GANs?
Denoising Diffusion GANs是一种创新的生成模型架构,它解决了传统扩散模型需要数千步采样过程的效率问题。传统的去噪扩散模型通常假设去噪分布可以用高斯分布建模,这个假设只在小的去噪步骤中成立,导致实践中需要数千步去噪步骤才能完成合成过程。
在Denoising Diffusion GANs中,我们使用多模态和复杂的条件GAN来表示去噪模型,使我们能够高效地在少至两步内生成数据。这种设计巧妙地将扩散模型的稳定训练特性与GANs的高效采样能力相结合。
🔧 核心架构解析
扩散系数设计
项目的核心从扩散系数开始。在train_ddgan.py文件中,我们找到了关键的扩散系数计算函数:
# 扩散系数计算 def var_func_vp(t, beta_min, beta_max): log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min var = 1. - torch.exp(2. * log_mean_coeff) return var def var_func_geometric(t, beta_min, beta_max): return beta_min * ((beta_max / beta_min) ** t)这些函数定义了两种不同的方差调度策略:VP(Variance Preserving)调度和几何调度。VP调度是连续时间扩散模型中的标准选择,而几何调度提供了一种替代方案。
时间调度系统
时间调度系统是Denoising Diffusion GANs的关键组成部分。在get_sigma_schedule函数中,我们可以看到如何生成噪声调度:
def get_sigma_schedule(args, device): n_timestep = args.num_timesteps beta_min = args.beta_min beta_max = args.beta_max eps_small = 1e-3 t = np.arange(0, n_timestep + 1, dtype=np.float64) t = t / n_timestep t = torch.from_numpy(t) * (1. - eps_small) + eps_small if args.use_geometric: var = var_func_geometric(t, beta_min, beta_max) else: var = var_func_vp(t, beta_min, beta_max) alpha_bars = 1.0 - var betas = 1 - alpha_bars[1:] / alpha_bars[:-1] first = torch.tensor(1e-8) betas = torch.cat((first[None], betas)).to(device) betas = betas.type(torch.float32) sigmas = betas**0.5 a_s = torch.sqrt(1-betas) return sigmas, a_s, betas这个函数根据用户选择的调度策略(几何或VP)生成噪声调度,这对于控制扩散过程的质量至关重要。
🏗️ 生成器设计细节
NCSN++架构
Denoising Diffusion GANs使用NCSN++(Noise Conditional Score Network++)作为生成器架构。这个架构位于score_sde/models/ncsnpp_generator_adagn.py文件中,是项目中最复杂的组件之一。
@utils.register_model(name='ncsnpp') class NCSNpp(nn.Module): """NCSN++ model""" def __init__(self, config): super().__init__() self.config = config self.not_use_tanh = config.not_use_tanh self.act = act = nn.SiLU() self.z_emb_dim = z_emb_dim = config.z_emb_dim self.nf = nf = config.num_channels_dae ch_mult = config.ch_mult self.num_res_blocks = num_res_blocks = config.num_res_blocks self.attn_resolutions = attn_resolutions = config.attn_resolutions dropout = config.dropout resamp_with_conv = config.resamp_with_conv self.num_resolutions = num_resolutions = len(ch_mult) self.all_resolutions = all_resolutions = [config.image_size // (2 ** i) for i in range(num_resolutions)]NCSN++架构的特点包括:
- 多分辨率处理:支持从高分辨率到低分辨率的渐进式处理
- 注意力机制:在特定分辨率上应用注意力层
- 残差块设计:使用两种类型的残差块(DDPM和BigGAN风格)
- 条件嵌入:支持时间步和噪声级别的条件嵌入
自适应组归一化(AdaGN)
Denoising Diffusion GANs的一个关键创新是使用自适应组归一化(AdaGN)来注入时间步信息:
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn ResnetBlockBigGAN_one = layerspp.ResnetBlockBigGANpp_Adagn_one这些残差块变体集成了AdaGN,允许网络根据当前时间步自适应地调整归一化参数。
🎯 判别器架构
时间条件判别器
判别器设计在score_sde/models/discriminator.py中,它接收两个输入:原始图像和对应的噪声版本。这种设计使得判别器能够学习评估去噪质量:
class DownConvBlock(nn.Module): def __init__( self, in_channel, out_channel, kernel_size=3, padding=1, t_emb_dim = 128, downsample=False, act = nn.LeakyReLU(0.2), fir_kernel=(1, 3, 3, 1) ): super().__init__() self.fir_kernel = fir_kernel self.downsample = downsample self.conv1 = nn.Sequential( conv2d(in_channel, out_channel, kernel_size, padding=padding), )判别器架构包括:
- 时间嵌入:将时间步信息嵌入到网络中
- 下采样块:逐步降低特征图分辨率
- 条件机制:通过密集层将时间信息注入每个块
🔄 训练流程解析
扩散过程采样
在训练过程中,模型使用q_sample_pairs函数生成噪声图像对:
def q_sample_pairs(coeff, x_start, t): """ 生成一对扰动图像用于训练 :param x_start: x_0 :param t: 时间步t :return: x_t, x_{t+1} """ noise = torch.randn_like(x_start) x_t = q_sample(coeff, x_start, t) x_t_plus_one = extract(coeff.a_s, t+1, x_start.shape) * x_t + \ extract(coeff.sigmas, t+1, x_start.shape) * noise return x_t, x_t_plus_one这个函数生成当前时间步t和下一个时间步t+1的噪声图像对,用于训练生成器学习从x_{t+1}到x_t的去噪映射。
损失函数设计
Denoising Diffusion GANs使用对抗性损失来训练生成器,同时结合了R1正则化来稳定训练:
# 判别器损失 real_loss = F.softplus(-real_predict) fake_loss = F.softplus(fake_predict) d_loss = real_loss + fake_loss # R1正则化 r1_loss = compute_grad2(d_out_real, x_t).mean() d_loss = d_loss + args.r1_gamma * 0.5 * r1_loss🚀 快速采样优势
后验采样系数
与传统扩散模型需要数千步采样不同,Denoising Diffusion GANs可以在很少的步骤内生成高质量样本。这得益于后验采样系数的设计:
class Posterior_Coefficients(): def __init__(self, args, device): _, _, self.betas = get_sigma_schedule(args, device=device) # 我们不需要零 self.betas = self.betas.type(torch.float32)[1:] self.alphas = 1 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, 0) self.alphas_cumprod_prev = torch.cat( (torch.tensor([1.], dtype=torch.float32,device=device), self.alphas_cumprod[:-1]), 0 ) self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)这些系数允许模型在推理时使用更少的步骤,同时保持生成质量。
📈 实际应用配置
数据集支持
Denoising Diffusion GANs支持多种数据集,包括:
- CIFAR-10:32×32彩色图像
- LSUN Church Outdoor:256×256场景图像
- CelebA HQ 256:256×256人脸图像
每个数据集都有专门的配置参数,在train_ddgan.py中可以看到相应的数据加载和处理逻辑。
训练命令示例
对于CIFAR-10数据集,训练命令如下:
python3 train_ddgan.py --dataset cifar10 --exp ddgan_cifar10_exp1 --num_channels 3 --num_channels_dae 128 --num_timesteps 4 \ --num_res_blocks 2 --batch_size 64 --num_epoch 1800 --ngf 64 --nz 100 --z_emb_dim 256 --n_mlp 4 --embedding_type positional \ --use_ema --ema_decay 0.9999 --r1_gamma 0.02 --lr_d 1.25e-4 --lr_g 1.6e-4 --lazy_reg 15 --num_process_per_node 4 \ --ch_mult 1 2 2 2 --save_content关键参数说明:
num_timesteps=4:仅需4个时间步num_channels_dae=128:去噪自编码器的通道数embedding_type=positional:使用位置嵌入use_ema:启用指数移动平均
🎨 生成过程详解
采样算法
在test_ddgan.py中,我们可以看到生成过程的核心算法:
# 生成样本的主循环 for i in reversed(range(1, args.num_timesteps)): t = torch.full((x_t.shape[0],), i, device=device, dtype=torch.long) x_t = netG(x_t, t) # 添加噪声进行下一步 if i > 1: z = torch.randn_like(x_t) x_t = x_t + sigmas[i] * z这个过程从纯噪声开始,通过生成器逐步去噪,最终得到清晰的图像。与传统扩散模型相比,这个过程的步骤数大大减少。
🔍 性能评估
FID分数计算
项目使用PyTorch FID实现来评估生成质量:
from pytorch_fid.fid_score import calculate_fid_given_pathsFID(Fréchet Inception Distance)是评估生成模型质量的常用指标,它衡量生成图像分布与真实图像分布之间的距离。
Inception Score计算
除了FID,项目还支持Inception Score计算:
python ./pytorch_fid/inception_score.py --sample_dir /path/to/sampled_images💡 关键设计理念
1.多时间步条件生成
生成器接收当前时间步作为条件输入,使其能够学习不同噪声水平下的去噪映射。
2.对抗性训练
使用判别器评估去噪质量,而不是传统的均方误差损失,这允许模型学习更复杂的多模态分布。
3.高效采样
通过减少时间步数(通常为2-4步),Denoising Diffusion GANs实现了比传统扩散模型快几个数量级的采样速度。
4.稳定训练
结合了多种稳定技术,包括梯度惩罚、指数移动平均和学习率调度。
🛠️ 实用技巧
超参数调优建议
时间步数选择:对于简单数据集(如CIFAR-10),4个时间步通常足够;对于复杂数据集(如CelebA HQ 256),可能需要更多时间步。
学习率设置:生成器和判别器使用不同的学习率,通常生成器的学习率略高于判别器。
批量大小:根据GPU内存调整,更大的批量大小通常有助于稳定训练。
EMA衰减:使用较高的EMA衰减值(如0.9999)有助于生成更稳定的样本。
📊 实验结果
根据论文报告,Denoising Diffusion GANs在多个数据集上取得了优异的结果:
- CIFAR-10:FID 3.75(仅4步采样)
- CelebA HQ 256:FID 5.57(仅2步采样)
- LSUN Church 256:FID 4.98(仅4步采样)
这些结果证明了该模型在保持高质量生成的同时,显著提高了采样效率。
🎯 总结
Denoising Diffusion GANs代表了生成模型领域的一个重要进展,它成功地将扩散模型的稳定训练特性与GANs的高效采样能力相结合。通过创新的架构设计和训练策略,该模型在生成质量、多样性和采样效率之间取得了良好的平衡。
项目的代码结构清晰,模块化设计良好,使得研究人员和开发者能够轻松理解和修改。无论你是想在自己的项目中应用这项技术,还是想深入研究生成模型的前沿,Denoising Diffusion GANs都提供了一个优秀的起点。
希望这篇解析能帮助你更好地理解这个创新性的模型!🌟 如果你对生成模型感兴趣,不妨克隆项目并亲自尝试一下:
git clone https://gitcode.com/gh_mirrors/de/denoising-diffusion-gan开始你的生成模型探索之旅吧!
【免费下载链接】denoising-diffusion-ganTackling the Generative Learning Trilemma with Denoising Diffusion GANs https://arxiv.org/abs/2112.07804项目地址: https://gitcode.com/gh_mirrors/de/denoising-diffusion-gan
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考