news 2026/4/17 21:29:53

用PyTorch从零实现DQN算法:以CartPole游戏为例(附完整代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
用PyTorch从零实现DQN算法:以CartPole游戏为例(附完整代码)

用PyTorch从零实现DQN算法:以CartPole游戏为例(附完整代码)

在强化学习领域,深度Q网络(DQN)算法无疑是一座重要的里程碑。它将深度学习的强大表征能力与强化学习的决策框架完美结合,为解决复杂环境中的决策问题提供了新思路。对于已经掌握Python和PyTorch基础,想要深入实践强化学习的开发者来说,从零实现一个DQN算法并将其应用于经典控制问题CartPole,是一次绝佳的学习机会。

本文将带你一步步构建完整的DQN系统,从网络架构设计到训练策略优化,每个环节都配有详细的代码解析和实战技巧。不同于理论推导为主的教程,我们更关注工程实现中的"坑"与"解",比如如何设置合理的奖励机制、调试探索率衰减策略、优化经验回放缓冲区等实际问题。通过这个项目,你不仅能理解DQN的核心思想,更能获得可直接复用的代码模板。

1. 环境准备与问题定义

在开始编码之前,我们需要明确CartPole问题的具体定义。这是一个经典的强化学习测试环境:一根杆子通过非驱动关节连接到小车上,小车沿着无摩擦的轨道移动。系统的状态由四个连续变量描述:

  • 小车位置(-4.8到4.8)
  • 小车速度(无限制)
  • 杆子角度(约-24°到24°)
  • 杆子顶端速度(无限制)

动作空间是离散的:向左施加力(0)或向右施加力(1)。每步的奖励为+1,当杆子倾斜超过15度、小车移动超出边界(中心点2.4单位距离)或持续200步时,回合结束。

安装必要依赖

pip install gym torch numpy

关键参数初始化

import gym import torch import numpy as np env = gym.make('CartPole-v1') state_size = env.observation_space.shape[0] # 4 action_size = env.action_space.n # 2

2. DQN核心组件实现

2.1 Q网络架构设计

DQN的核心是用神经网络近似Q函数。我们设计一个三层的全连接网络,输入维度与状态空间匹配(4),输出维度与动作空间匹配(2)。隐藏层使用ReLU激活函数引入非线性。

import torch.nn as nn import torch.nn.functional as F class QNetwork(nn.Module): def __init__(self, state_size, action_size, hidden_size=24): super(QNetwork, self).__init__() self.fc1 = nn.Linear(state_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, action_size) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x)

提示:隐藏层大小是重要的超参数。过小会导致欠拟合,过大则可能过拟合。24-64之间的值对CartPole通常效果不错。

2.2 经验回放机制

经验回放是DQN稳定训练的关键技术,它通过存储并随机采样过往经验,打破数据间的相关性。

from collections import deque import random class ReplayBuffer: def __init__(self, capacity=2000): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size) def __len__(self): return len(self.buffer)

经验回放的三个优势

  1. 提高数据效率:每条经验可被多次使用
  2. 减少相关性:随机采样打破时序依赖
  3. 稳定训练:平滑学习过程

3. DQN智能体实现

3.1 智能体核心逻辑

DQN智能体需要管理探索与利用的平衡(ε-greedy策略)、目标网络更新和经验回放等关键功能。

class DQNAgent: def __init__(self, state_size, action_size): self.state_size = state_size self.action_size = action_size self.memory = ReplayBuffer() self.gamma = 0.95 # 未来奖励折扣因子 self.epsilon = 1.0 # 初始探索率 self.epsilon_min = 0.01 self.epsilon_decay = 0.995 self.learning_rate = 0.001 self.model = QNetwork(state_size, action_size) self.target_model = QNetwork(state_size, action_size) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) self.update_target_model() def update_target_model(self): self.target_model.load_state_dict(self.model.state_dict()) def act(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_size) state = torch.FloatTensor(state) with torch.no_grad(): q_values = self.model(state) return torch.argmax(q_values).item() def train(self, batch_size): if len(self.memory) < batch_size: return minibatch = self.memory.sample(batch_size) states = torch.FloatTensor([t[0] for t in minibatch]) actions = torch.LongTensor([t[1] for t in minibatch]) rewards = torch.FloatTensor([t[2] for t in minibatch]) next_states = torch.FloatTensor([t[3] for t in minibatch]) dones = torch.FloatTensor([t[4] for t in minibatch]) current_q = self.model(states).gather(1, actions.unsqueeze(1)) next_q = self.target_model(next_states).max(1)[0].detach() target = rewards + (1 - dones) * self.gamma * next_q loss = F.mse_loss(current_q.squeeze(), target) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay

3.2 训练流程优化

训练过程中有几个关键点需要特别注意:

  1. 奖励设计:CartPole默认每步+1奖励,但可以调整终止惩罚
  2. 探索策略:ε的初始值和衰减率需要调优
  3. 目标网络更新:可以定期更新或软更新
