Shortcuts

Source code for torch.nn.quantized.modules


import torch
from torch.nn.modules.pooling import MaxPool2d

from .activation import ReLU, ReLU6, Hardswish, ELU
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
    InstanceNorm2d, InstanceNorm3d
from .conv import Conv1d, Conv2d, Conv3d
from .linear import Linear

from .functional_modules import FloatFunctional, QFunctional


[docs]class Quantize(torch.nn.Module): r"""Quantizes an incoming tensor Args: `scale`: scale of the output Quantized Tensor `zero_point`: zero_point of output Quantized Tensor `dtype`: data type of output Quantized Tensor Attributes: `scale`, `zero_point`, `dtype` Examples:: >>> t = torch.tensor([[1., -1.], [1., -1.]]) >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 >>> qm = Quantize(scale, zero_point, dtype) >>> qt = qm(t) >>> print(qt) tensor([[ 1., -1.], [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) """ def __init__(self, scale, zero_point, dtype): super(Quantize, self).__init__() self.register_buffer('scale', torch.tensor([scale])) self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.long)) self.dtype = dtype def forward(self, X): return torch.quantize_per_tensor(X, float(self.scale), int(self.zero_point), self.dtype) @staticmethod def from_float(mod): assert hasattr(mod, 'activation_post_process') scale, zero_point = mod.activation_post_process.calculate_qparams() return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype) def extra_repr(self): return 'scale={}, zero_point={}, dtype={}'.format(self.scale, self.zero_point, self.dtype)
[docs]class DeQuantize(torch.nn.Module): r"""Dequantizes an incoming tensor Examples:: >>> input = torch.tensor([[1., -1.], [1., -1.]]) >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 >>> qm = Quantize(scale, zero_point, dtype) >>> quantized_input = qm(input) >>> dqm = DeQuantize() >>> dequantized = dqm(quantized_input) >>> print(dequantized) tensor([[ 1., -1.], [ 1., -1.]], dtype=torch.float32) """ def __init__(self): super(DeQuantize, self).__init__() def forward(self, Xq): return Xq.dequantize() @staticmethod def from_float(mod): return DeQuantize()
__all__ = [ 'BatchNorm2d', 'BatchNorm3d', 'Conv1d', 'Conv2d', 'Conv3d', 'DeQuantize', 'Linear', 'MaxPool2d', 'Quantize', 'ReLU', 'ReLU6', 'Hardswish', 'ELU', 'LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', # Wrapper modules 'FloatFunctional', 'QFunctional', ]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources