用PyTorch复现NeRF:从零开始手把手教你训练自己的乐高小车3D模型(附完整代码)
当你第一次看到NeRF生成的3D模型时,那种震撼感是难以言喻的——从一个简单的神经网络中,竟然能重建出如此精细的三维场景。作为计算机视觉领域近年来的重大突破,神经辐射场技术正在改变我们对3D重建的认知。本文将带你从零开始,用PyTorch完整实现一个NeRF模型,并以经典的乐高小车数据集为例,一步步教你训练出自己的3D重建模型。
1. 环境准备与数据加载
在开始之前,我们需要搭建一个合适的开发环境。建议使用Python 3.8+和PyTorch 1.10+版本,这些组合经过验证能够很好地支持NeRF的实现。
基础环境配置:
conda create -n nerf python=3.8 conda activate nerf pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install numpy matplotlib imageio scikit-image opencv-python乐高数据集是NeRF论文中使用的标准数据集之一,它包含了从不同角度拍摄的乐高小车图像及对应的相机参数。我们可以从官方仓库下载这个数据集:
import os import json import numpy as np import imageio def load_lego_data(data_dir): """加载乐高数据集""" with open(os.path.join(data_dir, 'transforms_train.json'), 'r') as f: train_meta = json.load(f) images = [] poses = [] for frame in train_meta['frames']: img_path = os.path.join(data_dir, frame['file_path'][2:]) images.append(imageio.imread(img_path)) poses.append(np.array(frame['transform_matrix'])) images = np.stack(images, axis=0) poses = np.stack(poses, axis=0) return images, poses, train_meta['camera_angle_x']注意:数据集中的图像需要归一化到0-1范围,相机参数需要正确处理世界坐标系和相机坐标系之间的转换关系。
2. NeRF模型架构实现
NeRF的核心是一个多层感知机(MLP),它接收5D坐标(3D位置+2D视角方向)作为输入,输出体积密度和视角相关的RGB颜色值。让我们用PyTorch来实现这个网络:
import torch import torch.nn as nn import torch.nn.functional as F class PositionalEncoding(nn.Module): """位置编码模块,将低维输入映射到高维空间""" def __init__(self, L=10): super().__init__() self.L = L def forward(self, x): encodings = [] for i in range(self.L): encodings.append(torch.sin(2**i * torch.pi * x)) encodings.append(torch.cos(2**i * torch.pi * x)) return torch.cat(encodings, dim=-1) class NeRF(nn.Module): """NeRF核心网络结构""" def __init__(self, L_pos=10, L_dir=4): super().__init__() # 位置编码 self.pos_encoder = PositionalEncoding(L_pos) self.dir_encoder = PositionalEncoding(L_dir) # 主干网络 self.backbone = nn.Sequential( nn.Linear(3 + 3*2*L_pos, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), ) # 密度预测头 self.density_head = nn.Sequential( nn.Linear(256, 1), nn.ReLU() ) # 颜色预测头 self.color_head = nn.Sequential( nn.Linear(256 + 3*2*L_dir, 128), nn.ReLU(), nn.Linear(128, 3), nn.Sigmoid() ) def forward(self, x, d): # 编码位置和方向 x_encoded = self.pos_encoder(x) d_encoded = self.dir_encoder(d) # 通过主干网络 features = self.backbone(x_encoded) # 预测密度 sigma = self.density_head(features) # 预测颜色 color_features = torch.cat([features, d_encoded], dim=-1) rgb = self.color_head(color_features) return rgb, sigma提示:位置编码是NeRF能够捕捉高频细节的关键,确保L_pos和L_dir参数设置与论文一致。
3. 体渲染实现与训练流程
有了模型架构后,我们需要实现NeRF的核心创新之一——可微分的体渲染过程。这个过程将模型预测的密度和颜色转换为最终的2D图像。
def render_rays(model, rays_o, rays_d, near, far, N_samples, rand=False): """渲染单条射线上的颜色""" # 计算采样点 t_vals = torch.linspace(near, far, N_samples) if rand: t_vals = t_vals + torch.rand_like(t_vals) * (far-near)/N_samples # 采样点坐标 pts = rays_o[...,None,:] + rays_d[...,None,:] * t_vals[...,:,None] # 扩展视角方向 dirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) dirs = dirs.expand(N_samples, -1) # 预测颜色和密度 rgb, sigma = model(pts, dirs) # 计算透明度 delta = t_vals[...,1:] - t_vals[...,:-1] delta = torch.cat([delta, torch.tensor([1e10]).expand(delta[...,:1].shape)], dim=-1) alpha = 1 - torch.exp(-sigma.squeeze() * delta) # 累积透明度 weights = alpha * torch.cumprod(1.-alpha + 1e-10, dim=-1) # 计算最终颜色 rgb_map = torch.sum(weights[...,None] * rgb, dim=-2) return rgb_map def train(model, optimizer, images, poses, hwf, N_samples=64, N_rand=1024): """训练循环""" H, W, focal = hwf for epoch in range(100000): # 随机选择一条射线 img_i = np.random.randint(0, len(images)) pose = poses[img_i] # 随机选择像素 i = np.random.randint(0, H, size=N_rand) j = np.random.randint(0, W, size=N_rand) # 生成射线 rays_o, rays_d = get_rays(H, W, focal, pose, i, j) # 渲染 rgb_pred = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples) # 计算损失 target = images[img_i][i, j] loss = F.mse_loss(rgb_pred, target) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step()在实际训练中,我们还需要实现多层级采样策略,即先用粗采样估计场景结构,再在重要区域进行精细采样。这可以显著提高渲染质量。
4. 实战技巧与性能优化
训练NeRF模型时,有几个关键技巧可以显著提高训练效率和最终效果:
GPU内存优化:NeRF的体渲染过程需要处理大量射线和采样点,容易导致GPU内存不足。可以通过以下方式优化:
- 分块处理:将射线分成小块进行处理
chunk_size = 1024 * 16 # 根据GPU内存调整 for i in range(0, rays_o.shape[0], chunk_size): rgb_chunk = render_rays(model, rays_o[i:i+chunk_size], rays_d[i:i+chunk_size], near, far, N_samples)- 混合精度训练:使用PyTorch的自动混合精度
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): rgb_pred = render_rays(model, rays_o, rays_d, near, far, N_samples) loss = F.mse_loss(rgb_pred, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()训练加速技巧:
| 技巧 | 实现方式 | 效果提升 |
|---|---|---|
| 学习率预热 | 前1000步线性增加学习率 | 训练更稳定 |
| 权重衰减 | Adam优化器设置weight_decay=1e-5 | 防止过拟合 |
| 学习率调度 | 每10k步学习率减半 | 更好收敛 |
| 早停策略 | 验证集PSNR不再提升时停止 | 节省时间 |
常见问题解决方案:
模型收敛慢:
- 检查位置编码是否正确实现
- 确保相机参数处理正确
- 尝试增加网络容量
渲染结果模糊:
- 增加采样点数量
- 检查位置编码的频带数量
- 确保视角方向输入正确
训练不稳定:
- 降低学习率
- 添加梯度裁剪
- 使用学习率预热
5. 结果可视化与模型评估
训练完成后,我们需要评估模型的质量并可视化结果。NeRF论文中常用的评估指标包括PSNR、SSIM和LPIPS。
渲染新视角:
def render_path(model, poses, hwf, save_path): """渲染相机路径上的所有图像""" H, W, focal = hwf images = [] for pose in poses: rays_o, rays_d = get_rays(H, W, focal, pose) rgb = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=128) images.append((rgb.cpu().numpy() * 255).astype(np.uint8)) # 保存为视频 imageio.mimsave(save_path, images, fps=30)评估指标计算:
def compute_metrics(gt_images, pred_images): """计算PSNR和SSIM""" psnrs = [] ssims = [] for gt, pred in zip(gt_images, pred_images): # 计算PSNR mse = np.mean((gt - pred) ** 2) psnr = -10. * np.log(mse) / np.log(10) psnrs.append(psnr) # 计算SSIM ssim = structural_similarity(gt, pred, multichannel=True) ssims.append(ssim) return np.mean(psnrs), np.mean(ssims)在乐高数据集上,一个训练良好的NeRF模型应该能达到以下指标:
| 指标 | 粗网络 | 精细网络 |
|---|---|---|
| PSNR | 28.5 dB | 31.2 dB |
| SSIM | 0.92 | 0.94 |
| 训练时间 | 8小时 | 15小时 |
6. 进阶技巧与扩展应用
掌握了基础NeRF实现后,你可以尝试以下进阶技巧来提升模型性能:
动态场景处理:通过添加时间维度,NeRF可以用于动态场景建模。这需要修改网络结构以接受时间输入:
class DynamicNeRF(nn.Module): def __init__(self): super().__init__() self.time_encoder = PositionalEncoding(L=4) # 其余部分与标准NeRF类似大规模场景优化:对于大场景,可以使用以下策略:
- 空间哈希编码加速训练
- 场景分块处理
- 渐进式训练策略
实时渲染技术:NeRF的传统渲染速度较慢,但可以通过以下方法加速:
- 网络蒸馏为轻量级模型
- 预计算辐射场数据
- 使用专用推理引擎
在实际项目中,我发现最耗时的部分往往是数据准备和参数调试。特别是相机参数的准确性对最终结果影响极大,建议在开始训练前仔细验证数据质量。另一个实用技巧是在训练初期使用低分辨率图像,待模型初步收敛后再切换到高分辨率,这可以显著节省训练时间。