DHEIVER commited on
Commit
55cb274
·
verified ·
1 Parent(s): 55bc620

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -26,7 +26,7 @@ list_llm = [
26
  "mosaicml/mpt-7b-instruct"
27
  ]
28
 
29
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
30
 
31
  # Função para carregar documentos PDF
32
  def load_doc(list_file_path, chunk_size, chunk_overlap):
@@ -56,12 +56,21 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
56
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
57
 
58
  progress(0.4, desc="Inicializando pipeline...")
 
 
 
 
 
 
 
 
 
59
  pipeline_obj = pipeline(
60
- "text-generation",
61
  model=llm_model,
62
  tokenizer=tokenizer,
63
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
64
- device=0 if torch.cuda.is_available() else -1,
65
  max_new_tokens=max_tokens,
66
  do_sample=True,
67
  top_k=top_k,
@@ -87,8 +96,8 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
87
  # Interface Gradio
88
  def demo():
89
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
- vector_db = gr.State(None) # Inicializa com None
91
- qa_chain = gr.State(None) # Inicializa com None
92
 
93
  gr.Markdown("## 🤖 Chatbot para PDFs com Modelos Gratuitos")
94
 
@@ -102,7 +111,7 @@ def demo():
102
  process_status = gr.Textbox(label="Status do Processamento", interactive=False)
103
 
104
  with gr.Tab("🧠 Modelo"):
105
- model_selector = gr.Dropdown(list_llm_simple, label="Selecione o Modelo", value=list_llm_simple[0])
106
  temperature = gr.Slider(0, 1, value=0.7, label="Criatividade")
107
  load_model_btn = gr.Button("Carregar Modelo")
108
  model_status = gr.Textbox(label="Status do Modelo", interactive=False)
@@ -114,10 +123,13 @@ def demo():
114
 
115
  # Eventos
116
  def process_documents(files, cs, co):
117
- file_paths = [f.name for f in files]
118
- splits = load_doc(file_paths, cs, co)
119
- db = create_db(splits, "docs")
120
- return db, "Documentos processados!"
 
 
 
121
 
122
  process_btn.click(
123
  process_documents,
@@ -126,10 +138,15 @@ def demo():
126
  )
127
 
128
  def load_model(model, temp, vector_db_state):
129
- if vector_db_state is None:
130
- return None, "Por favor, processe os documentos primeiro."
131
- qa = initialize_llmchain(list_llm[list_llm_simple.index(model)], temp, 512, 3, vector_db_state)
132
- return qa, "Modelo carregado!"
 
 
 
 
 
133
 
134
  load_model_btn.click(
135
  load_model,
@@ -138,17 +155,20 @@ def demo():
138
  )
139
 
140
  def respond(message, chat_history):
141
- if qa_chain.value is None:
142
- return "Por favor, carregue um modelo primeiro.", chat_history
143
-
144
- result = qa_chain.value({"question": message, "chat_history": chat_history})
145
- response = result["answer"]
146
-
147
- sources = "\n".join([f"📄 Página {doc.metadata['page']+1}: {doc.page_content[:50]}..."
148
- for doc in result.get("source_documents", [])[:2]])
149
 
150
- chat_history.append((message, f"{response}\n\n🔍 Fontes:\n{sources}"))
151
- return "", chat_history
 
 
 
 
 
 
 
 
 
152
 
153
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
154
  clear_btn.click(lambda: [], outputs=[chatbot])
 
26
  "mosaicml/mpt-7b-instruct"
27
  ]
28
 
29
+ list_llm_simple = [name.split("/")[-1] for name in list_llm]
30
 
31
  # Função para carregar documentos PDF
32
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
56
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
57
 
58
  progress(0.4, desc="Inicializando pipeline...")
59
+
60
+ # Define a tarefa correta para cada modelo
61
+ task = "text2text-generation" if "flan-t5" in llm_model.lower() else "text-generation"
62
+
63
+ # Configuração específica para dispositivos
64
+ device = 0 if torch.cuda.is_available() else -1
65
+ if "phi-2" in llm_model.lower() and device == 0:
66
+ device = "cuda"
67
+
68
  pipeline_obj = pipeline(
69
+ task,
70
  model=llm_model,
71
  tokenizer=tokenizer,
72
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
73
+ device=device,
74
  max_new_tokens=max_tokens,
75
  do_sample=True,
76
  top_k=top_k,
 
96
  # Interface Gradio
97
  def demo():
98
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
99
+ vector_db = gr.State(None)
100
+ qa_chain = gr.State(None)
101
 
102
  gr.Markdown("## 🤖 Chatbot para PDFs com Modelos Gratuitos")
103
 
 
111
  process_status = gr.Textbox(label="Status do Processamento", interactive=False)
112
 
113
  with gr.Tab("🧠 Modelo"):
114
+ model_selector = gr.Dropdown(list_llm_simple, label="Selecione o Modelo", value=list_llm_simple[1])
115
  temperature = gr.Slider(0, 1, value=0.7, label="Criatividade")
116
  load_model_btn = gr.Button("Carregar Modelo")
117
  model_status = gr.Textbox(label="Status do Modelo", interactive=False)
 
123
 
124
  # Eventos
125
  def process_documents(files, cs, co):
126
+ try:
127
+ file_paths = [f.name for f in files]
128
+ splits = load_doc(file_paths, cs, co)
129
+ db = create_db(splits, "docs")
130
+ return db, "Documentos processados!"
131
+ except Exception as e:
132
+ return None, f"Erro: {str(e)}"
133
 
134
  process_btn.click(
135
  process_documents,
 
138
  )
139
 
140
  def load_model(model, temp, vector_db_state):
141
+ try:
142
+ if vector_db_state is None:
143
+ raise ValueError("Processe os documentos primeiro.")
144
+
145
+ model_name = list_llm[list_llm_simple.index(model)]
146
+ qa = initialize_llmchain(model_name, temp, 512, 3, vector_db_state)
147
+ return qa, "Modelo carregado!"
148
+ except Exception as e:
149
+ return None, f"Erro: {str(e)}"
150
 
151
  load_model_btn.click(
152
  load_model,
 
155
  )
156
 
157
  def respond(message, chat_history):
158
+ if not qa_chain.value:
159
+ return "Erro: Modelo não carregado ou documentos não processados!", chat_history
 
 
 
 
 
 
160
 
161
+ try:
162
+ result = qa_chain.value({"question": message, "chat_history": chat_history})
163
+ response = result["answer"]
164
+
165
+ sources = "\n".join([f"📄 Página {doc.metadata['page']+1}: {doc.page_content[:50]}..."
166
+ for doc in result.get("source_documents", [])[:2]])
167
+
168
+ chat_history.append((message, f"{response}\n\n🔍 Fontes:\n{sources}"))
169
+ return "", chat_history
170
+ except Exception as e:
171
+ return f"Erro na geração: {str(e)}", chat_history
172
 
173
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
174
  clear_btn.click(lambda: [], outputs=[chatbot])