跳过正文

PPO 算法

·1366 字·3 分钟
RL Hands-on-Rl
Hands-on-RL - 这篇文章属于一个选集。
§ 9: 本文

本系列是学习《动手学强化学习》过程中做的摘抄。

TRPO 算法在很多场景上的应用都很成功,但是它的计算过程非常复杂,每一步更新的运算量非常大。PPO 算法在 TRPO 的基础上进行了改进,其优化目标与 TRPO 相似,但 PPO 用了一些相对简单的方法来求解。

TRPO 的优化目标(使用泰勒展开近似、共轭梯度、线性搜索等方法直接求解):

$$ \max_{\theta} \text{E}_{s \text{\textasciitilde} \nu^{\pi_{\theta_{k}}}} \text{E}_{a \text{\textasciitilde} \pi_{\theta_{k}} (\cdot|s)} \left[ \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)} A^{\pi_{\theta_{k}}}(s,a) \right] \\ \text{s.t.} \ \ \text{E}_{s \text{\textasciitilde} \nu^{\pi_{\theta_{k}}}} \left[ D_{KL}(\pi_{\theta_{k}}(\cdot|s) ,\pi_{\theta}(\cdot|s)) \right] \leq \delta $$

将时序差分残差定义为优势函数 \(A^{\pi_{\theta}}(s_t,a_t)=r(s_t,a_t)+\gamma V^{\pi_{\theta}}(s_{t+1})-V^{\pi_{\theta}}(s_t)\),

需要注意的是,TRPO 和 PPO 都属于在线策略(on-policy)学习算法,即使优化目标中包含重要性采样的过程,但其只是用到了上一轮策略的数据,而不是过去所有策略的数据。

9.1 PPO-Penalty
#

PPO-Penalty 用拉格朗日乘数法直接将 KL 散度的限制放进目标函数中,这就变成了一个无约束的优化问题,在迭代过程中不断更新 KL 散度前的系数。即:

$$ \arg \max_{\theta} \text{E}_{s \text{\textasciitilde} \nu^{\pi_{\theta_{k}}}} \text{E}_{a \text{\textasciitilde} \pi_{\theta_{k}} (\cdot|s)} \left[ \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)} A^{\pi_{\theta_{k}}}(s,a) - \beta D_{KL}[\pi_{\theta_{k}}(\cdot|s) ,\pi_{\theta}(\cdot|s)] \right] $$

令 \(d_k = D_{KL}^{\nu^{\pi_{\theta_{k}}}} (\pi_{\theta_{k}}, \pi_{\theta})\),\(\beta\) 的更新规则如下:

  1. 如果 \(d_k < \delta / 1.5\),那么 \(\beta_{k+1} = \beta_{k} / 2\);
  2. 如果 \(d_k > \delta \times 1.5\),那么 \(\beta_{k+1} = \beta_{k} \times 2\);
  3. 否则 \(\beta_{k+1} = \beta_{k}\)。

其中,\(\delta\) 是事先设定的一个超参数,用于限制学习策略和之前一轮策略的差距。

9.2 PPO-Clip
#

PPO-Clip 是 PPO 的一种变种,它在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大,即:

$$ \arg \max_{\theta} \text{E}_{s \text{\textasciitilde} \nu^{\pi_{\theta_{k}}}} \text{E}_{a \text{\textasciitilde} \pi_{\theta_{k}} (\cdot|s)} \left[ \min \left( \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)} A^{\pi_{\theta_{k}}}(s,a), \text{clip} \left( \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)}, 1-\epsilon, 1+\epsilon \right) A^{\pi_{\theta_{k}}}(s,a) \right) \right] $$

其中,\(\text{clip}(x,l,r)=\max (\min(x,r),l)\),即把 \(x\) 限制在 \([l,r]\) 内。\(\epsilon\) 是一个超参数,表示进行截断的范围。

  • 如果 \(A^{\pi_{\theta_{k}}}(s,a)>0\),说明这个动作的价值高于平均值,最大化这个式子会增大 \(\frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)}\),但不会让其超过 \(1+\epsilon\);
  • 反之,如果 \(A^{\pi_{\theta_{k}}}(s,a)<0\),最大化这个式子会减小 \(\frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)}\),但不会让其小于 \(1-\epsilon\)。

PPO-Clip 示意图

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F


def compute_advantage(gamma: float, lmbda: float, td_delta: torch.Tensor):
    """计算广义优势估计 GAE"""
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float32)


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int, action_dim: int):
        """策略网络"""
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x: torch.Tensor):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int):
        """价值网络"""
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class PPO:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device):
        """PPO 算法,采用截断方式"""
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用来训练轮数
        self.eps = eps  # PPO 中截断范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict["states"], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict["actions"]).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict["rewards"], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict["next_states"], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict["dones"], dtype=torch.float).view(-1, 1).to(self.device)

        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)

        advantage = compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO 损失函数
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()
Hands-on-RL - 这篇文章属于一个选集。
§ 9: 本文