Source code for machina.traj.traj_functional

"""
These are functions which is applied to trajectory.
"""

import torch

from machina import loss_functional as lf


[docs]def update_pris(traj, td_loss, indices, alpha=0.6, epsilon=1e-6): pris = (torch.abs(td_loss) + epsilon) ** alpha traj.data_map['pris'][indices] = pris.detach() return traj