从CartPole到Pendulum:用Python代码‘玩转’Gym两大经典控制环境
强化学习作为机器学习的重要分支,其魅力在于让智能体通过与环境交互来自主学习最优策略。而OpenAI Gym作为最流行的强化学习实验平台,提供了丰富的训练环境。本文将带您深入探索Gym中两个经典控制问题——CartPole-v1和Pendulum-v1,通过完整的Python代码示例,让您亲身体验强化学习环境的运行机制。
1. 环境搭建与基础API
在开始之前,确保已安装最新版gym库:
pip install gym==0.26.2Gym环境的核心交互流程遵循"重置-执行-观察"循环:
import gym # 创建环境 env = gym.make('CartPole-v1') # 初始化环境 observation = env.reset() for _ in range(1000): # 渲染当前状态 env.render() # 随机选择动作 action = env.action_space.sample() # 执行动作并获取反馈 observation, reward, done, info = env.step(action) # 检查是否终止 if done: observation = env.reset() # 关闭环境 env.close()每个环境都提供以下关键属性:
| 属性 | 描述 | 示例值 |
|---|---|---|
observation_space | 状态空间定义 | Box(4,) |
action_space | 动作空间定义 | Discrete(2) |
reward_range | 奖励值范围 | (0,1) |
2. CartPole-v1:平衡杆问题解析
CartPole-v1是一个经典的物理控制问题,目标是让小车上的杆子保持直立不倒下。让我们深入分析其核心要素:
2.1 状态空间详解
CartPole的状态由4个连续值组成:
- 小车位置:水平坐标,范围[-4.8,4.8]
- 小车速度:移动速率,无固定界限
- 杆子角度:偏离垂直方向的角度,范围[-0.418,0.418]弧度
- 杆子角速度:角度变化率,无固定界限
通过以下代码可以查看状态空间的具体参数:
env = gym.make('CartPole-v1') print("状态空间:", env.observation_space) print("动作空间:", env.action_space)2.2 动作与奖励机制
CartPole采用离散动作空间:
- 0:向左施加力
- 1:向右施加力
奖励设计非常简单:每存活一步获得+1奖励。终止条件包括:
- 杆子倾斜超过12度
- 小车移动超出边界
- 持续500步(v1版本)
下面是一个完整的交互示例,包含状态可视化:
import numpy as np import matplotlib.pyplot as plt def run_cartpole(episodes=10): env = gym.make('CartPole-v1') rewards = [] for ep in range(episodes): obs = env.reset() ep_reward = 0 states = [] while True: states.append(obs) action = env.action_space.sample() obs, reward, done, _ = env.step(action) ep_reward += reward if done: rewards.append(ep_reward) visualize_states(np.array(states), ep) break env.close() return rewards def visualize_states(states, ep_num): plt.figure(figsize=(10,6)) labels = ['Position', 'Velocity', 'Angle', 'Angular Velocity'] for i in range(4): plt.plot(states[:,i], label=labels[i]) plt.title(f"Episode {ep_num} - State Trajectory") plt.xlabel("Time Step") plt.ylabel("Value") plt.legend() plt.grid() plt.show()3. Pendulum-v1:钟摆控制挑战
Pendulum-v1比CartPole更具挑战性,需要控制钟摆使其保持直立状态。与CartPole不同,Pendulum采用连续动作空间。
3.1 状态表示解析
Pendulum的状态空间包含3个值:
cos(theta):角度θ的余弦值sin(theta):角度θ的正弦值theta_dot:角速度
这种表示方式避免了角度周期性带来的不连续问题。θ=0表示钟摆垂直向下,范围是[-π,π]弧度。
3.2 连续动作空间
Pendulum的动作是施加在钟摆上的扭矩,范围[-2,2]:
env = gym.make('Pendulum-v1') print("动作空间:", env.action_space) # Box(-2.0, 2.0, (1,), float32)奖励函数设计更为复杂:
reward = -(θ² + 0.1*θ_dot² + 0.001*action²)这意味着:
- 角度θ越接近0(垂直向下),奖励越高
- 角速度和动作幅度越小越好
下面是一个Pendulum的完整交互示例:
def run_pendulum(episodes=5): env = gym.make('Pendulum-v1') for ep in range(episodes): state = env.reset() total_reward = 0 frames = [] for t in range(200): # 最大200步 # 随机动作 action = env.action_space.sample() # 执行动作 next_state, reward, done, _ = env.step(action) total_reward += reward # 渲染并保存帧 frames.append(env.render(mode='rgb_array')) if done: break print(f"Episode {ep+1}, Total Reward: {total_reward:.2f}") display_animation(frames) env.close() def display_animation(frames): from IPython.display import HTML import matplotlib.animation as animation fig = plt.figure(figsize=(6,6)) plt.axis('off') img = plt.imshow(frames[0]) def animate(i): img.set_array(frames[i]) return img, anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=50) plt.close() return HTML(anim.to_html5_video())4. 高级技巧与性能优化
在实际应用中,我们通常需要超越随机动作的策略。以下是几个实用技巧:
4.1 状态预处理
对于Pendulum,将状态转换为更直观的角度表示:
def pendulum_state_to_angle(state): cos_theta, sin_theta, theta_dot = state theta = np.arctan2(sin_theta, cos_theta) return np.array([theta, theta_dot])4.2 简单控制策略
基于规则的策略可以作为baseline:
def simple_pendulum_policy(state): theta, theta_dot = pendulum_state_to_angle(state) # PD控制器 kp = 15.0 # 比例系数 kd = 1.0 # 微分系数 action = kp * theta + kd * theta_dot # 限制在[-2,2]范围内 return np.clip(action, -2, 2).reshape(1,)4.3 性能评估指标
建立评估函数比较不同策略:
def evaluate_policy(env, policy, n_episodes=10): total_rewards = [] for _ in range(n_episodes): state = env.reset() episode_reward = 0 while True: action = policy(state) state, reward, done, _ = env.step(action) episode_reward += reward if done: total_rewards.append(episode_reward) break return { 'mean_reward': np.mean(total_rewards), 'std_reward': np.std(total_rewards), 'max_reward': np.max(total_rewards), 'min_reward': np.min(total_rewards) }5. 从模拟到实战:训练建议
当您准备开始训练真正的强化学习算法时,考虑以下实践要点:
- 帧堆叠:连续几帧作为输入,捕捉动态信息
- 奖励塑形:设计中间奖励加速学习
- 超参数调优:学习率、折扣因子等对性能影响显著
# 示例:帧堆叠实现 class FrameStack: def __init__(self, env, k=4): self.env = env self.k = k self.frames = deque([], maxlen=k) def reset(self): obs = self.env.reset() for _ in range(self.k): self.frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info = self.env.step(action) self.frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): return np.concatenate(self.frames, axis=0)