Source code for torch.jit._script
"""TorchScript
This module contains functionality to support the JIT's scripting frontend, notably:
- torch.jit.script
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import torch
import functools
import collections
from torch.jit._recursive import ScriptMethodStub
from torch.nn import Module
from torch.jit._state import _enabled
from torch._six import with_metaclass
if _enabled:
Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:
def Attribute(value, type):
return value
# For each user-defined class that subclasses ScriptModule, this meta-class:
# (1) finds all the methods annotated with @script_method in a ScriptModule and
# removes them from the class attributes
# (2) puts a wrapper around the class's __init__ method to recusively compile
# all of the script_methods with the module after the original __init__ has
# run. This has to occur after the user-defined __init__ so that submodules and
# parameters are initialized _before_ the script compiler resolve references to
# `self.param` or `self.module`.
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# Aggregate all the ScriptMethods and constants from superclasses
cls._methods = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# find all the script methods of the current class
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
cls._methods[v.original_method.__name__] = v
if getattr(cls, "_disable_script_meta", False):
# We leave built-in ScriptModule types alone, since this metaclass
# is only for compiling user classes that inherit from
# ScriptModule.
return super(ScriptMeta, cls).__init__(name, bases, attrs)
original_init = getattr(cls, "__init__", lambda self: None)
@functools.wraps(original_init)
def init_then_script(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if type(self) == cls:
def make_stubs(module):
cls = type(module)
return [v for k, v in sorted(cls._methods.items())]
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs)
# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
# the scripted versions.
concrete_type = self._actual_script_module._concrete_type
for name in concrete_type.get_attributes():
delattr(self, name)
for name, _ in concrete_type.get_modules():
delattr(self, name)
for name in ("_parameters", "_buffers", "_modules"):
delattr(self, name)
cls.__init__ = init_then_script
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class _CachedForward(object):
def __get__(self, obj, cls):
return self.__getattr__("forward")
if _enabled:
# this is a Python 'non-data descriptor' that causes the first access
# to ScriptModule's forward to lookup the forward method and stash
# it in the objects dict. Due to the standard rules for attribute lookup
# subsequent lookups will just directly return the previously looked up method.
# This is necessary because nn.Module defines forward as a method. If we
# did nothing __getattr__ would not be called. Instead we'd get nn.Module.forward
# which always throws an exception.
class ScriptModule(with_metaclass(ScriptMeta, Module)):
"""
``ScriptModule``s wrap a C++ ``torch::jit::Module``. ``ScriptModule``s
contain methods, attributes, parameters, and
constants. These can be accessed the same as on a normal ``nn.Module``.
"""
def __init__(self):
super(ScriptModule, self).__init__()
forward = _CachedForward()
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
return super(ScriptModule, self).__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + recording
# the provided type in __annotations__.
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
# NB: Ensure that we set __annotations__ on the specific
# class in question, and not on a superclass (which would
# be wrong wrong wrong!).
# See also https://github.com/pytorch/pytorch/issues/39463
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
setattr(self._actual_script_module, attr, value)
def define(self, src):
if "_actual_script_module" in self.__dict__:
# If we have completed initialization, just defer to the
# backing RecursiveScriptModule to eagerly compile the provided
# source.
return self._actual_script_module.define(src)
# Otherwise, we are still in the object's __init__.
# In that case, add `src` as a stub to be compiled.
#
# We use frames_up=1 to get to the proper surrounding scope. The stack
# will look like:
# 0. createResolutionCallback
# 1. define()
# 2. surrounding scope.
#
# createResolutionCallback internally adds 1 to get us to our frame, then
# we add 1 to get to the proper surrounding scope.
rcb = torch._jit_internal.createResolutionCallbackFromFrame(frames_up=1)
ast = torch._C._parse_source_def(src)
self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None)
def _replicate_for_data_parallel(self):
return self._actual_script_module._replicate_for_data_parallel()
else:
# TODO MAKE SURE THAT DISABLING WORKS
[docs] class ScriptModule(torch.nn.Module):
def __init__(self):
super(ScriptModule, self).__init__()