Spaces:
Paused
Paused
Carlos Rosas
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -54,23 +54,23 @@ def hybrid_search(text):
|
|
54 |
title = row['section']
|
55 |
content = row['text']
|
56 |
|
57 |
-
document.append(f"<|source_id_start|>{hash_id}<|source_id_end
|
58 |
-
|
|
|
59 |
|
60 |
-
document = "\n
|
61 |
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
|
62 |
return document, document_html
|
63 |
|
64 |
class CassandreChatBot:
|
65 |
-
def __init__(self, system_prompt="Tu es
|
66 |
self.system_prompt = system_prompt
|
67 |
|
68 |
def predict(self, user_message):
|
69 |
fiches, fiches_html = hybrid_search(user_message)
|
70 |
|
71 |
-
detailed_prompt = f"""<|query_start|>{user_message}<|query_end|>\n
|
72 |
|
73 |
-
# Convert inputs to tensor
|
74 |
input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device)
|
75 |
attention_mask = torch.ones_like(input_ids)
|
76 |
|
@@ -88,7 +88,7 @@ class CassandreChatBot:
|
|
88 |
eos_token_id=tokenizer.eos_token_id
|
89 |
)
|
90 |
|
91 |
-
#
|
92 |
generated_text = tokenizer.decode(output[0][len(input_ids[0]):])
|
93 |
|
94 |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
|
|
|
54 |
title = row['section']
|
55 |
content = row['text']
|
56 |
|
57 |
+
document.append(f"<|source_start|><|source_id_start|>{hash_id}<|source_id_end|>{title}\n{content}<|source_end|>")
|
58 |
+
|
59 |
+
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>')
|
60 |
|
61 |
+
document = "\n".join(document)
|
62 |
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
|
63 |
return document, document_html
|
64 |
|
65 |
class CassandreChatBot:
|
66 |
+
def __init__(self, system_prompt="Tu es un asistant de recherche qui donne des responses sourcées"):
|
67 |
self.system_prompt = system_prompt
|
68 |
|
69 |
def predict(self, user_message):
|
70 |
fiches, fiches_html = hybrid_search(user_message)
|
71 |
|
72 |
+
detailed_prompt = f"""<|query_start|>{user_message}<|query_end|>\n{fiches}\n<|source_analysis_start|>"""
|
73 |
|
|
|
74 |
input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device)
|
75 |
attention_mask = torch.ones_like(input_ids)
|
76 |
|
|
|
88 |
eos_token_id=tokenizer.eos_token_id
|
89 |
)
|
90 |
|
91 |
+
# Decode only the new tokens
|
92 |
generated_text = tokenizer.decode(output[0][len(input_ids[0]):])
|
93 |
|
94 |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
|