import gradio as gr # !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig # from google.colab import userdata import os model_id = "somosnlp/GemmaColRAC-AeroExpert" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) max_seq_length=400 # if torch.cuda.get_device_capability()[0] >= 8: # # print("Flash Attention") # attn_implementation="flash_attention_2" # else: # attn_implementation=None attn_implementation=None tokenizer = AutoTokenizer.from_pretrained(model_id, max_length = max_seq_length) model = AutoModelForCausalLM.from_pretrained(model_id, # quantization_config=bnb_config, device_map = {"":0}, attn_implementation = attn_implementation, # A100 o H100 ).eval() class ListOfTokensStoppingCriteria(StoppingCriteria): """ Clase para definir un criterio de parada basado en una lista de tokens específicos. """ def __init__(self, tokenizer, stop_tokens): self.tokenizer = tokenizer # Codifica cada token de parada y guarda sus IDs en una lista self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens] def __call__(self, input_ids, scores, **kwargs): # Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada for stop_token_ids in self.stop_token_ids_list: len_stop_tokens = len(stop_token_ids) if len(input_ids[0]) >= len_stop_tokens: if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids: return True return False # Uso del criterio de parada personalizado stop_tokens = [""] # Lista de tokens de parada # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens) # Añade tu criterio de parada a una StoppingCriteriaList stopping_criteria_list = StoppingCriteriaList([stopping_criteria]) def generate_text(prompt, max_length=2100): # prompt="""What were the main contributions of Eratosthenes to the development of mathematics in ancient Greece?""" prompt=prompt.replace("\n", "").replace("¿","").replace("?","") #EXAMPLE input_text = f'''system\nYou are a helpful AI assistant.\nResponde en formato json.\nEres un agente experto en la normativa aeronautica Colombiana.\nuser\n¿{prompt}?\nmodel\n''' inputs = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=False).to("cuda:0") max_new_tokens=max_length generation_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=0.15, #top_p=0.9, top_k=40, # 45 repetition_penalty=1., #1.1 do_sample=True, ) outputs = model.generate(generation_config=generation_config, input_ids=inputs, stopping_criteria=stopping_criteria_list,) return tokenizer.decode(outputs[0], skip_special_tokens=False) #True def mostrar_respuesta(pregunta): json_obj={} json_obj['respuesta']='Esperando' json_obj['pagina']='Esperando' json_obj['rac']='Esperando' if pregunta!="": try: res= generate_text(pregunta, max_length=500) # print(">> RES:",res) inicio_json = res.find('{') fin_json = res.rfind('}') + 1 json_str = res[inicio_json:fin_json] json_obj = json.loads(json_str) # print("json_obj:",json_obj) return json_obj["respuesta"], json_obj["pagina"], json_obj["rac"] except: return json_obj["respuesta"], json_obj["pagina"], json_obj["rac"] return json_obj["respuesta"], json_obj["pagina"], json_obj["rac"] # Ejemplos de preguntas ejemplos = [ ["¿Cuál fue la fecha de publicación del RAC 1 en el Diario Oficial?"], ["¿Qué se incorpora a los Reglamentos Aeronáuticos de Colombia?"], ["Cuál fue la fecha de publicación del RAC 1 en el Diario Oficial?"], ] iface = gr.Interface( fn=mostrar_respuesta, inputs=gr.Textbox(label="Pregunta"), outputs=[ gr.Textbox(label="Respuesta", lines=2), gr.Textbox(label="Pagina", lines=1), gr.Textbox(label="Rac", lines=1) ], title="Consultas Normativa Aeronáutica Colombiana", description="Introduce tu pregunta sobre la normativa aeronáutica colombiana para obtener una respuesta.", examples=ejemplos, ) iface.queue(max_size=14).launch(debug=True) # share=True,debug=True