Spaces:
Running
Running
Marina Pliusnina
commited on
Commit
·
c774338
1
Parent(s):
c8bd9ca
adding number of chunks and context
Browse files
app.py
CHANGED
@@ -37,13 +37,14 @@ def generate(prompt, model_parameters):
|
|
37 |
)
|
38 |
|
39 |
|
40 |
-
def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
|
41 |
if input_.strip() == "":
|
42 |
gr.Warning("Not possible to inference an empty input")
|
43 |
return None
|
44 |
|
45 |
|
46 |
model_parameters = {
|
|
|
47 |
"MAX_NEW_TOKENS": max_new_tokens,
|
48 |
"REPETITION_PENALTY": repetition_penalty,
|
49 |
"TOP_K": top_k,
|
@@ -109,6 +110,13 @@ def gradio_app():
|
|
109 |
|
110 |
with gr.Row(variant="panel"):
|
111 |
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
max_new_tokens = Slider(
|
113 |
minimum=50,
|
114 |
maximum=1000,
|
@@ -154,7 +162,7 @@ def gradio_app():
|
|
154 |
label="Temperature"
|
155 |
)
|
156 |
|
157 |
-
parameters_compontents = [max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
|
158 |
|
159 |
with gr.Column(variant="panel"):
|
160 |
output = Textbox(
|
|
|
37 |
)
|
38 |
|
39 |
|
40 |
+
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
|
41 |
if input_.strip() == "":
|
42 |
gr.Warning("Not possible to inference an empty input")
|
43 |
return None
|
44 |
|
45 |
|
46 |
model_parameters = {
|
47 |
+
"NUM_CHUNKS": num_chunks,
|
48 |
"MAX_NEW_TOKENS": max_new_tokens,
|
49 |
"REPETITION_PENALTY": repetition_penalty,
|
50 |
"TOP_K": top_k,
|
|
|
110 |
|
111 |
with gr.Row(variant="panel"):
|
112 |
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
113 |
+
num_chunks = Slider(
|
114 |
+
minimum=1,
|
115 |
+
maximum=6,
|
116 |
+
step=1,
|
117 |
+
value=4,
|
118 |
+
label="Number of chunks"
|
119 |
+
)
|
120 |
max_new_tokens = Slider(
|
121 |
minimum=50,
|
122 |
maximum=1000,
|
|
|
162 |
label="Temperature"
|
163 |
)
|
164 |
|
165 |
+
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
|
166 |
|
167 |
with gr.Column(variant="panel"):
|
168 |
output = Textbox(
|
rag.py
CHANGED
@@ -24,19 +24,11 @@ class RAG:
|
|
24 |
|
25 |
logging.info("RAG loaded!")
|
26 |
|
27 |
-
def get_context(self, instruction, number_of_contexts=
|
28 |
-
|
29 |
-
context = ""
|
30 |
-
|
31 |
|
32 |
documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
|
33 |
|
34 |
-
|
35 |
-
for doc in documentos:
|
36 |
-
|
37 |
-
context += doc[0].page_content
|
38 |
-
|
39 |
-
return context
|
40 |
|
41 |
def predict(self, instruction, context, model_parameters):
|
42 |
|
@@ -61,14 +53,30 @@ class RAG:
|
|
61 |
response = requests.post(self.model_name, headers=headers, json=payload)
|
62 |
|
63 |
return response.json()[0]["generated_text"].split("###")[-1][8:-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
66 |
|
67 |
-
|
|
|
|
|
|
|
68 |
|
69 |
-
response = self.predict(prompt,
|
70 |
|
71 |
if not response:
|
72 |
return self.NO_ANSWER_MESSAGE
|
73 |
|
74 |
-
return response
|
|
|
24 |
|
25 |
logging.info("RAG loaded!")
|
26 |
|
27 |
+
def get_context(self, instruction, number_of_contexts=4):
|
|
|
|
|
|
|
28 |
|
29 |
documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
|
30 |
|
31 |
+
return documentos
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def predict(self, instruction, context, model_parameters):
|
34 |
|
|
|
53 |
response = requests.post(self.model_name, headers=headers, json=payload)
|
54 |
|
55 |
return response.json()[0]["generated_text"].split("###")[-1][8:-1]
|
56 |
+
|
57 |
+
def beautiful_context(self, docs):
|
58 |
+
|
59 |
+
text_context = ""
|
60 |
+
|
61 |
+
full_context = ""
|
62 |
+
|
63 |
+
for doc in docs:
|
64 |
+
text_context += doc[0].page_content
|
65 |
+
full_context += doc[0].page_content + "\n"
|
66 |
+
full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
|
67 |
+
|
68 |
+
return text_context, full_context
|
69 |
|
70 |
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
71 |
|
72 |
+
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
73 |
+
text_context, full_context = beautiful_context(docs)
|
74 |
+
|
75 |
+
del model_parameters["NUM_CHUNKS"]
|
76 |
|
77 |
+
response = self.predict(prompt, text_context, model_parameters)
|
78 |
|
79 |
if not response:
|
80 |
return self.NO_ANSWER_MESSAGE
|
81 |
|
82 |
+
return response, full_context
|