强化学习:Asynchronous Advantage Actor-Critic (A3C) 学习笔记

一、A3C 是什么?🤔

1.1 核心概念

  • A3C = Asynchronous + Advantage + Actor-Critic
    • Asynchronous:多个智能体并行训练
    • Advantage:使用优势函数降低方差
    • Actor-Critic:结合策略和价值学习
    • 核心思想:多个"平行宇宙"的智能体同时探索环境,共享学习成果

1.2 类比理解

想象多个自己同时学习不同技能:

分享经验
分享经验
分享经验
更新知识
更新知识
更新知识
中央大脑
你1号: 学习数学
你2号: 学习编程
你3号: 学习绘画

1.3 锁的工作原理

  • 获取锁:线程尝试获取锁,如果锁已被其他线程持有,则阻塞等待。
  • 释放锁:线程执行完关键操作后释放锁,允许其他线程获取。
  • 原子性:同一时间只有一个线程能持有锁,确保关键代码段的操作是原子的。

二、为什么需要 A3C?🚀

2.1 解决的核心问题

问题传统方法局限A3C 解决方案
训练速度慢单线程顺序训练并行训练 加速3-10倍
样本相关性经验回放占用内存异步探索 天然解相关
收敛困难单一轨迹更新多源经验 稳定学习

三、A3C 核心原理 🧠

3.1 系统架构

Worker N
Worker 2
Worker 1
Global Network
参数同步
参数同步
参数同步
梯度更新
梯度更新
梯度更新
Actor-Critic Local Net
环境N
Actor-Critic Local Net
环境2
Actor-Critic Local Net
环境1
Actor-Critic Global Net

3.2 工作流程

工作线程 i
同步本地模型
线程开始
重置环境
获取状态s_t
选择动作a_t
执行动作a_t
获得奖励r_t和状态s_{t+1}
达到更新条件?
计算优势函数A
计算梯度
更新全局模型
同步本地模型
主线程
创建全局模型
初始化全局网络参数
创建多个工作线程
全局模型更新
所有线程完成?
测试模型

3.3 算法步骤

1. 初始化阶段

(1) 创建全局模型

  • 共享的 Actor-Critic 网络:同时输出动作概率分布(Actor)和状态价值估计(Critic)。
  • 初始化优化器(如 Adam)和线程锁(用于参数同步)。

(2) 创建多个工作线程

  • 每个线程有独立的环境副本和本地模型(结构与全局模型相同)。
2. 训练阶段(多线程并行)

每个线程循环执行以下步骤
(1) 同步本地模型

  • 从全局模型复制最新参数到本地模型。

(2) 与环境交互

  • 使用本地模型选择动作(基于策略概率分布)。
  • 执行动作,观察奖励和下一个状态。

(3) 计算优势函数

  • 估计每个时间步的优势值 A(s,a)=Q(s,a)−V(s)A(s,a) = Q(s,a) - V(s)A(s,a)=Q(s,a)V(s),通常用折扣累积回报近似 Q(s,a)Q(s,a)Q(s,a)

(4) 计算损失

  • Actor 损失(策略梯度):最大化优势加权的动作对数概率。
  • Critic 损失:最小化价值估计与实际回报的误差。
  • 总损失:通常是 Actor 损失与 Critic 损失的加权和。

(5) 更新全局模型

  • 计算本地模型的梯度。
  • 通过线程锁保护,将梯度应用到全局模型。
3. 测试阶段
  • 环境初始化:创建一个新的CartPole-v1环境,并开启渲染模式。
  • 模型推理:将初始状态输入全局模型,得到动作概率分布,选择概率最大的动作执行。
  • 环境交互:在环境中执行选择的动作,获取下一个状态和奖励,累计奖励并渲染环境。
  • 测试结束:当达到最大步数或回合结束时,关闭环境并打印测试奖励。

四、关键算法解析 🔧

4.1 优势函数计算

A(st,at)=∑i=0k−1γirt+i+γkV(st+k)−V(st) A(s_t, a_t) = \sum_{i=0}^{k-1} \gamma^i r_{t+i} + \gamma^k V(s_{t+k}) - V(s_t) A(st,at)=i=0k1γirt+i+γkV(st+k)V(st)

  • kkk:前瞻步数(通常5-20步)
  • V(s)V(s)V(s):状态价值函数

4.2 Actor 更新(策略梯度)

∇θlog⁡π(at∣st;θ)⋅A(st,at) \nabla_\theta \log \pi(a_t|s_t; \theta) \cdot A(s_t, a_t) θlogπ(atst;θ)A(st,at)

4.3 Critic 更新(价值学习)

L=12(∑i=0k−1γirt+i+γkV(st+k)−V(st))2 L = \frac{1}{2} \left( \sum_{i=0}^{k-1} \gamma^i r_{t+i} + \gamma^k V(s_{t+k}) - V(s_t) \right)^2 L=21(i=0k1γirt+i+γkV(st+k)V(st))2

五、TensorFlow 2.x 实现 💻

5.1 全局网络

