|
import torch |
|
import pandas as pd |
|
import gradio as gr |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer, util, models |
|
|
|
|
|
q_encoder = SentenceTransformer(modules=[ |
|
models.Transformer(model_name_or_path="checkpoints/q_encoder", max_seq_length=512), |
|
models.Pooling(word_embedding_dimension=768, pooling_mode='cls'), |
|
]) |
|
doc_embeddings = torch.load('checkpoints/doc_embeddings.pt', map_location=torch.device('cpu')) |
|
docs = pd.DataFrame(load_dataset("maastrichtlawtech/bsard", data_files="articles_fr.csv")['train']) |
|
|
|
def search(question): |
|
q_emb = q_encoder.encode(question, convert_to_tensor=True) |
|
hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0] |
|
return {docs.loc[h['corpus_id'], 'article'] + '\n\n' + f"- Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]} |
|
|
|
title = "Belgian Legislation Search" |
|
description = "A biencoder model was trained to retrieve relevant statutory articles to legal issues. Ask it a question in French!" |
|
article = """ |
|
The model will return the most semantically relevant laws from a corpus of 22,633 statutory articles collected from 32 Belgian codes. |
|
""" |
|
examples = [ |
|
"Qu'est-ce que je risque si je viole le secret professionnel ?", |
|
"Mon employeur peut-il me licencier alors que je suis malade ?", |
|
"Mon voisin fait beaucoup de bruit, que faire ?", |
|
] |
|
gr.Interface(fn=search, inputs=['text'], outputs=['textbox']*5, allow_flagging="never", title=title, description=description, article=article, examples=examples).launch() |
|
|