Spaces:
Running
Running
from sentence_transformers import SentenceTransformer | |
from together import Together | |
from dotenv import load_dotenv | |
import vecs | |
import os | |
import gradio as gr | |
load_dotenv() | |
user = os.getenv("user") | |
password = os.getenv("password") | |
host = os.getenv("host") | |
port = os.getenv("port") | |
db_name = "postgres" | |
DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}" | |
vx = vecs.create_client(DB_CONNECTION) | |
model = SentenceTransformer('Snowflake/snowflake-arctic-embed-xs') | |
client = Together(api_key=os.getenv('TOGETHER_API_KEY')) | |
def live_inference(messages, max_new_tokens = 512): | |
response = client.chat.completions.create( | |
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | |
messages = messages, | |
max_tokens = max_new_tokens | |
) | |
return response.choices[0].message.content | |
def query_db(query, limit = 5, filters = {}, measure = "cosine_distance", include_value = False, include_metadata=False, table = "2023"): | |
query_embeds = vx.get_or_create_collection(name= table, dimension=384) | |
ans = query_embeds.query( | |
data=query, | |
limit=limit, | |
filters=filters, | |
measure=measure, | |
include_value=include_value, | |
include_metadata=include_metadata, | |
) | |
return ans | |
def construct_result(ans): | |
ans.sort(key=sort_by_score, reverse=True) | |
results = "" | |
for i in range(0, len(ans)): | |
a, b = ans[i][2].get("sentencia"), ans[i][2].get("fragmento") | |
results += (f"En la sentencia {a}, se dijo {b}\n") | |
return results | |
def sort_by_score(item): | |
return item[1] | |
def referencias(results): | |
references = 'Sentencias encontradas: \n' | |
enlistadas = [] | |
for item in results: | |
if item[2].get('sentencia') in enlistadas: | |
pass | |
else: | |
references += item[2].get('sentencia')+ ' ' | |
enlistadas.append(item[2].get('sentencia')) | |
return references | |
def inference(prompt): | |
encoded_prompt1 = model.encode(prompt) | |
years = range(2019, 2025) | |
results = [] | |
for year in years: | |
results.extend(query_db(encoded_prompt1, include_metadata = True, table = str(year), include_value=True, limit = 5)) | |
results.sort(key=sort_by_score, reverse=True) | |
researchAI=[ | |
{"role": "system", "content": f""" | |
Eres Ticio, un asistente de investigaci贸n jur铆dica. Tu deber es organizar el contenido de las sentencias de la jurisprudencia de acuerdo | |
a las necesidades del usuario. Debes responder solo en espa帽ol. Debes responder solo en base a la informaci贸n del contexto a continuaci贸n. | |
Siempre debes mencionar la fuente en tu escrito, debe tener un estilo formal y juridico. | |
Contexto: | |
{construct_result(results)} | |
""" | |
}, | |
{"role": "user", "content": prompt}, | |
] | |
return live_inference(researchAI, max_new_tokens=1024) + '\n' + referencias(results) | |
theme = gr.themes.Base( | |
primary_hue="red", | |
secondary_hue="red", | |
neutral_hue="neutral", | |
).set( | |
button_primary_background_fill='#910A0A', | |
button_primary_border_color='*primary_300', | |
button_primary_text_color='*primary_50' | |
) | |
with gr.Blocks(theme=theme) as demo: | |
output = gr.Textbox(label = "Ticio", lines = 15, show_label = True, show_copy_button= True) | |
name = gr.Textbox(label="Name", show_label = False, container = True, placeholder = "驴Que quieres buscar?") | |
greet_btn = gr.Button("Preguntar", variant = "primary") | |
greet_btn.click(fn=inference, inputs=name, outputs=output, api_name=False) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=60) | |
demo.launch(show_api=False) |