news 2026/6/13 9:11:57

别光看理论了!手把手带你用PyGame和Keras,看AI贪吃蛇从‘智障’到‘大神’的进化实录

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别光看理论了!手把手带你用PyGame和Keras,看AI贪吃蛇从‘智障’到‘大神’的进化实录

从零到英雄:用PyGame和Keras打造会学习的贪吃蛇AI

第一次看到AI玩贪吃蛇时,我被它从"无头苍蝇"到"策略大师"的转变过程深深吸引。这不仅仅是代码的堆砌,更像是在见证一个数字生命的成长历程。本文将带你完整经历这个奇妙旅程——不需要深厚的数学背景,只要对Python有基本了解,就能亲手培养出一个会学习的贪吃蛇AI。

我们将使用PyGame构建游戏环境,Keras搭建神经网络,通过强化学习让AI从零开始掌握游戏技巧。不同于传统教程只展示最终成果,我会重点呈现训练过程中的关键转折点:AI如何从随机移动,到学会寻找食物,再到避免自我碰撞,最终形成高效策略。每个阶段都会配有行为分析和可视化展示,让你直观感受机器学习的魅力。

1. 环境搭建与基础架构

1.1 游戏引擎搭建

我们先从游戏本体开始。PyGame的轻量级特性使其成为理想选择:

import pygame import random import numpy as np # 初始化游戏参数 SCREEN_SIZE = 600 GRID_SIZE = 20 GRID_WIDTH = SCREEN_SIZE // GRID_SIZE class SnakeGame: def __init__(self): pygame.init() self.screen = pygame.display.set_mode((SCREEN_SIZE, SCREEN_SIZE)) self.clock = pygame.time.Clock() self.reset() def reset(self): self.snake_pos = [(GRID_WIDTH//2, GRID_WIDTH//2)] self.snake_dir = random.choice([(0,1), (0,-1), (1,0), (-1,0)]) self.food_pos = self._place_food() self.score = 0 self.steps = 0 def _place_food(self): while True: pos = (random.randint(0,GRID_WIDTH-1), random.randint(0,GRID_WIDTH-1)) if pos not in self.snake_pos: return pos

这个基础框架包含了蛇的初始化、食物生成和游戏重置功能。注意到我们采用了网格系统而非像素级移动,这能大幅简化后续的状态表示。

1.2 游戏核心逻辑

接下来实现移动、碰撞检测等核心机制:

def move(self): head_x, head_y = self.snake_pos[0] dir_x, dir_y = self.snake_dir new_head = ((head_x + dir_x) % GRID_WIDTH, (head_y + dir_y) % GRID_WIDTH) if new_head in self.snake_pos[:-1]: # 撞到自己身体 return True # 游戏结束 self.snake_pos.insert(0, new_head) if new_head == self.food_pos: # 吃到食物 self.score += 1 self.food_pos = self._place_food() else: self.snake_pos.pop() # 没吃到食物时移除尾部 self.steps += 1 return False # 游戏继续

这里有几个关键设计点:

  • 网格边界采用循环处理(取模运算),避免撞墙死亡
  • 碰撞检测只检查蛇头是否碰到身体
  • 每次移动都记录步数,用于后续奖励计算

2. 强化学习模型设计

2.1 状态表示的艺术

如何将游戏状态转化为神经网络能理解的输入至关重要。我们采用12维向量表示:

状态维度描述
0-3四个方向是否有障碍(蛇身)
4-7食物相对于蛇头的方位
8-11当前移动方向

对应的实现代码:

def get_state(self): head_x, head_y = self.snake_pos[0] food_x, food_y = self.food_pos # 四个方向的相邻格子 points = [ ((head_x-1)%GRID_WIDTH, head_y), # 左 ((head_x+1)%GRID_WIDTH, head_y), # 右 (head_x, (head_y-1)%GRID_WIDTH), # 上 (head_x, (head_y+1)%GRID_WIDTH) # 下 ] state = [ # 障碍物检测 *(point in self.snake_pos for point in points), # 食物方位 food_x < head_x, # 食物在左侧 food_x > head_x, # 食物在右侧 food_y < head_y, # 食物在上方 food_y > head_y, # 食物在下方 # 当前方向 self.snake_dir == (-1,0), # 向左 self.snake_dir == (1,0), # 向右 self.snake_dir == (0,-1), # 向上 self.snake_dir == (0,1) # 向下 ] return np.array(state, dtype=np.float32)

这种表示方式既包含了局部环境信息(障碍物),也包含了全局目标信息(食物位置),还保留了运动状态。

2.2 深度Q网络实现

采用Keras构建DQN模型,包含经验回放机制和目标网络:

