Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Aug 17 12:11:26 2023 | |
@author: crodrig1 | |
""" | |
from optparse import OptionParser | |
import sys, re, os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from pymongo import MongoClient | |
import torch | |
import warnings | |
import re, string | |
from dotenv import load_dotenv | |
load_dotenv() | |
MONGO_URI = os.environ.get("MONGO_URI") | |
model_path = "crodri/RAGGFlor1.3" | |
warnings.filterwarnings("ignore") | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
llm_int8_enable_fp32_cpu_offload=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
model_4bit = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
model_type="BloomForCausalLM", | |
device_map="auto", | |
#verbose=False, | |
#quantization_config=quantization_config, | |
trust_remote_code=True) | |
# tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# tokenizer = AutoTokenizer.from_pretrained(model_4bit) | |
llm_pipeline = pipeline( | |
"text-generation", | |
model=model_4bit, | |
tokenizer=tokenizer, | |
use_cache=True, | |
device_map="auto", | |
max_length=400, | |
do_sample=True, | |
top_k=30, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
pipe = pipeline("token-classification", model="crodri/ccma_ner",aggregation_strategy='first') | |
intent = pipeline("text-classification", model="projecte-aina/roberta-large-ca-v2-massive") | |
def retrievehighest(key,result): | |
try: | |
candidates = [x for x in result if (x["entity_group"] == key)] | |
topone = max(candidates, key=lambda x:x['score']) | |
return topone['word'].strip() | |
except ValueError: | |
return [] | |
def retrieveFor(result): | |
try: | |
removeend = lambda frase: re.sub("[?,]", "", frase) | |
location = removeend(retrievehighest("location",result)) | |
day = removeend(retrievehighest("day",result)) | |
except TypeError: | |
print("No hem trobat el lloc o la data. Torna a provar") | |
return None | |
intervalo = [x["word"].strip() for x in result if (x["entity_group"] == "interval")] | |
try: | |
client = MongoClient(MONGO_URI) | |
db = client['aina'] | |
collection = db['new_ccma_meteo'] | |
record = collection.find_one({"location":location.strip(),"day":day.lower().strip()}) | |
if record: | |
return (record, intervalo[0] if intervalo else 'tot el dia') | |
else: | |
print("No hem trobat el lloc o la data. Torna a provar") | |
return None | |
except: | |
print("Error al buscar en la base de dades.") | |
return None | |
#context": "Day: dilluns | Location: Sant Salvador de Guardiola | mati: la nuvolositat anirà en augment | tarda: els núvols alts taparan el cel | nit: cel clar | temp: Lleugera pujada de les temperatures" | |
def pipeIt(jresponse): | |
regex = re.compile('[%s]' % re.escape(string.punctuation)) | |
d = jresponse[0] | |
#i = jresponse[-1] | |
#i = regex.sub('', i) | |
#context = i +" del "+ db["day"]+" a "+db["location"]+" al mati "+db["mati"]+", "+"a la tarda "+db["tarde"]+", a la nit "+db["nit"] +", i "+db["temperature"] | |
#context = d["day"]+" a "+d["location"]+": al mati "+d["mati"]+", "+"a la tarda "+d["tarde"]+", a la nit "+d["nit"] +", i "+d["temperature"] | |
return d["context"] | |
#question = "Quin temps farà a la tarda a Algete dijous?" | |
def givePrediction(question, context, temperature, repetition): | |
prompt = "You are an expert meteorologist. Using the information in the 'context', write in Catalan a weather forecast that satisfies to the user's request about the weather at an specific location or time. Do not mention that you are using a 'context', and only provide the information that is relevant to answer the question. Do not give any expanded response, explanation, context or any other information and be as accurate as possible. " | |
query = f"### Instruction\n{{instruction}}\n\n### Context\n{{context}}\n\n### Answer\n" | |
response = llm_pipeline(prompt + query.format(instruction=question, context=context), | |
temperature=temperature, | |
repetition_penalty=repetition, | |
max_new_tokens=40 | |
)[0]["generated_text"] | |
answer = response.split("###")[-1][8:] | |
return answer | |
def assistant(question): | |
is_intent = intent(question)[0] | |
if is_intent['label'] == 'weather_query': | |
result = pipe(question) | |
jresponse = retrieveFor(result) | |
if jresponse: | |
# context = jresponse[0]['context']#pipeIt(jresponse) | |
# #jresponse[0]['context'] = context | |
# print("Context: ",context) | |
# print() | |
return jresponse | |
elif is_intent['label'] in ["general_greet","general_quirky"]: | |
print("Hola, quina és la teva consulta meteorològica?") | |
#sys.exit(0) | |
else: | |
print(is_intent['label']) | |
print("Ho sento. Jo només puc respondre a preguntes sobre el temps a alguna localitat en concret ...") | |
#sys.exit(0) | |
return None | |
def generate(question,temperature,repetition): | |
jresponse = assistant(question) | |
if jresponse: | |
codes = jresponse[0]['codis'] | |
interval = jresponse[1] | |
codes_interval = {"codis": codes, "interval": interval} | |
context = jresponse[0]['context'] | |
ccma_response = jresponse[0]['response'] | |
answer = givePrediction(question, context,temperature,repetition) | |
print("Codes: ",codes_interval) | |
print() | |
print("Context: ",context) | |
print() | |
print("CCMA generated: ",ccma_response) | |
print("="*16) | |
print("LLM answer: ",answer) | |
print() | |
return {"codes":codes_interval, "context": context, "ccma_response": ccma_response, "model_answer": answer} | |
else: | |
print("No response") | |
return None | |
def main(): | |
parser = OptionParser() | |
parser.add_option("-q", "--question", dest="question", type="string", | |
help="question to test", default="Quin temps farà a la tarda a Algete dijous?") | |
parser.add_option("-t", "--temperature", dest="temperature", type="float", | |
help="temperature generation", default=0.94) | |
parser.add_option("-r", "--repetition", dest="repetition", type="float", | |
help="repetition penalty", default=0.5) | |
(options, args) = parser.parse_args(sys.argv) | |
print(options) | |
#question = options.question | |
#print(question) | |
answer = generate(options.question,options.temperature,options.repetition) | |
#print(answer) | |
if __name__ == "__main__": | |
main() |