1. 项目概述:当贪吃蛇遇上强化学习
最近在GitHub上看到一个挺有意思的项目,叫“linyiLYi/snake-ai”。光看名字,你大概就能猜到,这是一个用人工智能来玩贪吃蛇游戏的项目。贪吃蛇,这个几乎刻在每个人童年记忆里的经典游戏,规则简单到极致:控制一条蛇在网格里移动,吃到食物变长,撞到墙壁或自己的身体就结束。但就是这样一个简单的游戏,却成了检验和展示各种AI算法,特别是强化学习算法的绝佳“试验场”。
这个项目本质上是一个强化学习的实战演练。它不只是一个简单的脚本,而是一个完整的、结构清晰的代码库,旨在通过贪吃蛇这个直观的载体,来演示如何从零开始构建、训练并最终得到一个能自主玩游戏的AI智能体。对于刚接触强化学习的朋友来说,看那些关于Q-Learning、DQN、策略梯度的理论公式,可能远不如亲眼看着一条蛇从四处乱撞、频繁自杀,到逐渐学会规划路径、高效觅食来得震撼和直观。这个项目就提供了这样一个窗口。
它适合谁呢?首先,当然是强化学习的入门者和爱好者。如果你已经看过了David Silver的课程或者Sutton的经典教材,正苦于找不到一个合适的、轻量级的项目来练手,那么这个项目再合适不过了。其次,它也适合对游戏AI感兴趣的程序员,想了解一个简单的游戏环境是如何被抽象成标准的强化学习问题(状态、动作、奖励)。最后,哪怕你只是想找一个结构清晰、注释良好的Python项目来学习代码组织,它也能给你不少启发。
接下来,我会带你深入这个项目的内部,拆解它的核心设计思路、代码实现的关键细节,并分享在复现和实验过程中可能遇到的“坑”以及如何填平它们。我们的目标不仅仅是让代码跑起来,更是要理解每一步背后的“为什么”,从而真正掌握用强化学习解决实际问题的完整流程。
2. 核心设计思路与算法选型
2.1 问题定义:将游戏转化为强化学习框架
任何强化学习任务的第一步,都是清晰地定义智能体与环境交互的框架。对于贪吃蛇AI,我们需要明确三个核心要素:状态(State)、动作(Action)和奖励(Reward)。
状态(State):智能体“看到”的世界是什么样子?最直观的想法是把整个游戏棋盘(比如10x10的网格)的像素或格子类型(空、蛇身、食物、墙壁)直接作为输入。但这对于简单的贪吃蛇来说可能信息冗余,且维度较高。更高效的做法是进行特征工程。在这个项目中,常见的状态表示可能包括:
- 相对方向特征:食物相对于蛇头的位置(上、下、左、右)。
- 危险感知特征:蛇头前方、左方、右方一格是否是墙壁或自己的身体(即是否危险)。
- 蛇身方向:蛇当前的前进方向。 这种特征化的状态表示,大大降低了状态空间的维度,让模型更容易学习。
动作(Action):智能体能做什么?贪吃蛇的动作空间很简单,通常就是四个方向:上、下、左、右。但需要注意的是,在代码实现时,要防止蛇做出“自杀式”的180度掉头(例如从向右移动瞬间变为向左),这通常通过规则进行限制。
奖励(Reward):环境如何评价智能体的行为?设计奖励函数是强化学习中最具艺术性也最关键的一环。一个朴素的设计是:
- 吃到食物:+10分(正向激励)。
- 撞墙或撞到自己:-10分,并结束游戏(严厉惩罚)。
- 其他每一步:-0.1分或-0.01分(鼓励快速找到食物,避免无效徘徊)。 这个项目可能会采用类似的奖励结构。细微的调整,比如增加“越来越靠近食物”的小额正向奖励,或调整每步惩罚的系数,都会显著影响AI的学习效率和最终策略。
2.2 算法选择:为什么是Deep Q-Network (DQN)?
贪吃蛇的状态,即使经过特征化,其可能组合对于传统的表格型Q-Learning来说也太多了(因为蛇的长度、食物位置不断变化)。因此,使用函数逼近器(如神经网络)来估计Q值(状态-动作价值)是更可行的方案。Deep Q-Network (DQN) 正是将深度学习与Q-Learning结合的里程碑式算法,也是此类项目最经典的选择。
项目选择DQN或它的变种(如Double DQN, Dueling DQN),主要基于以下几点考量:
- 适用性:DQN非常适合处理像贪吃蛇这样具有离散动作空间(四个方向)的问题。
- 成熟度:DQN算法非常经典,有大量的开源实现、教程和调参经验可供参考,降低了项目的不确定性。
- 教育意义:通过实现DQN,可以深入理解经验回放(Experience Replay)、目标网络(Target Network)这两个稳定训练的关键技术。
- 扩展性:在基础DQN上,可以相对容易地集成其他改进,如Double DQN(解决Q值过估计)、Dueling DQN(更好地区分状态价值和动作优势),让项目有持续的迭代和优化空间。
注意:虽然DQN是主流选择,但并不意味着它是唯一的。像Policy Gradient(策略梯度)方法,例如REINFORCE或A2C/A3C,也可以用于解决这个问题。它们直接学习策略(给定状态选择动作的概率分布),在某些情况下可能更稳定。但这个项目选择DQN作为起点,无疑是一个稳健且教育意义最大的决定。
2.3 项目结构设计
一个清晰的项目结构是代码可读、可维护的基础。linyiLYi/snake-ai项目通常会包含以下模块:
game/:贪吃蛇游戏环境的核心逻辑。包括网格渲染、蛇的移动、食物生成、碰撞检测等。这部分通常会被封装成一个类(如SnakeGame),提供类似OpenAI Gym的接口:reset(),step(action),render()。agent/:AI智能体的定义。包含神经网络模型的定义(如DuelingQNetwork)、经验回放缓冲区(ReplayBuffer)以及核心的学习算法(learn方法)。models/:存放训练好的模型权重文件(.pth或.h5)。utils/:一些工具函数,如绘制训练曲线、保存/加载模型、配置参数管理等。train.py:训练脚本的主入口。包含训练循环,控制智能体与环境交互,并定期保存模型和日志。test.py/play.py:测试或演示脚本,加载训练好的模型,直观展示AI的表现。config.py或params.yaml:集中管理所有超参数(学习率、折扣因子、回放缓冲区大小等),方便实验和调参。
这种模块化的设计,使得游戏逻辑、AI算法、训练流程相互解耦,无论是想要更换游戏环境还是尝试新的强化学习算法,都可以在最小范围内修改代码,非常优雅。
3. 关键代码实现细节解析
3.1 游戏环境(Environment)封装
游戏环境是智能体交互的对象,其设计的优劣直接影响训练效率。一个标准的强化学习环境需要提供几个关键方法。
首先是reset()方法。它负责初始化游戏状态,返回初始观察(状态)。在贪吃蛇中,这包括将蛇重置为初始长度(通常为1),在随机空闲位置生成食物,并计算初始的状态特征向量。
def reset(self): # 初始化蛇:通常放在棋盘中央,长度为1 self.snake = [(self.board_size // 2, self.board_size // 2)] self.direction = random.choice([UP, DOWN, LEFT, RIGHT]) # 随机初始方向 self.food = self._generate_food() # 在非蛇身位置生成食物 self.score = 0 self.done = False self.steps = 0 return self._get_state() # 返回特征化后的状态核心是step(self, action)方法。它接收智能体选择的动作,更新游戏状态,并返回下一个状态、奖励、是否结束标志以及其他信息。
def step(self, action): # 1. 处理动作:防止180度反向 if (action == UP and self.direction != DOWN) or ...: self.direction = action # 2. 计算新的蛇头位置 head_x, head_y = self.snake[0] if self.direction == UP: new_head = (head_x, head_y - 1) # ... 其他方向 # 3. 碰撞检测 game_over = False reward = self.step_penalty # 默认每步的小惩罚,例如-0.01 # 撞墙? if not (0 <= new_head[0] < self.board_size and 0 <= new_head[1] < self.board_size): game_over = True reward = self.collision_penalty # 例如 -10 # 撞自己? elif new_head in self.snake: game_over = True reward = self.collision_penalty else: # 4. 移动蛇身 self.snake.insert(0, new_head) # 5. 判断是否吃到食物 if new_head == self.food: self.score += 1 reward = self.food_reward # 例如 +10 self.food = self._generate_food() # 生成新食物 else: # 没吃到食物,移除蛇尾,保持长度不变 self.snake.pop() self.steps += 1 # 可选:防止游戏无限循环,设置最大步数限制 if self.steps > self.max_steps: game_over = True next_state = self._get_state() if not game_over else None return next_state, reward, game_over, {'score': self.score}_get_state()函数是实现特征工程的地方。一个简单的例子是返回一个包含7个布尔值的向量:
- 危险:正前方有障碍吗?
- 危险:右前方有障碍吗?
- 危险:左前方有障碍吗?
- 食物在蛇头的上方吗?
- 食物在蛇头的下方吗?
- 食物在蛇头的左方吗?
- 食物在蛇头的右方吗? 这种表示非常紧凑,且包含了决策所需的关键信息。
3.2 DQN智能体(Agent)实现
智能体类是项目的大脑,它包含策略(如何根据状态选择动作)和学习(如何从经验中更新模型)两部分。
神经网络模型:输入是状态特征向量(例如维度为7),输出是4个Q值,分别对应4个动作。网络结构通常很简单,两到三个全连接层足以。
import torch.nn as nn import torch.nn.functional as F class QNetwork(nn.Module): def __init__(self, state_size, action_size, hidden_size=128): super(QNetwork, self).__init__() self.fc1 = nn.Linear(state_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, action_size) def forward(self, state): x = F.relu(self.fc1(state)) x = F.relu(self.fc2(x)) return self.fc3(x) # 输出Q值,不经过softmax经验回放缓冲区(Replay Buffer):这是DQN稳定训练的关键。它存储智能体与环境交互的经验元组(state, action, reward, next_state, done)。训练时,随机从缓冲区中采样一小批(mini-batch)经验,打破了数据间的时序相关性,大大提高了数据利用效率和训练的稳定性。
from collections import deque import random class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) # 固定大小的队列,满时自动丢弃旧经验 def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size) def __len__(self): return len(self.buffer)学习过程(learn方法):这是DQN算法的核心。其步骤如下:
- 采样:从经验回放缓冲区中随机采样一个批次的经验。
- 计算当前Q值:将批次中的
state输入在线网络(online network),得到对应每个动作的Q值,然后根据实际执行的action索引出对应的Q值,记为Q_current。 - 计算目标Q值:对于批次中的每个经验:
- 如果
done为True(游戏结束),则目标Q值就是reward。 - 如果
done为False,则目标Q值 =reward+gamma*max(Q_target(next_state))。这里Q_target是目标网络(target network)的输出。目标网络是在线网络的一个滞后副本,定期从在线网络同步参数,用于稳定目标值的计算,防止“追逐移动目标”导致的振荡。
- 如果
- 计算损失:使用均方误差损失(MSE Loss)计算
Q_current和目标Q值的差异。 - 反向传播:通过优化器(如Adam)更新在线网络的参数。
- 软更新目标网络:每隔一定步数(C步),将在线网络的参数以微小比例(τ,如0.001)更新到目标网络:
target_params = τ * online_params + (1 - τ) * target_params。
实操心得:目标网络的更新频率(C)和软更新系数(τ)是需要仔细调参的。更新太频繁(C太小)或硬更新(τ=1)容易导致训练不稳定;更新太慢(C太大)则学习效率低下。通常C在1000到10000步之间,τ在0.001到0.01之间开始尝试。
3.3 探索与利用的平衡:ε-贪婪策略
在训练初期,智能体对世界一无所知,需要大量探索(尝试随机动作)来收集经验。随着学习的进行,它应该更多地利用(exploit)已学到的知识,选择当前认为最优的动作。ε-贪婪策略完美地实现了这一平衡。
def select_action(self, state, epsilon): # state: 当前状态,可能是numpy数组 # epsilon: 探索概率,随着训练进行从高(如1.0)衰减到低(如0.01) if random.random() < epsilon: return random.randrange(self.action_size) # 探索:随机选动作 else: # 利用:选择Q值最大的动作 with torch.no_grad(): # 不计算梯度,节省内存 state = torch.FloatTensor(state).unsqueeze(0).to(self.device) q_values = self.online_net(state) return q_values.argmax().item()在训练脚本中,epsilon通常会随着训练步数(episode)线性或指数衰减。例如,从epsilon_start=1.0衰减到epsilon_end=0.01,在epsilon_decay步内完成。这个衰减策略对最终性能影响很大。
4. 训练流程与超参数调优实战
4.1 完整的训练循环搭建
有了环境和智能体,训练循环的骨架就清晰了。下面是一个简化版的train.py主循环逻辑:
import torch from game import SnakeGame from agent import DQNAgent from utils import plot_scores # 初始化 env = SnakeGame(board_size=10) agent = DQNAgent(state_size=env.state_size, action_size=env.action_size) scores = [] # 记录每个episode的得分 epsilons = [] # 记录epsilon变化 num_episodes = 2000 batch_size = 64 gamma = 0.99 tau = 0.001 target_update_freq = 1000 # C步更新一次目标网络(硬更新时用) epsilon_start = 1.0 epsilon_end = 0.01 epsilon_decay = 0.995 # 指数衰减因子 epsilon = epsilon_start for episode in range(1, num_episodes+1): state = env.reset() total_reward = 0 done = False while not done: # 1. 选择动作 action = agent.select_action(state, epsilon) # 2. 执行动作,与环境交互 next_state, reward, done, info = env.step(action) total_reward += reward # 3. 存储经验 agent.memory.push(state, action, reward, next_state, done) # 4. 学习 if len(agent.memory) > batch_size: agent.learn(batch_size, gamma) # 5. 转移到下一个状态 state = next_state # 6. (可选)硬更新目标网络 if agent.steps_done % target_update_freq == 0: agent.hard_update_target_net() # 7. 软更新目标网络(如果采用软更新,则每步或每episode进行) # agent.soft_update_target_net(tau) # 记录分数和衰减epsilon scores.append(info['score']) epsilons.append(epsilon) epsilon = max(epsilon_end, epsilon_decay * epsilon) # 指数衰减 # 打印进度 if episode % 100 == 0: avg_score = np.mean(scores[-100:]) print(f'Episode {episode}, Avg Score (last 100): {avg_score:.2f}, Epsilon: {epsilon:.3f}') # 可选:保存模型 if avg_score > best_avg_score: torch.save(agent.online_net.state_dict(), 'best_model.pth') # 训练结束后绘图 plot_scores(scores, epsilons)4.2 超参数调优:寻找最佳组合
超参数是算法的“旋钮”,调优是获得好性能的必经之路。以下是贪吃蛇DQN中几个最关键的超参数及其影响:
| 超参数 | 典型范围/值 | 作用与影响 | 调优建议 |
|---|---|---|---|
| 学习率 (lr) | 1e-4 到 1e-3 | 控制每次参数更新的步长。太大导致震荡不收敛,太小导致学习过慢。 | 从1e-3或5e-4开始尝试。如果训练曲线剧烈抖动,尝试调小;如果长期没有提升,可以尝试调大或检查其他参数。 |
| 折扣因子 (gamma) | 0.9 到 0.999 | 衡量未来奖励的重要性。接近1表示智能体很有远见,接近0表示目光短浅。 | 对于贪吃蛇这种需要一定规划(绕开自己身体去追食物)的游戏,建议设置在0.95-0.99之间。 |
| 经验回放缓冲区大小 | 1e4 到 1e6 | 存储过去经验的容量。太小导致数据相关性高、容易过拟合旧经验;太大导致学习缓慢,且内存占用高。 | 对于10x10的简单环境,5万到10万足够。可以观察训练效果,如果智能体很快陷入局部最优(比如只会转圈),可能是缓冲区太小或探索不足。 |
| 批次大小 (batch_size) | 32 到 256 | 每次从缓冲区采样用于训练的经验数量。影响梯度估计的稳定性和训练速度。 | 常用64或128。较小的批次(如32)可能带来正则化效果但噪声大;较大的批次(如256)训练更稳定但可能泛化能力稍差且内存需求高。 |
| 探索率衰减策略 | ε_start=1.0, ε_end=0.01, ε_decay | 控制探索与利用的平衡。衰减太快可能导致探索不足,衰减太慢导致收敛慢。 | ε_decay可以设为每episode乘以一个固定数(如0.995),也可以线性衰减。确保在训练中期(如一半episode时)ε已降到较低水平(如0.1)。 |
| 目标网络更新频率/系数 | C=1000, τ=0.001 | 控制目标网络的更新方式。硬更新(C步)或软更新(每步按τ比例混合)。 | 软更新通常更稳定。τ是一个很小的数,如0.001或0.005。如果训练不稳定(Q值或损失爆炸),尝试减小τ或降低更新频率。 |
| 网络结构 | [state, 128, 128, action] | 神经网络的层数和每层神经元数。太简单可能拟合能力不足,太复杂容易过拟合且训练慢。 | 对于简单的特征化状态,两层隐藏层,每层128或256个神经元是一个不错的起点。可以用更复杂的网络,但未必有显著提升。 |
调参是一个系统性的实验过程。强烈建议使用config.py文件管理所有超参数,并使用TensorBoard或简单的日志文件来记录每个实验配置下的训练曲线(得分、平均奖励、损失值)。通过对比不同配置下的学习曲线,才能科学地判断哪个组合更优。
4.3 训练过程监控与可视化
“黑箱”训练是痛苦的。我们需要一些工具来洞察训练过程。
1. 实时渲染游戏画面:在训练循环中,可以每隔N个episode调用一次env.render(),直观地观察AI的实时表现。但注意,渲染会极大拖慢训练速度,建议只在调试或演示时开启。
2. 绘制关键指标曲线:
- 每轮得分(Score per Episode):这是最直接的性能指标。理想情况下,它应该随着训练轮次总体呈上升趋势,并最终稳定在一个较高的水平(例如,在10x10棋盘上能稳定拿到20分以上)。
- 每轮总奖励(Total Reward per Episode):由于奖励函数中包含每步惩罚,总奖励可能为负。观察其趋势是否上升。
- 探索率ε(Epsilon):确保其按计划衰减。
- 训练损失(Training Loss):DQN的损失值通常会震荡下降。如果损失值突然变成NaN或急剧上升,通常意味着学习率过高、梯度爆炸或奖励函数设计有问题。
3. 使用TensorBoard:PyTorch和TensorFlow都提供了强大的可视化工具TensorBoard。你可以记录上面所有的标量指标,还可以记录模型参数的分布直方图、梯度等信息,对于深度调试非常有用。
# PyTorch 使用 TensorBoard 的简单示例 from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter('runs/snake_experiment_1') for episode in range(num_episodes): # ... 训练循环 ... if episode % 10 == 0: writer.add_scalar('Score', episode_score, episode) writer.add_scalar('Avg Reward', total_reward, episode) writer.add_scalar('Loss', loss.item(), episode) # 假设loss在agent.learn中返回 writer.add_scalar('Epsilon', epsilon, episode)5. 常见问题、调试技巧与性能优化
5.1 训练中遇到的典型问题与解决方案
即使代码没有语法错误,训练过程也可能不如预期。以下是一些常见问题及其排查思路:
问题1:智能体完全不学习,得分始终为0或1。
- 检查奖励函数:这是最常见的原因。确认吃到食物有显著的正奖励,死亡有显著的负奖励。确保每步的小惩罚(如-0.01)不会在死亡惩罚(如-10)面前被淹没,但又能起到鼓励效率的作用。可以尝试暂时去掉每步惩罚,看智能体是否开始学习“吃食物”这个基本任务。
- 检查探索率ε:初始ε是否足够高(如1.0)?在训练早期,智能体需要大量随机探索来获得正向经验(吃到食物)。如果一开始ε就很低,它可能永远尝试不到正确的动作,无法获得正向反馈。
- 检查网络输入/输出:打印出状态向量的值,看是否合理。确认网络输出的Q值不是全零或NaN。检查前向传播和损失计算代码是否正确。
- 检查经验回放:经验是否被正确存入缓冲区?采样出的批次数据形状是否正确?
问题2:训练初期有提升,但很快陷入平台期,得分不再增长。
- 探索衰减过快:ε可能衰减得太快了,智能体过早地停止了探索,陷入了一个局部最优策略(比如只会绕着小圈转)。尝试放缓ε的衰减速度,或者使用更复杂的探索策略,如噪声探索(NoisyNet)。
- 奖励函数设计:当前的奖励函数可能无法引导智能体学习更复杂的策略(如规划路径、避免被困)。可以考虑增加一些启发式奖励,例如给予“向食物移动”的小额正向奖励,或对“长时间吃不到食物”给予惩罚。
- 网络容量不足:也许状态-动作价值函数过于复杂,当前的小网络无法很好地拟合。可以尝试增加网络的层数或宽度。
- 过拟合:智能体可能过拟合了早期探索到的某些特定局面。确保经验回放缓冲区足够大,能覆盖更多样化的状态。
问题3:训练不稳定,损失值或Q值剧烈波动甚至变成NaN。
- 学习率过高:这是导致梯度爆炸和数值不稳定的首要原因。立即降低学习率(例如从1e-3降到1e-4)。
- 梯度裁剪(Gradient Clipping):在反向传播后、优化器更新前,对梯度进行裁剪,限制其最大范数,可以有效防止梯度爆炸。
torch.nn.utils.clip_grad_norm_(agent.online_net.parameters(), max_norm=1.0) - 奖励尺度:如果奖励值非常大(比如+1000),可能会导致Q值变得非常大。考虑对奖励进行缩放,使其在一个合理的范围内(如[-1, 1]或[-10, 10])。
- 检查数学运算:确保在计算目标Q值时,对
done为True的项,没有错误地加上未来的折扣奖励。
问题4:智能体学会了吃食物,但非常“短视”,经常把自己困死。
- 调整折扣因子gamma:提高
gamma(例如从0.9提高到0.99),让智能体更重视未来的奖励,从而有动力去规划更长的存活路径,而不是只盯着眼前的一块食物。 - 在奖励函数中惩罚“危险”:除了在撞上时给予大惩罚,还可以在蛇头非常接近自己身体或墙壁时,给予一个小的负奖励,作为一种“危险预警”,让智能体学会提前规避。
5.2 性能优化与进阶技巧
当基础DQN能工作后,可以尝试以下改进来提升性能和稳定性:
1. 实现Double DQN (DDQN):标准DQN在计算目标Q值时,使用目标网络选择并评估动作,这会导致对Q值的高估。Double DQN将动作选择和动作评估解耦:用在线网络选择下一个状态的最佳动作,用目标网络评估这个动作的Q值。这通常能带来更稳定、更准确的Q值估计。只需修改目标Q值的计算部分:
# 标准DQN with torch.no_grad(): next_q_values_target = target_net(next_states) max_next_q_values = next_q_values_target.max(1)[0] # 用目标网络选和评 target_q_values = rewards + (gamma * max_next_q_values * (1 - dones)) # Double DQN with torch.no_grad(): # 用在线网络选择下一个状态的最佳动作 next_actions_online = online_net(next_states).argmax(1).unsqueeze(1) # 用目标网络评估这个动作的Q值 next_q_values_target = target_net(next_states) max_next_q_values = next_q_values_target.gather(1, next_actions_online).squeeze(1) target_q_values = rewards + (gamma * max_next_q_values * (1 - dones))2. 实现Dueling DQN:Dueling网络架构将Q值分解为状态价值V(s)和动作优势A(s, a)两部分:Q(s, a) = V(s) + A(s, a) - mean(A(s, a))。这样,网络可以独立地学习某个状态本身的好坏,以及每个动作相对于平均水平的优势。在某些游戏中,这能带来更快的收敛和更好的策略。网络结构需要相应调整。
3. 优先经验回放(Prioritized Experience Replay):不是均匀地从缓冲区采样,而是根据经验的“重要性”(通常用时序差分误差TD-error的绝对值来衡量)来采样。TD-error大的经验,即预测与实际差距大的经验,被认为更有学习价值。这可以显著提高数据效率,让智能体更快地从关键经验中学习。
4. 使用卷积神经网络处理图像状态:如果你不使用特征化状态,而是直接将游戏屏幕的RGB图像作为输入,那么就需要使用CNN来提取视觉特征。这会大大增加模型的复杂度和训练时间,但更接近通用游戏AI的设定。
5.3 模型评估与部署
训练完成后,如何评价你的AI蛇到底有多聪明?
1. 定量评估:在测试模式(epsilon=0,即完全利用)下,让AI运行足够多的局数(比如100局),计算平均得分、平均存活步数、最高得分等指标。一个优秀的AI在10x10棋盘上,平均得分应该能轻松超过20,甚至达到30以上(棋盘格子总数为100,理论最高得分99,但受蛇身阻挡限制,实际很难达到)。
2. 定性评估(可视化):这是最有成就感的部分!运行test.py或play.py脚本,关闭训练模式,加载最佳模型,然后静静地观看你的AI表演。观察它是否:
- 能高效地直奔食物而去。
- 在身体较长时,懂得规划路径,不会把自己困死在角落里。
- 在食物出现在身体包围圈内时,懂得“绕路”而不是直冲送死。
你可以将游戏过程录制成视频或GIF,这是展示项目成果的最佳方式。
3. 模型保存与加载:使用PyTorch的torch.save和torch.load来保存和加载模型的状态字典。通常只需保存online_net的参数。
# 保存 torch.save(agent.online_net.state_dict(), 'snake_dqn_final.pth') # 加载(在测试脚本中) agent = DQNAgent(state_size, action_size) agent.online_net.load_state_dict(torch.load('snake_dqn_final.pth')) agent.online_net.eval() # 切换到评估模式回过头看,“linyiLYi/snake-ai”这样一个项目,其价值远不止于让一条像素蛇自动吃东西。它是一个完整的强化学习微缩实践平台,涵盖了从问题定义、环境构建、算法实现、训练调试到评估展示的全流程。过程中遇到的每一个问题——奖励函数怎么设计、神经网络为什么不收敛、探索和利用如何平衡——都是强化学习领域的核心挑战。
我个人的体会是,成功训练出一个像样的贪吃蛇AI,其关键往往不在于使用了多么复杂的网络或玄妙的算法,而在于对基础细节的扎实把握:一个合理的奖励函数、一组稳定的超参数、一个正确的训练循环实现。很多时候,代码中的一个微小bug(比如忘记处理done标志后的next_state)就足以让整个训练失败。因此,耐心地添加日志、可视化中间结果、进行消融实验(比如对比有/无目标网络的效果),是通往成功的必经之路。
最后,当你看到那条曾经笨拙的蛇,最终能在棋盘上优雅地穿梭、觅食时,那种感觉就像教会了一个孩子走路。它验证了你对算法的理解,也给了你继续探索更复杂智能体(比如玩Atari游戏、控制机器人)的信心。这个项目是一个完美的起点,而强化学习的海洋,才刚刚展现在你面前。