from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from collections import deque class DQNAgent: def __init__(self, state_size, action_size): self.state_size = state_size self.action_size = action_size self.memory = deque(maxlen=2000) self.gamma = 0.95 # 折扣因子 self.epsilon = 1.0 # 探索率 self.epsilon_min = 0.01 self.epsilon_decay = 0.995 self.model = self._build_model() self.target_model = self._build_model() def _build_model(self): model = Sequential() model.add(Dense(64, input_dim=self.state_size, activation='relu')) model.add(Dense(64, activation='relu')) model.add(Dense(self.action_size, activation='linear')) model.compile(loss='mse', optimizer='adam') return model def update_target_model(self): self.target_model.set_weights(self.model.get_weights())

网络结构虽然简单(仅两个隐藏层),但已经足够处理贪吃蛇的决策问题。关键组件包括:

  • 经验回放缓冲区(memory)
  • ε-贪婪策略(epsilon)
  • 双网络设计(model和target_model)

3. 训练策略与奖励设计

3.1 动态奖励机制

奖励函数是强化学习的灵魂。我们采用分阶段奖励策略:

def get_reward(self, done, prev_distance): current_distance = self._get_food_distance() if done: # 游戏结束 return -10 elif self.snake_pos[0] == self.food_pos: # 吃到食物 return 10 elif current_distance < prev_distance: # 靠近食物 return 1 elif current_distance > prev_distance: # 远离食物 return -1 else: # 保持距离 return -0.1 # 小惩罚鼓励探索

训练初期常见问题与解决方案:

  1. 原地转圈问题:AI发现转圈不会撞死自己还能获得时间奖励

    • 解决方案:增加步数惩罚,每步给予-0.1奖励
  2. 食物回避问题:AI害怕接近食物因为可能引发危险

    • 解决方案:调整奖励比例,增加吃到食物的正奖励
  3. 局部最优陷阱:AI找到一种能得分的简单策略后停止改进

    • 解决方案:阶段性重置ε值,重新鼓励探索

3.2 训练流程优化

完整的训练循环需要考虑多个因素:

def train(self, episodes=1000): batch_size = 32 agent = DQNAgent(state_size=12, action_size=4) for e in range(episodes): game = SnakeGame() state = game.get_state() total_reward = 0 while True: action = agent.act(state) # 执行动作并获取新状态 done = game.move(action) next_state = game.get_state() reward = game.get_reward(done) # 存储经验 agent.remember(state, action, reward, next_state, done) state = next_state total_reward += reward if done: print(f"Episode: {e}, Score: {game.score}, Epsilon: {agent.epsilon:.2f}") break if len(agent.memory) > batch_size: agent.replay(batch_size) if e % 10 == 0: agent.update_target_model()

关键优化点:

  • 每10轮更新一次目标网络
  • 只有当经验池足够大时才开始训练
  • 动态显示训练进度和探索率

4. 训练过程可视化与分析

4.1 典型训练阶段

阶段一:随机探索期(0-100轮)

  • 平均得分:0-2分
  • 行为特征:蛇经常直行直到撞到自己
  • 学习重点:理解移动与碰撞的关系

阶段二:基础觅食期(100-300轮)

  • 平均得分:3-5分
  • 行为特征:能主动接近食物但常陷入循环
  • 学习重点:建立食物与奖励的关联

阶段三:避障学习期(300-600轮)

  • 平均得分:6-10分
  • 行为特征:开始绕开自己身体,形成简单策略
  • 学习重点:平衡觅食与安全

阶段四:策略优化期(600+轮)

  • 平均得分:15+分
  • 行为特征:形成高效螺旋路径,能预测多步后的位置
  • 学习重点:长期规划能力

4.2 关键突破点记录

  1. 第一次吃到食物(通常在50-100轮):

    • 平均需要:约3000次尝试
    • 典型反应:之后几轮得分快速上升
  2. 避开第一个自我碰撞

    • 通常发生在200轮左右
    • 需要理解身体位置的时空关系
  3. 形成稳定策略

    • 约500轮后出现可重复的高分策略
    • 开始展现类似"螺旋前进"的智能行为

4.3 性能评估指标

指标初期值中期值后期值
平均得分0.24.518.7
最大得分31232
食物获取率2%35%78%
平均存活步数50200500+

这些指标可以通过简单的统计代码实现:

def log_performance(episode, scores, window=100): avg_score = np.mean(scores[-window:]) max_score = np.max(scores[-window:]) survival_steps = np.mean([s['steps'] for s in scores[-window:]]) print(f"Episode {episode} - Avg: {avg_score:.1f}, Max: {max_score}, Steps: {survival_steps:.0f}")

5. 高级优化技巧

5.1 网络架构改进

基础网络可以扩展为更复杂的Dueling DQN:

from tensorflow.keras.layers import Input, Dense, Lambda from tensorflow.keras.models import Model def build_dueling_dqn(input_shape, action_size): inputs = Input(shape=(input_shape,)) x = Dense(64, activation='relu')(inputs) x = Dense(64, activation='relu')(x) # 分离价值流和优势流 value_stream = Dense(1)(x) advantage_stream = Dense(action_size)(x) # 合并两个流 q_values = value_stream + (advantage_stream - tf.reduce_mean(advantage_stream, axis=1, keepdims=True)) return Model(inputs=inputs, outputs=q_values)

