from transformers import pipeline from torch import Tensor from transformers import AutoTokenizer, AutoModel from torch.nn.functional import cosine_similarity import gradio as gr def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def get_similarity(sentence1, sentence2): input_texts = [sentence1, sentence2] # Tokenize and compute embeddings batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors="pt") outputs = model(**batch_dict) embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) similarity = cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0)) similarity = round(similarity.item(), 4) return similarity checkpoint = "intfloat/multilingual-e5-large" tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModel.from_pretrained(checkpoint) demo = gr.Blocks(theme="freddyaboulton/dracula_revamped") with demo: gr.Markdown("# Sentence Similarity") gr.Markdown("### How to use:") gr.Markdown("- Enter Passage 1 and Passage 2, then press Submit") gr.Markdown("- Select an example, then press Submit") gr.Markdown("Model: https://huggingface.co/intfloat/multilingual-e5-large (Multilingual: 94 languages)") with gr.Row(): p_txt1 = gr.Textbox(placeholder="Enter passage 1", label="Passage 1", lines=3, scale=2) p_txt2 = gr.Textbox(placeholder="Enter passage 2", label="Passage 2", lines=3, scale=2) o_txt = gr.Textbox(placeholder="Similarity score", lines=1, interactive=False, label="Similarity score (0-1)", scale=1) submit = gr.Button("Submit") gr.Examples( [ ["A big bus is running on the road in the city.", "There is a big bus running on the road."], ["A big bus is running on the road in the city.", "Two children in costumes are standing on the bed."], ["街中の道路を大きなバスが走っています。", "道路を大きなバスが走っています。"], ["街中の道路を大きなバスが走っています。", "ベッドの上で衣装を着た二人の子供が立っています。"], ["A big bus is running on the road in the city.", "道路を大きなバスが走っています。"] ], inputs=[p_txt1, p_txt2] ) submit.click( get_similarity, [p_txt1, p_txt2], o_txt ) demo.launch()