class GlobalAC(tf.keras.Model):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        # Actor 网络
        self.actor = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(action_dim * 2)  # 输出均值和标准差
        ])
        
        # Critic 网络
        self.critic = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(1)  # 输出状态价值
        ])
    
    def call(self, state):
        # 返回动作分布参数和状态价值
        mu_sigma = self.actor(state)
        mu, sigma = tf.split(mu_sigma, 2, axis=-1)
        sigma = tf.nn.softplus(sigma) + 1e-5  # 确保标准差为正
        
        v = self.critic(state)
        return mu, sigma, v

5.2 工作线程

class Worker:
    def __init__(self, global_ac, env_name, worker_id):
        self.global_ac = global_ac
        self.env = gym.make(env_name)
        self.worker_id = worker_id
        
        # 创建本地网络(全局网络的副本)
        self.local_ac = GlobalAC(global_ac.state_dim, global_ac.action_dim)
        self.local_ac.set_weights(global_ac.get_weights())
        
        self.optimizer = tf.keras.optimizers.Adam(0.0001)
    
    def train(self, n_steps=5):
        # 1. 从全局网络同步参数
        self.local_ac.set_weights(self.global_ac.get_weights())
        
        # 2. 收集n步经验
        states, actions, rewards = [], [], []
        state = self.env.reset()
        
        for _ in range(n_steps):
            # 选择动作(带探索)
            mu, sigma, _ = self.local_ac(state[np.newaxis, :])
            dist = tfp.distributions.Normal(mu, sigma)
            action = dist.sample().numpy()[0]
            
            # 执行动作
            next_state, reward, done, _ = self.env.step(action)
            
            # 存储经验
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            
            if done:
                state = self.env.reset()
            else:
                state = next_state
        
        # 3. 计算优势函数
        _, _, last_v = self.local_ac(state[np.newaxis, :])
        advantages = self.compute_advantage(rewards, last_v)
        
        # 4. 计算梯度并更新全局网络
        self.update_global_net(np.array(states), np.array(actions), advantages)
    
    def compute_advantage(self, rewards, last_v):
        """计算优势函数"""
        advantages = np.zeros_like(rewards)
        R = last_v
        
        # 反向计算
        for t in reversed(range(len(rewards))):
            R = rewards[t] + 0.99 * R
            advantages[t] = R - self.values[t]
        
        return advantages
    
    def update_global_net(self, states, actions, advantages):
        with tf.GradientTape() as tape:
            # 计算策略损失
            mu, sigma, values = self.local_ac(states)
            dist = tfp.distributions.Normal(mu, sigma)
            log_probs = dist.log_prob(actions)
            actor_loss = -tf.reduce_mean(log_probs * advantages)
            
            # 计算价值损失
            critic_loss = tf.reduce_mean(tf.square(advantages))
            
            # 总损失
            total_loss = actor_loss + 0.5 * critic_loss
        
        # 计算梯度并更新全局网络
        grads = tape.gradient(total_loss, self.local_ac.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.global_ac.trainable_variables))

5.3 并行训练

def train_a3c(env_name, num_workers=4, max_steps=1000000):
    env = gym.make(env_name)
    global_ac = GlobalAC(env.observation_space.shape[0], env.action_space.shape[0])
    
    # 创建工作线程
    workers = [Worker(global_ac, env_name, i) for i in range(num_workers)]
    
    # 启动并行训练
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(worker.train) for worker in workers]
        
        step_count = 0
        while step_count < max_steps:
            # 等待任意一个工作线程完成
            done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
            
            # 更新进度并重新提交任务
            step_count += 1
            for future in done:
                futures.remove(future)
                futures.append(executor.submit(worker.train))
    
    return global_ac

六、网络架构详解 🧱

6.1 Actor 网络结构

输入状态
全连接层256
ReLU激活
全连接层256
ReLU激活
全连接层2*动作维度
拆分均值和标准差
动作均值 μ
动作标准差 σ
正态分布采样

6.2 Critic 网络结构

输入状态
全连接层256
ReLU激活
全连接层256
ReLU激活
全连接层1
状态价值 V(s)

七、优势分析 ⚖️

7.1 核心优势

40%30%20%10%A3C 优势分布训练速度样本效率无需经验回放更好探索

7.2 性能对比

指标DQNA3C提升
Atari训练时间7-10天4-8小时20-50倍
CPU利用率10-20%70-90%4-8倍
内存占用减少50-70%

7.3 对比其他算法

  • 与 DQN 对比:直接优化策略(而非动作价值函数),适用于连续动作空间。
  • 与 A2C 对比:A3C 采用异步更新,无需经验回放缓冲区,更高效。
  • 与 PPO 对比:A3C 使用全局模型和多线程,PPO 采用单线程和批量更新。

八、实战建议 🎯

8.1 超参数设置

a3c_params = {
    'num_workers': 8,          # 工作线程数(通常设为CPU核心数)
    'n_steps': 5,              # 每个线程的前瞻步数
    'gamma': 0.99,             # 折扣因子
    'learning_rate': 0.0001,   # 学习率
    'entropy_coef': 0.01,      # 熵正则化系数
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值