本节课你将学到
- 理解强化学习的基本概念和框架
- 掌握Q-learning算法原理
- 使用Python实现贪吃蛇游戏AI
- 训练能够自主玩游戏的智能体
开始之前
环境要求
- Python 3.8+
- PyTorch 2.0+
- Gymnasium (原OpenAI Gym)
- NumPy
- Matplotlib
- 推荐使用Jupyter Notebook进行实验
前置知识
- Python基础编程(第1-8讲)
- 基本数学概念(函数、导数)
- 神经网络基础(第23讲)
核心概念
什么是强化学习?
强化学习就像训练宠物:
- 环境(Environment):宠物的生活空间(如房间)
- 智能体(Agent):你的宠物
- 动作(Action):宠物能做的行为(坐、握手)
- 奖励(Reward):做对给零食,做错轻轻打
关键术语解释
- 状态(State):当前环境情况(如游戏画面)
- 策略(Policy):智能体的行为准则(什么状态下做什么动作)
- 价值函数(Value Function):预测长期回报
- 探索(Exploration) vs 利用(Exploitation):尝试新动作 vs 选择已知最佳动作
Q-learning算法
Q-learning就像建立一本"经验手册":
- Q-table:记录每个状态下每个动作的价值
- 更新规则:
新Q值 = 老Q值 + α*(即时奖励 + γ*最大未来奖励 - 老Q值)
- 参数说明:
- α (alpha):学习率(0.1)
- γ (gamma):折扣因子(0.9)
代码实战
1. 游戏环境搭建
我们使用简化版贪吃蛇游戏作为训练环境
import numpy as np
import pygame
import random
from enum import Enum
class Direction(Enum):
RIGHT = 0
LEFT = 1
UP = 2
DOWN = 3
class SnakeGame:
def __init__(self, width=10, height=10):
self.width = width
self.height = height
self.reset()
def reset(self):
# 初始化蛇:3个格子,向右移动
self.direction = Direction.RIGHT
self.head = [self.width//2, self.height//2]
self.snake = [
self.head,
[self.head[0]-1, self.head[1]],
[self.head[0]-2, self.head[1]]
]
self.score = 0
self.food = self._place_food()
self.frame_iteration = 0
return self._get_state()
def _place_food(self):
while True:
food = [random.randint(0, self.width-1),
random.randint(0, self.height-1)]
if food not in self.snake:
return food
def _get_state(self):
# 状态表示:12个危险方向 + 4个移动方向 + 2个食物位置
head = self.head
point_l = [head[0] - 1, head[1]]
point_r = [head[0] + 1, head[1]]
point_u = [head[0], head[1] - 1]
point_d = [head[0], head[1] + 1]
dir_l = self.direction == Direction.LEFT
dir_r = self.direction == Direction.RIGHT
dir_u = self.direction == Direction.UP
dir_d = self.direction == Direction.DOWN
state = [
# 危险:前方有障碍
(dir_r and self._is_collision(point_r)) or
(dir_l and self._is_collision(point_l)) or
(dir_u and self._is_collision(point_u)) or
(dir_d and self._is_collision(point_d)),
# 危险:右侧有障碍
(dir_u and self._is_collision(point_r)) or
(dir_d and self._is_collision(point_l)) or
(dir_l and self._is_collision(point_u)) or
(dir_r and self._is_collision(point_d)),
# 危险:左侧有障碍
(dir_d and self._is_collision(point_r)) or
(dir_u and self._is_collision(point_l)) or
(dir_r and self._is_collision(point_u)) or
(dir_l and self._is_collision(point_d)),
# 移动方向
dir_l, dir_r, dir_u, dir_d,
# 食物位置
self.food[0] < self.head[0], # 食物在左侧
self.food[0] > self.head[0], # 食物在右侧
self.food[1] < self.head[1], # 食物在上方
self.food[1] > self.head[1] # 食物在下方
]
return np.array(state, dtype=int)
def _is_collision(self, point=None):
if point is None:
point = self.head
# 撞墙
if (point[0] < 0 or point[0] >= self.width or
point[1] < 0 or point[1] >= self.height):
return True
# 撞自己
if point in self.snake[:-1]:
return True
return False
def step(self, action):
self.frame_iteration += 1
# 1. 移动
self._move(action)
# 2. 检查游戏结束
reward = 0
game_over = False
if self._is_collision() or self.frame_iteration > 100*len(self.snake):
game_over = True
reward = -10
return self._get_state(), reward, game_over, self.score
# 3. 放置新食物或移动
if self.head == self.food:
self.score += 1
reward = 10
self.food = self._place_food()
else:
self.snake.pop()
# 4. 返回新状态和奖励
return self._get_state(), reward, game_over, self.score
def _move(self, action):
# action: [直行, 右转, 左转]
clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
idx = clock_wise.index(self.direction)
if np.array_equal(action, [1, 0, 0]): # 直行
new_dir = clock_wise[idx]
elif np.array_equal(action, [0, 1, 0]): # 右转
next_idx = (idx + 1) % 4
new_dir = clock_wise[next_idx]
else: # 左转
next_idx = (idx - 1) % 4
new_dir = clock_wise[next_idx]
self.direction = new_dir
# 更新头部位置
x, y = self.head
if self.direction == Direction.RIGHT:
x += 1
elif self.direction == Direction.LEFT:
x -= 1
elif self.direction == Direction.DOWN:
y += 1
elif self.direction == Direction.UP:
y -= 1
self.head = [x, y]
self.snake.insert(0, self.head)
# ⚠️ 常见错误1:动作空间定义不一致
# 确保:
# 1. 动作编码与游戏逻辑匹配
# 2. 动作维度与模型输出一致
2. Q-learning实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class QTrainer:
def __init__(self, model, lr, gamma):
self.model = model
self.lr = lr
self.gamma = gamma
self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
self.criterion = nn.MSELoss()
def train_step(self, state, action, reward, next_state, done):
state = torch.tensor(state, dtype=torch.float)
next_state = torch.tensor(next_state, dtype=torch.float)
action = torch.tensor(action, dtype=torch.long)
reward = torch.tensor(reward, dtype=torch.float)
if len(state.shape) == 1:
# 单个样本 (1, x)
state = torch.unsqueeze(state, 0)
next_state = torch.unsqueeze(next_state, 0)
action = torch.unsqueeze(action, 0)
reward = torch.unsqueeze(reward, 0)
done = (done, )
# 1. 当前状态的预测Q值
pred = self.model(state)
# 2. 计算目标Q值
target = pred.clone()
for idx in range(len(done)):
Q_new = reward[idx]
if not done[idx]:
Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
target[idx][torch.argmax(action[idx]).item()] = Q_new
# 3. 计算损失并更新
self.optimizer.zero_grad()
loss = self.criterion(target, pred)
loss.backward()
self.optimizer.step()
class Linear_QNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x
class Agent:
def __init__(self):
self.n_games = 0
self.epsilon = 0 # 随机探索率
self.gamma = 0.9 # 折扣因子
self.memory = deque(maxlen=100_000) # 经验回放缓存
self.model = Linear_QNet(11, 256, 3) # 输入11维,输出3个动作
self.trainer = QTrainer(self.model, lr=0.001, gamma=self.gamma)
def get_state(self, game):
return game._get_state()
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def train_long_memory(self):
if len(self.memory) > 1000:
mini_sample = random.sample(self.memory, 1000) # 小批量采样
else:
mini_sample = self.memory
states, actions, rewards, next_states, dones = zip(*mini_sample)
self.trainer.train_step(states, actions, rewards, next_states, dones)
def train_short_memory(self, state, action, reward, next_state, done):
self.trainer.train_step(state, action, reward, next_state, done)
def get_action(self, state):
# 探索-利用权衡
self.epsilon = 80 - self.n_games # 随着游戏次数减少探索
final_move = [0, 0, 0]
if random.randint(0, 200) < self.epsilon:
move = random.randint(0, 2)
final_move[move] = 1
else:
state0 = torch.tensor(state, dtype=torch.float)
prediction = self.model(state0)
move = torch.argmax(prediction).item()
final_move[move] = 1
return final_move
# ⚠️ 常见错误2:Q值爆炸
# 解决方案:
# 1. 减小学习率
# 2. 增加gamma值
# 3. 使用梯度裁剪
3. 训练流程
import matplotlib.pyplot as plt
from IPython import display
def plot(scores, mean_scores):
display.clear_output(wait=True)
display.display(plt.gcf())
plt.clf()
plt.title('Training...')
plt.xlabel('Number of Games')
plt.ylabel('Score')
plt.plot(scores)
plt.plot(mean_scores)
plt.ylim(ymin=0)
plt.text(len(scores)-1, scores[-1], str(scores[-1]))
plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
plt.show()
def train():
record = 0
agent = Agent()
game = SnakeGame()
scores = []
mean_scores = []
total_score = 0
while True:
# 获取当前状态
state_old = agent.get_state(game)
# 获取动作
final_move = agent.get_action(state_old)
# 执行动作并获取新状态
reward, done, score = game.step(final_move)
state_new = agent.get_state(game)
# 训练短期记忆
agent.train_short_memory(state_old, final_move, reward, state_new, done)
# 存储经验
agent.remember(state_old, final_move, reward, state_new, done)
if done:
# 训练长期记忆(经验回放)
game.reset()
agent.n_games += 1
agent.train_long_memory()
if score > record:
record = score
agent.model.save()
scores.append(score)
total_score += score
mean_score = total_score / agent.n_games
mean_scores.append(mean_score)
plot(scores, mean_scores)
print(f"Game: {agent.n_games}, Score: {score}, Record: {record}, Mean Score: {mean_score}")
# 停止条件
if agent.n_games > 500:
break
if __name__ == "__main__":
train()
4. 可视化游戏
import pygame
import time
def play_with_ai(model_path="model.pth"):
pygame.init()
game = SnakeGame(width=20, height=20)
agent = Agent()
agent.model.load_state_dict(torch.load(model_path))
agent.model.eval()
clock = pygame.time.Clock()
screen = pygame.display.set_mode((800, 600))
pygame.display.set_caption('Snake AI')
while True:
state = agent.get_state(game)
final_move = agent.get_action(state)
_, done, score = game.step(final_move)
# 绘制游戏
screen.fill((0, 0, 0))
for point in game.snake:
pygame.draw.rect(screen, (0, 255, 0),
pygame.Rect(point[0]*30, point[1]*30, 30, 30))
pygame.draw.rect(screen, (255, 0, 0),
pygame.Rect(game.food[0]*30, game.food[1]*30, 30, 30))
pygame.display.flip()
clock.tick(10)
if done:
print(f"Game Over! Score: {score}")
time.sleep(3)
game.reset()
完整项目
项目结构:
lesson_33_reinforcement_learning/
├── game/
│ ├── snake.py # 游戏环境实现
│ └── display.py # 可视化工具
├── model/
│ ├── qnet.py # Q网络实现
│ └── trainer.py # 训练逻辑
├── agent.py # 智能体实现
├── train.py # 训练脚本
├── play.py # 游戏演示
├── requirements.txt # 依赖列表
└── README.md # 项目说明
requirements.txt
torch==2.0.1
numpy==1.24.3
matplotlib==3.7.1
pygame==2.3.0
gymnasium==0.28.1
ipython==8.12.0
train.py 主程序
from game.snake import SnakeGame
from model.qnet import Linear_QNet
from model.trainer import QTrainer
from agent import Agent
import torch
def main():
game = SnakeGame()
agent = Agent()
# 训练参数
num_episodes = 1000
batch_size = 1000
for episode in range(num_episodes):
# 训练循环(同上文)
# ...
# 定期保存模型
if episode % 100 == 0:
torch.save(agent.model.state_dict(), f"model_{episode}.pth")
if __name__ == "__main__":
main()
运行效果
训练过程输出
Game: 1, Score: 2, Record: 2, Mean Score: 2.0
Game: 2, Score: 3, Record: 3, Mean Score: 2.5
...
Game: 500, Score: 25, Record: 32, Mean Score: 18.7
常见问题
Q1: 智能体一直转圈不找食物
解决方案:
- 调整奖励函数(找到食物奖励加大)
- 增加探索率epsilon
- 检查状态表示是否包含食物位置信息
Q2: 训练不收敛,得分波动大
可能原因:
- 学习率过高(尝试减小到0.0001)
- 批量大小不合适(增大到512或1024)
- 网络容量不足(增加隐藏层大小)
Q3: 如何应用到其他游戏?
步骤:
- 实现新游戏环境(需提供state、action、reward接口)
- 调整状态表示和动作空间
- 可能需要更复杂的神经网络
课后练习
-
奖励函数实验
尝试修改奖励函数(如增加生存奖励),观察对训练效果的影响 -
深度Q网络(DQN)
将线性网络改为CNN,直接处理游戏画面像素 -
双人对战
修改游戏规则实现双蛇对战,训练竞争型智能体 -
迁移学习
将在贪吃蛇学到的策略迁移到类似游戏(如吃豆人)
扩展阅读
下节预告:第34讲将深入探讨深度Q网络(DQN),实现Atari游戏AI!