本系列是学习《动手学强化学习》过程中做的摘抄。
Actor-Critic 既学习价值函数,又学习策略函数,它是囊括一系列算法的整体架构。Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。
REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。
- Actor(策略网络)与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略;
- Critic(价值网络)要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断当前状态什么动作是好的、什么动作是不好的,从而帮助 Actor 进行策略更新。

Actor 的更新采用策略梯度的原则;将 Critic 价值网络表示为 \(V_{\omega}\),参数为 \(\omega\)。于是,可以采用时序差分残差的学习方式,对单个数据定义如下价值函数的损失函数:
$$ L(\omega) = \frac{1}{2} (r + \gamma V_{\omega} (s_{t+1}) - V_{\omega} (s_t))^2 $$将上式中的 \(r + \gamma V_{\omega} (s_{t+1})\) 作为时序差分目标,不会产生梯度来更新价值函数。因此,价值函数的梯度为:
$$ \nabla_{\omega} L(\omega) = -(r + \gamma V_{\omega}(s_{t+1}) - V_{\omega}(s_t)) \nabla_{\omega} V_{\omega} (s_t) $$然后使用梯度下降方法来更新 Critic 价值网络即可。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
"""策略网络"""
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x: torch.Tensor):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x), dim=1)
class ValueNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim):
"""价值网络"""
super(ValueNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, 1)
def forward(self, x: torch.Tensor):
x = F.relu(self.fc1(x))
return self.fc2(x)
class ActorCritic:
def __init__(self, state_dim, hidden_dim, action_dim, actor_lr: float, critic_lr: float, gamma: float, device):
self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.critic = ValueNet(state_dim, hidden_dim).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
self.gamma = gamma
self.device = device
def take_action(self, state):
state = torch.tensor([state], dtype=torch.float32).to(self.device)
probs = self.actor(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
def update(self, transition_dict):
states = torch.tensor(transition_dict["states"], dtype=torch.float32).to(self.device)
actions = torch.tensor(transition_dict["actions"]).view(-1, 1).to(self.device)
rewards = torch.tensor(transition_dict["rewards"], dtype=torch.float32).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict["next_states"], dtype=torch.float32).to(self.device)
dones = torch.tensor(transition_dict["dones"], dtype=torch.float32).view(-1, 1).to(self.device)
# 时序差分目标
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
td_delta = td_target - self.critic(states) # 时序差分误差
log_probs = torch.log(self.actor(states).gather(1, actions))
actor_loss = torch.mean(-log_probs * td_delta.detach())
# 均方误差损失函数
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
actor_loss.backward()
critic_loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()