Shortcuts

Source code for torch.quantization.quantize

from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import itertools
import warnings

import torch
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.quantized as nnq

from .default_mappings import (DEFAULT_DYNAMIC_MODULE_MAPPING,
                               DEFAULT_MODULE_MAPPING,
                               DEFAULT_QAT_MODULE_MAPPING,
                               DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST)
from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig

def _propagate_qconfig_helper(module, qconfig_dict, white_list=None,
                              qconfig_parent=None, prefix=''):
    r"""This is a helper function for `propagate_qconfig_`

    Args:
        module: input module
        qconfig_dict: dictionary that maps from name of submodule to quantization
                     configuration
        white_list: list of quantizable modules
        qconfig_parent: quantization config of parent module, we will fallback to
                       this config when there is no specified config for current
                       module
        prefix: corresponding prefix of the current module, used as key in
                qconfig_dict

    Return:
        None, module is modified inplace with qconfig attached
    """
    # TODO: Add test
    if white_list is None:
        white_list = DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST

    module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
    module_qconfig = getattr(module, 'qconfig', module_qconfig)

    if type(module) in white_list:
        module.qconfig = module_qconfig
    for name, child in module.named_children():
        module_prefix = prefix + '.' + name if prefix else name
        _propagate_qconfig_helper(child, qconfig_dict, white_list,
                                  module_qconfig, module_prefix)

