machina package

Subpackages

Submodules

machina.logger module

machina.logger.tabulate(tabular_data, headers=[], tablefmt='simple', floatfmt='g', numalign='decimal', stralign='left', missingval='')[source]

Format a fixed width table for pretty printing.

>>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]]))
---  ---------
  1      2.34
-56      8.999
  2  10001
---  ---------

The first required argument (tabular_data) can be a list-of-lists (or another iterable of iterables), a list of named tuples, a dictionary of iterables, a two-dimensional NumPy array, NumPy record array, or a Pandas’ dataframe.

To print nice column headers, supply the second argument (headers):

  • headers can be an explicit list of column headers
  • if headers=”firstrow”, then the first row of data is used
  • if headers=”keys”, then dictionary keys or column indices are used

Otherwise a headerless table is produced.

If the number of headers is less than the number of columns, they are supposed to be names of the last columns. This is consistent with the plain-text format of R and Pandas’ dataframes.

>>> print(tabulate([["sex","age"],["Alice","F",24],["Bob","M",19]],
...       headers="firstrow"))
       sex      age
-----  -----  -----
Alice  F         24
Bob    M         19

tabulate tries to detect column types automatically, and aligns the values properly. By default it aligns decimal points of the numbers (or flushes integer numbers to the right), and flushes everything else to the left. Possible column alignments (numalign, stralign) are: “right”, “center”, “left”, “decimal” (only for numalign), and None (to disable alignment).

floatfmt is a format specification used for columns which contain numeric data with a decimal point.

None values are replaced with a missingval string:

>>> print(tabulate([["spam", 1, None],
...                 ["eggs", 42, 3.14],
...                 ["other", None, 2.7]], missingval="?"))
-----  --  ----
spam    1  ?
eggs   42  3.14
other   ?  2.7
-----  --  ----

Various plain-text table formats (tablefmt) are supported: ‘plain’, ‘simple’, ‘grid’, ‘pipe’, ‘orgtbl’, ‘rst’, ‘mediawiki’, and ‘latex’. Variable tabulate_formats contains the list of currently supported formats.

“plain” format doesn’t use any pseudographics to draw tables, it separates columns with a double space:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
...                 ["strings", "numbers"], "plain"))
strings      numbers
spam         41.9999
eggs        451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="plain"))
spam   41.9999
eggs  451

“simple” format is like Pandoc simple_tables:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
...                 ["strings", "numbers"], "simple"))
strings      numbers
---------  ---------
spam         41.9999
eggs        451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="simple"))
----  --------
spam   41.9999
eggs  451
----  --------

“grid” is similar to tables produced by Emacs table.el package or Pandoc grid_tables:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
...                ["strings", "numbers"], "grid"))
+-----------+-----------+
| strings   |   numbers |
+===========+===========+
| spam      |   41.9999 |
+-----------+-----------+
| eggs      |  451      |
+-----------+-----------+
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="grid"))
+------+----------+
| spam |  41.9999 |
+------+----------+
| eggs | 451      |
+------+----------+

“pipe” is like tables in PHP Markdown Extra extension or Pandoc pipe_tables:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
...                ["strings", "numbers"], "pipe"))
| strings   |   numbers |
|:----------|----------:|
| spam      |   41.9999 |
| eggs      |  451      |
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="pipe"))
|:-----|---------:|
| spam |  41.9999 |
| eggs | 451      |

“orgtbl” is like tables in Emacs org-mode and orgtbl-mode. They are slightly different from “pipe” format by not using colons to define column alignment, and using a “+” sign to indicate line intersections:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
...                ["strings", "numbers"], "orgtbl"))
| strings   |   numbers |
|-----------+-----------|
| spam      |   41.9999 |
| eggs      |  451      |
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="orgtbl"))
| spam |  41.9999 |
| eggs | 451      |

“rst” is like a simple table format from reStructuredText; please note that reStructuredText accepts also “grid” tables:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
...                ["strings", "numbers"], "rst"))
=========  =========
strings      numbers
=========  =========
spam         41.9999
eggs        451
=========  =========
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="rst"))
====  ========
spam   41.9999
eggs  451
====  ========

“mediawiki” produces a table markup used in Wikipedia and on other MediaWiki-based sites:

>>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], ["eggs", "451.0"]],
...                headers="firstrow", tablefmt="mediawiki"))
{| class="wikitable" style="text-align: left;"
|+ <!-- caption -->
|-
! strings   !! align="right"|   numbers
|-
| spam      || align="right"|   41.9999
|-
| eggs      || align="right"|  451
|}

