set_grad_enabled¶
-
class
torch.
set_grad_enabled
(mode)[source]¶ Context-manager that sets gradient calculation to on or off.
set_grad_enabled
will enable or disable grads based on its argumentmode
. It can be used as a context-manager or as a function.This context manager is thread local; it will not affect computation in other threads.
- Parameters
mode (bool) – Flag whether to enable grad (
True
), or disable (False
). This can be used to conditionally enable gradients.
Example:
>>> x = torch.tensor([1], requires_grad=True) >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> torch.set_grad_enabled(True) >>> y = x * 2 >>> y.requires_grad True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad False