GPEN训练全流程指南:数据对生成与学习率设置实战
1. 镜像环境说明
本镜像基于GPEN人像修复增强模型构建,预装了完整的深度学习开发环境,集成了推理及评估所需的所有依赖,开箱即用。适用于人脸超分辨率、图像增强、老照片修复等场景的快速实验与模型训练。
| 组件 | 版本 |
|---|---|
| 核心框架 | PyTorch 2.5.0 |
| CUDA 版本 | 12.4 |
| Python 版本 | 3.11 |
| 推理代码位置 | /root/GPEN |
主要依赖库:
facexlib: 用于人脸检测与对齐basicsr: 基础超分框架支持opencv-python,numpy<2.0,datasets==2.21.0,pyarrow==12.0.1sortedcontainers,addict,yapf
该环境已配置好训练和推理所需的全部组件,用户无需手动安装依赖即可直接进入开发阶段。
2. 快速上手
2.1 激活环境
在使用前,请先激活预设的 Conda 环境:
conda activate torch252.2 模型推理 (Inference)
进入项目主目录并运行推理脚本:
cd /root/GPEN场景 1:运行默认测试图
python inference_gpen.py此命令将处理内置测试图像Solvay_conference_1927.jpg,输出结果为output_Solvay_conference_1927.png。
场景 2:修复自定义图片
python inference_gpen.py --input ./my_photo.jpg输入文件路径通过--input参数指定,输出自动保存为output_my_photo.jpg。
场景 3:自定义输入与输出文件名
python inference_gpen.py -i test.jpg -o custom_name.png支持简写参数-i和-o分别指定输入与输出路径。
注意:所有推理结果将保存在项目根目录下,建议提前确认图片路径正确性以避免报错。
3. 已包含权重文件
为保障离线可用性和部署效率,镜像中已预下载官方训练好的权重文件,位于 ModelScope 缓存路径:
~/.cache/modelscope/hub/iic/cv_gpen_image-portrait-enhancement包含以下关键模型组件:
- GPEN 生成器(Generator):负责高保真人脸细节重建
- Face Detection 模型:基于 RetinaFace 实现精准人脸定位
- Landmark Alignment 模型:实现五点对齐,提升修复稳定性
若未执行过推理任务,系统会在首次调用时自动加载并缓存模型权重,后续无需重复下载。
4. 训练全流程详解
4.1 数据准备:构建高质量-低质量图像对
GPEN 采用监督式训练方式,需提供成对的高清原图(HR)与对应降质图像(LR)。理想情况下,这些 LR 图像是通过对 HR 图像进行模拟退化生成的。
推荐数据来源
- FFHQ(Flickr-Faces-HQ):广泛用于人脸生成任务,共7万张高分辨率人像(1024×1024)
- 下载地址:https://github.com/NVlabs/ffhq-dataset
数据对生成策略
可使用RealESRGAN或BSRGAN提供的退化流程生成逼真的低质量样本。
示例代码片段(使用 RealESRGAN 的 degradation 流程):
import cv2 import numpy as np from basicsr.data.degradations import random_add_gaussian_noise, random_add_poisson_noise from basicsr.utils import img2tensor, tensor2img def degrade_image(hr_img_path, save_lr_path): # 读取高清图像 hr_img = cv2.imread(hr_img_path) hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB) # 添加模糊核 kernel_size = np.random.choice([7, 9, 11]) sigma = np.random.uniform(0.2, 2.0) blur_kernel = cv2.getGaussianKernel(kernel_size, sigma) degraded = cv2.filter2D(hr_img, -1, blur_kernel @ blur_kernel.T) # 下采样(模拟低分辨率) scale = 4 h, w = degraded.shape[:2] degraded = cv2.resize(degraded, (w // scale, h // scale), interpolation=cv2.INTER_LINEAR) # 上采样回原始尺寸(保持空间一致) degraded = cv2.resize(degraded, (w, h), interpolation=cv2.INTER_LINEAR) # 添加噪声 degraded = img2tensor(degraded.astype(np.float32) / 255., bgr2rgb=False, float32=True) degraded = random_add_gaussian_noise(degraded, sigma_range=[1, 30], clip=True, rounds=False, gray_prob=0.4) degraded = random_add_poisson_noise(degraded, scale_range=[0, 1.0], gray_prob=0.4, clip=True, rounds=False) # 转换回图像格式并保存 degraded = tensor2img(degraded, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)) cv2.imwrite(save_lr_path, degraded) # 使用示例 degrade_image('./data/hr/example.jpg', './data/lr/example.jpg')提示:建议统一将图像缩放到目标分辨率(如 512×512),避免训练过程中的尺寸不匹配问题。
4.2 训练配置与参数设置
目录结构要求
确保训练数据组织如下:
datasets/ ├── train/ │ ├── HR/ │ └── LR/ └── val/ ├── HR/ └── LR/修改训练配置文件
编辑options/train_GAN_stage.json中的关键参数:
{ "datasets": { "train": { "name": "gpen_train", "type": "PairedImageDataset", "dataroot_gt": "./datasets/train/HR", "dataroot_lq": "./datasets/train/LR", "io_backend": {"type": "disk"} }, "val": { "name": "gpen_val", "type": "PairedImageDataset", "dataroot_gt": "./datasets/val/HR", "dataroot_lq": "./datasets/val/LR" } }, "network_g": { "type": "GPENModel", "in_channel": 3, "out_channel": 3, "channel": 256, "n_res": 12 }, "train": { "optim_g": { "type": "Adam", "lr": 0.0001, "weight_decay": 0, "betas": [0.9, 0.99] }, "scheduler": { "type": "CosineAnnealingRestartLR", "periods": [250000, 250000, 250000], "restart_weights": [1, 1, 1], "eta_min": 1e-7 }, "total_iter": 750000, "warmup_iter": -1 } }学习率设置建议
| 模块 | 初始学习率 | 调整策略 | 说明 |
|---|---|---|---|
| 生成器(Generator) | 1e-4 | Cosine + Restart | 稳定收敛,防止震荡 |
| 判别器(Discriminator) | 1e-4 ~ 5e-4 | 同步或稍高于生成器 | 提升对抗能力 |
| 特征提取网络(如 VGG) | 冻结或 1e-5 | 小幅微调 | 防止破坏已有特征 |
经验法则:判别器学习率可略高于生成器(例如 1.2~1.5 倍),但不宜过高以免导致模式崩溃。
4.3 启动训练
执行以下命令开始训练:
python train.py -opt options/train_GAN_stage.json训练过程中日志将输出至控制台,并记录于./experiments目录下的时间戳子文件夹中,包括:
- 模型权重(每10k iter保存一次)
- 可视化中间结果(每epoch保存)
- 损失曲线(可通过 TensorBoard 查看)
5. 性能优化与常见问题
5.1 显存不足解决方案
- 降低 batch_size:从默认 8 降至 4 或 2
- 启用梯度累积(Gradient Accumulation)
修改配置文件添加:
"train": { "accumulations": 2, ... }相当于每两次前向传播后才更新一次参数,等效增大 batch size。
5.2 训练不稳定应对措施
| 现象 | 可能原因 | 解决方案 |
|---|---|---|
| 输出图像模糊 | 判别器过强或 L1 损失权重不足 | 降低判别器学习率,增加pixel_loss_weight |
| 出现伪影或畸变 | GAN loss 收敛异常 | 引入谱归一化(Spectral Norm),调整 adversarial weight |
| 模式崩塌 | 判别器主导训练 | 使用 R1 正则化,增加数据多样性 |
5.3 推理加速技巧
- FP16 推理:在支持 Tensor Core 的 GPU 上启用半精度推断
- ONNX 导出:将
.pth模型导出为 ONNX 格式,结合 TensorRT 加速部署 - TorchScript 编译:使用
torch.jit.script()提升推理吞吐量
6. 总结
本文围绕 GPEN 人像修复增强模型,系统介绍了从镜像环境搭建、推理使用到完整训练流程的实践方法。重点涵盖:
- 数据对生成机制:利用 RealESRGAN/BSRGAN 的退化流程构造真实感 LR-HR 对;
- 学习率配置策略:推荐生成器初始学习率为
1e-4,配合余弦重启调度器实现稳定收敛; - 训练工程优化:包括显存管理、梯度累积、损失平衡等实用技巧;
- 开箱即用优势:本镜像预置完整依赖与权重,极大降低入门门槛。
通过合理设置训练参数与数据质量控制,可在 512×512 分辨率下实现高质量的人脸细节恢复,适用于老旧照片修复、安防图像增强等多种应用场景。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。