def train_agent(env, agent, episodes=1000, batch_size=32): scores = [] for e in range(episodes): state = env.reset() total_reward = 0 for t in range(500): # 最大步数 action = agent.act(state) next_state, reward, done, _ = env.step(action) # 自定义终止惩罚 reward = reward if not done else -10 agent.memory.push(state, action, reward, next_state, done) state = next_state total_reward += reward agent.train(batch_size) if done: break scores.append(total_reward) # 定期更新目标网络 if e % 10 == 0: agent.update_target_model() print(f"Episode: {e}, Score: {total_reward}, Epsilon: {agent.epsilon:.2f}") return scores

4. 高级技巧与性能优化

4.1 双重DQN(Double DQN)

原始DQN存在Q值高估问题。双重DQN通过解耦动作选择和Q值评估来缓解这个问题:

# 在DQNAgent类的train方法中修改目标Q计算 next_actions = self.model(next_states).max(1)[1].unsqueeze(1) next_q = self.target_model(next_states).gather(1, next_actions).squeeze() target = rewards + (1 - dones) * self.gamma * next_q

4.2 优先级经验回放

不是所有经验都同等重要。可以为缓冲区中的经验分配优先级,更频繁地回放"重要"经验:

class PrioritizedReplayBuffer: def __init__(self, capacity=2000, alpha=0.6): self.buffer = deque(maxlen=capacity) self.priorities = deque(maxlen=capacity) self.alpha = alpha def push(self, state, action, reward, next_state, done): max_prio = max(self.priorities) if self.priorities else 1.0 self.buffer.append((state, action, reward, next_state, done)) self.priorities.append(max_prio) def sample(self, batch_size, beta=0.4): prios = np.array(self.priorities) probs = prios ** self.alpha probs /= probs.sum() indices = np.random.choice(len(self.buffer), batch_size, p=probs) samples = [self.buffer[idx] for idx in indices] weights = (len(self.buffer) * probs[indices]) ** (-beta) weights /= weights.max() return samples, indices, np.array(weights, dtype=np.float32) def update_priorities(self, indices, priorities): for idx, prio in zip(indices, priorities): self.priorities[idx] = prio

4.3 超参数调优指南

DQN性能对超参数敏感。以下是经过实验验证的推荐范围:

超参数推荐值作用
γ (gamma)0.9-0.99未来奖励折扣因子
ε初始值1.0初始探索率
ε最小值0.01-0.1最小探索率
ε衰减率0.99-0.999探索率衰减速度
学习率1e-4到1e-3优化器步长
批量大小32-128每次训练样本数
目标网络更新频率每10-100步稳定训练

在实际项目中,我发现ε衰减策略对最终性能影响显著。一个实用的技巧是在训练初期保持较高探索率(ε=1.0),然后随着训练逐步衰减,但不要降得太低(保持在0.01左右),以保留一定的探索能力。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 21:29:07

从程序员到AI大模型专家:一份超全转行攻略与学习资源大放送!

本文为有志于转行至AI大模型领域的开发程序员提供了一份详细攻略&#xff0c;涵盖数学知识、编程技能、机器学习基础、深度学习、特定领域知识、实践项目、阅读论文、参加会议、跟踪行业动态、面试准备、持续学习与适应变化以及心态调整等方面。此外&#xff0c;还提供了丰富的…

作者头像 李华
网站建设 2026/4/17 21:27:56

ABAP CDS注解实战:从元数据定义到系统交互的深度解析

1. ABAP CDS注解&#xff1a;数据模型的元数据桥梁 第一次接触ABAP CDS注解时&#xff0c;我把它当成了普通的代码注释——直到发现用AbapCatalog.sqlViewName定义的视图竟然自动出现在SE11事务码里&#xff0c;才意识到这完全是另一种存在。CDS注解本质上是一套结构化元数据标…

作者头像 李华
网站建设 2026/4/17 21:23:37

log2对数二阶多项式近似计算

目录 0. 目标 1. 对数核心分解 2. 为什么只需要近似 f ∈ [1,2)&#xff1f; 3. 二阶多项式近似公式 4. Q8 定点化&#xff08;系数 369、185 的由来&#xff09; 5. 归一化 f&#xff08;代码最关键一步&#xff09; 6. d 的 Q8 表示 7. 二阶多项式计算 8. 最终结果合…

作者头像 李华
网站建设 2026/4/17 21:21:59

终极指南:3步在Windows上安装安卓应用,告别臃肿模拟器

终极指南&#xff1a;3步在Windows上安装安卓应用&#xff0c;告别臃肿模拟器 【免费下载链接】APK-Installer An Android Application Installer for Windows 项目地址: https://gitcode.com/GitHub_Trending/ap/APK-Installer APK-Installer是一款专为Windows系统设计…

作者头像 李华