本系列是学习《动手学强化学习》过程中做的摘抄。
策略梯度算法和 Actor-Critic 算法都是基于策略的算法,这些算法虽然简单、直观,但在实际应用过程中会遇到训练不稳定的情况。如策略梯度算法主要沿着 \(\nabla_{\theta} J(\theta)\) 方向迭代更新策略参数 \(\theta\),但是当策略网络是深度模型时,沿着策略梯度更新参数,很可能由于步长太长,策略突然显著变差,进而影响训练效果。
针对以上问题,信任区域策略优化(TRPO)算法的主要思想是在更新时找到一块信任区域,在这个区域上更新策略时能够得到某种策略性能的安全性保证。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_advantage(gamma: float, lmbda: float, td_delta: torch.Tensor):
"""计算广义优势估计 GAE"""
td_delta = td_delta.detach().numpy()
advantage_list = []
advantage = 0.0
for delta in td_delta[::-1]:
advantage = gamma * lmbda * advantage + delta
advantage_list.append(advantage)
advantage_list.reverse()
return torch.tensor(advantage_list, dtype=torch.float32)
class PolicyNet(nn.Module):
def __init__(self, state_dim: int, hidden_dim: int, action_dim: int):
"""策略网络"""
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, action_dim)
def forward(self, x: torch.Tensor):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x), dim=1)
class ValueNet(nn.Module):
def __init__(self, state_dim, hidden_dim):
"""价值网络"""
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, x: torch.Tensor):
x = F.relu(self.fc1(x))
return self.fc2(x)
class TRPO:
def __init__(
self,
hidden_dim: int,
state_space,
action_space,
lmbda: float,
kl_constraint: float,
alpha: float,
critic_lr: float,
gamma: float,
device: torch.device,
):
"""TRPO 算法"""
state_dim = state_space.shape[0]
action_dim = action_space.n
# 策略网络参数不需要优化器更新
self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.critic = ValueNet(state_dim, hidden_dim).to(device)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
self.gamma = gamma
self.lmbda = lmbda # GAE 参数
self.kl_constraint = kl_constraint # KL 距离最大限制
self.alpha = alpha # 线性搜索参数
self.device = device
def take_action(self, state):
state = torch.tensor([state], dtype=torch.float32).to(self.device)
probs = self.actor(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
def hessian_matrix_vector_product(self, states, old_action_dists, vector):
"""计算黑塞矩阵和一个向量的乘积"""
new_action_dists = torch.distributions.Categorical(self.actor(states))
kl = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists, new_action_dists)) # 计算平均 KL 距离
kl_grad = torch.autograd.grad(kl, self.actor.parameters(), create_graph=True)
kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])
# KL 距离的梯度先和向量进行点积运算
kl_grad_vector_product = torch.dot(kl_grad_vector, vector)
grad2 = torch.autograd.grad(kl_grad_vector_product, self.actor.parameters())
grad2_vector = torch.cat([grad.view(-1) for grad in grad2])
return grad2_vector
def conjugate_gradient(self, grad, states, old_action_dists):
"""共轭梯度法求解方程"""
x = torch.zeros_like(grad)
r = grad.clone()
p = grad.clone()
rdotr = torch.dot(r, r)
for i in range(10): # 共轭梯度主循环
Hp = self.hessian_matrix_vector_product(states, old_action_dists, p)
alpha = rdotr / torch.dot(p, Hp)
x += alpha * p
r -= alpha * Hp
new_rdotr = torch.dot(r, r)
if new_rdotr < 1e-10:
break
beta = new_rdotr / rdotr
p = r + beta * p
rdotr = new_rdotr
return x
def compute_surrogate_obj(self, states, actions, advantage, old_log_probs, actor):
"""计算策略目标"""
log_probs = torch.log(actor(states).gather(1, actions))
ratio = torch.exp(log_probs - old_log_probs)
return torch.mean(ratio * advantage)
def line_search(self, states, actions, advantage, old_log_probs, old_action_dists, max_vec):
"""线性搜索"""
old_para = torch.nn.utils.convert_parameters.parameters_to_vector(self.actor.parameters())
old_obj = self.compute_surrogate_obj(states, actions, advantage, old_log_probs, self.actor)
for i in range(15): # 线性搜索主循环
coef = self.alpha**i
new_para = old_para + coef * max_vec
new_actor = copy.deepcopy(self.actor)
torch.nn.utils.convert_parameters.vector_to_parameters(new_para, new_actor.parameters())
new_action_dists = torch.distributions.Categorical(new_actor(states))
kl_div = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists, new_action_dists))
new_obj = self.compute_surrogate_obj(states, actions, advantage, old_log_probs, new_actor)
if new_obj > old_obj and kl_div < self.kl_constraint:
return new_para
return old_para
def policy_learn(self, states, actions, old_action_dists, old_log_probs, advantage):
"""更新策略函数"""
surrogate_obj = self.compute_surrogate_obj(states, actions, advantage, old_log_probs, self.actor)
grads = torch.autograd.grad(surrogate_obj, self.actor.parameters())
obj_grad = torch.cat([grad.view(-1) for grad in grads]).detach()
# 用共轭梯度法计算 x = H^(-1)g
descent_direction = self.conjugate_gradient(obj_grad, states, old_action_dists)
Hd = self.hessian_matrix_vector_product(states, old_action_dists, descent_direction)
max_coef = torch.sqrt(2 * self.kl_constraint / (torch.dot(descent_direction, Hd) + 1e-8))
new_para = self.line_search(states, actions, advantage, old_log_probs, old_action_dists, descent_direction * max_coef)
torch.nn.utils.convert_parameters.vector_to_parameters(
new_para, self.actor.parameters()
) # 用线性搜索后的参数更新策略
def update(self, transition_dict):
states = torch.tensor(transition_dict["states"], dtype=torch.float32).to(self.device)
actions = torch.tensor(transition_dict["actions"]).view(-1, 1).to(self.device)
rewards = torch.tensor(transition_dict["rewards"], dtype=torch.float32).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict["next_states"], dtype=torch.float32).to(self.device)
dones = torch.tensor(transition_dict["dones"], dtype=torch.float32).view(-1, 1).to(self.device)
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
td_delta = td_target - self.critic(states)
advantage = compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
old_action_dists = torch.distributions.Categorical(self.actor(states).detach())
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step() # 更新价值函数
# 更新策略函数
self.policy_learn(states, actions, old_action_dists, old_log_probs, advantage)