File size: 3,590 Bytes
18a943c
1dcf491
 
18a943c
 
77302ce
18a943c
1dcf491
18a943c
 
 
 
 
 
 
 
 
1dcf491
18a943c
1dcf491
 
 
 
 
 
 
18a943c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dcf491
 
18a943c
 
1dcf491
18a943c
1dcf491
 
 
 
 
 
 
 
 
 
 
 
18a943c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from sentence_transformers import SentenceTransformer
from together import Together
from dotenv import load_dotenv
import vecs
import os
import gradio as gr

load_dotenv()

user = os.getenv("user")
password = os.getenv("password")
host = os.getenv("host")
port = os.getenv("port")
db_name = "postgres"
DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
vx = vecs.create_client(DB_CONNECTION)
model = SentenceTransformer('Snowflake/snowflake-arctic-embed-xs')
client = Together(api_key=os.getenv('TOGETHER_API_KEY'))

def live_inference(messages, max_new_tokens = 512):
    response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
    messages = messages,
    max_tokens = max_new_tokens
    )
    return response.choices[0].message.content

def query_db(query, limit = 5, filters = {}, measure = "cosine_distance", include_value = False, include_metadata=False, table = "2023"):
  query_embeds = vx.get_or_create_collection(name= table, dimension=384)
  ans = query_embeds.query(
      data=query,
      limit=limit,
      filters=filters,
      measure=measure,
      include_value=include_value,
      include_metadata=include_metadata,
  )
  return ans  

def construct_result(ans):
  ans.sort(key=sort_by_score, reverse=True)
  results = ""
  for i in range(0, len(ans)):
    a, b = ans[i][2].get("sentencia"), ans[i][2].get("fragmento")
    results += (f"En la sentencia {a}, se dijo {b}\n")
  return results

def sort_by_score(item):
  return item[1]

def referencias(results):
  references = 'Sentencias encontradas: \n'
  enlistadas = []
  for item in results: 
    if item[2].get('sentencia') in enlistadas: 
      pass
    else: 
      references += item[2].get('sentencia')+ ' '
      enlistadas.append(item[2].get('sentencia'))
  return references

def inference(prompt):
  encoded_prompt1 = model.encode(prompt)
  years = range(2019, 2025)
  results = []
  for year in years:
    results.extend(query_db(encoded_prompt1, include_metadata = True, table = str(year), include_value=True, limit = 5))
  results.sort(key=sort_by_score, reverse=True)
  researchAI=[
        {"role": "system", "content": f"""
                  Eres Ticio, un asistente de investigaci贸n jur铆dica. Tu deber es organizar el contenido de las sentencias de la jurisprudencia de acuerdo 
                  a las necesidades del usuario. Debes responder solo en espa帽ol. Debes responder solo en base a la informaci贸n del contexto a continuaci贸n.
                  Siempre debes mencionar la fuente en tu escrito, debe tener un estilo formal y juridico.
                  Contexto: 
                  {construct_result(results)}
                  """
        },
        {"role": "user", "content": prompt},
    ]
  return live_inference(researchAI, max_new_tokens=1024) + '\n' + referencias(results)

theme = gr.themes.Base(
    primary_hue="red",
    secondary_hue="red",
    neutral_hue="neutral",
).set(
    button_primary_background_fill='#910A0A',
    button_primary_border_color='*primary_300',
    button_primary_text_color='*primary_50'
)

with gr.Blocks(theme=theme) as demo:
    output = gr.Textbox(label = "Ticio", lines = 15, show_label = True, show_copy_button= True)
    name = gr.Textbox(label="Name", show_label = False, container = True, placeholder = "驴Que quieres buscar?")
    greet_btn = gr.Button("Preguntar", variant = "primary")
    greet_btn.click(fn=inference, inputs=name, outputs=output, api_name=False)

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=60)
    demo.launch(show_api=False)