Spaces:
Sleeping
Sleeping
Fix
Browse files- .gitignore +2 -1
- app.py +3 -6
- meteocat_app.py +175 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
**/venv
|
2 |
-
**/__pycache__
|
|
|
|
1 |
**/venv
|
2 |
+
**/__pycache__
|
3 |
+
**/.env
|
app.py
CHANGED
@@ -3,7 +3,7 @@ from dotenv import load_dotenv
|
|
3 |
import gradio as gr
|
4 |
from gradio.components import Textbox, Button, Slider, Image
|
5 |
from AinaTheme import AinaGradioTheme
|
6 |
-
from
|
7 |
import csv
|
8 |
|
9 |
load_dotenv()
|
@@ -33,7 +33,7 @@ def submit_input(input_, repetition_penalty, temperature):
|
|
33 |
És possible que no hagi trobat el lloc o la data.
|
34 |
Només puc respondre a preguntes sobre el temps a alguna localitat en concret.
|
35 |
""")
|
36 |
-
return
|
37 |
|
38 |
data_as_dict = csv_to_dict("./code2simbol.csv")
|
39 |
codes = outputs["context"]
|
@@ -203,8 +203,5 @@ with gr.Blocks(**AinaGradioTheme().get_kwargs()) as demo:
|
|
203 |
outputs=[output_answer, output_image, output_CCMA]
|
204 |
)
|
205 |
|
206 |
-
# clear_btn.click(fn=clean, inputs=[], outputs=[input_, output_answer, output_context, output_CCMA, repetition_penalty, temperature], queue=False)
|
207 |
-
# submit_btn.click(fn=submit_input, inputs=[input_, repetition_penalty, temperature], outputs=[output_answer, output_context, output_CCMA])
|
208 |
|
209 |
-
|
210 |
-
demo.launch(show_api=True, share=True, debug=True, max_threads=1)
|
|
|
3 |
import gradio as gr
|
4 |
from gradio.components import Textbox, Button, Slider, Image
|
5 |
from AinaTheme import AinaGradioTheme
|
6 |
+
from meteocat_app import generate
|
7 |
import csv
|
8 |
|
9 |
load_dotenv()
|
|
|
33 |
És possible que no hagi trobat el lloc o la data.
|
34 |
Només puc respondre a preguntes sobre el temps a alguna localitat en concret.
|
35 |
""")
|
36 |
+
return None, None, None
|
37 |
|
38 |
data_as_dict = csv_to_dict("./code2simbol.csv")
|
39 |
codes = outputs["context"]
|
|
|
203 |
outputs=[output_answer, output_image, output_CCMA]
|
204 |
)
|
205 |
|
|
|
|
|
206 |
|
207 |
+
demo.launch(show_api=True)
|
|
meteocat_app.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Thu Aug 17 12:11:26 2023
|
5 |
+
|
6 |
+
@author: crodrig1
|
7 |
+
"""
|
8 |
+
from optparse import OptionParser
|
9 |
+
import sys, re, os
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
11 |
+
from pymongo import MongoClient
|
12 |
+
from pprint import pprint
|
13 |
+
import torch
|
14 |
+
import warnings
|
15 |
+
import re, string
|
16 |
+
from dotenv import load_dotenv
|
17 |
+
|
18 |
+
load_dotenv()
|
19 |
+
|
20 |
+
MONGO_URI = os.environ.get("MONGO_URI")
|
21 |
+
|
22 |
+
|
23 |
+
warnings.filterwarnings("ignore")
|
24 |
+
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained("crodri/bloom1.3_meteo")
|
26 |
+
|
27 |
+
|
28 |
+
from transformers import BitsAndBytesConfig
|
29 |
+
|
30 |
+
quantization_config = BitsAndBytesConfig(
|
31 |
+
load_in_4bit=True,
|
32 |
+
llm_int8_enable_fp32_cpu_offload=True,
|
33 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
34 |
+
bnb_4bit_quant_type="nf4",
|
35 |
+
bnb_4bit_use_double_quant=True,
|
36 |
+
)
|
37 |
+
model_4bit = AutoModelForCausalLM.from_pretrained(
|
38 |
+
"crodri/bloom1.3_meteo",
|
39 |
+
model_type="BloomForCausalLM",
|
40 |
+
device_map="cpu",
|
41 |
+
# verbose=False,
|
42 |
+
# quantization_config=quantization_config,
|
43 |
+
trust_remote_code=True)
|
44 |
+
# #tokenizer = AutoTokenizer.from_pretrained(model_id)
|
45 |
+
|
46 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_4bit)
|
47 |
+
|
48 |
+
llm_pipeline = pipeline(
|
49 |
+
"text-generation",
|
50 |
+
model=model_4bit,
|
51 |
+
tokenizer=tokenizer,
|
52 |
+
use_cache=True,
|
53 |
+
device_map="auto",
|
54 |
+
#max_length=800,
|
55 |
+
do_sample=True,
|
56 |
+
top_k=10,
|
57 |
+
num_return_sequences=1,
|
58 |
+
eos_token_id=tokenizer.eos_token_id,
|
59 |
+
pad_token_id=tokenizer.eos_token_id,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def retrieveFor(result):
|
64 |
+
def retrievehighest(key,result):
|
65 |
+
try:
|
66 |
+
candidates = [x for x in result if (x["entity_group"] == key)]
|
67 |
+
topone = max(candidates, key=lambda x:x['score'])
|
68 |
+
return topone['word'].strip()
|
69 |
+
except ValueError:
|
70 |
+
return []
|
71 |
+
def removeend(frase):
|
72 |
+
frase = re.sub("\?","",frase)
|
73 |
+
frase = re.sub(",","",frase)
|
74 |
+
return frase
|
75 |
+
|
76 |
+
intervalo = [x["word"].strip() for x in result if (x["entity_group"] == "interval")]
|
77 |
+
client = MongoClient(MONGO_URI)
|
78 |
+
db = client['aina']
|
79 |
+
collection = db['new_ccma_meteo']
|
80 |
+
try:
|
81 |
+
location = removeend(retrievehighest("location",result))
|
82 |
+
day = removeend(retrievehighest("day",result))
|
83 |
+
except TypeError:
|
84 |
+
print("No hem trobat el lloc o la data. Torna a provar")
|
85 |
+
return None
|
86 |
+
record = collection.find({"location":location.strip(),"day":day.lower().strip()})
|
87 |
+
try:
|
88 |
+
j = record.next()
|
89 |
+
if intervalo:
|
90 |
+
return (j,intervalo[0])
|
91 |
+
return (j,'tot el dia')
|
92 |
+
except:
|
93 |
+
print("No hem trobat el lloc o la data. Torna a provar")
|
94 |
+
return None
|
95 |
+
#context": "Day: dilluns | Location: Sant Salvador de Guardiola | mati: la nuvolositat anirà en augment | tarda: els núvols alts taparan el cel | nit: cel clar | temp: Lleugera pujada de les temperatures"
|
96 |
+
|
97 |
+
pipe = pipeline("token-classification", model="crodri/ccma_ner",aggregation_strategy='first')
|
98 |
+
|
99 |
+
intent = pipeline("text-classification", model="projecte-aina/roberta-large-ca-v2-massive")
|
100 |
+
|
101 |
+
|
102 |
+
def pipeIt(jresponse):
|
103 |
+
regex = re.compile('[%s]' % re.escape(string.punctuation))
|
104 |
+
d = jresponse[0]
|
105 |
+
#i = jresponse[-1]
|
106 |
+
#i = regex.sub('', i)
|
107 |
+
#context = i +" del "+ db["day"]+" a "+db["location"]+" al mati "+db["mati"]+", "+"a la tarda "+db["tarde"]+", a la nit "+db["nit"] +", i "+db["temperature"]
|
108 |
+
#context = d["day"]+" a "+d["location"]+": al mati "+d["mati"]+", "+"a la tarda "+d["tarde"]+", a la nit "+d["nit"] +", i "+d["temperature"]
|
109 |
+
return d["context"]
|
110 |
+
|
111 |
+
#question = "Quin temps farà a la tarda a Algete dijous?"
|
112 |
+
|
113 |
+
def givePrediction(question, context,temperature,repetition):
|
114 |
+
instruction = question
|
115 |
+
text = f"### Instruction\n{{instruction}}\n\n### Context\n{{context}}\n\n### Answer\n"
|
116 |
+
response = llm_pipeline(text.format(instruction=instruction, context=context),temperature=temperature,repetition_penalty=repetition, max_new_tokens=100)[0]["generated_text"]
|
117 |
+
answer = response.split("###")[-1][8:]
|
118 |
+
return answer
|
119 |
+
|
120 |
+
def assistant(question):
|
121 |
+
is_intent = intent(question)[0]
|
122 |
+
if is_intent['label'] == 'weather_query':
|
123 |
+
result = pipe(question)
|
124 |
+
jresponse = retrieveFor(result)
|
125 |
+
if jresponse:
|
126 |
+
context = jresponse[0]['context']#pipeIt(jresponse)
|
127 |
+
#jresponse[0]['context'] = context
|
128 |
+
print("Context: ",context)
|
129 |
+
print()
|
130 |
+
return jresponse
|
131 |
+
elif is_intent['label'] in ["general_greet","general_quirky"]:
|
132 |
+
print("Hola, quina es la teva consulta meteorològica?")
|
133 |
+
#sys.exit(0)
|
134 |
+
else:
|
135 |
+
print(is_intent['label'])
|
136 |
+
print("Ho sento. Jo només puc respondre a preguntes sobre el temps a alguna localitat en concret ...")
|
137 |
+
#sys.exit(0)
|
138 |
+
return None
|
139 |
+
def generate(question,temperature,repetition):
|
140 |
+
jresponse = assistant(question)
|
141 |
+
#print(jresponse)
|
142 |
+
if jresponse:
|
143 |
+
codes = jresponse[0]['codis']
|
144 |
+
interval = jresponse[1]
|
145 |
+
context = {"codis": codes, "interval": interval}
|
146 |
+
# context = jresponse[0]['context']
|
147 |
+
ccma_response = jresponse[0]['response']
|
148 |
+
answer = givePrediction(question, context,temperature,repetition)
|
149 |
+
print("CCMA generated: ",ccma_response)
|
150 |
+
print("="*16)
|
151 |
+
print("LLM answer: ",answer)
|
152 |
+
print()
|
153 |
+
|
154 |
+
return {"context": context, "ccma_response": ccma_response, "model_answer": answer}
|
155 |
+
else:
|
156 |
+
print("No response")
|
157 |
+
return None
|
158 |
+
|
159 |
+
def main():
|
160 |
+
parser = OptionParser()
|
161 |
+
parser.add_option("-q", "--question", dest="question", type="string",
|
162 |
+
help="question to test", default="Quin temps farà a la tarda a Algete dijous?")
|
163 |
+
parser.add_option("-t", "--temperature", dest="temperature", type="float",
|
164 |
+
help="temperature generation", default=1.0)
|
165 |
+
parser.add_option("-r", "--repetition", dest="repetition", type="float",
|
166 |
+
help="repetition penalty", default=1.0)
|
167 |
+
(options, args) = parser.parse_args(sys.argv)
|
168 |
+
print(options)
|
169 |
+
#question = options.question
|
170 |
+
#print(question)
|
171 |
+
answer = generate(options.question,options.temperature,options.repetition)
|
172 |
+
#print(answer)
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
main()
|