Source code for torch.nn.qat.modules.conv
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
from torch.nn.intrinsic import ConvReLU2d
[docs]class Conv2d(nn.Conv2d):
r"""
A Conv2d module attached with FakeQuantize modules for both output
activation and weight, used for quantization aware training.
We adopt the same interface as `torch.nn.Conv2d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
for documentation.
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
activation_post_process: fake quant module for output activation
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros', qconfig=None):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.activation_post_process = qconfig.activation()
self.weight_fake_quant = qconfig.weight()
def forward(self, input):
return self.activation_post_process(
self._conv_forward(input, self.weight_fake_quant(self.weight)))
[docs] @classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
if type(mod) == ConvReLU2d:
mod = mod[0]
qconfig = mod.qconfig
qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
stride=mod.stride, padding=mod.padding, dilation=mod.dilation,
groups=mod.groups, bias=mod.bias is not None,
padding_mode=mod.padding_mode, qconfig=qconfig)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv