Source code for torch.autograd.grad_mode
import torch
import functools
import inspect
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func):
if inspect.isgeneratorfunction(func):
return self._wrap_generator(func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_context
def _wrap_generator(self, func):
"""Wrap each generator invocation with the context manager"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
while True:
try:
with self:
x = next(gen)
yield x
except StopIteration:
break
return generator_context
[docs]class no_grad(_DecoratorContextManager):
r"""Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
"""
def __enter__(self):
self.prev = torch.is_grad_enabled()
torch._C.set_grad_enabled(False)
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
[docs]class enable_grad(_DecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
or :class:`~set_grad_enabled`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
True
"""
def __enter__(self):
self.prev = torch.is_grad_enabled()
torch._C.set_grad_enabled(True)
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
[docs]class set_grad_enabled(object):
r"""Context-manager that sets gradient calculation to on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Arguments:
mode (bool): Flag whether to enable grad (``True``), or disable
(``False``). This can be used to conditionally enable
gradients.
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
"""
def __init__(self, mode):
self.prev = torch.is_grad_enabled()
torch._C.set_grad_enabled(mode)
def __enter__(self):
pass
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)