Source code for machina.utils

import contextlib

import torch
import torch.autograd as autograd

from machina import logger

_DEVICE = torch.device('cpu')


[docs]def set_device(device): global _DEVICE _DEVICE = device
[docs]def get_device(): return _DEVICE
[docs]@contextlib.contextmanager def cpu_mode(): global _DEVICE tmp = _DEVICE _DEVICE = torch.device('cpu') yield _DEVICE = tmp
[docs]@contextlib.contextmanager def measure(name): import time s = time.time() yield e = time.time() logger.log("{}: {:.4f}sec".format(name, e-s))
[docs]def detach_tensor_dict(d): _d = dict() for key in d.keys(): if d[key] is None: continue if isinstance(d[key], tuple): _d[key] = (d[key][0].detach(), d[key][1].detach()) continue _d[key] = d[key].detach() return _d