Source code for machina.prepro.base

import numpy as np
import torch


[docs]class BasePrePro(object): """ Preprocess for observations. Parameters ---------- ob_space : gym.Space normalize_ob : bool """ def __init__(self, ob_space, normalize_ob=True): self.ob_space = ob_space self.normalize_ob = normalize_ob if self.normalize_ob: self.ob_rm = np.zeros(self.ob_space.shape) self.ob_rv = np.ones(self.ob_space.shape) self.alpha = 0.001
[docs] def update_ob_rms(self, ob): """ Updating running mean and running variance. """ self.ob_rm = self.ob_rm * (1-self.alpha) + self.alpha * ob self.ob_rv = self.ob_rv * (1-self.alpha) + \ self.alpha * np.square(ob-self.ob_rm)
[docs] def prepro(self, ob): """ Applying preprocess to observations. """ if self.normalize_ob: ob = (ob - self.ob_rm) / (np.sqrt(self.ob_rv) + 1e-8) ob = np.clip(ob, -5, 5) return ob
[docs] def prepro_with_update(self, ob): """ Applying preprocess to observations with update. """ if self.normalize_ob: self.update_ob_rms(ob) ob = (ob - self.ob_rm) / (np.sqrt(self.ob_rv) + 1e-8) ob = np.clip(ob, -5, 5) return ob