# """ # Author: Amir Hossein Kargaran # Date: August, 2023 # Description: This code applies LIME (Local Interpretable Model-Agnostic Explanations) on language identification models. # MIT License # Some part of the code is adopted from here: https://gist.github.com/ageitgey/60a8b556a9047a4ca91d6034376e5980 # """ import gradio as gr from io import BytesIO from fasttext.FastText import _FastText import re import lime.lime_text import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from selenium import webdriver from selenium.common.exceptions import WebDriverException import os # Define a dictionary to map model choices to their respective paths model_paths = { "OpenLID": ["laurievb/OpenLID", 'model.bin'], "GlotLID": ["cis-lmu/glotlid", 'model.bin'], "NLLB": ["facebook/fasttext-language-identification", 'model.bin'] } # Create a dictionary to cache classifiers cached_classifiers = {} def load_classifier(model_choice): if model_choice in cached_classifiers: return cached_classifiers[model_choice] # Load the FastText language identification model from Hugging Face Hub model_path = hf_hub_download(repo_id=model_paths[model_choice][0], filename=model_paths[model_choice][1]) # Create the FastText classifier classifier = _FastText(model_path) cached_classifiers[model_choice] = classifier return classifier # cache all models for model_choice in model_paths.keys(): load_classifier(model_choice) def remove_label_prefix(item): return item.replace('__label__', '') def remove_label_prefix_list(input_list): if isinstance(input_list[0], list): return [[remove_label_prefix(item) for item in inner_list] for inner_list in input_list] else: return [remove_label_prefix(item) for item in input_list] def tokenize_string(sentence, n=None): if n is None: tokens = sentence.split() else: tokens = [] for i in range(len(sentence) - n + 1): tokens.append(sentence[i:i + n]) return tokens def fasttext_prediction_in_sklearn_format(classifier, texts, num_class): # if isinstance(texts, str): # texts = [texts] res = [] labels, probabilities = classifier.predict(texts, -1) labels = remove_label_prefix_list(labels) for label, probs, text in zip(labels, probabilities, texts): order = np.argsort(np.array(label)) res.append(probs[order]) return np.array(res) def generate_explanation_html(input_sentence, explainer, classifier, num_class): preprocessed_sentence = input_sentence exp = explainer.explain_instance( preprocessed_sentence, classifier_fn=lambda x: fasttext_prediction_in_sklearn_format(classifier, x, num_class), top_labels=2, num_features=20, ) output_html_filename = "explanation.html" exp.save_to_file(output_html_filename) return output_html_filename def take_screenshot(local_html_path): options = webdriver.ChromeOptions() options.add_argument('--headless') options.add_argument('--no-sandbox') options.add_argument('--disable-dev-shm-usage') try: local_html_path = os.path.abspath(local_html_path) wd = webdriver.Chrome(options=options) wd.set_window_size(1366, 728) wd.get('file://' + local_html_path) wd.implicitly_wait(10) screenshot = wd.get_screenshot_as_png() except WebDriverException as e: return Image.new('RGB', (1, 1)) finally: if wd: wd.quit() return Image.open(BytesIO(screenshot)) # Define the merge function def merge_function(input_sentence, selected_model): input_sentence = input_sentence.replace('\n', ' ') # Load the FastText language identification model from Hugging Face Hub classifier = load_classifier(selected_model) class_names = remove_label_prefix_list(classifier.labels) class_names = np.sort(class_names) num_class = len(class_names) # Load Lime explainer = lime.lime_text.LimeTextExplainer( split_expression=tokenize_string, bow=False, class_names=class_names) # Generate output output_html_filename = generate_explanation_html(input_sentence, explainer, classifier, num_class) im = take_screenshot(output_html_filename) return im, output_html_filename # Define the Gradio interface input_text = gr.Textbox(label="Input Text", value="J'ai visited la beautiful beach avec mes amis for a relaxing journée under the sun.") model_choice = gr.Radio(choices=["GlotLID", "OpenLID", "NLLB"], label="Select Model", value='GlotLID') output_explanation = gr.outputs.File(label="Explanation HTML") iface = gr.Interface(merge_function, inputs=[input_text, model_choice], outputs=[gr.Image(type="pil", height=364, width=683, label = "Explanation Image"), output_explanation], title="LIME LID", description="This code applies LIME (Local Interpretable Model-Agnostic Explanations) on fasttext language identification.", allow_flagging='never', theme=gr.themes.Soft()) iface.launch()