Spaces:
Runtime error
Runtime error
"""⭐ Text Classification with Optimum and ONNXRuntime | |
Streamlit application to classify text using multiple models. | |
Author: | |
- @ChainYo - https://github.com/ChainYo | |
""" | |
import plotly | |
import plotly.figure_factory as ff | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from pathlib import Path | |
from time import sleep | |
from typing import Dict, List, Union | |
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer | |
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig | |
from optimum.pipelines import pipeline as ort_pipeline | |
from transformers import BertTokenizer, BertForSequenceClassification | |
from transformers import pipeline as pt_pipeline | |
from utils import calculate_inference_time | |
HUB_MODEL_PATH = "yiyanghkust/finbert-tone" | |
BASE_PATH = Path("models") | |
ONNX_MODEL_PATH = BASE_PATH.joinpath("model.onnx") | |
OPTIMIZED_BASE_PATH = BASE_PATH.joinpath("optimized") | |
OPTIMIZED_MODEL_PATH = OPTIMIZED_BASE_PATH.joinpath("model-optimized.onnx") | |
QUANTIZED_BASE_PATH = BASE_PATH.joinpath("quantized") | |
QUANTIZED_MODEL_PATH = QUANTIZED_BASE_PATH.joinpath("model-quantized.onnx") | |
VAR2LABEL = { | |
"pt_pipeline": "PyTorch", | |
"ort_pipeline": "ONNXRuntime", | |
"ort_optimized_pipeline": "ONNXRuntime (Optimized)", | |
"ort_quantized_pipeline": "ONNXRuntime (Quantized)", | |
} | |
# Check if repositories exist, if not create them | |
BASE_PATH.mkdir(exist_ok=True) | |
QUANTIZED_BASE_PATH.mkdir(exist_ok=True) | |
OPTIMIZED_BASE_PATH.mkdir(exist_ok=True) | |
def get_timers( | |
samples: Union[List[str], str], exp_number: int, only_mean: bool = False | |
) -> Dict[str, float]: | |
""" | |
Calculate inference time for each model for a given sample or list of samples. | |
Parameters | |
---------- | |
samples : Union[List[str], str] | |
Sample or list of samples to calculate inference time for. | |
exp_number : int | |
Number of experiments to run. | |
Returns | |
------- | |
Dict[str, float] | |
Dictionary of inference times for each model for the given samples. | |
""" | |
if isinstance(samples, str): | |
samples = [samples] | |
timers: Dict[str, float] = {} | |
for model in VAR2LABEL.keys(): | |
time_buffer = [] | |
st.session_state["pipeline"] = load_pipeline(model) | |
for _ in range(exp_number): | |
with calculate_inference_time(time_buffer): | |
st.session_state["pipeline"](samples) | |
timers[VAR2LABEL[model]] = np.mean(time_buffer) if only_mean else time_buffer | |
return timers | |
def get_plot(timers: Dict[str, Union[float, List[float]]]) -> plotly.graph_objs.Figure: | |
""" | |
Plot the inference time for each model. | |
Parameters | |
---------- | |
timers : Dict[str, Union[float, List[float]]] | |
Dictionary of inference times for each model. | |
""" | |
data = pd.DataFrame.from_dict(timers, orient="columns") | |
colors = ["#84353f", "#b4524b", "#f47e58", "#ffbe67"] | |
fig = ff.create_distplot( | |
[data[col] for col in data.columns], data.columns, bin_size=0.001, colors=colors, show_curve=False | |
) | |
fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples") | |
return fig | |
def load_pipeline(pipeline_name: str) -> None: | |
""" | |
Load a pipeline for a given model. | |
Parameters | |
---------- | |
pipeline_name : str | |
Name of the pipeline to load. | |
""" | |
if pipeline_name == "pt_pipeline": | |
model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3) | |
pipeline = pt_pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model) | |
elif pipeline_name == "ort_pipeline": | |
model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True) | |
if not ONNX_MODEL_PATH.exists(): | |
model.save_pretrained(ONNX_MODEL_PATH) | |
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model) | |
elif pipeline_name == "ort_optimized_pipeline": | |
if not OPTIMIZED_MODEL_PATH.exists(): | |
optimization_config = OptimizationConfig(optimization_level=99) | |
optimizer = ORTOptimizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification") | |
optimizer.export(ONNX_MODEL_PATH, OPTIMIZED_MODEL_PATH, optimization_config=optimization_config) | |
optimizer.model.config.save_pretrained(OPTIMIZED_BASE_PATH) | |
model = ORTModelForSequenceClassification.from_pretrained( | |
OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name | |
) | |
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model) | |
elif pipeline_name == "ort_quantized_pipeline": | |
if not QUANTIZED_MODEL_PATH.exists(): | |
quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) | |
quantizer = ORTQuantizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification") | |
quantizer.export(ONNX_MODEL_PATH, QUANTIZED_MODEL_PATH, quantization_config=quantization_config) | |
quantizer.model.config.save_pretrained(QUANTIZED_BASE_PATH) | |
model = ORTModelForSequenceClassification.from_pretrained( | |
QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name | |
) | |
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model) | |
print(type(pipeline)) | |
return pipeline | |
st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐") | |
st.title("⭐ Optimum Text Classification") | |
st.subheader("Classify financial news tone with 🤗 Optimum and ONNXRuntime") | |
st.markdown(""" | |
[![GitHub](https://img.shields.io/badge/-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/ChainYo) | |
[![HuggingFace](https://img.shields.io/badge/-yellow.svg?style=for-the-badge&logo=)](https://huggingface.co/ChainYo) | |
[![LinkedIn](https://img.shields.io/badge/-%230077B5.svg?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/thomas-chaigneau-dev/) | |
[![Discord](https://img.shields.io/badge/Chainyo%233610-%237289DA.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/) | |
""") | |
with st.expander("⭐ Details", expanded=True): | |
st.markdown( | |
""" | |
This app is a **demo** of the [🤗 Optimum Text Classification](https://huggingface.co/docs/optimum/onnxruntime/modeling_ort#optimum-inference-with-onnx-runtime) pipeline. | |
We aim to compare the original pipeline with the ONNXRuntime pipeline. | |
We use the [Finbert-Tone](https://huggingface.co/yiyanghkust/finbert-tone) model to classify financial news tone for the demo. | |
You can enter multiple sentences to classify them by separating them with a `; (semicolon)`. | |
""" | |
) | |
if "init_models" not in st.session_state: | |
st.session_state["init_models"] = True | |
if st.session_state["init_models"]: | |
with st.spinner(text="Loading files and models..."): | |
loading_logs = st.empty() | |
with loading_logs.container(): | |
BASE_PATH.mkdir(exist_ok=True) | |
QUANTIZED_BASE_PATH.mkdir(exist_ok=True) | |
OPTIMIZED_BASE_PATH.mkdir(exist_ok=True) | |
if "tokenizer" not in st.session_state: | |
tokenizer = BertTokenizer.from_pretrained(HUB_MODEL_PATH) | |
st.session_state["tokenizer"] = tokenizer | |
st.text("✅ Tokenizer loaded.") | |
if "pipeline" not in st.session_state: | |
for pipeline in VAR2LABEL.keys(): | |
st.session_state["pipeline"] = load_pipeline(pipeline) | |
st.text("✅ Models ready.") | |
sleep(2) | |
loading_logs.success("🎉 Everything is ready!") | |
st.session_state["init_models"] = False | |
if "inference_timers" not in st.session_state: | |
st.session_state["inference_timers"] = {} | |
exp_number = st.slider("The number of experiments per model.", min_value=10, max_value=300, value=150) | |
get_only_mean = st.checkbox("Get only the mean of the inference time for each model.", value=False) | |
input_text = st.text_area( | |
"Enter text to classify", | |
"there is a shortage of capital, and we need extra financing; growth is strong and we have plenty of liquidity; there are doubts about our finances; profits are flat" | |
) | |
run_inference = st.button("🚀 Run inference") | |
if run_inference: | |
st.text("🔎 Running inference...") | |
sentences = input_text.split(";") | |
st.session_state["inference_timers"] = get_timers(samples=sentences, exp_number=exp_number, only_mean=get_only_mean) | |
st.plotly_chart(get_plot(st.session_state["inference_timers"]), use_container_width=True) | |