from __future__ import absolute_import, division, print_function, unicode_literals
import math
import torch
import torch.nn as nn
import torch.nn.intrinsic
import torch.nn.qat as nnqat
import torch.nn.functional as F
from torch.nn import init
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter
class _ConvBnNd(nn.modules.conv._ConvNd):
_version = 2
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True)
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if self.training:
if freeze_bn:
self.freeze_bn_stats()
else:
self.update_bn_stats()
else:
self.freeze_bn_stats()
def reset_running_stats(self):
self.bn.reset_running_stats()
def reset_bn_parameters(self):
self.bn.reset_running_stats()
init.uniform_(self.bn.weight)
init.zeros_(self.bn.bias)
# note: below is actully for conv, not BN
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def reset_parameters(self):
super(_ConvBnNd, self).reset_parameters()
def update_bn_stats(self):
self.freeze_bn = False
self.bn.training = True
return self
def freeze_bn_stats(self):
self.freeze_bn = True
self.bn.training = False
return self
def _forward(self, input):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape([-1, 1, 1, 1]))
# this does not include the conv bias
conv = self._conv_forward(input, scaled_weight)
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1])
conv = self.bn(conv_orig)
return conv
def extra_repr(self):
# TODO(jerryzh): extend
return super(_ConvBnNd, self).extra_repr()
def forward(self, input):
return self.activation_post_process(self._forward(input))
def train(self, mode=True):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self.training = mode
if not self.freeze_bn:
for module in self.children():
module.train(mode)
return self
# ===== Serialization version history =====
#
# Version 1/None
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- gamma : Tensor
# |--- beta : Tensor
# |--- running_mean : Tensor
# |--- running_var : Tensor
# |--- num_batches_tracked : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- bn : Module
# |--- weight : Tensor (moved from v1.self.gamma)
# |--- bias : Tensor (moved from v1.self.beta)
# |--- running_mean : Tensor (moved from v1.self.running_mean)
# |--- running_var : Tensor (moved from v1.self.running_var)
# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version == 1:
# BN related parameters and buffers were moved into the BN module for v2
v2_to_v1_names = {
'bn.weight': 'gamma',
'bn.bias': 'beta',
'bn.running_mean': 'running_mean',
'bn.running_var': 'running_var',
'bn.num_batches_tracked': 'num_batches_tracked',
}
for v2_name, v1_name in v2_to_v1_names.items():
if prefix + v1_name in state_dict:
state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
state_dict.pop(prefix + v1_name)
elif strict:
missing_keys.append(prefix + v2_name)
super(_ConvBnNd, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups, conv.bias is not None,
conv.padding_mode,
bn.eps, bn.momentum,
False,
qconfig)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.bn.weight = bn.weight
qat_convbn.bn.bias = bn.bias
qat_convbn.bn.running_mean = bn.running_mean
qat_convbn.bn.running_var = bn.running_var
qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn
[docs]class ConvBn2d(_ConvBnNd, nn.Conv2d):
r"""
A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
attached with FakeQuantize modules for both output activation and weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv2d` and
:class:`torch.nn.BatchNorm2d`.
Implementation details: https://arxiv.org/pdf/1806.08342.pdf section 3.2.2
Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
activation_post_process: fake quant module for output activation
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvBn2d
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
_ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig)
[docs]class ConvBnReLU2d(ConvBn2d):
r"""
A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
attached with FakeQuantize modules for both output activation and weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv2d` and
:class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
Implementation details: https://arxiv.org/pdf/1806.08342.pdf
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
observer: fake quant module for output activation, it's called observer
to align with post training flow
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvBnReLU2d
def __init__(self,
# Conv2d args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias,
padding_mode, eps, momentum,
freeze_bn,
qconfig)
def forward(self, input):
return self.activation_post_process(F.relu(ConvBn2d._forward(self, input)))
@classmethod
def from_float(cls, mod, qconfig=None):
return super(ConvBnReLU2d, cls).from_float(mod, qconfig)
[docs]class ConvReLU2d(nnqat.Conv2d):
r"""
A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for both output activation and weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
activation_post_process: fake quant module for output activation
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
qconfig=None):
super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode,
qconfig=qconfig)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
def forward(self, input):
return self.activation_post_process(F.relu(
self._conv_forward(input, self.weight_fake_quant(self.weight))))
@classmethod
def from_float(cls, mod, qconfig=None):
return super(ConvReLU2d, cls).from_float(mod, qconfig)
def update_bn_stats(mod):
if type(mod) in set([ConvBnReLU2d, ConvBn2d]):
mod.update_bn_stats()
def freeze_bn_stats(mod):
if type(mod) in set([ConvBnReLU2d, ConvBn2d]):
mod.freeze_bn_stats()