File size: 5,924 Bytes
db5855f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import datasets
import nncf
import openvino as ov
import time

from contextlib import contextmanager
from jiwer import wer, wer_standardize
from nncf.quantization.range_estimator import (
    RangeEstimatorParameters,
    StatisticsCollectorParameters,
    StatisticsType,
)
from optimum.intel import OVModelForSeq2SeqLM
from optimum.intel.openvino.quantization import InferRequestWrapper
from pathlib import Path
from tqdm.auto import tqdm
from typing import List, Dict
from transformers import Pipeline, pipeline, PreTrainedTokenizer

CALIBRATION_DATASET_SIZE = 10


def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_dataset_size: int) -> List[Dict]:
    calibration_data = []
    ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past

    # Wrap decoder inference for data collection
    ov_decoder.request = InferRequestWrapper(ov_decoder.request, calibration_data, apply_caching=True)

    # Run inference for data collection
    try:
        calibration_dataset = datasets.load_dataset("jfleg", split="validation")
        calibration_dataset = calibration_dataset.shuffle(seed=42)[:calibration_dataset_size]
        for data_item in tqdm(
            calibration_dataset["sentence"],
            total=calibration_dataset_size,
            desc="Collecting calibration data",
        ):
            grammar_corrector_pipe_fp32(data_item)
    finally:
        ov_decoder.request = ov_decoder.request.request

    return calibration_data


def quantize(

    grammar_corrector_pipe_fp32: Pipeline,

    core: ov.Core,

    quantized_model_path: Path,

    calibration_dataset_size: int,

):
    if quantized_model_path.exists():
        print("Loading quantized model")
        quantized_model = core.read_model(model=quantized_model_path)
    else:
        calibration_data = collect_calibration_data(grammar_corrector_pipe_fp32, calibration_dataset_size)
        ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past
        quantized_model = nncf.quantize(
            ov_decoder.model,
            calibration_dataset=nncf.Dataset(calibration_data),
            subset_size=len(calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            advanced_parameters=nncf.AdvancedQuantizationParameters(
                disable_bias_correction=True,
                # Disable bias correction because the model does not contain quantizable operations with bias
                activations_range_estimator_params=RangeEstimatorParameters(
                    # Quantile statistic is employed due to outliers in some activations
                    # This parameter was found useful by quantize_with_accuracy_control method
                    max=StatisticsCollectorParameters(StatisticsType.QUANTILE)
                ),
            ),
        )

        if not quantized_model_path.parent.exists():
            quantized_model_path.parent.mkdir(parents=True)
        ov.save_model(quantized_model, quantized_model_path)

    return quantized_model


def get_quantized_pipeline(

    grammar_corrector_pipe: Pipeline,

    grammar_corrector_tokenizer: PreTrainedTokenizer,

    core: ov.Core,

    grammar_corrector_dir: Path,

    quantized_model_path: Path,

    device: str,

    calibration_dataset_size=CALIBRATION_DATASET_SIZE,

):
    # Get quantized OV model
    quantized_model = quantize(grammar_corrector_pipe, core, quantized_model_path, calibration_dataset_size)

    # Load quantized model into grammar correction pipeline
    grammar_corrector_model_int8 = OVModelForSeq2SeqLM.from_pretrained(grammar_corrector_dir, device=device)
    grammar_corrector_model_int8.decoder_with_past.model = quantized_model
    grammar_corrector_model_int8.decoder_with_past.request = None
    grammar_corrector_model_int8.decoder_with_past._compile()
    grammar_corrector_pipe_int8 = pipeline(
        "text2text-generation",
        model=grammar_corrector_model_int8,
        tokenizer=grammar_corrector_tokenizer,
    )

    return grammar_corrector_pipe_int8


def calculate_compression_rate(model_path_ov, model_path_ov_int8):
    model_size_fp32 = model_path_ov.with_suffix(".bin").stat().st_size / 1024
    model_size_int8 = model_path_ov_int8.with_suffix(".bin").stat().st_size / 1024
    print("Model footprint comparison:")
    print(f"    * FP32 IR model size: {model_size_fp32:.2f} KB")
    print(f"    * INT8 IR model size: {model_size_int8:.2f} KB")
    return model_size_fp32, model_size_int8


def calculate_inference_time_and_accuracy(grammar_corrector_pipe: Pipeline, test_subset_size: int):
    ground_truths = []
    predictions = []
    inference_time = []

    test_dataset = datasets.load_dataset("jfleg", split="test").shuffle(seed=42)[:test_subset_size]
    zipped_dataset = zip(test_dataset["sentence"], test_dataset["corrections"])
    for input_text, references in tqdm(zipped_dataset, total=test_subset_size, desc="Evaluation"):
        # For example, a sample pair may look like:
        # input_text: "For not use car . "
        # references: [ "Not for use with a car . ", "Do not use in the car . ", "Car not for use . "]

        start_time = time.perf_counter()
        corrected_text = grammar_corrector_pipe(input_text)[0]["generated_text"]
        end_time = time.perf_counter()
        delta_time = end_time - start_time

        ground_truths.extend(references)
        predictions.extend([corrected_text] * len(references))
        inference_time.append(delta_time)

    word_accuracy = (
        1
        - wer(
            ground_truths,
            predictions,
            reference_transform=wer_standardize,
            hypothesis_transform=wer_standardize,
        )
    ) * 100
    sum_inference_time = sum(inference_time)
    return sum_inference_time, word_accuracy