Skip to content

Commit

Permalink
ReadMe update and QATv2 Refactoring (#339)
Browse files Browse the repository at this point in the history
* ReadMe update and QATv2 Refactoring

* Distributed mode improvements, quantile function replacement

* Revert local_rank change for pre_qat

* Add optimize_ddp flag to test()

* QATv2.md update
  • Loading branch information
oguzhanbsolak authored Dec 6, 2024
1 parent 1a45d1f commit 0fa1fb1
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 65 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ADI MAX78000/MAX78002 Model Training and Synthesis

August 27, 2024
November 7, 2024

**Note: This branch requires PyTorch 2. Please see the archive-1.8 branch for PyTorch 1.8 support. [KNOWN_ISSUES](KNOWN_ISSUES.txt) contains a list of known issues.**

Expand Down Expand Up @@ -1636,7 +1636,7 @@ Quantization-aware training can be <u>disabled</u> by specifying `--qat-policy N

The proper choice of `start_epoch` is important for achieving good results, and the default policy’s `start_epoch` may be much too small. As a rule of thumb, set `start_epoch` to a very high value (e.g., 1000) to begin, and then observe where in the training process the model stops learning. This epoch can be used as `start_epoch`, and the final network metrics (after an additional number of epochs) should be close to the non-QAT metrics. *Additionally, ensure that the learning rate after the `start_epoch` epoch is relatively small.*

For more information, please also see [Quantization](#quantization).
For more information, please also see [Quantization](#quantization) and [QATv2](https://github.com/analogdevicesinc/ai8x-training/blob/develop/docs/QATv2.md).

#### Batch Normalization

Expand Down
Binary file modified README.pdf
Binary file not shown.
63 changes: 23 additions & 40 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from torch.autograd import Function
from torch.fx import symbolic_trace

from tqdm import tqdm

import devices

dev = None
Expand Down Expand Up @@ -435,45 +437,6 @@ def forward(self, _, x): # pylint: disable=arguments-differ
return x


def interp(x, xp, fp, method='linear'):
"""
Simple PyTorch implementation of `np.interp`.
1D data only, length must be 2 or greater.
`method` must be "linear" or "lower".
"""
# Find the index
n = len(xp) - 1
if n == 0:
return fp[0]
if x == 1.:
return fp[-1]
i = torch.clip(torch.searchsorted(xp, x, side='right').unsqueeze(0), 1, n) - 1
# Calculate fractional index
if method == 'linear':
g = x * n - i
else:
assert method == 'lower'
g = .0
# Interpolate result
return fp[i] + g * (fp[i + 1] - fp[i])


def quantile(x, q, method='linear'):
"""
Ersatz quantile function in PyTorch that works with torch.compile().
1D data only, len(x) must be 2 or greater.
`method` must be "linear" or "lower".
"""
x = x.flatten()
n = len(x)
return interp(
q,
torch.linspace(1 / (2 * n), (2 * n - 1) / (2 * n), n, device=x.device),
torch.sort(x)[0],
method,
).squeeze(0)


class OutputShiftLimit(nn.Module):
"""
Calculate the clamped output shift when adjusting during quantization-aware training.
Expand All @@ -484,7 +447,7 @@ def __init__(self, shift_quantile=1.0):

def forward(self, x, _): # pylint: disable=arguments-differ
"""Forward prop"""
limit = quantile(x.abs(), self.shift_quantile)
limit = torch.quantile(x.abs(), self.shift_quantile)
return -(1./limit).log2().floor().clamp(min=-15., max=15.)


Expand Down Expand Up @@ -2265,6 +2228,26 @@ def apply_scales(model):
requires_grad=False)


@torch.no_grad()
def stat_collect(train_loader, model, args):
"""Collect statistics for quantization aware training"""
model.eval()
for inputs, _ in tqdm(train_loader):
inputs = inputs.to(args.device)
model(inputs)


def pre_qat(model, train_loader, args, qat_policy):
"""
Prepare the model for quantization aware training
"""
init_hist(model)
stat_collect(train_loader, model, args)
init_threshold(model, qat_policy["outlier_removal_z_score"])
release_hist(model)
apply_scales(model)


def init_hist(model):
"""
Place forward hooks to collect histograms of activations
Expand Down
Binary file added docs/QATv2-Adds.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/QATv2-Apply Scales.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/QATv2-Concats.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/QATv2-Layer Sharing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
56 changes: 56 additions & 0 deletions docs/QATv2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Quantization Aware Training (QAT)

This document aims to explain Quantization Aware Training framework for MAX7800x series microcontrollers. QAT for MAX7800x consists of four main stages:
Activation Statistics Collection, Activation Threshold Determination, Scale Adjustments, and Weights Quantization.

## Activation Statistics Collection

To train a quantization-aware model, the first step is to collect activation statistics. The activation statistics are collected by running the model on the training dataset. The training script includes the activation statistics collection step(stat_collect() function). The activation statistics are the histogram of the activations for each layer.

## Activation Threshold Determination

The collected statistics are use to determine the activation thresholds. To do this, first, an outlier removal step based on z-score is applied to the activation statistics. The default z-score is 8.0, and it can be changed by defining a z-score on the qat policy file. Then, an iterative algorithm [1] that minimizes the quantization error by adjusting the threshold to determine the full activation range. This algorithm finds a balance point in the tradeoff between range and resolution. Scales are calculated as powers of two, making the scaling-down operation more computationally efficient by defining them as bit shift operations at the edge hardware.

## Scale Adjustments

To implement the threshold-based quantization, the scales of the layers are adjusted. The scales are adjusted based on the type of operation that is performed on the layers. The scale adjustments are made for residual additions, concatenations, and layer sharing. Figure 1. shows the scale adjustments for residual additions. In the figure, Layer1 and Layer2 are layers that are added together. The scale of the residual addition is selected as the scale of the layers that are connected to the residual addition.

<img src="QATv2-Adds.png" style="zoom: 50%;" />

Figure 1. Scale Adjustments for Residual Additions

Figure 2. shows the scale adjustments for concatenations. In the figure, Layer1 and Layer2 are layers that are concatenated. The maximum scale of the layers is selected as the scale for the concatenated layer.

<img src="QATv2-Concats.png" style="zoom: 50%;" />

Figure 2. Scale Adjustments for Concatenations

Figure 3. shows the scale adjustments for layer sharing. In the figure, Layer1, Layer2 and Layer3 are layers that share weights. The maximum scale of the layers is selected as the scale for the shared layer.

<img src="QATv2-Layer Sharing.png" style="zoom: 50%;" />

Figure 3. Scale Adjustments for Layer Sharing

Figure 4. provides a simplified diagram showing how the scaling-down and scale carry-over operations are implemented. In the diagram, Layer1 and Layer2 represent linear layers with weights w1 and w2, and biases b1 and b2. S1 and S2 represent the activation scales, which are calculated as previously described. As shown, the output of Layer1 is scaled down using the S1 threshold, and the scale carry-over operation is achieved by adjusting
Layer2’s scale and dividing its biases accordingly.

<img src="QATv2-Apply Scales.png" style="zoom: 50%;" />

Figure 4. Scaling-down and Scale Carry Over Diagram

## Weights Quantization

After determining the activation thresholds and scales, the next step is to quantize the weights. The weights are quantized using the QAT framework, which is based on the method proposed by Jacob et al. [2]. While training the model, weights and biasses are fake quantized to integers. The fake quantization is done by quantizing the weights and biases to integers and then dequantizing them back to floating-point numbers.

## Deploying the Quantized Model

The output shifts from the weights quantization are merged with the scale shifts from the activation quantization to form the final shifts of the quantized model. When the model is deployed, the final layer's scale should be restored to the original scale by multiplying the outputs with the final layer's scale. In the auto-generated C code, the cnn_unload() function is responsible for restoring the final layer's scale. If the cnn_unload() function is not used, the final layer's scale should be restored manually by multiplying the outputs with the final layer's scale. The final layer's scale values can be found at the cnn.c file in the comments section.




## References

[1] [Habi, Hai Victor, et al. "Hptq: Hardware-friendly post training quantization." arXiv preprint arXiv:2109.09113 (2021).](https://arxiv.org/abs/2109.09113)

[2] [Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., ... & Kalenichenko, D. (2018). Quantization and training of neural networks for efficient integer-arithmetic-only inference. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 2704-2713).](https://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)
58 changes: 35 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import CustomKNN
from torchmetrics.detection import MeanAveragePrecision
from tqdm import tqdm

import ai8x
import ai8x_nas
Expand Down Expand Up @@ -608,15 +607,10 @@ def flush(self):

# Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
ai8x.fuse_bn_layers(model)
ai8x.init_hist(model)

msglogger.info('Collecting statistics for quantization aware training (QAT)...')
stat_collect(train_loader, model, args)

ai8x.init_threshold(model, qat_policy["outlier_removal_z_score"])
ai8x.release_hist(model)

ai8x.apply_scales(model)
ai8x.pre_qat(model, train_loader, args, qat_policy)

# Update the optimizer to reflect fused batchnorm layers
optimizer = ai8x.update_optimizer(model, optimizer)
Expand Down Expand Up @@ -646,6 +640,12 @@ def flush(self):
torch._dynamo.reset() # pylint: disable=protected-access
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend)

# TODO: Optimize DDP is currently not supported with QAT.
# Once pytorch supports DDP with higher order ops,
# we can enable optimize DDP with QAT.
# https://github.com/pytorch/pytorch/issues/104674.
torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
Expand Down Expand Up @@ -740,7 +740,7 @@ def flush(self):
if not args.dr:
test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt")
test(test_loader, model, criterion, [pylogger], args=args, mode="best",
ckpt_name=checkpoint_name)
ckpt_name=checkpoint_name, local_rank=local_rank)

if args.copy_output_folder and local_rank <= 0:
msglogger.info('Copying output folder to: %s', args.copy_output_folder)
Expand Down Expand Up @@ -850,15 +850,6 @@ def create_nas_kd_policy(model, compression_scheduler, epoch, next_state_start_e
' | '.join([f'{val:.2f}' for val in dlw]))


@torch.no_grad()
def stat_collect(train_loader, model, args):
"""Collect statistics for quantization aware training"""
model.eval()
for inputs, _ in tqdm(train_loader):
inputs = inputs.to(args.device)
model(inputs)


def train(train_loader, model, criterion, optimizer, epoch,
compression_scheduler, loggers, args, loss_optimizer=None):
"""Training loop for one epoch."""
Expand Down Expand Up @@ -1082,19 +1073,40 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non
return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger)


def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None):
def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None, local_rank=0):
"""Model Test"""
assert msglogger is not None
if mode == 'ckpt':
msglogger.info('--- test (ckpt) ---------------------')
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)
else:
msglogger.info('--- test (best) ---------------------')
if ckpt_name is None:
best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
else:
best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
model = apputils.load_lean_checkpoint(model, best_ckpt_path)
model, dynamo, ddp = model_wrapper.unwrap(model)
if local_rank <= 0:
if ckpt_name is None:
best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
else:
best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
model = apputils.load_lean_checkpoint(model, best_ckpt_path)

if ddp:
model = DistributedDataParallel(
model,
device_ids=[local_rank] if args.device == 'cuda' else None,
output_device=local_rank if args.device == 'cuda' else None,
)

if dynamo:
torch._dynamo.reset() # pylint: disable=protected-access
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend)
torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
torch._dynamo.config.cache_size_limit, # pylint: disable=protected-access
)

top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)

return top1, top5, vloss, mAP
Expand Down

0 comments on commit 0fa1fb1

Please sign in to comment.