这种架构能更好地区分状态价值和动作优势,特别适合像贪吃蛇这种某些动作价值差异明显的场景。

5.2 课程学习策略

逐步提高训练难度能显著提升最终性能:

  1. 初期:小地图(10×10),高探索率(ε=1.0)
  2. 中期:中等地图(15×15),中等探索率(ε=0.3)
  3. 后期:标准地图(20×20),低探索率(ε=0.01)

实现方法只需简单修改游戏初始化:

def __init__(self, grid_size=20): self.GRID_WIDTH = grid_size # 其余初始化代码...

5.3 集成学习方法

训练多个AI智能体并集成它们的决策:

class EnsembleAgent: def __init__(self, num_agents=3): self.agents = [DQNAgent(12,4) for _ in range(num_agents)] def act(self, state): # 收集所有agent的Q值 q_values = [agent.model.predict(state[np.newaxis])[0] for agent in self.agents] # 取平均Q值最大的动作 return np.argmax(np.mean(q_values, axis=0))

这种方法能减少过拟合,提高决策的稳健性,特别在游戏后期复杂场景中表现优异。

6. 实战经验分享

在实际训练中,有几个关键点需要特别注意:

  1. 学习率的选择:开始可以使用较高的学习率(如0.001),当分数停滞时降低到0.0001

  2. 探索率的衰减:不要线性衰减ε,而应该在分数提升时阶段性降低

  3. 记忆缓冲区大小:太小会导致过拟合,太大会减慢学习,2000-5000是个不错的范围

  4. 批处理大小:32或64都是常用选择,更大的批次需要更多内存但更稳定

  5. 训练中断与恢复:定期保存模型权重,可以使用回调函数:

checkpoint = tf.keras.callbacks.ModelCheckpoint( 'snake_weights.h5', save_weights_only=True, save_best_only=True, monitor='score', mode='max' )

在多次实验中,我发现最有趣的不是最终的高分AI,而是观察AI如何突破自己的局限。有一次,一个AI在300轮时陷入了总是向右转的循环,但在调整奖励函数后,它突然"顿悟"般地发展出了复杂的螺旋策略,分数从平均5分直接跃升到15分。这种突破时刻正是强化学习最迷人的部分。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/13 9:04:55

告别手动搜索!百度网盘资源工具一键获取提取码的终极方案

告别手动搜索&#xff01;百度网盘资源工具一键获取提取码的终极方案 【免费下载链接】baidupankey 项目地址: https://gitcode.com/gh_mirrors/ba/baidupankey 还在为百度网盘分享链接的提取码而烦恼吗&#xff1f;每次遇到需要输入提取码的资源&#xff0c;都要在多个…

作者头像 李华
网站建设 2026/6/13 9:03:52

别再只盯着Datasheet了!手把手教你用DRV8313驱动三相无刷电机(附完整Arduino代码)

从零玩转DRV8313&#xff1a;三相无刷电机驱动实战指南在创客圈和嵌入式开发领域&#xff0c;无刷电机因其高效率、长寿命和低噪音特性正逐渐取代传统有刷电机。但许多开发者面对电机驱动芯片时&#xff0c;往往陷入数据手册的海洋而迟迟无法让电机转起来。本文将用最直观的方式…

作者头像 李华
网站建设 2026/6/13 8:55:58

周志华《Machine Learning》学习笔记(13)--特征选择与稀疏学习

上篇主要介绍了经典的降维方法与度量学习&#xff0c;首先从“维数灾难”导致的样本稀疏以及距离难计算两大难题出发&#xff0c;引出了降维的概念&#xff0c;即通过某种数学变换将原始高维空间转变到一个低维的子空间&#xff0c;接着分别介绍了kNN、MDS、PCA、KPCA以及两种经…

作者头像 李华
网站建设 2026/6/13 8:46:52

DARTH-PUM架构:混合内存计算的能效优化与实现

1. DARTH-PUM架构概述&#xff1a;混合内存计算的能效突破DARTH-PUM&#xff08;Digital-Analog Reconfigurable Technology for Hybrid Processing-Using-Memory&#xff09;是近年来内存计算领域最具突破性的架构之一。它通过创新的混合设计理念&#xff0c;将模拟PIM的高能效…

作者头像 李华
网站建设 2026/6/13 8:42:59

解锁学术研究新境界:3步掌握Zotero SciHub插件的文献自动下载

解锁学术研究新境界&#xff1a;3步掌握Zotero SciHub插件的文献自动下载 【免费下载链接】zotero-scihub A plugin that will automatically download PDFs of zotero items from sci-hub 项目地址: https://gitcode.com/gh_mirrors/zo/zotero-scihub 探索学术文献管理…

作者头像 李华