“latex” produces a tabular environment of LaTeX document markup:

>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="latex"))
\begin{tabular}{lr}
\hline
 spam &  41.9999 \\
 eggs & 451      \\
\hline
\end{tabular}
machina.logger.simple_separated_format(separator)[source]

Construct a simple TableFormat with columns separated by a separator.

>>> tsv = simple_separated_format("\t") ;         tabulate([["foo", 1], ["spam", 23]], tablefmt=tsv) == 'foo \t 1\nspam\t23'
True

machina.loss_functional module

These are functions for loss. Algorithms should be written by combining these functions.

machina.loss_functional.ag(pol, qf, batch, sampling=1)[source]

DDPG style action gradient.

Parameters:
  • pol (Pol) –
  • qf (SAVfunction) –
  • batch (dict of torch.Tensor) –
  • sampling (int) – Number of samping in calculating expectation.
Returns:

pol_loss

Return type:

torch.Tensor

machina.loss_functional.bellman(qf, targ_qf, targ_pol, batch, gamma, continuous=True, deterministic=True, sampling=1, reduction='elementwise_mean')[source]

Bellman loss. Mean Squared Error of left hand side and right hand side of Bellman Equation.

Parameters:
  • qf (SAVfunction) –
  • targ_qf (SAVfunction) –
  • targ_pol (Pol) –
  • batch (dict of torch.Tensor) –
  • gamma (float) –
  • continuous (bool) – action space is continuous or not
  • sampling (int) – Number of samping in calculating expectation.
  • reduction (str) – This argument takes only elementwise, sum, and none. Loss shape is pytorch’s manner.
Returns:

bellman_loss

Return type:

torch.Tensor

machina.loss_functional.clipped_double_bellman(qf, targ_qf1, targ_qf2, batch, gamma, loss_type='bce')[source]

Bellman loss of Clipped Double DQN. Mean Squared Error of left hand side and right hand side of Bellman Equation. or Binary Cross Entropy of left hand side and right hand side of Bellman Equation.

Parameters:
  • qf (SAVfunction) –
  • targ_qf1 (SAVfunction) –
  • targ_qf2 (SAVfunction) –
  • batch (dict of torch.Tensor) –
  • gamma (float) –
  • type (loss) – This argument takes only bce and mse. Loss shape is pytorch’s manner.
Returns:

ret

Return type:

torch.Tensor

machina.loss_functional.monte_carlo(vf, batch, clip_param=0.2, clip=False)[source]

Montecarlo loss for V function.

Parameters:
  • vf (SVfunction) –
  • batch (dict of torch.Tensor) –
  • clip_param (float) –
  • clip (bool) –
machina.loss_functional.pg(pol, batch, ent_beta=0)[source]

Policy Gradient.

Parameters:
  • pol (Pol) –
  • batch (dict of torch.Tensor) –
Returns:

pol_loss

Return type:

torch.Tensor

machina.loss_functional.pg_clip(pol, batch, clip_param, ent_beta)[source]

Policy Gradient with clipping.

Parameters:
  • pol (Pol) –
  • batch (dict of torch.Tensor) –
  • clip_param (float) –
  • ent_beta (float) – entropy coefficient
Returns:

pol_loss

Return type:

torch.Tensor

machina.loss_functional.pg_kl(pol, batch, kl_beta, ent_beta=0)[source]

Policy Gradient with KL divergence restriction.

Parameters:
  • pol (Pol) –
  • batch (dict of torch.Tensor) –
  • kl_beta (float) – KL divergence coefficient
Returns:

pol_loss

Return type:

torch.Tensor

machina.loss_functional.sac(pol, qfs, targ_qfs, log_alpha, batch, gamma, sampling=1, reparam=True, normalize=False, eps=1e-06)[source]

Loss for soft actor critic.

Parameters:
  • pol (Pol) –
  • qfs (list of SAVfunction) –
  • targ_qfs (list of SAVfunction) –
  • log_alpha (torch.Tensor) –
  • batch (dict of torch.Tensor) –
  • gamma (float) –
  • sampling (int) – Number of samping in calculating expectation.
  • reparam (bool) – Reparameterization trick is used or not.
  • normalize (bool) – If True, normalize value of log likelihood.
  • eps (float) –
Returns:

pol_loss, qf_loss, alpha_loss

Return type:

torch.Tensor, torch.Tensor, torch.Tensor

machina.utils module

machina.utils.cpu_mode()[source]
machina.utils.detach_tensor_dict(d)[source]
machina.utils.get_device()[source]
machina.utils.measure(name)[source]
machina.utils.set_device(device)[source]