Source code for machina.pols.gaussian_pol

import numpy as np
import torch
import torch.nn as nn

from machina.pols import BasePol
from machina.pds.gaussian_pd import GaussianPd
from machina.utils import get_device


[docs]class GaussianPol(BasePol): """ Policy with Gaussian distribution. Parameters ---------- ob_space : gym.Space observation's space ac_space : gym.Space action's space This should be gym.spaces.Box net : torch.nn.Module rnn : bool normalize_ac : bool If True, the output of network is spreaded for ac_space. In this situation the output of network is expected to be in -1~1. data_parallel : bool If True, network computation is executed in parallel. parallel_dim : int Splitted dimension in data parallel. """ def __init__(self, ob_space, ac_space, net, rnn=False, normalize_ac=True, data_parallel=False, parallel_dim=0): BasePol.__init__(self, ob_space, ac_space, net, rnn, normalize_ac, data_parallel, parallel_dim) self.pd = GaussianPd() self.to(get_device())
[docs] def forward(self, obs, hs=None, h_masks=None): obs = self._check_obs_shape(obs) if self.rnn: time_seq, batch_size, *_ = obs.shape if hs is None: if self.hs is None: self.hs = self.net.init_hs(batch_size) if self.dp_run: self.hs = (self.hs[0].unsqueeze( 0), self.hs[1].unsqueeze(0)) hs = self.hs if h_masks is None: h_masks = hs[0].new(time_seq, batch_size, 1).zero_() h_masks = h_masks.reshape(time_seq, batch_size, 1) if self.dp_run: mean, log_std, hs = self.dp_net(obs, hs, h_masks) else: mean, log_std, hs = self.net(obs, hs, h_masks) self.hs = hs else: if self.dp_run: mean, log_std = self.dp_net(obs) else: mean, log_std = self.net(obs) log_std = log_std.expand_as(mean) ac = self.pd.sample(dict(mean=mean, log_std=log_std)) ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy()) return ac_real, ac, dict(mean=mean, log_std=log_std, hs=hs)
[docs] def deterministic_ac_real(self, obs, hs=None, h_masks=None): """ action for deployment """ obs = self._check_obs_shape(obs) if self.rnn: time_seq, batch_size, *_ = obs.shape if hs is None: if self.hs is None: self.hs = self.net.init_hs(batch_size) hs = self.hs if h_masks is None: h_masks = hs[0].new(time_seq, batch_size, 1).zero_() h_masks = h_masks.reshape(time_seq, batch_size, 1) mean, _, hs = self.net(obs, hs, h_masks) self.hs = hs else: mean, log_std = self.net(obs) mean_real = self.convert_ac_for_real(mean.detach().cpu().numpy()) return mean_real, mean, dict(mean=mean, log_std=log_std, hs=hs)