schroneko commited on
Commit
babe057
1 Parent(s): 3bc31bd
Files changed (7) hide show
  1. GLuCoSE-base-ja-v2.py +20 -0
  2. RoSEtta-base-ja.py +22 -0
  3. app.py +86 -0
  4. pyproject.toml +13 -0
  5. requirements.txt +5 -0
  6. ruri-large.py +26 -0
  7. uv.lock +0 -0
GLuCoSE-base-ja-v2.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLuCoSE-base-ja-v2.py
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ # Download from the 🤗 Hub
6
+ model = SentenceTransformer("pkshatech/GLuCoSE-base-ja-v2")
7
+ # Run inference
8
+ sentences = [
9
+ 'The weather is lovely today.',
10
+ "It's so sunny outside!",
11
+ 'He drove to the stadium.',
12
+ ]
13
+ embeddings = model.encode(sentences)
14
+ print(embeddings.shape)
15
+ # [3, 768]
16
+
17
+ # Get the similarity scores for the embeddings
18
+ similarities = model.similarity(embeddings, embeddings)
19
+ print(similarities.shape)
20
+ # [3, 3]
RoSEtta-base-ja.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RoSEtta-base-ja.py
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ # Download from the 🤗 Hub
6
+ # model = SentenceTransformer("pkshatech/RoSEtta-base")
7
+ # 自分の環境では `trust_remote_code=True)` を追加しないとエラーが発生しました
8
+ model = SentenceTransformer("pkshatech/RoSEtta-base", trust_remote_code=True)
9
+ # Run inference
10
+ sentences = [
11
+ 'The weather is lovely today.',
12
+ "It's so sunny outside!",
13
+ 'He drove to the stadium.',
14
+ ]
15
+ embeddings = model.encode(sentences)
16
+ print(embeddings.shape)
17
+ # [3, 768]
18
+
19
+ # Get the similarity scores for the embeddings
20
+ similarities = model.similarity(embeddings, embeddings)
21
+ print(similarities.shape)
22
+ # [3, 3]
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import torch.nn.functional as F
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ def load_model(model_name):
9
+ if model_name == "GLuCoSE-base-ja-v2":
10
+ return SentenceTransformer("pkshatech/GLuCoSE-base-ja-v2")
11
+ elif model_name == "RoSEtta-base-ja":
12
+ return SentenceTransformer("pkshatech/RoSEtta-base", trust_remote_code=True)
13
+ elif model_name == "ruri-large":
14
+ return SentenceTransformer("cl-nagoya/ruri-large")
15
+
16
+
17
+ def get_similarities(model_name, sentences):
18
+ model = load_model(model_name)
19
+
20
+ if model_name == "ruri-large":
21
+ sentences = [
22
+ "クエリ: " + s if i % 2 == 0 else "文章: " + s
23
+ for i, s in enumerate(sentences)
24
+ ]
25
+
26
+ embeddings = model.encode(sentences, convert_to_tensor=True)
27
+
28
+ if model_name in ["GLuCoSE-base-ja-v2", "RoSEtta-base-ja"]:
29
+ similarities = model.similarity(embeddings, embeddings)
30
+ else: # ruri-large
31
+ similarities = F.cosine_similarity(
32
+ embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2
33
+ )
34
+
35
+ return similarities.cpu().numpy()
36
+
37
+
38
+ def format_similarities(similarities):
39
+ return "\n".join([" ".join([f"{val:.4f}" for val in row]) for row in similarities])
40
+
41
+
42
+ def process_input(model_name, input_text):
43
+ sentences = [s.strip() for s in input_text.split("\n") if s.strip()]
44
+ similarities = get_similarities(model_name, sentences)
45
+ return format_similarities(similarities)
46
+
47
+
48
+ models = ["GLuCoSE-base-ja-v2", "RoSEtta-base-ja", "ruri-large"]
49
+
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("# Sentence Similarity Demo")
52
+
53
+ with gr.Row():
54
+ with gr.Column():
55
+ model_dropdown = gr.Dropdown(
56
+ choices=models, label="Select Model", value=models[0]
57
+ )
58
+ input_text = gr.Textbox(lines=5, label="Input Sentences (one per line)")
59
+ submit_btn = gr.Button(value="Calculate Similarities")
60
+
61
+ with gr.Column():
62
+ output_text = gr.Textbox(label="Similarity Matrix", lines=10)
63
+
64
+ submit_btn.click(
65
+ process_input, inputs=[model_dropdown, input_text], outputs=output_text
66
+ )
67
+
68
+ gr.Examples(
69
+ examples=[
70
+ [
71
+ "GLuCoSE-base-ja-v2",
72
+ "The weather is lovely today.\nIt's so sunny outside!\nHe drove to the stadium.",
73
+ ],
74
+ [
75
+ "RoSEtta-base-ja",
76
+ "The weather is lovely today.\nIt's so sunny outside!\nHe drove to the stadium.",
77
+ ],
78
+ [
79
+ "ruri-large",
80
+ "瑠璃色はどんな色?\n瑠璃色(るりいろ)は、紫みを帯びた濃い青。名は、半貴石の瑠璃(ラピスラズリ、英: lapis lazuli)による。JIS慣用色名では「こい紫みの青」(略号 dp-pB)と定義している[1][2]。\nワシやタカのように、鋭いくちばしと爪を持った大型の鳥類を総称して「何類」というでしょう?\nワシ、タカ、ハゲワシ、ハヤブサ、コンドル、フクロウが代表的である。これらの猛禽類はリンネ前後の時代(17~18世紀)には鷲類・鷹類・隼類及び梟類に分類された。ちなみにリンネは狩りをする鳥を単一の目(もく)にまとめ、vultur(コンドル、ハゲワシ)、falco(ワシ、タカ、ハヤブサなど)、strix(フクロウ)、lanius(モズ)の4属を含めている。",
81
+ ],
82
+ ],
83
+ inputs=[model_dropdown, input_text],
84
+ )
85
+
86
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "playground-embedding"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "fugashi>=1.3.2",
9
+ "gradio>=4.42.0",
10
+ "sentence-transformers>=3.0.1",
11
+ "sentencepiece>=0.2.0",
12
+ "unidic-lite>=1.0.8",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fugashi>=1.3.2
2
+ gradio>=4.42.0
3
+ sentence-transformers>=3.0.1
4
+ sentencepiece>=0.2.0
5
+ unidic-lite>=1.0.8
ruri-large.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruri-large.py
2
+
3
+ import torch.nn.functional as F
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ # Download from the 🤗 Hub
7
+ model = SentenceTransformer("cl-nagoya/ruri-large")
8
+
9
+ # Don't forget to add the prefix "クエリ: " for query-side or "文章: " for passage-side texts.
10
+ sentences = [
11
+ "クエリ: 瑠璃色はどんな色?",
12
+ "文章: 瑠璃色(るりいろ)は、紫みを帯びた濃い青。名は、半貴石の瑠璃(ラピスラズリ、英: lapis lazuli)による。JIS慣用色名では「こい紫みの青」(略号 dp-pB)と定義している[1][2]。",
13
+ "クエリ: ワシやタカのように、鋭いくちばしと爪を持った大型の鳥類を総称して「何類」というでしょう?",
14
+ "文章: ワシ、タカ、ハゲワシ、ハヤブサ、コンドル、フクロウが代表的である。これらの猛禽類はリンネ前後の時代(17~18世紀)には鷲類・鷹類・隼類及び梟類に分類された。ちなみにリンネは狩りをする鳥を単一の目(もく)にまとめ、vultur(コンドル、ハゲワシ)、falco(ワシ、タカ、ハヤブサなど)、strix(フクロウ)、lanius(モズ)の4属を含めている。",
15
+ ]
16
+
17
+ embeddings = model.encode(sentences, convert_to_tensor=True)
18
+ print(embeddings.size())
19
+ # [4, 1024]
20
+
21
+ similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
22
+ print(similarities)
23
+ # [[1.0000, 0.9429, 0.6565, 0.6997],
24
+ # [0.9429, 1.0000, 0.6579, 0.6768],
25
+ # [0.6565, 0.6579, 1.0000, 0.8933],
26
+ # [0.6997, 0.6768, 0.8933, 1.0000]]
uv.lock ADDED
The diff for this file is too large to render. See raw diff