import base64 import re import json import pandas as pd import gradio as gr import pyterrier as pt pt.init() import pyt_splade from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D factory_max = pyt_splade.SpladeFactory(agg='max') factory_sum = pyt_splade.SpladeFactory(agg='sum') COLAB_NAME = 'pyterrier_splade.ipynb' COLAB_INSTALL = ''' !pip install -q git+https://github.com/naver/splade !pip install -q git+https://github.com/seanmacavaney/pyt_splade@misc '''.strip() def generate_vis(df, mode='Document'): if len(df) == 0: return '' result = [] if mode == 'Document': max_score = max(max(t.values()) for t in df['toks']) for row in df.itertuples(index=False): if mode == 'Query': tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)} for key, value in list(tok_scores.items()): if key.startswith('#base64('): b64 = re.search('#base64\(([^)]+)\)', key).group(1) del tok_scores[key] key = base64.b64decode(b64).decode() tok_scores[key] = value max_score = max(tok_scores.values()) orig_tokens = factory_max.tokenizer.tokenize(row.query_0) id = row.qid else: tok_scores = row.toks orig_tokens = factory_max.tokenizer.tokenize(row.text) id = row.docno def toks2span(toks): return ' '.join(f'{t}' for t in toks) orig_tokens_set = set(orig_tokens) exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set] result.append(f'''