File size: 3,856 Bytes
5b29d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import datasets
import faiss
import numpy as np
import streamlit as st
import torch
from datasets import Dataset
from transformers import FeatureExtractionPipeline, pipeline


@st.cache_resource
def load_encoder_pipeline(encoder_path: str) -> FeatureExtractionPipeline:
    """訓練済みの教師なしSimCSEのエンコーダを読み込む"""
    encoder_pipeline = pipeline("feature-extraction", model=encoder_path)
    return encoder_pipeline


@st.cache_resource
def load_dataset(dataset_dir: str) -> Dataset:
    """文埋め込み適用済みのデータセットを読み込み、Faissのインデックスを構築"""
    # ディスクに保存されたデータセットを読み込む
    dataset = datasets.load_from_disk(dataset_dir)

    # データセットの"embeddings"フィールドの値からFaissのインデックスを構築する
    emb_dim = len(dataset[0]["embeddings"])
    index = faiss.IndexFlatIP(emb_dim)
    dataset.add_faiss_index("embeddings", custom_index=index)

    return dataset


def embed_text(
    text: str, encoder_pipeline: FeatureExtractionPipeline
) -> np.ndarray:
    """教師なしSimCSEのエンコーダを用いてテキストの埋め込みを計算"""
    with torch.inference_mode():
        # encoder_pipelineが返すTensorのsizeは(1, トークン数, 埋め込みの次元数)
        encoded_text = encoder_pipeline(text, return_tensors="pt")[0][0]

    # ベクトルをNumPyのarrayに変換
    emb = encoded_text.cpu().numpy().astype(np.float32)
    # ベクトルのノルムが1になるように正規化
    emb = emb / np.linalg.norm(emb)
    return emb


def search_similar_texts(
    query_text: str,
    dataset: Dataset,
    encoder_pipeline: FeatureExtractionPipeline,
    k: int = 5,
) -> list[dict[str, float | str]]:
    """モデルとデータセットを用いてクエリの類似文検索を実行"""
    # クエリに対して類似テキストをk件取得する
    scores, retrieved_examples = dataset.get_nearest_examples(
        "embeddings", embed_text(query_text, encoder_pipeline), k=k
    )
    titles = retrieved_examples["title"]
    texts = retrieved_examples["text"]

    # 検索された類似テキストをdictのlistにして返す
    results = [
        {"score": score, "title": title, "text": text}
        for score, title, text in zip(scores, titles, texts)
    ]
    return results


# 訓練済みの教師なしSimCSEのモデルを読み込む
encoder_pipeline = load_encoder_pipeline("outputs_unsup_simcse/encoder")

# 文埋め込み適用済みのデータセットを読み込む
dataset = load_dataset("outputs_unsup_simcse/embedded_paragraphs")

# デモページのタイトルを表示する
st.title(":mag: Wikipedia Paragraph Search")

# デモページのフォームを表示する
with st.form("input_form"):
    # クエリの入力欄を表示し、入力された値を受け取る
    query_text = st.text_input(
        "クエリを入力:", value="日本語は、主に日本で話されている言語である。", max_chars=150
    )
    # 検索する段落数のスライダーを表示し、設定された値を受け取る
    k = st.slider("検索する段落数:", min_value=1, max_value=100, value=10)
    # 検索を実行するボタンを表示し、押下されたらTrueを受け取る
    is_submitted = st.form_submit_button("Search")

# 検索結果を表示する
if is_submitted and len(query_text) > 0:
    # クエリに対して類似文検索を実行し、検索結果を受け取る
    serach_results = search_similar_texts(
        query_text, dataset, encoder_pipeline, k=k
    )
    # 検索結果を表示する
    st.subheader("検索結果")
    st.dataframe(serach_results, use_container_width=True)
    st.caption("セルのダブルクリックで全体が表示されます")