跳过正文

Actor-Critic 算法

·785 字·2 分钟
RL Hands-on-Rl
Hands-on-RL - 这篇文章属于一个选集。
§ 7: 本文

本系列是学习《动手学强化学习》过程中做的摘抄。

Actor-Critic 既学习价值函数,又学习策略函数,它是囊括一系列算法的整体架构。Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。

  • Actor(策略网络)与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略;
  • Critic(价值网络)要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断当前状态什么动作是好的、什么动作是不好的,从而帮助 Actor 进行策略更新。

Actor 和 Critic 的关系

Actor 的更新采用策略梯度的原则;将 Critic 价值网络表示为 \(V_{\omega}\),参数为 \(\omega\)。于是,可以采用时序差分残差的学习方式,对单个数据定义如下价值函数的损失函数:

$$ L(\omega) = \frac{1}{2} (r + \gamma V_{\omega} (s_{t+1}) - V_{\omega} (s_t))^2 $$

将上式中的 \(r + \gamma V_{\omega} (s_{t+1})\) 作为时序差分目标,不会产生梯度来更新价值函数。因此,价值函数的梯度为:

$$ \nabla_{\omega} L(\omega) = -(r + \gamma V_{\omega}(s_{t+1}) - V_{\omega}(s_t)) \nabla_{\omega} V_{\omega} (s_t) $$

然后使用梯度下降方法来更新 Critic 价值网络即可。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        """策略网络"""
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x: torch.Tensor):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        """价值网络"""
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr: float, critic_lr: float, gamma: float, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float32).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict["states"], dtype=torch.float32).to(self.device)
        actions = torch.tensor(transition_dict["actions"]).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict["rewards"], dtype=torch.float32).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict["next_states"], dtype=torch.float32).to(self.device)
        dones = torch.tensor(transition_dict["dones"], dtype=torch.float32).view(-1, 1).to(self.device)

        # 时序差分目标
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)  # 时序差分误差
        log_probs = torch.log(self.actor(states).gather(1, actions))
        actor_loss = torch.mean(-log_probs * td_delta.detach())

        # 均方误差损失函数
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        self.actor_optimizer.step()
        self.critic_optimizer.step()
Hands-on-RL - 这篇文章属于一个选集。
§ 7: 本文