想让智能体通过图像输入学会预测环境变化并做出智能决策?循环状态空间模型(RSSM)正是实现这一目标的核心技术,本文将带你深入理解并实战实现这一强大模型。
在现代人工智能领域,如何让智能系统从高维感官输入(如图像像素)中学习环境动态并进行有效规划,一直是一个核心挑战。循环状态空间模型(Recurrent State Space Model, RSSM)作为基于模型的强化学习中的关键技术,为解决这一问题提供了优雅而有效的方案。
RSSM由Danijar Hafer等人在论文《Learning Latent Dynamics for Planning from Pixels》中提出,它能够从像素输入中学习环境的低维表示,并在此基础上预测未来状态和奖励。
本文将深入解析RSSM的原理、核心组件,并提供PyTorch实战代码,让你全面掌握这一强大模型。
一、RSSM的核心思想
RSSM的核心创新在于将环境的潜在状态分解为确定性部分和随机性部分:
- 确定性状态(h):通过循环神经网络(如GRU)捕获历史信息的压缩表示,具有确定性演变规律
- 随机性状态(s或z):捕获环境中无法预测的随机因素和多种可能的未来
这种混合设计使RSSM既能保持长期记忆的稳定性,又能灵活应对环境的不确定性,为智能体的长期规划和决策提供了坚实基础。
二、RSSM的核心组件
RSSM由四个关键模块组成,每个模块都有其独特的功能和作用。
1. 编码器(Encoder)
编码器负责将高维原始观测(如图像像素)压缩为低维潜在表示。通常采用卷积神经网络(CNN)实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class EncoderCNN(nn.Module):
def __init__(self, in_channels, embedding_dim=2048, input_shape=(128, 128)):
super(EncoderCNN, self).__init__()
# 定义卷积层结构
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.fc1 = nn.Linear(self._compute_conv_output((in_channels, *input_shape)), embedding_dim)
# 批标准化层
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
def _compute_conv_output(self, shape):
with torch.no_grad():
x = torch.randn(1, *shape)
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.conv4(x)
return x.shape[1] * x.shape[2] * x.shape[3]
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.bn1(x)
x = F.relu(self.conv2(x))
x = self.bn2(x)
x = F.relu(self.conv3(x))
x = self.bn3(x)
x = self.conv4(x)
x = self.bn4(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
2. 动态模型(Dynamics Model)
动态模型是RSSM中最复杂的组件,负责状态转移的建模,包括先验和后验估计:
class DynamicsModel(nn.Module):
def __init__(self, hidden_dim, action_dim, state_dim, embedding_dim, rnn_layer=1):
super(DynamicsModel, self).__init__()
self.hidden_dim = hidden_dim
self.state_dim = state_dim
self.embedding_dim = embedding_dim
# 递归层实现,支持多层GRU
self.rnn = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(rnn_layer)])
# 状态动作投影层
self.project_state_action = nn.Linear(action_dim + state_dim, hidden_dim)
# 先验网络:输出正态分布参数
self.prior = nn.Linear(hidden_dim, state_dim * 2)
self.project_hidden_action = nn.Linear(hidden_dim + action_dim, hidden_dim)
# 后验网络:输出正态分布参数
self.posterior = nn.Linear(hidden_dim, state_dim * 2)
self.project_hidden_obs = nn.Linear(hidden_dim + embedding_dim, hidden_dim)
self.act_fn = nn.ReLU()
def forward(self, prev_hidden, prev_state, actions, obs=None, dones=None):
batch_size, seq_len, _ = actions.size()
# 初始化存储列表
hiddens_list = []
prior_means_list = []
prior_logvars_list = []
posterior_means_list = []
posterior_logvars_list = []
prior_states_list = []
posterior_states_list = []
# 存储初始状态
hiddens_list.append(prev_hidden.unsqueeze(1))
prior_states_list.append(prev_state.unsqueeze(1))
posterior_states_list.append(prev_state.unsqueeze(1))
# 时序展开
for t in range(seq_len - 1):
# 提取当前时刻状态和动作
action_t = actions[:, t, :]
obs_t = obs[:, t, :] if obs is not None else torch.zeros(batch_size, self.embedding_dim, device=actions.device)
state_t = posterior_states_list[-1][:, 0, :] if obs is not None else prior_states_list[-1][:, 0, :]
if dones is not None:
state_t = state_t * (1 - dones[:, t, :])
hidden_t = hiddens_list[-1][:, 0, :]
# 状态动作组合
state_action = torch.cat([state_t, action_t], dim=-1)
state_action = self.act_fn(self.project_state_action(state_action))
# RNN状态更新
for i in range(len(self.rnn)):
hidden_t = self.rnn[i](state_action, hidden_t)
# 先验分布计算
hidden_action = torch.cat([hidden_t, action_t], dim=-1)
hidden_action = self.act_fn(self.project_hidden_action(hidden_action))
prior_params = self.prior(hidden_action)
prior_mean, prior_logvar = torch.chunk(prior_params, 2, dim=-1)
# 从先验分布采样
prior_dist = torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar)))
prior_state_t = prior_dist.rsample()
# 后验分布计算
if obs is None:
posterior_mean = prior_mean
posterior_logvar = prior_logvar
else:
hidden_obs = torch.cat([hidden_t, obs_t], dim=-1)
hidden_obs = self.act_fn(self.project_hidden_obs(hidden_obs))
posterior_params = self.posterior(hidden_obs)
posterior_mean, posterior_logvar = torch.chunk(posterior_params, 2, dim=-1)
# 从后验分布采样
posterior_dist = torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar)))
posterior_state_t = posterior_dist.rsample()
# 保存状态
posterior_means_list.append(posterior_mean.unsqueeze(1))
posterior_logvars_list.append(posterior_logvar.unsqueeze(1))
prior_means_list.append(prior_mean.unsqueeze(1))
prior_logvars_list.append(prior_logvar.unsqueeze(1))
prior_states_list.append(prior_state_t.unsqueeze(1))
posterior_states_list.append(posterior_state_t.unsqueeze(1))
hiddens_list.append(hidden_t.unsqueeze(1))
# 合并时序数据
hiddens = torch.cat(hiddens_list, dim=1)
prior_states = torch.cat(prior_states_list, dim=1)
posterior_states = torch.cat(posterior_states_list, dim=1)
prior_means = torch.cat(prior_means_list, dim=1)
prior_logvars = torch.cat(prior_logvars_list, dim=1)
posterior_means = torch.cat(posterior_means_list, dim=1)
posterior_logvars = torch.cat(posterior_logvars_list, dim=1)
return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars
3. 解码器(Decoder)
解码器负责从潜在状态重建原始观测,通常采用转置卷积网络实现:
class DecoderCNN(nn.Module):
def __init__(self, hidden_size, state_size, embedding_size, use_bn=True, output_shape=(3, 128, 128)):
super(DecoderCNN, self).__init__()
self.output_shape = output_shape
self.embedding_size = embedding_size
# 全连接层进行特征变换
self.fc1 = nn.Linear(hidden_size + state_size, embedding_size)
self.fc2 = nn.Linear(embedding_size, 256 * (output_shape[1] // 16) * (output_shape[2] // 16))
# 反卷积层进行上采样
self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv4 = nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1)
# 批标准化层
self.bn1 = nn.BatchNorm2d(128)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(32)
self.use_bn = use_bn
def forward(self, h, s):
x = torch.cat([h, s], dim=-1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = x.view(-1, 256, self.output_shape[1] // 16, self.output_shape[2] // 16)
if self.use_bn:
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
else:
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.conv4(x)
return x
4. 奖励模型(Reward Model)
奖励模型预测在给定状态下将获得的奖励值:
class RewardModel(nn.Module):
def __init__(self, hidden_dim, state_dim):
super(RewardModel, self).__init__()
self.fc1 = nn.Linear(hidden_dim + state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 2) # 输出均值和方差
def forward(self, h, s):
x = torch.cat([h, s], dim=-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
三、完整的RSSM实现
将上述组件整合为完整的RSSM模型:
class RSSM(nn.Module):
def __init__(self, encoder, decoder, reward_model, dynamics_model,
hidden_dim, state_dim, action_dim, embedding_dim, device="cuda"):
super(RSSM, self).__init__()
# 模型组件初始化
self.dynamics = dynamics_model
self.encoder = encoder
self.decoder = decoder
self.reward_model = reward_model
# 维度参数存储
self.hidden_dim = hidden_dim
self.state_dim = state_dim
self.action_dim = action_dim
self.embedding_dim = embedding_dim
# 模型迁移至指定设备
self.dynamics.to(device)
self.encoder.to(device)
self.decoder.to(device)
self.reward_model.to(device)
self.device = device
def generate_rollout(self, actions, hiddens=None, states=None, obs=None, dones=None):
# 状态初始化
if hiddens is None:
hiddens = torch.zeros(actions.size(0), self.hidden_dim).to(self.device)
if states is None:
states = torch.zeros(actions.size(0), self.state_dim).to(self.device)
# 执行动态模型展开
dynamics_result = self.dynamics(hiddens, states, actions, obs, dones)
hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = dynamics_result
return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars
def train_mode(self):
"""启用训练模式"""
self.dynamics.train()
self.encoder.train()
self.decoder.train()
self.reward_model.train()
def eval_mode(self):
"""启用评估模式"""
self.dynamics.eval()
self.encoder.eval()
self.decoder.eval()
self.reward_model.eval()
def encode(self, obs):
"""观察编码"""
return self.encoder(obs)
def decode(self, h, s):
"""状态解码为观察"""
return self.decoder(h, s)
def predict_reward(self, h, s):
"""奖励预测"""
return self.reward_model(h, s)
四、RSSM的训练与应用
RSSM的训练通常涉及多个损失函数,包括重建损失、KL散度损失和奖励预测损失:
def compute_loss(obs, reconstructed_obs, prior_means, prior_logvars,
posterior_means, posterior_logvars, rewards, predicted_rewards):
# 重建损失(均方误差)
reconstruction_loss = F.mse_loss(reconstructed_obs, obs)
# KL散度损失(先验与后验分布的差异)
kl_loss = compute_kl_loss(prior_means, prior_logvars, posterior_means, posterior_logvars)
# 奖励预测损失
reward_loss = F.mse_loss(predicted_rewards, rewards)
# 总损失
total_loss = reconstruction_loss + kl_loss + reward_loss
return total_loss, reconstruction_loss, kl_loss, reward_loss
def compute_kl_loss(prior_means, prior_logvars, posterior_means, posterior_logvars):
# 计算两个高斯分布之间的KL散度
kl_div = 0.5 * torch.sum(
posterior_logvars - prior_logvars +
(torch.exp(prior_logvars) + (prior_means - posterior_means)**2) / torch.exp(posterior_logvars) - 1,
dim=-1
)
return kl_div.mean()
五、RSSM的优势与应用前景
RSSM的核心优势在于其高效处理高维输入的能力和显式分离不确定性与确定性的设计理念。通过将图像等原始数据压缩到低维潜在空间,RSSM大幅降低了计算和规划的复杂度。
在应用方面,RSSM已成为许多先进强化学习框架的基石,最著名的当属Dreamer系列算法:
- DreamerV1:验证了在纯潜在空间中进行端到端Actor-Critic训练的可行性
- DreamerV2:引入离散随机状态,在Atari游戏上的性能媲美无模型方法
- DreamerV3:通过自适应归一化等技术,在超过150种不同任务中表现出强大的通用性和鲁棒性
此外,RSSM也被广泛应用于自动驾驶领域的世界模型中,用于驾驶场景生成、轨迹预测和规划控制,例如GAIA-1、DriveDreamer等项目。
RSSM通过其巧妙的确定性-随机性混合状态设计,成功地在潜在空间中实现了对复杂环境动态的紧凑、高效建模。它不仅是推动基于模型的强化学习发展的关键技术,更在自动驾驶等需要精确预测和安全规划的现实世界应用中展现出巨大潜力。