Shortcuts

Source code for torch.quantization.fuse_modules

from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import copy

import torch.nn.intrinsic.modules.fused as torch_fused

def fuse_conv_bn(conv, bn):
    r"""Given the conv and bn modules, fuses them and returns the fused module

    Args:
        conv: Module instance of type conv2d/conv3d
        bn: Spatial BN instance that needs to be fused with the conv

    Examples::

        >>> m1 = nn.Conv2d(10, 20, 3)
        >>> b1 = nn.BatchNorm2d(20)
        >>> m2 = fuse_conv_bn(m1, b1)
    """
    assert(conv.training == bn.training),\
        "Conv and BN both must be in the same mode (train or eval)."

    is_3d = isinstance(conv, torch.nn.Conv3d)

    if conv.training:
        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
        return torch.nn.intrinsic.ConvBn3d(conv, bn) if is_3d \
            else torch.nn.intrinsic.ConvBn2d(conv, bn)
    else:
        return torch.nn.utils.fuse_conv_bn_eval(conv, bn)

def fuse_conv_bn_relu(conv, bn, relu):
    r"""Given the conv and bn modules, fuses them and returns the fused module

    Args:
        conv: Module instance of type conv2d/conv3d
        bn: Spatial BN instance that needs to be fused with the conv

    Examples::

        >>> m1 = nn.Conv2d(10, 20, 3)
        >>> b1 = nn.BatchNorm2d(20)
        >>> m2 = fuse_conv_bn(m1, b1)
    """
    assert(conv.training == bn.training == relu.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    if conv.training:
        map_to_fused_module_train = {
            torch.nn.Conv2d: torch_fused.ConvBnReLU2d,
            torch.nn.Conv3d: torch_fused.ConvBnReLU3d,
        }
        assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
        assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
        fused_module = map_to_fused_module_train.get(type(conv))
        if fused_module is not None:
            return fused_module(conv, bn, relu)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
    else:
        map_to_fused_module_eval = {
            torch.nn.Conv1d: torch_fused.ConvReLU1d,
            torch.nn.Conv2d: torch_fused.ConvReLU2d,
            torch.nn.Conv3d: torch_fused.ConvReLU3d,
        }
        fused_module = map_to_fused_module_eval[type(conv)]
        if fused_module is not None:
            return fused_module(torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
        else:
            raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))

# Generalization of getattr
def _get_module(model, submodule_key):
    tokens = submodule_key.split('.')
    cur_mod = model
    for s in tokens:
        cur_mod = getattr(cur_mod, s)
    return cur_mod

# Generalization of setattr
def _set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)

    setattr(cur_mod, tokens[-1], module)

def fuse_known_modules(mod_list):
    r"""Returns a list of modules that fuses the operations specified
     in the input module list.

    Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, relu
    For these sequences, the first element in the output module list performs
    the fused operation. The rest of the elements are set to nn.Identity()
    """

    OP_LIST_TO_FUSER_METHOD = {
        (torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn,
        (torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu,
        (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
        (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
        (torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn,
        (torch.nn.Conv3d, torch.nn.BatchNorm3d, torch.nn.ReLU): fuse_conv_bn_relu,
        (torch.nn.Conv1d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU1d,
        (torch.nn.Conv2d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU2d,
        (torch.nn.Conv3d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU3d,
        (torch.nn.Linear, torch.nn.ReLU): torch.nn.intrinsic.LinearReLU,
        (torch.nn.BatchNorm2d, torch.nn.ReLU): torch.nn.intrinsic.BNReLU2d,
        (torch.nn.BatchNorm3d, torch.nn.ReLU): torch.nn.intrinsic.BNReLU3d,
    }

    types = tuple(type(m) for m in mod_list)
    fuser_method = OP_LIST_TO_FUSER_METHOD.get(types, None)
    if fuser_method is None:
        raise NotImplementedError("Cannot fuse modules: {}".format(types))
    new_mod = [None] * len(mod_list)
    new_mod[0] = fuser_method(*mod_list)

    for i in range(1, len(mod_list)):
        new_mod[i] = torch.nn.Identity()
        new_mod[i].training = mod_list[0].training

    return new_mod

def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules):

    mod_list = []
    for item in modules_to_fuse:
        mod_list.append(_get_module(model, item))

    # Fuse list of modules
    new_mod_list = fuser_func(mod_list)

    # Replace original module list with fused module list
    for i, item in enumerate(modules_to_fuse):
        _set_module(model, item, new_mod_list[i])

[docs]def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules): r"""Fuses a list of modules into a single module Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, relu bn, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity. Arguments: model: Model containing the modules to be fused modules_to_fuse: list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse. inplace: bool specifying if fusion happens in place on the model, by default a new model is returned fuser_func: Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.quantization.fuse_known_modules Returns: model with fused modules. A new copy is created if inplace=True. Examples:: >>> m = myModel() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = myModel() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) """ if not inplace: model = copy.deepcopy(model) if all(isinstance(module_element, str) for module_element in modules_to_fuse): # Handle case of modules_to_fuse being a list _fuse_modules(model, modules_to_fuse, fuser_func) else: # Handle case of modules_to_fuse being a list of lists for module_list in modules_to_fuse: _fuse_modules(model, module_list, fuser_func) return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources