import datasets import nncf import openvino as ov import time from contextlib import contextmanager from jiwer import wer, wer_standardize from nncf.quantization.range_estimator import ( RangeEstimatorParameters, StatisticsCollectorParameters, StatisticsType, ) from optimum.intel import OVModelForSeq2SeqLM from optimum.intel.openvino.quantization import InferRequestWrapper from pathlib import Path from tqdm.auto import tqdm from typing import List, Dict from transformers import Pipeline, pipeline, PreTrainedTokenizer CALIBRATION_DATASET_SIZE = 10 def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_dataset_size: int) -> List[Dict]: calibration_data = [] ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past # Wrap decoder inference for data collection ov_decoder.request = InferRequestWrapper(ov_decoder.request, calibration_data, apply_caching=True) # Run inference for data collection try: calibration_dataset = datasets.load_dataset("jfleg", split="validation") calibration_dataset = calibration_dataset.shuffle(seed=42)[:calibration_dataset_size] for data_item in tqdm( calibration_dataset["sentence"], total=calibration_dataset_size, desc="Collecting calibration data", ): grammar_corrector_pipe_fp32(data_item) finally: ov_decoder.request = ov_decoder.request.request return calibration_data def quantize( grammar_corrector_pipe_fp32: Pipeline, core: ov.Core, quantized_model_path: Path, calibration_dataset_size: int, ): if quantized_model_path.exists(): print("Loading quantized model") quantized_model = core.read_model(model=quantized_model_path) else: calibration_data = collect_calibration_data(grammar_corrector_pipe_fp32, calibration_dataset_size) ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past quantized_model = nncf.quantize( ov_decoder.model, calibration_dataset=nncf.Dataset(calibration_data), subset_size=len(calibration_data), model_type=nncf.ModelType.TRANSFORMER, advanced_parameters=nncf.AdvancedQuantizationParameters( disable_bias_correction=True, # Disable bias correction because the model does not contain quantizable operations with bias activations_range_estimator_params=RangeEstimatorParameters( # Quantile statistic is employed due to outliers in some activations # This parameter was found useful by quantize_with_accuracy_control method max=StatisticsCollectorParameters(StatisticsType.QUANTILE) ), ), ) if not quantized_model_path.parent.exists(): quantized_model_path.parent.mkdir(parents=True) ov.save_model(quantized_model, quantized_model_path) return quantized_model def get_quantized_pipeline( grammar_corrector_pipe: Pipeline, grammar_corrector_tokenizer: PreTrainedTokenizer, core: ov.Core, grammar_corrector_dir: Path, quantized_model_path: Path, device: str, calibration_dataset_size=CALIBRATION_DATASET_SIZE, ): # Get quantized OV model quantized_model = quantize(grammar_corrector_pipe, core, quantized_model_path, calibration_dataset_size) # Load quantized model into grammar correction pipeline grammar_corrector_model_int8 = OVModelForSeq2SeqLM.from_pretrained(grammar_corrector_dir, device=device) grammar_corrector_model_int8.decoder_with_past.model = quantized_model grammar_corrector_model_int8.decoder_with_past.request = None grammar_corrector_model_int8.decoder_with_past._compile() grammar_corrector_pipe_int8 = pipeline( "text2text-generation", model=grammar_corrector_model_int8, tokenizer=grammar_corrector_tokenizer, ) return grammar_corrector_pipe_int8 def calculate_compression_rate(model_path_ov, model_path_ov_int8): model_size_fp32 = model_path_ov.with_suffix(".bin").stat().st_size / 1024 model_size_int8 = model_path_ov_int8.with_suffix(".bin").stat().st_size / 1024 print("Model footprint comparison:") print(f" * FP32 IR model size: {model_size_fp32:.2f} KB") print(f" * INT8 IR model size: {model_size_int8:.2f} KB") return model_size_fp32, model_size_int8 def calculate_inference_time_and_accuracy(grammar_corrector_pipe: Pipeline, test_subset_size: int): ground_truths = [] predictions = [] inference_time = [] test_dataset = datasets.load_dataset("jfleg", split="test").shuffle(seed=42)[:test_subset_size] zipped_dataset = zip(test_dataset["sentence"], test_dataset["corrections"]) for input_text, references in tqdm(zipped_dataset, total=test_subset_size, desc="Evaluation"): # For example, a sample pair may look like: # input_text: "For not use car . " # references: [ "Not for use with a car . ", "Do not use in the car . ", "Car not for use . "] start_time = time.perf_counter() corrected_text = grammar_corrector_pipe(input_text)[0]["generated_text"] end_time = time.perf_counter() delta_time = end_time - start_time ground_truths.extend(references) predictions.extend([corrected_text] * len(references)) inference_time.append(delta_time) word_accuracy = ( 1 - wer( ground_truths, predictions, reference_transform=wer_standardize, hypothesis_transform=wer_standardize, ) ) * 100 sum_inference_time = sum(inference_time) return sum_inference_time, word_accuracy