Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import csv | |
import re | |
import unicodedata | |
from collections import defaultdict | |
from itertools import chain | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from bm25s.hf import BM25HF, TokenizerHF | |
from transformers import AutoModelForPreTraining, AutoTokenizer | |
ALIAS_SEP = "|" | |
CATEGORY_ENTITY_PREFIX = "Category:" | |
ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"] | |
MAX_TEXT_LENGTH = 800 | |
MAX_TEXT_FILE_LINES = 10 | |
MAX_ENTITY_FILE_LINES = 1000 | |
repo_id = "studio-ousia/luxe" | |
revision = "ja-v0.3.2" | |
nayose_repo_id = "studio-ousia/luxe-nayose-bm25" | |
ignore_category_patterns = [ | |
r"\d+年", | |
r"楽曲 [ぁ-ん]", | |
r"漫画作品 [ぁ-ん]", | |
r"アニメ作品 [ぁ-ん]", | |
r"アニメ作品 [ぁ-ん]", | |
r"の一覧", | |
r"各国の", | |
r"各年の", | |
] | |
def clean_default_entity_vocab(tokenizer): | |
entity_vocab = {} | |
for entity, entity_id in tokenizer.entity_vocab.items(): | |
if entity.startswith("ja:"): | |
entity = entity.removeprefix("ja:") | |
elif entity.startswith("Category:ja:"): | |
entity = "Category:" + entity.removeprefix("Category:ja:") | |
entity_vocab[entity] = entity_id | |
tokenizer.entity_vocab = entity_vocab | |
def normalize_text(text: str) -> str: | |
return unicodedata.normalize("NFKC", text).strip() | |
def get_texts_from_file(file_path: str | None): | |
texts = [] | |
if file_path is not None: | |
try: | |
with open(file_path, newline="") as f: | |
reader = csv.DictReader(f, fieldnames=["text"]) | |
for i, row in enumerate(reader): | |
if i >= MAX_TEXT_FILE_LINES: | |
gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。", duration=5) | |
break | |
text = row["text"] | |
if text.strip() != "": | |
texts.append(text[:MAX_TEXT_LENGTH]) | |
except Exception as e: | |
gr.Warning("ファイルを正しく読み込めませんでした。", duration=5) | |
print(e) | |
texts = [] | |
return texts | |
def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]: | |
token_spans = [] | |
end = 0 | |
for token in tokenizer.tokenize(text): | |
token = token.removeprefix("##") | |
start = text.index(token, end) | |
end = start + len(token) | |
token_spans.append((start, end)) | |
return [(0, 0)] + token_spans + [(end, end)] # count for "[CLS]" and "[SEP]" | |
def get_predicted_entity_spans( | |
ner_logits: torch.Tensor, token_spans: list[tuple[int, int]], entity_span_sensitivity: float = 1.0 | |
) -> list[tuple[int, int]]: | |
length = ner_logits.size(-1) | |
assert ner_logits.size() == (length, length) # not batched | |
ner_probs = torch.sigmoid(ner_logits).triu() | |
probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True) | |
predicted_entity_spans = [] | |
if entity_span_sensitivity > 0.0: | |
for p, i in zip(probs_sorted, sort_idxs.tolist()): | |
if p < 10.0 ** (-1.0 * entity_span_sensitivity): | |
break | |
start_idx = i // length | |
end_idx = i % length | |
start = token_spans[start_idx][0] | |
end = token_spans[end_idx][1] | |
for ex_start, ex_end in predicted_entity_spans: | |
if not (start < end <= ex_start or ex_end <= start < end): | |
break | |
else: | |
predicted_entity_spans.append((start, end)) | |
return sorted(predicted_entity_spans) | |
def get_topk_entities_from_texts( | |
models, | |
texts: str | list[str], | |
k: int = 5, | |
entity_span_sensitivity: float = 1.0, | |
nayose_coef: float = 1.0, | |
entity_replaced_counts: bool = False, | |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]: | |
gr.Info("LUXEによる予測を実行しています。", duration=5) | |
if isinstance(texts, str): | |
texts = [texts] | |
model, tokenizer, bm25_tokenizer, bm25_retriever = models | |
batch_entity_spans: list[list[tuple[int, int]]] = [] | |
topk_normal_entities: list[list[str]] = [] | |
topk_category_entities: list[list[str]] = [] | |
topk_span_entities: list[list[list[str]]] = [] | |
id2normal_entity = { | |
entity_id: entity | |
for entity, entity_id in tokenizer.entity_vocab.items() | |
if entity_id < model.config.num_normal_entities | |
} | |
id2category_entity = { | |
entity_id - model.config.num_normal_entities: entity | |
for entity, entity_id in tokenizer.entity_vocab.items() | |
if entity_id >= model.config.num_normal_entities | |
} | |
ignore_category_entity_ids = [ | |
entity_id - model.config.num_normal_entities | |
for entity, entity_id in tokenizer.entity_vocab.items() | |
if entity_id >= model.config.num_normal_entities | |
and any(re.search(pattern, entity) for pattern in ignore_category_patterns) | |
] | |
entity_k = min(k, len(id2normal_entity)) | |
category_k = min(k, len(id2category_entity)) | |
for text in texts: | |
text = normalize_text(text).strip() | |
tokenized_examples = tokenizer(text, return_tensors="pt") | |
model_outputs = model(**tokenized_examples) | |
token_spans = get_token_spans(tokenizer, text) | |
entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity) | |
batch_entity_spans.append(entity_spans) | |
tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt") | |
model_outputs = model(**tokenized_examples) | |
if model_outputs.topic_entity_logits is not None: | |
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(entity_k) | |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()]) | |
else: | |
topk_normal_entities.append([]) | |
if model_outputs.topic_category_logits is not None: | |
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf") | |
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(category_k) | |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()]) | |
else: | |
topk_category_entities.append([]) | |
if model_outputs.entity_logits is not None: | |
span_entity_logits = model_outputs.entity_logits[0, :, :500000] | |
if nayose_coef > 0.0 and entity_replaced_counts == 0: | |
nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans] | |
nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries) | |
nayose_scores = torch.vstack( | |
[torch.from_numpy(bm25_retriever.get_scores(tokens)) for tokens in nayose_query_tokens] | |
) | |
span_entity_logits += nayose_coef * nayose_scores | |
_, topk_span_entity_ids = span_entity_logits.topk(entity_k) | |
topk_span_entities.append( | |
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()] | |
) | |
else: | |
topk_span_entities.append([]) | |
return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities | |
def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]: | |
new_entity_text_pairs = [] | |
if file_path is not None: | |
try: | |
with open(file_path, newline="") as f: | |
reader = csv.DictReader(f, fieldnames=["entity", "text"]) | |
for i, row in enumerate(reader): | |
if i >= MAX_ENTITY_FILE_LINES: | |
gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。", duration=5) | |
break | |
entity = normalize_text(row["entity"]).strip() | |
text = normalize_text(row["text"]).strip() | |
if entity != "" and text != "": | |
new_entity_text_pairs.append([entity, text]) | |
except Exception as e: | |
gr.Warning("ファイルを正しく読み込めませんでした。", duration=5) | |
print(e) | |
new_entity_text_pairs = [] | |
return new_entity_text_pairs | |
def replace_entities( | |
models, new_entity_text_pairs: list[tuple[str, str]], entity_replaced_counts: int, preserve_default_entities: bool | |
) -> int: | |
if len(new_entity_text_pairs) == 0: | |
return entity_replaced_counts | |
gr.Info("LUXEのモデルとトークナイザのエンティティ語彙を更新しています。完了までお待ちください。", duration=5) | |
model, tokenizer, bm25_tokenizer, bm25_retriever = models | |
normal_entity_embeddings = defaultdict(list) # entity -> list of embeddings | |
category_entity_embeddings = defaultdict(list) # entity -> list of embeddings | |
normal_entity_counts = {} # entity -> count (int) | |
category_entity_counts = {} # entity -> count (int) | |
for entity, entity_id in sorted(tokenizer.entity_vocab.items(), key=lambda x: x[1]): | |
if entity in ENTITY_SPECIAL_TOKENS or preserve_default_entities: | |
entity_embedding = model.luke.entity_embeddings.entity_embeddings.weight.data[entity_id] | |
if entity.startswith(CATEGORY_ENTITY_PREFIX): | |
category_entity_embeddings[entity].append(entity_embedding) | |
if model.config.entity_counts is not None: | |
category_entity_counts[entity] = model.config.entity_counts[entity_id] | |
else: | |
category_entity_counts[entity] = 1 | |
else: | |
normal_entity_embeddings[entity].append(entity_embedding) | |
if model.config.entity_counts is not None: | |
normal_entity_counts[entity] = model.config.entity_counts[entity_id] | |
else: | |
normal_entity_counts[entity] = 1 | |
for entity, text in new_entity_text_pairs: | |
tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt") | |
model_outputs = model(**tokenized_inputs) | |
entity_embedding = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])[0] | |
if entity.startswith(CATEGORY_ENTITY_PREFIX): | |
category_entity_embeddings[entity].append(entity_embedding) | |
category_entity_counts.setdefault(entity, 1) | |
else: | |
normal_entity_embeddings[entity].append(entity_embedding) | |
normal_entity_counts.setdefault(entity, 1) | |
num_normal_entities = len(normal_entity_embeddings) | |
num_category_entities = len(category_entity_embeddings) | |
entity_embeddings = { | |
entity: sum(embeddings) / len(embeddings) | |
for entity, embeddings in chain(normal_entity_embeddings.items(), category_entity_embeddings.items()) | |
} | |
entity_vocab = {entity: entity_id for entity_id, entity in enumerate(entity_embeddings.keys())} | |
entity_counts = [ | |
category_entity_counts[entity] if entity.startswith(CATEGORY_ENTITY_PREFIX) else normal_entity_counts[entity] | |
for entity in entity_vocab.keys() | |
] | |
tokenizer.entity_vocab = entity_vocab | |
tokenizer.entity_pad_token_id = entity_vocab["[PAD]"] | |
tokenizer.entity_unk_token_id = entity_vocab["[UNK]"] | |
tokenizer.entity_mask_token_id = entity_vocab["[MASK]"] | |
tokenizer.entity_mask2_token_id = entity_vocab["[MASK2]"] | |
entity_embeddings_tensor = torch.vstack(list(entity_embeddings.values())) | |
if model.config.normalize_entity_embeddings: | |
entity_embeddings_tensor = F.normalize(entity_embeddings_tensor) | |
entity_vocab_size, entity_emb_size = entity_embeddings_tensor.size() | |
entity_embeddings_module = nn.Embedding( | |
entity_vocab_size, | |
entity_emb_size, | |
padding_idx=tokenizer.entity_pad_token_id, | |
device=model.luke.entity_embeddings.entity_embeddings.weight.device, | |
dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype, | |
) | |
entity_embeddings_module.weight.data = entity_embeddings_tensor.data | |
model.luke.entity_embeddings.entity_embeddings = entity_embeddings_module | |
entity_decoder_module = nn.Linear(entity_emb_size, entity_vocab_size, bias=False) | |
model.entity_predictions.decoder = entity_decoder_module | |
model.entity_predictions.bias = nn.Parameter(torch.zeros(entity_vocab_size)) | |
model.tie_weights() | |
if model.config.entity_counts is not None: | |
total_normal_entity_count = sum(entity_counts[:num_normal_entities]) | |
total_category_entity_count = sum(entity_counts[num_normal_entities:]) | |
entity_counts_tensor = torch.tensor(entity_counts, dtype=model.dtype, device=model.device) | |
total_entity_counts = torch.tensor( | |
[total_normal_entity_count] * num_normal_entities + [total_category_entity_count] * num_category_entities, | |
dtype=model.dtype, | |
device=model.device, | |
) | |
entity_log_probs = torch.log(entity_counts_tensor / total_entity_counts) | |
model.entity_log_probs = entity_log_probs | |
model.config.entity_vocab_size = entity_vocab_size | |
model.config.num_normal_entities = num_normal_entities | |
model.config.num_category_entities = num_category_entities | |
if model.config.entity_counts is not None: | |
model.config.entity_counts = entity_counts | |
gr.Info("LUXEのモデルとトークナイザのエンティティ語彙の更新が完了しました。", duration=5) | |
return entity_replaced_counts + 1 | |
with gr.Blocks() as demo: | |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True) | |
bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None) | |
bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25") | |
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25") | |
clean_default_entity_vocab(tokenizer) | |
# Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて | |
# __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。 | |
# ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。 | |
# cf. https://www.gradio.app/docs/gradio/state#param-state-value | |
models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever)) | |
texts_input = gr.State([]) | |
entity_replaced_counts = gr.State(0) | |
topk = gr.State(5) | |
entity_span_sensitivity = gr.State(1.0) | |
nayose_coef = gr.State(1.0) | |
texts = gr.State([]) | |
batch_entity_spans = gr.State([]) | |
topk_normal_entities = gr.State([]) | |
topk_category_entities = gr.State([]) | |
topk_span_entities = gr.State([]) | |
gr.Markdown("# 📝 LUXE Demo (β版)") | |
gr.Markdown( | |
"""Studio Ousia で開発中の次世代知識強化言語モデル **LUXE** の動作デモです。 | |
入力されたテキストに対して、テキスト中に出現するエンティティ(事物)と、テキスト全体の主題となるエンティティおよびカテゴリを予測します。 | |
デフォルトのLUXEは、エンティティおよびカテゴリとして、それぞれ日本語 Wikipedia における被リンク数上位50万件および10万件の項目を使用しています。 | |
予測対象のエンティティを任意のものに置き換えて推論を行うことも可能です(下記「LUXE のエンティティ語彙を置き換える」を参照してください)。""", | |
line_breaks=True, | |
) | |
gr.Markdown("## 入力テキスト") | |
with gr.Tab(label="直接入力"): | |
text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH) | |
text_submit_button = gr.Button(value="予測実行", variant="huggingface") | |
with gr.Tab(label="ファイルアップロード"): | |
gr.Markdown( | |
f"""1行1事例のテキストファイル(最大{MAX_TEXT_FILE_LINES}行)をアップロードできます。 | |
アップロードされたテキストのそれぞれに対して推論が実行されます。""", | |
line_breaks=True, | |
) | |
texts_file = gr.File(label="入力テキストファイル") | |
texts_submit_button = gr.Button(value="予測実行", variant="huggingface") | |
text_input.submit( | |
fn=get_topk_entities_from_texts, | |
inputs=[models, text_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts], | |
outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], | |
) | |
text_submit_button.click( | |
fn=get_topk_entities_from_texts, | |
inputs=[models, text_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts], | |
outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], | |
) | |
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts_input) | |
texts_submit_button.click( | |
fn=get_topk_entities_from_texts, | |
inputs=[models, texts_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts], | |
outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities], | |
) | |
gr.Markdown("---") | |
with gr.Accordion(label="ハイパーパラメータ", open=False): | |
topk_input = gr.Number(5, label="予測するエンティティの件数 (Top K)", interactive=True) | |
entity_span_sensitivity_input = gr.Slider( | |
minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True | |
) | |
nayose_coef_input = gr.Slider( | |
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True | |
) | |
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk) | |
entity_span_sensitivity_input.change( | |
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity | |
) | |
nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef) | |
with gr.Accordion(label="LUXE のエンティティ語彙を置き換える", open=False): | |
gr.Markdown( | |
"""LUXE のモデルとトークナイザのエンティティ語彙を任意のエンティティ集合に置き換えます。 | |
エンティティとともに与えられるエンティティの説明文から、エンティティの埋め込みが計算され、LUXE の推論に利用されます。""", | |
line_breaks=True, | |
) | |
gr.Markdown( | |
f"「エンティティ」と「エンティティの説明文」の2列からなる CSV ファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。" | |
) | |
new_entity_text_pairs_file = gr.File(label="エンティティと説明文の CSV ファイル", height="128px") | |
gr.Markdown("CSV ファイルから読み込まれた項目が以下の表に表示されます。表の内容を直接編集することも可能です。") | |
new_entity_text_pairs_input = gr.Dataframe( | |
# value=sample_new_entity_text_pairs, | |
headers=["entity", "text"], | |
col_count=(2, "fixed"), | |
type="array", | |
label="エンティティと説明文", | |
interactive=True, | |
) | |
preserve_default_entities_checkbox = gr.Checkbox(label="既存のエンティティを保持する", value=True) | |
replace_entity_button = gr.Button(value="エンティティ語彙を置き換える") | |
gr.Markdown("LUXE のモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。") | |
new_entity_text_pairs_file.change( | |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input | |
) | |
replace_entity_button.click( | |
fn=replace_entities, | |
inputs=[models, new_entity_text_pairs_input, entity_replaced_counts, preserve_default_entities_checkbox], | |
outputs=entity_replaced_counts, | |
) | |
gr.Markdown("---") | |
gr.Markdown("## 予測されたエンティティとカテゴリ") | |
def render_topk_entities( | |
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities | |
): | |
for text, entity_spans, normal_entities, category_entities, span_entities in zip( | |
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities | |
): | |
highlighted_text_value = [] | |
cur = 0 | |
for start, end in entity_spans: | |
if cur < start: | |
highlighted_text_value.append((text[cur:start], None)) | |
highlighted_text_value.append((text[start:end], "Entity")) | |
cur = end | |
if cur < len(text): | |
highlighted_text_value.append((text[cur:], None)) | |
gr.HighlightedText( | |
value=highlighted_text_value, | |
color_map={"Entity": "green"}, | |
combine_adjacent=False, | |
label="予測されたエンティティのスパン", | |
) | |
# gr.Textbox(text, label="Text") | |
if normal_entities: | |
gr.Dataset( | |
label="テキスト全体に関連するエンティティ", | |
components=["text"], | |
samples=[[entity] for entity in normal_entities], | |
) | |
if category_entities: | |
gr.Dataset( | |
label="テキスト全体に関連するカテゴリ", | |
components=["text"], | |
samples=[[entity] for entity in category_entities], | |
) | |
with gr.Accordion(label="テキスト中のスパンに対応するエンティティ", open=len(texts) == 1): | |
span_texts = [text[start:end] for start, end in entity_spans] | |
for span_text, entities in zip(span_texts, span_entities): | |
gr.Dataset( | |
label=f"「{span_text}」に対応するエンティティ", | |
components=["text"], | |
samples=[[entity] for entity in entities], | |
) | |
demo.launch() | |