machina.vfuncs.state_action_vfuncs package

Submodules

machina.vfuncs.state_action_vfuncs.base module

class machina.vfuncs.state_action_vfuncs.base.BaseSAVfunc(ob_space, ac_space, net, rnn=False, data_parallel=False, parallel_dim=0)[source]

Bases: torch.nn.modules.module.Module

Base function of State Action Value Function. It takes observations and actions and then output value. For example Q Func.

Parameters:
  • ob_space (gym.Space) –
  • ac_space (gym.Space) –
  • net (torch.nn.Module) –
  • rnn (bool) –
  • data_parallel (bool) – If True, network computation is executed in parallel.
  • parallel_dim (int) – Splitted dimension in data parallel.
reset()[source]

reset for rnn’s hidden state.

machina.vfuncs.state_action_vfuncs.cem_state_action_vfunc module

Deterministic State Action Value function with Cross Entropy Method

class machina.vfuncs.state_action_vfuncs.cem_state_action_vfunc.CEMDeterministicSAVfunc(ob_space, ac_space, net, rnn=False, data_parallel=False, parallel_dim=0, num_sampling=64, num_best_sampling=6, num_iter=2, multivari=True, delta=0.0001)[source]

Bases: machina.vfuncs.state_action_vfuncs.deterministic_state_action_vfunc.DeterministicSAVfunc

Deterministic State Action Vfunction with Cross Entropy Method. :param ob_space: :type ob_space: gym.Space :param ac_space: :type ac_space: gym.Space :param net: :type net: torch.nn.Module :param rnn: :type rnn: bool :param data_parallel: If True, network computation is executed in parallel. :type data_parallel: bool :param parallel_dim: Splitted dimension in data parallel. :type parallel_dim: int :param num_sampling: Number of samples sampled from Gaussian in CEM. :type num_sampling: int :param num_best_sampling: Number of best samples used for fitting Gaussian in CEM. :type num_best_sampling: int :param num_iter: Number of iteration of CEM. :type num_iter: int :param delta: Coefficient used for making covariance matrix positive definite. :type delta: float

max(obs)[source]

Max and Argmax of Qfunc :param obs: :type obs: torch.Tensor

Returns:
Return type:max_qs, max_acs

machina.vfuncs.state_action_vfuncs.deterministic_state_action_vfunc module

Deterministic State Action Valu function

class machina.vfuncs.state_action_vfuncs.deterministic_state_action_vfunc.DeterministicSAVfunc(ob_space, ac_space, net, rnn=False, data_parallel=False, parallel_dim=0)[source]

Bases: machina.vfuncs.state_action_vfuncs.base.BaseSAVfunc

Deterministic version of State Action Value Function.

Parameters:
  • ob_space (gym.Space) –
  • ac_space (gym.Space) –
  • net (torch.nn.Module) –
  • rnn (bool) –
  • data_parallel (bool) – If True, network computation is executed in parallel.
  • parallel_dim (int) – Splitted dimension in data parallel.
forward(obs, acs, hs=None, h_masks=None)[source]

Calculating values.