File size: 4,372 Bytes
3d04cf1
9a33833
98bd34e
3d04cf1
735d156
3d04cf1
 
 
 
 
7b1af97
 
 
 
 
 
 
3d04cf1
 
 
 
 
 
 
96a7392
3d04cf1
 
 
 
96a7392
 
 
 
 
 
 
 
 
 
 
 
 
 
3d04cf1
 
 
 
 
96a7392
3d04cf1
 
96a7392
 
 
 
3d04cf1
96a7392
3d04cf1
 
96a7392
3d04cf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8a1d3
3d04cf1
 
 
 
 
 
 
9d8a1d3
3d04cf1
 
 
 
 
9d8a1d3
7b1af97
 
 
 
 
 
 
 
 
9d8a1d3
3d04cf1
 
3e028aa
3d04cf1
 
3e028aa
3d04cf1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# app.py

import gradio as gr
from models import load_embedding_model, load_yi_coder_model
from pinecone_utils import connect_to_pinecone, vector_search  # Ahora deber铆a funcionar correctamente
from ui import build_interface
from config import SIMILARITY_THRESHOLD_DEFAULT, SYSTEM_PROMPT, MAX_LENGTH_DEFAULT
from decorators import gpu_decorator
import torch

########################

from utils import process_tags_chat

########################


# Cargar modelos
embedding_model = load_embedding_model()
tokenizer, yi_coder_model, yi_coder_device = load_yi_coder_model()

# Conectar a Pinecone
index = connect_to_pinecone()

# Funci贸n para generar c贸digo utilizando Yi-Coder
@gpu_decorator(duration=100)
def generate_code(system_prompt, user_prompt, max_length):
    device = yi_coder_device
    model = yi_coder_model
    tokenizer_ = tokenizer  # Ya lo tenemos cargado

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    
    # Aplicar la plantilla de chat y preparar el texto
    text = tokenizer_.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer_([text], return_tensors="pt").to(device)

    with torch.no_grad():
        generated_ids = model.generate(
            model_inputs.input_ids,
            max_new_tokens=max_length,
            eos_token_id=tokenizer_.eos_token_id  
        )

    # Extraer solo la parte generada
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer_.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response


# Funci贸n para combinar b煤squeda vectorial y Yi-Coder
@gpu_decorator(duration=100)
def combined_function(user_prompt, similarity_threshold, selected_option, system_prompt, max_length):
    if selected_option == "Solo B煤squeda Vectorial":
        # Realizar b煤squeda vectorial
        search_results = vector_search(user_prompt, embedding_model, index)
        if search_results:
            # Usar el primer resultado
            content = search_results[0]['content']
            return content, None
        else:
            return "No se encontraron resultados en Pinecone.", None
    elif selected_option == "Solo Yi-Coder":
        # Generar respuesta usando Yi-Coder
        yi_coder_response = generate_code(system_prompt, user_prompt, max_length)
        return yi_coder_response, None
    elif selected_option == "Ambos (basado en umbral de similitud)":
        # Realizar b煤squeda vectorial
        search_results = vector_search(user_prompt, embedding_model, index)
        if search_results:
            top_result = search_results[0]
            if top_result['score'] >= similarity_threshold:
                content = top_result['content']
                return content, None
            else:
                yi_coder_response = generate_code(system_prompt, user_prompt, max_length)
                return yi_coder_response, None
        else:
            yi_coder_response = generate_code(system_prompt, user_prompt, max_length)
            return yi_coder_response, None
    else:
        return "Opci贸n no v谩lida.", None

# Funciones para el procesamiento de entradas y actualizaci贸n de im谩genes
def process_input(message, history, selected_option, similarity_threshold, system_prompt, max_length):
    response, image = combined_function(message, similarity_threshold, selected_option, system_prompt, max_length)
    history.append((message, response))
    return history, history, image

def update_image(message, history):
    # Realizar b煤squeda vectorial
    search_results = vector_search(message, embedding_model, index)

    # Llamar a process_tags_chat para procesar los resultados y obtener la imagen
    full_response, image_url = process_tags_chat(search_results)

    return image_url if image_url else None


def send_preset_question(question, history, selected_option, similarity_threshold, system_prompt, max_length):
    return process_input(question, history, selected_option, similarity_threshold, system_prompt, max_length)

# Construir y lanzar la interfaz
demo = build_interface(process_input, send_preset_question, update_image)

if __name__ == "__main__":
    demo.launch()