BlooMeteo / meteocat_app.py
PaulNdrei's picture
Fix
1999a4e
raw
history blame
No virus
6.24 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 17 12:11:26 2023
@author: crodrig1
"""
from optparse import OptionParser
import sys, re, os
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from pymongo import MongoClient
from pprint import pprint
import torch
import warnings
import re, string
from dotenv import load_dotenv
load_dotenv()
MONGO_URI = os.environ.get("MONGO_URI")
warnings.filterwarnings("ignore")
tokenizer = AutoTokenizer.from_pretrained("crodri/bloom1.3_meteo")
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_enable_fp32_cpu_offload=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
"crodri/bloom1.3_meteo",
model_type="BloomForCausalLM",
device_map="cpu",
# verbose=False,
# quantization_config=quantization_config,
trust_remote_code=True)
# #tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_4bit)
llm_pipeline = pipeline(
"text-generation",
model=model_4bit,
tokenizer=tokenizer,
use_cache=True,
device_map="auto",
#max_length=800,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
def retrieveFor(result):
def retrievehighest(key,result):
try:
candidates = [x for x in result if (x["entity_group"] == key)]
topone = max(candidates, key=lambda x:x['score'])
return topone['word'].strip()
except ValueError:
return []
def removeend(frase):
frase = re.sub("\?","",frase)
frase = re.sub(",","",frase)
return frase
intervalo = [x["word"].strip() for x in result if (x["entity_group"] == "interval")]
client = MongoClient(MONGO_URI)
db = client['aina']
collection = db['new_ccma_meteo']
try:
location = removeend(retrievehighest("location",result))
day = removeend(retrievehighest("day",result))
except TypeError:
print("No hem trobat el lloc o la data. Torna a provar")
return None
record = collection.find({"location":location.strip(),"day":day.lower().strip()})
try:
j = record.next()
if intervalo:
return (j,intervalo[0])
return (j,'tot el dia')
except:
print("No hem trobat el lloc o la data. Torna a provar")
return None
#context": "Day: dilluns | Location: Sant Salvador de Guardiola | mati: la nuvolositat anirà en augment | tarda: els núvols alts taparan el cel | nit: cel clar | temp: Lleugera pujada de les temperatures"
pipe = pipeline("token-classification", model="crodri/ccma_ner",aggregation_strategy='first')
intent = pipeline("text-classification", model="projecte-aina/roberta-large-ca-v2-massive")
def pipeIt(jresponse):
regex = re.compile('[%s]' % re.escape(string.punctuation))
d = jresponse[0]
#i = jresponse[-1]
#i = regex.sub('', i)
#context = i +" del "+ db["day"]+" a "+db["location"]+" al mati "+db["mati"]+", "+"a la tarda "+db["tarde"]+", a la nit "+db["nit"] +", i "+db["temperature"]
#context = d["day"]+" a "+d["location"]+": al mati "+d["mati"]+", "+"a la tarda "+d["tarde"]+", a la nit "+d["nit"] +", i "+d["temperature"]
return d["context"]
#question = "Quin temps farà a la tarda a Algete dijous?"
def givePrediction(question, context,temperature,repetition):
instruction = question
text = f"### Instruction\n{{instruction}}\n\n### Context\n{{context}}\n\n### Answer\n"
response = llm_pipeline(text.format(instruction=instruction, context=context),temperature=temperature,repetition_penalty=repetition, max_new_tokens=100)[0]["generated_text"]
answer = response.split("###")[-1][8:]
return answer
def assistant(question):
is_intent = intent(question)[0]
if is_intent['label'] == 'weather_query':
result = pipe(question)
jresponse = retrieveFor(result)
if jresponse:
context = jresponse[0]['context']#pipeIt(jresponse)
#jresponse[0]['context'] = context
print("Context: ",context)
print()
return jresponse
elif is_intent['label'] in ["general_greet","general_quirky"]:
print("Hola, quina es la teva consulta meteorològica?")
#sys.exit(0)
else:
print(is_intent['label'])
print("Ho sento. Jo només puc respondre a preguntes sobre el temps a alguna localitat en concret ...")
#sys.exit(0)
return None
def generate(question,temperature,repetition):
jresponse = assistant(question)
#print(jresponse)
if jresponse:
codes = jresponse[0]['codis']
interval = jresponse[1]
context = {"codis": codes, "interval": interval}
# context = jresponse[0]['context']
ccma_response = jresponse[0]['response']
answer = givePrediction(question, context,temperature,repetition)
print("CCMA generated: ",ccma_response)
print("="*16)
print("LLM answer: ",answer)
print()
return {"context": context, "ccma_response": ccma_response, "model_answer": answer}
else:
print("No response")
return None
def main():
parser = OptionParser()
parser.add_option("-q", "--question", dest="question", type="string",
help="question to test", default="Quin temps farà a la tarda a Algete dijous?")
parser.add_option("-t", "--temperature", dest="temperature", type="float",
help="temperature generation", default=1.0)
parser.add_option("-r", "--repetition", dest="repetition", type="float",
help="repetition penalty", default=1.0)
(options, args) = parser.parse_args(sys.argv)
print(options)
#question = options.question
#print(question)
answer = generate(options.question,options.temperature,options.repetition)
#print(answer)
if __name__ == "__main__":
main()