BlooMeteo / meteocat_app.py
crodri's picture
Update meteocat_app.py
04a50b4 verified
raw
history blame
No virus
6.85 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 17 12:11:26 2023
@author: crodrig1
"""
from optparse import OptionParser
import sys, re, os
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from pymongo import MongoClient
import torch
import warnings
import re, string
from dotenv import load_dotenv
load_dotenv()
MONGO_URI = os.environ.get("MONGO_URI")
model_path = "crodri/RAGGFlor1.3"
warnings.filterwarnings("ignore")
tokenizer = AutoTokenizer.from_pretrained(model_path)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_enable_fp32_cpu_offload=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
model_path,
model_type="BloomForCausalLM",
device_map="auto",
#verbose=False,
#quantization_config=quantization_config,
trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_4bit)
llm_pipeline = pipeline(
"text-generation",
model=model_4bit,
tokenizer=tokenizer,
use_cache=True,
device_map="auto",
max_length=400,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
pipe = pipeline("token-classification", model="crodri/ccma_ner",aggregation_strategy='first')
intent = pipeline("text-classification", model="projecte-aina/roberta-large-ca-v2-massive")
def retrievehighest(key,result):
try:
candidates = [x for x in result if (x["entity_group"] == key)]
topone = max(candidates, key=lambda x:x['score'])
return topone['word'].strip()
except ValueError:
return []
def retrieveFor(result):
try:
removeend = lambda frase: re.sub("[?,]", "", frase)
location = removeend(retrievehighest("location",result))
day = removeend(retrievehighest("day",result))
except TypeError:
print("No hem trobat el lloc o la data. Torna a provar")
return None
intervalo = [x["word"].strip() for x in result if (x["entity_group"] == "interval")]
try:
client = MongoClient(MONGO_URI)
db = client['aina']
collection = db['new_ccma_meteo']
record = collection.find_one({"location":location.strip(),"day":day.lower().strip()})
if record:
return (record, intervalo[0] if intervalo else 'tot el dia')
else:
print("No hem trobat el lloc o la data. Torna a provar")
return None
except:
print("Error al buscar en la base de dades.")
return None
#context": "Day: dilluns | Location: Sant Salvador de Guardiola | mati: la nuvolositat anirà en augment | tarda: els núvols alts taparan el cel | nit: cel clar | temp: Lleugera pujada de les temperatures"
def pipeIt(jresponse):
regex = re.compile('[%s]' % re.escape(string.punctuation))
d = jresponse[0]
#i = jresponse[-1]
#i = regex.sub('', i)
#context = i +" del "+ db["day"]+" a "+db["location"]+" al mati "+db["mati"]+", "+"a la tarda "+db["tarde"]+", a la nit "+db["nit"] +", i "+db["temperature"]
#context = d["day"]+" a "+d["location"]+": al mati "+d["mati"]+", "+"a la tarda "+d["tarde"]+", a la nit "+d["nit"] +", i "+d["temperature"]
return d["context"]
#question = "Quin temps farà a la tarda a Algete dijous?"
def givePrediction(question, context, temperature, repetition):
prompt = "You are an expert meteorologist. Using the information in the 'context', write in Catalan a weather forecast that satisfies to the user's request about the weather at an specific location or time. Do not mention that you are using a 'context', and only provide the information that is relevant to answer the question. Do not give any expanded response, explanation, context or any other information and be as accurate as possible. "
query = f"### Instruction\n{{instruction}}\n\n### Context\n{{context}}\n\n### Answer\n"
response = llm_pipeline(prompt + query.format(instruction=question, context=context),
temperature=temperature,
repetition_penalty=repetition,
max_new_tokens=40
)[0]["generated_text"]
answer = response.split("###")[-1][8:]
return answer
def assistant(question):
is_intent = intent(question)[0]
if is_intent['label'] == 'weather_query':
result = pipe(question)
jresponse = retrieveFor(result)
if jresponse:
# context = jresponse[0]['context']#pipeIt(jresponse)
# #jresponse[0]['context'] = context
# print("Context: ",context)
# print()
return jresponse
elif is_intent['label'] in ["general_greet","general_quirky"]:
print("Hola, quina és la teva consulta meteorològica?")
#sys.exit(0)
else:
print(is_intent['label'])
print("Ho sento. Jo només puc respondre a preguntes sobre el temps a alguna localitat en concret ...")
#sys.exit(0)
return None
def generate(question,temperature,repetition):
jresponse = assistant(question)
if jresponse:
codes = jresponse[0]['codis']
interval = jresponse[1]
codes_interval = {"codis": codes, "interval": interval}
context = jresponse[0]['context']
ccma_response = jresponse[0]['response']
answer = givePrediction(question, context,temperature,repetition)
print("Codes: ",codes_interval)
print()
print("Context: ",context)
print()
print("CCMA generated: ",ccma_response)
print("="*16)
print("LLM answer: ",answer)
print()
return {"codes":codes_interval, "context": context, "ccma_response": ccma_response, "model_answer": answer}
else:
print("No response")
return None
def main():
parser = OptionParser()
parser.add_option("-q", "--question", dest="question", type="string",
help="question to test", default="Quin temps farà a la tarda a Algete dijous?")
parser.add_option("-t", "--temperature", dest="temperature", type="float",
help="temperature generation", default=0.94)
parser.add_option("-r", "--repetition", dest="repetition", type="float",
help="repetition penalty", default=0.5)
(options, args) = parser.parse_args(sys.argv)
print(options)
#question = options.question
#print(question)
answer = generate(options.question,options.temperature,options.repetition)
#print(answer)
if __name__ == "__main__":
main()