# TODO(jerryzh): expose white_list
[docs]def propagate_qconfig_(module, qconfig_dict=None, white_list=None): r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module qconfig_dict: dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute) Return: None, module is modified inplace with qconfig attached """ if qconfig_dict is None: qconfig_dict = {} _propagate_qconfig_helper(module, qconfig_dict, white_list)
def _observer_forward_hook(self, input, output): r"""Forward hook that calls observer on the output """ return self.activation_post_process(output)
[docs]def add_observer_(module, non_leaf_module_list=None, device=None): r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that has a valid qconfig attribute. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize device: parent device, if any non_leaf_module_list: list of non-leaf modules we want to add observer Return: None, module is modified inplace with added observer modules and forward_hooks """ # respect device affinity when adding observers if device is None: devices = get_unique_devices_(module) assert len(devices) <= 1, ( "add_observer_ only works with cpu or single-device CUDA modules, " "but got devices {}".format(devices) ) device = next(iter(devices)) if len(devices) > 0 else None for child in module.children(): if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional: if hasattr(child, 'qconfig') and child.qconfig is not None: activation = child.qconfig.activation() if device is not None: activation.to(device) child.activation_post_process = activation elif non_leaf_module_list is not None and type(child) in non_leaf_module_list: if hasattr(child, 'qconfig') and child.qconfig is not None: child.add_module('activation_post_process', child.qconfig.activation()) child.register_forward_hook(_observer_forward_hook) else: add_observer_(child, non_leaf_module_list, device) # Insert observers only for leaf nodes, note that this observer is for # the output of the module, for input QuantStub will observe them if hasattr(module, 'qconfig') and module.qconfig is not None and \ len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential): # observer and hook will be gone after we swap the module activation = module.qconfig.activation() if device is not None: activation.to(device) module.add_module('activation_post_process', activation) module.register_forward_hook(_observer_forward_hook)
def get_unique_devices_(module): return {p.device for p in module.parameters()} | \ {p.device for p in module.buffers()}
[docs]def add_quant_dequant(module): r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: Either the inplace modified module with submodules wrapped in `QuantWrapper` based on qconfig or a new `QuantWrapper` module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it. """ if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig: return QuantWrapper(module) for name, child in module.named_children(): module._modules[name] = add_quant_dequant(child) return module
[docs]def prepare(model, inplace=False, white_list=DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST, observer_non_leaf_module_list=None): r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. The model will be attached with observer or fake quant modules, and qconfig will be propagated. Args: model: input model to be modified in-place inplace: carry out model transformations in-place, the original module is mutated white_list: list of quantizable modules observer_non_leaf_module_list: list of non-leaf modules we want to add observer """ if not inplace: model = copy.deepcopy(model) propagate_qconfig_(model, qconfig_dict=None, white_list=white_list) # sanity check common API misusage if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): warnings.warn("None of the submodule got qconfig applied. Make sure you " "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") add_observer_(model, observer_non_leaf_module_list) return model
def _remove_qconfig(module): r"""Clean up the qconfig left in the module so that new qconfig can be propagated. Args: module: module to be cleaned up """ for child in module.children(): _remove_qconfig(child) if hasattr(module, "qconfig"): del module.qconfig
[docs]def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """ if mapping is None: mapping = DEFAULT_MODULE_MAPPING if not inplace: model = copy.deepcopy(model) model.eval() prepare(model, inplace=True) run_fn(model, run_args) convert(model, mapping, inplace=True) _remove_qconfig(model) return model
[docs]def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False): r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. If `qconfig` is provided, the `dtype` argument is ignored. Args: module: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfigDynamic instances. - A set of types and/or submodule names to apply dynamic quantization to, in which case the `dtype` argument is used to specify the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """ if qconfig_spec is None: if dtype == torch.qint8: qconfig_spec = { nn.Linear : default_dynamic_qconfig, nn.LSTM : default_dynamic_qconfig, nn.GRU : default_dynamic_qconfig, nn.LSTMCell : default_dynamic_qconfig, nn.RNNCell : default_dynamic_qconfig, nn.GRUCell : default_dynamic_qconfig, } elif dtype == torch.float16: qconfig_spec = { nn.Linear : float16_dynamic_qconfig, nn.LSTM : float16_dynamic_qconfig, nn.GRU : float16_dynamic_qconfig, nn.LSTMCell : float16_dynamic_qconfig, nn.RNNCell : float16_dynamic_qconfig, nn.GRUCell : float16_dynamic_qconfig, } else: raise ValueError( "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype)) elif isinstance(qconfig_spec, set): if dtype is torch.qint8: default_qconfig = default_dynamic_qconfig elif dtype is torch.float16: default_qconfig = float16_dynamic_qconfig else: raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) if mapping is None: mapping = DEFAULT_DYNAMIC_MODULE_MAPPING if not inplace: model = copy.deepcopy(model) model.eval() propagate_qconfig_(model, qconfig_spec) convert(model, mapping, inplace=True) _remove_qconfig(model) return model
[docs]def prepare_qat(model, mapping=None, inplace=False): r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. Args: model: input model to be modified in-place mapping: dictionary that maps float modules to quantized modules to be replaced. inplace: carry out model transformations in-place, the original module is mutated """ if mapping is None: mapping = DEFAULT_QAT_MODULE_MAPPING model = prepare(model, inplace=inplace) convert(model, mapping, inplace=True) return model
[docs]def quantize_qat(model, run_fn, run_args, inplace=False): r"""Do quantization aware training and output a quantized model Args: model: input model run_fn: a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop run_args: positional arguments for `run_fn` Return: Quantized model. """ if not inplace: model = copy.deepcopy(model) model.train() prepare_qat(model, inplace=True) run_fn(model, run_args) convert(model, inplace=True) return model
[docs]def convert(module, mapping=None, inplace=False): r"""Converts the float module with observers (where we can get quantization parameters) to a quantized module. Args: module: calibrated module with observers mapping: a dictionary that maps from float module type to quantized module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated """ if mapping is None: mapping = DEFAULT_MODULE_MAPPING if not inplace: module = copy.deepcopy(module) reassign = {} # TODO(jerryzh): remove after deciding on the impl of intrinsic modules # This is required because intrinsic modules right now are implemented as # nn.Sequential and we don't want to swap their constituents SWAPPABLE_MODULES = (nni.ConvBn2d, nni.ConvBnReLU2d, nni.LinearReLU, nni.BNReLU2d, nni.BNReLU3d, nni.ConvBn1d, nni.ConvReLU1d, nni.ConvBnReLU1d, nni.ConvReLU2d, nni.ConvReLU3d) for name, mod in module.named_children(): if type(mod) not in SWAPPABLE_MODULES: convert(mod, mapping, inplace=True) reassign[name] = swap_module(mod, mapping) for key, value in reassign.items(): module._modules[key] = value return module
[docs]def swap_module(mod, mapping): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. Args: mod: input module mapping: a dictionary that maps from nn module to nnq module Return: The corresponding quantized module of `mod` """ new_mod = mod # Always replace dequantstub with dequantize if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub: if type(mod) in mapping: # respect device affinity when swapping modules devices = get_unique_devices_(mod) assert len(devices) <= 1, ( "swap_module only works with cpu or single-device CUDA modules, " "but got devices {}".format(devices) ) device = next(iter(devices)) if len(devices) > 0 else None new_mod = mapping[type(mod)].from_float(mod) if device: new_mod.to(device) return new_mod
[docs]def get_observer_dict(mod, target_dict, prefix=""): r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """ def get_prefix(prefix): return prefix if prefix == "" else prefix + '.' if hasattr(mod, 'activation_post_process'): target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process for name, child in mod.named_children(): module_prefix = get_prefix(prefix) + name if prefix else name get_observer_dict(child, target_dict, module_prefix)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources