File size: 3,763 Bytes
b8c24aa
3a82207
63b82b4
3592f5f
63b82b4
 
c8fdb3b
3a82207
4e81072
7dc3087
08c1bd3
64d8a64
63b82b4
64d8a64
 
63b82b4
64d8a64
63b82b4
08c1bd3
3592f5f
 
 
 
63b82b4
3592f5f
 
3a82207
 
3592f5f
 
 
3a82207
 
3592f5f
 
3a82207
3592f5f
 
63b82b4
3592f5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b82b4
3592f5f
 
 
 
 
 
 
 
 
 
 
63b82b4
 
3592f5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
from transformers import (
    AutoModel,
    AutoTokenizer,
)
import os
from threading import Thread
import spaces
import time

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


def cls_pooling(model_output):
    return model_output[0][:, 0]


@spaces.GPU
def get_embedding(text, use_mean_pooling, model_id):

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModel.from_pretrained(model_id, torch_dtype=torch.float16)

    model = model.to(device)
    inputs = tokenizer(
        text, return_tensors="pt", padding=True, truncation=True, max_length=512
    )
    inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
    with torch.no_grad():
        model_output = model(**inputs)
    if use_mean_pooling:
        return mean_pooling(model_output, inputs["attention_mask"])
    return cls_pooling(model_output)


def get_similarity(text1, text2, pooling_method, model_id):
    use_mean_pooling = pooling_method == "Use Mean Pooling"
    embedding1 = get_embedding(text1, use_mean_pooling, model_id)
    embedding2 = get_embedding(text2, use_mean_pooling, model_id)
    return torch.nn.functional.cosine_similarity(embedding1, embedding2).item()


gr.Interface(
    get_similarity,
    [
        gr.Textbox(lines=7, label="Text 1"),
        gr.Textbox(lines=7, label="Text 2"),
        gr.Dropdown(
            choices=["Use Mean Pooling", "Use CLS"],
            value="Use Mean Pooling",
            label="Pooling Method",
            info="Mean Pooling: Averages all token embeddings (better for semantic similarity)\nCLS Pooling: Uses only the [CLS] token embedding (faster, might miss context)",
        ),
        gr.Dropdown(
            choices=[
                "tasksource/ModernBERT-base-embed",
                "tasksource/ModernBERT-base-nli",
                "joe32140/ModernBERT-large-msmarco",
                "answerdotai/ModernBERT-large",
                "answerdotai/ModernBERT-base",
            ],
            value="answerdotai/ModernBERT-large",
            label="Model",
            info="Choose between the variants of ModernBERT \nMight take a few seconds to load the model",
        ),
    ],
    gr.Textbox(label="Similarity"),
    title="ModernBERT Similarity Demo",
    description="Compute the similarity between two texts using ModernBERT. Choose between different pooling strategies for embedding generation.",
    examples=[
        [
            "The quick brown fox jumps over the lazy dog",
            "A swift brown fox leaps above a sleeping canine",
            "Use Mean Pooling",
            "answerdotai/ModernBERT-large"
        ],
        [
            "I love programming in Python",
            "I hate coding with Python",
            "Use Mean Pooling",
            "answerdotai/ModernBERT-large"
        ],
        [
            "The weather is beautiful today",
            "Machine learning models are improving rapidly",
            "Use Mean Pooling",
            "answerdotai/ModernBERT-large"
        ],
        [
            "def calculate_sum(a, b):\n    return a + b",
            "def add_numbers(x, y):\n    result = x + y\n    return result",
            "Use Mean Pooling",
            "answerdotai/ModernBERT-large"
        ]
    ]
).launch(share=True)