singletongue commited on
Commit
5b29d9a
1 Parent(s): 27212ed
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import faiss
3
+ import numpy as np
4
+ import streamlit as st
5
+ import torch
6
+ from datasets import Dataset
7
+ from transformers import FeatureExtractionPipeline, pipeline
8
+
9
+
10
+ @st.cache_resource
11
+ def load_encoder_pipeline(encoder_path: str) -> FeatureExtractionPipeline:
12
+ """訓練済みの教師なしSimCSEのエンコーダを読み込む"""
13
+ encoder_pipeline = pipeline("feature-extraction", model=encoder_path)
14
+ return encoder_pipeline
15
+
16
+
17
+ @st.cache_resource
18
+ def load_dataset(dataset_dir: str) -> Dataset:
19
+ """文埋め込み適用済みのデータセットを読み込み、Faissのインデックスを構築"""
20
+ # ディスクに保存されたデータセットを読み込む
21
+ dataset = datasets.load_from_disk(dataset_dir)
22
+
23
+ # データセットの"embeddings"フィールドの値からFaissのインデックスを構築する
24
+ emb_dim = len(dataset[0]["embeddings"])
25
+ index = faiss.IndexFlatIP(emb_dim)
26
+ dataset.add_faiss_index("embeddings", custom_index=index)
27
+
28
+ return dataset
29
+
30
+
31
+ def embed_text(
32
+ text: str, encoder_pipeline: FeatureExtractionPipeline
33
+ ) -> np.ndarray:
34
+ """教師なしSimCSEのエンコーダを用いてテキストの埋め込みを計算"""
35
+ with torch.inference_mode():
36
+ # encoder_pipelineが返すTensorのsizeは(1, トークン数, 埋め込みの次元数)
37
+ encoded_text = encoder_pipeline(text, return_tensors="pt")[0][0]
38
+
39
+ # ベクトルをNumPyのarrayに変換
40
+ emb = encoded_text.cpu().numpy().astype(np.float32)
41
+ # ベクトルのノルムが1になるように正規化
42
+ emb = emb / np.linalg.norm(emb)
43
+ return emb
44
+
45
+
46
+ def search_similar_texts(
47
+ query_text: str,
48
+ dataset: Dataset,
49
+ encoder_pipeline: FeatureExtractionPipeline,
50
+ k: int = 5,
51
+ ) -> list[dict[str, float | str]]:
52
+ """モデルとデータセットを用いてクエリの類似文検索を実行"""
53
+ # クエリに対して類似テキストをk件取得する
54
+ scores, retrieved_examples = dataset.get_nearest_examples(
55
+ "embeddings", embed_text(query_text, encoder_pipeline), k=k
56
+ )
57
+ titles = retrieved_examples["title"]
58
+ texts = retrieved_examples["text"]
59
+
60
+ # 検索された類似テキストをdictのlistにして返す
61
+ results = [
62
+ {"score": score, "title": title, "text": text}
63
+ for score, title, text in zip(scores, titles, texts)
64
+ ]
65
+ return results
66
+
67
+
68
+ # 訓練済みの教師なしSimCSEのモデルを読み込む
69
+ encoder_pipeline = load_encoder_pipeline("outputs_unsup_simcse/encoder")
70
+
71
+ # 文埋め込み適用済みのデータセットを読み込む
72
+ dataset = load_dataset("outputs_unsup_simcse/embedded_paragraphs")
73
+
74
+ # デモページのタイトルを表示する
75
+ st.title(":mag: Wikipedia Paragraph Search")
76
+
77
+ # デモページのフォームを表示する
78
+ with st.form("input_form"):
79
+ # クエリの入力欄を表示し、入力された値を受け取る
80
+ query_text = st.text_input(
81
+ "クエリを入力:", value="日本語は、主に日本で話されている言語である。", max_chars=150
82
+ )
83
+ # 検索する段落数のスライダーを表示し、設定された値を受け取る
84
+ k = st.slider("検索する段落数:", min_value=1, max_value=100, value=10)
85
+ # 検索を実行するボタンを表示し、押下されたらTrueを受け取る
86
+ is_submitted = st.form_submit_button("Search")
87
+
88
+ # 検索結果を表示する
89
+ if is_submitted and len(query_text) > 0:
90
+ # クエリに対して類似文検索を実行し、検索結果を受け取る
91
+ serach_results = search_similar_texts(
92
+ query_text, dataset, encoder_pipeline, k=k
93
+ )
94
+ # 検索結果を表示する
95
+ st.subheader("検索結果")
96
+ st.dataframe(serach_results, use_container_width=True)
97
+ st.caption("セルのダブルクリックで全体が表示されます")
outputs_unsup_simcse/embedded_paragraphs/data-00000-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57b90263da12e6f9eaa44d91172dd1f5f015ef6c0c54d61e7d54bccc6b79b759
3
+ size 458351816
outputs_unsup_simcse/embedded_paragraphs/data-00001-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3fde1ef4827099de7d8689bdbf02e3180c5227f0ca2b03e2c24da46bacbb49d
3
+ size 458002304
outputs_unsup_simcse/embedded_paragraphs/data-00002-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91f72a3f71b068f9008a289e4a361cba0a880bb25aa4da5453bb2463d3b3f454
3
+ size 456771176
outputs_unsup_simcse/embedded_paragraphs/data-00003-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5807ae91181b76a6836b65cf5c1092314cc126935be10bafc7e85b79500bc76a
3
+ size 457297584
outputs_unsup_simcse/embedded_paragraphs/data-00004-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:305256fabd2246f74dcd4a980d9ab6c3dced5327e64f7f992f2ee0eebb8a8d18
3
+ size 456882896
outputs_unsup_simcse/embedded_paragraphs/data-00005-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76cb3d35b85b02119e0d5de32782f1b53ce20e166b312f028288b95fdce6e2e5
3
+ size 456954640
outputs_unsup_simcse/embedded_paragraphs/data-00006-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3936501890f32e56a54c5ec091891421e1511e6b0c3d43d7a5511c326182998f
3
+ size 458542088
outputs_unsup_simcse/embedded_paragraphs/data-00007-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81167192200edbb2e2b947d5a6bd0437ddf42b01679bf8f34e3b5067f86ed53a
3
+ size 457251296
outputs_unsup_simcse/embedded_paragraphs/data-00008-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a2bf1d222cf15cb91d8789b7a1bbf17e349c39292303d70a0cdc2d29966d29f
3
+ size 458474520
outputs_unsup_simcse/embedded_paragraphs/data-00009-of-00010.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:027b919374c985209a89cd99a849de70fb75ffcf3c5c4b610cac21d938c59d3e
3
+ size 458407928
outputs_unsup_simcse/embedded_paragraphs/dataset_info.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builder_name": "jawiki-paragraphs",
3
+ "citation": "",
4
+ "config_name": "default",
5
+ "dataset_size": 4417130987,
6
+ "description": "\u66f8\u7c4d\u300e\u5927\u898f\u6a21\u8a00\u8a9e\u30e2\u30c7\u30eb\u5165\u9580\u300f\u3067\u4f7f\u7528\u3059\u308b Wikipedia \u6bb5\u843d\u306e\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3067\u3059\u3002GitHub \u30ea\u30dd\u30b8\u30c8\u30ea singletongue/wikipedia-utils \u3067\u516c\u958b\u3055\u308c\u3066\u3044\u308b\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3092\u5229\u7528\u3057\u3066\u3044\u307e\u3059\u3002",
7
+ "download_checksums": {
8
+ "https://github.com/singletongue/wikipedia-utils/releases/download/2023-04-03/paragraphs-jawiki-20230403.json.gz": {
9
+ "num_bytes": 1489512230,
10
+ "checksum": null
11
+ }
12
+ },
13
+ "download_size": 1489512230,
14
+ "features": {
15
+ "id": {
16
+ "dtype": "string",
17
+ "_type": "Value"
18
+ },
19
+ "pageid": {
20
+ "dtype": "int64",
21
+ "_type": "Value"
22
+ },
23
+ "revid": {
24
+ "dtype": "int64",
25
+ "_type": "Value"
26
+ },
27
+ "paragraph_index": {
28
+ "dtype": "int64",
29
+ "_type": "Value"
30
+ },
31
+ "title": {
32
+ "dtype": "string",
33
+ "_type": "Value"
34
+ },
35
+ "section": {
36
+ "dtype": "string",
37
+ "_type": "Value"
38
+ },
39
+ "text": {
40
+ "dtype": "string",
41
+ "_type": "Value"
42
+ },
43
+ "html_tag": {
44
+ "dtype": "string",
45
+ "_type": "Value"
46
+ },
47
+ "embeddings": {
48
+ "feature": {
49
+ "dtype": "float32",
50
+ "_type": "Value"
51
+ },
52
+ "_type": "Sequence"
53
+ }
54
+ },
55
+ "homepage": "https://github.com/singletongue/wikipedia-utils",
56
+ "license": "\u672c\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3067\u4f7f\u7528\u3057\u3066\u3044\u308b Wikipedia \u306e\u30b3\u30f3\u30c6\u30f3\u30c4\u306f\u3001\u30af\u30ea\u30a8\u30a4\u30c6\u30a3\u30d6\u30fb\u30b3\u30e2\u30f3\u30ba\u8868\u793a\u30fb\u7d99\u627f\u30e9\u30a4\u30bb\u30f3\u30b9 3.0 (CC BY-SA 3.0) \u304a\u3088\u3073 GNU \u81ea\u7531\u6587\u66f8\u30e9\u30a4\u30bb\u30f3\u30b9 (GFDL) \u306e\u4e0b\u306b\u914d\u5e03\u3055\u308c\u3066\u3044\u308b\u3082\u306e\u3067\u3059\u3002",
57
+ "size_in_bytes": 5906643217,
58
+ "splits": {
59
+ "train": {
60
+ "name": "train",
61
+ "num_bytes": 4417130987,
62
+ "num_examples": 9668476,
63
+ "shard_lengths": [
64
+ 984321,
65
+ 1031799,
66
+ 1101914,
67
+ 1132906,
68
+ 1123001,
69
+ 1143878,
70
+ 1138063,
71
+ 1139173,
72
+ 873421
73
+ ],
74
+ "dataset_name": "jawiki-paragraphs"
75
+ }
76
+ },
77
+ "version": {
78
+ "version_str": "1.0.0",
79
+ "major": 1,
80
+ "minor": 0,
81
+ "patch": 0
82
+ }
83
+ }
outputs_unsup_simcse/embedded_paragraphs/state.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00010.arrow"
5
+ },
6
+ {
7
+ "filename": "data-00001-of-00010.arrow"
8
+ },
9
+ {
10
+ "filename": "data-00002-of-00010.arrow"
11
+ },
12
+ {
13
+ "filename": "data-00003-of-00010.arrow"
14
+ },
15
+ {
16
+ "filename": "data-00004-of-00010.arrow"
17
+ },
18
+ {
19
+ "filename": "data-00005-of-00010.arrow"
20
+ },
21
+ {
22
+ "filename": "data-00006-of-00010.arrow"
23
+ },
24
+ {
25
+ "filename": "data-00007-of-00010.arrow"
26
+ },
27
+ {
28
+ "filename": "data-00008-of-00010.arrow"
29
+ },
30
+ {
31
+ "filename": "data-00009-of-00010.arrow"
32
+ }
33
+ ],
34
+ "_fingerprint": "8ff2a1214e978197",
35
+ "_format_columns": null,
36
+ "_format_kwargs": {},
37
+ "_format_type": null,
38
+ "_output_all_columns": false,
39
+ "_split": "train"
40
+ }
outputs_unsup_simcse/encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "cl-tohoku/bert-base-japanese-v3",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 12,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.30.2",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 32768
25
+ }
outputs_unsup_simcse/encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aca39ff56e5bdf8e331de99f48bc049bd2763b327f64457aa98c79bc8e98367e
3
+ size 444899885
outputs_unsup_simcse/encoder/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
outputs_unsup_simcse/encoder/tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": false,
5
+ "do_subword_tokenize": true,
6
+ "do_word_tokenize": true,
7
+ "jumanpp_kwargs": null,
8
+ "mask_token": "[MASK]",
9
+ "mecab_kwargs": {
10
+ "mecab_dic": "unidic_lite"
11
+ },
12
+ "model_max_length": 512,
13
+ "never_split": null,
14
+ "pad_token": "[PAD]",
15
+ "sep_token": "[SEP]",
16
+ "subword_tokenizer_type": "wordpiece",
17
+ "sudachi_kwargs": null,
18
+ "tokenizer_class": "BertJapaneseTokenizer",
19
+ "unk_token": "[UNK]",
20
+ "word_tokenizer_type": "mecab"
21
+ }
outputs_unsup_simcse/encoder/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets
2
+ faiss-cpu
3
+ numpy
4
+ torch
5
+ transformers[ja,torch]