machina.vfuncs.state_vfuncs package

Submodules

machina.vfuncs.state_vfuncs.base module

class machina.vfuncs.state_vfuncs.base.BaseSVfunc(ob_space, net, rnn=False, data_parallel=False, parallel_dim=0)[source]

Bases: torch.nn.modules.module.Module

Base function of State Value Function. It takes observations and then output value. For example V Func.

Parameters:
  • ob_space (gym.Space) –
  • net (torch.nn.Module) –
  • rnn (bool) –
  • data_parallel (bool) – If True, network computation is executed in parallel.
  • parallel_dim (int) – Splitted dimension in data parallel.
reset()[source]

reset for rnn’s hidden state.

machina.vfuncs.state_vfuncs.deterministic_state_vfunc module

Deterministic State Value function

class machina.vfuncs.state_vfuncs.deterministic_state_vfunc.DeterministicSVfunc(ob_space, net, rnn=False, data_parallel=False, parallel_dim=0)[source]

Bases: machina.vfuncs.state_vfuncs.base.BaseSVfunc

Deterministic version of State Action Value Function.

Parameters:
  • ob_space (gym.Space) –
  • net (torch.nn.Module) –
  • rnn (bool) –
  • data_parallel (bool) – If True, network computation is executed in parallel.
  • parallel_dim (int) – Splitted dimension in data parallel.
forward(obs, hs=None, h_masks=None)[source]

Calculating values.

class machina.vfuncs.state_vfuncs.deterministic_state_vfunc.NormalizedDeterministicSVfunc(ob_space, net)[source]

Bases: machina.vfuncs.state_vfuncs.deterministic_state_vfunc.DeterministicSVfunc

forward(obs, hs=None, h_masks=None)[source]

Calculating values.

set_mean(mean)[source]
set_std(std)[source]