torch.jit.script¶
-
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None)[source]¶ Scripting a function or
nn.Modulewill inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return aScriptModuleorScriptFunction. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the TorchScript Language Reference.torch.jit.scriptcan be used as a function for modules and functions, and as a decorator@torch.jit.scriptfor TorchScript Classes and functions.- Parameters
obj (callable, class, or
nn.Module) – Thenn.Module, function, or class type to compile.- Returns
If
objisnn.Module,scriptreturns aScriptModuleobject. The returnedScriptModulewill have the same set of sub-modules and parameters as the originalnn.Module. Ifobjis a standalone function, aScriptFunctionwill be returned.
- Scripting a function
The
@torch.jit.scriptdecorator will construct aScriptFunctionby compiling the body of the function.Example (scripting a function):
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFuncion # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2))
- Scripting an nn.Module
Scripting an
nn.Moduleby default will compile theforwardmethod and recursively compile any methods, submodules, and functions called byforward. If ann.Moduleonly uses features supported in TorchScript, no changes to the original module code should be necessary.scriptwill constructScriptModulethat has copies of the attributes, parameters, and methods of the original module.Example (scripting a simple module with a Parameter):
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3))
Example (scripting a module with traced submodules):
import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule())
To compile a method other than
forward(and recursively compile anything it calls), add the@torch.jit.exportdecorator to the method. To opt out of compilation use@torch.jit.ignoreor@torch.jit.unused.Example (an exported and ignored method in a module):
import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))