mateoluksenberg commited on
Commit
1f1f572
·
verified ·
1 Parent(s): c8f3971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -183
app.py CHANGED
@@ -1,60 +1,32 @@
1
  import torch
2
  from PIL import Image
3
- import gradio as gr
4
- import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
- from threading import Thread
8
-
9
  import pymupdf
10
  import docx
11
  from pptx import Presentation
12
-
13
  from fastapi import FastAPI, File, UploadFile, HTTPException
14
- from fastapi.responses import HTMLResponse
15
 
16
  app = FastAPI()
17
 
18
- @app.post("/test/")
19
- async def test_endpoint(message: dict):
20
- if "text" not in message:
21
- raise HTTPException(status_code=400, detail="Missing 'text' in request body")
22
-
23
- response = {"message": f"Received your message: {message['text']}"}
24
- return response
25
-
26
-
27
  MODEL_LIST = ["nikravan/glm-4vq"]
28
-
29
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
30
  MODEL_ID = MODEL_LIST[0]
31
  MODEL_NAME = "GLM-4vq"
32
 
33
- TITLE = "<h1>AI CHAT DOCS</h1>"
34
-
35
- DESCRIPTION = f"""
36
- <center>
37
- <p>
38
- <br>
39
- USANDO MODELO: <a href="https://hf.co/nikravan/glm-4vq">{MODEL_NAME}</a>
40
- </center>"""
41
-
42
- CSS = """
43
- h1 {
44
- text-align: center;
45
- display: block;
46
- }
47
- """
48
-
49
-
50
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
51
-
52
-
 
 
 
 
53
 
54
  def extract_text(path):
55
  return open(path, 'r').read()
56
 
57
-
58
  def extract_pdf(path):
59
  doc = pymupdf.open(path)
60
  text = ""
@@ -62,15 +34,10 @@ def extract_pdf(path):
62
  text += page.get_text()
63
  return text
64
 
65
-
66
  def extract_docx(path):
67
  doc = docx.Document(path)
68
- data = []
69
- for paragraph in doc.paragraphs:
70
- data.append(paragraph.text)
71
- content = '\n\n'.join(data)
72
- return content
73
-
74
 
75
  def extract_pptx(path):
76
  prs = Presentation(path)
@@ -81,49 +48,44 @@ def extract_pptx(path):
81
  text += shape.text + "\n"
82
  return text
83
 
84
-
85
  def mode_load(path):
86
- choice = ""
87
- file_type = path.split(".")[-1]
88
- print(file_type)
89
- if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
90
- if file_type.endswith("pdf"):
91
  content = extract_pdf(path)
92
- elif file_type.endswith("docx"):
93
  content = extract_docx(path)
94
- elif file_type.endswith("pptx"):
95
  content = extract_pptx(path)
96
  else:
97
  content = extract_text(path)
98
- choice = "doc"
99
- print(content[:100])
100
- return choice, content[:5000]
101
-
102
-
103
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
104
  content = Image.open(path).convert('RGB')
105
- choice = "image"
106
- return choice, content
107
-
108
  else:
109
- raise gr.Error("Oops, unsupported files.")
110
-
111
 
112
- @spaces.GPU()
113
- def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
 
 
114
 
115
- model = AutoModelForCausalLM.from_pretrained(
116
- MODEL_ID,
117
- torch_dtype=torch.bfloat16,
118
- low_cpu_mem_usage=True,
119
- trust_remote_code=True
120
- )
121
-
122
- print(f'message is - {message}')
123
- print(f'history is - {history}')
 
 
 
 
124
  conversation = []
125
- prompt_files = []
126
- if message["files"]:
127
  choice, contents = mode_load(message["files"][-1])
128
  if choice == "image":
129
  conversation.append({"role": "user", "image": contents, "content": message['text']})
@@ -132,35 +94,26 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
132
  conversation.append({"role": "user", "content": format_msg})
133
  else:
134
  if len(history) == 0:
135
- # raise gr.Error("Please upload an image first.")
136
- contents = None
137
  conversation.append({"role": "user", "content": message['text']})
138
  else:
139
- # image = Image.open(history[0][0][0])
140
  for prompt, answer in history:
141
  if answer is None:
142
- prompt_files.append(prompt[0])
143
  conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
144
  else:
145
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
146
- if len(prompt_files) > 0:
147
- choice, contents = mode_load(prompt_files[-1])
148
- else:
149
- choice = ""
150
- conversation.append({"role": "user", "image": "", "content": message['text']})
151
-
152
-
153
- if choice == "image":
154
- conversation.append({"role": "user", "image": contents, "content": message['text']})
155
- elif choice == "doc":
156
- format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
157
- conversation.append({"role": "user", "content": format_msg})
158
- print(f"Conversation is -\n{conversation}")
159
 
160
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
161
- return_tensors="pt", return_dict=True).to(model.device)
162
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
163
-
164
  generate_kwargs = dict(
165
  max_length=max_length,
166
  streamer=streamer,
@@ -168,97 +121,11 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
168
  top_p=top_p,
169
  top_k=top_k,
170
  temperature=temperature,
171
- repetition_penalty=penalty,
172
- eos_token_id=[151329, 151336, 151338],
173
  )
174
- gen_kwargs = {**input_ids, **generate_kwargs}
175
-
176
  with torch.no_grad():
177
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
178
- thread.start()
179
  buffer = ""
180
  for new_text in streamer:
181
  buffer += new_text
182
- yield buffer
183
-
184
-
185
- chatbot = gr.Chatbot(
186
- #rtl=True,
187
- )
188
- chat_input = gr.MultimodalTextbox(
189
- interactive=True,
190
- placeholder="Enter message or upload a file ...",
191
- show_label=False,
192
- #rtl=True,
193
-
194
-
195
-
196
- )
197
-
198
- EXAMPLES = [
199
- [{"text": "Resumir Documento"}],
200
- [{"text": "Explicar la Imagen"}],
201
- [{"text": "¿De qué es la foto?", "files": ["perro.jpg"]}],
202
- [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
203
- ]
204
-
205
- with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
206
- gr.HTML(TITLE)
207
- gr.HTML(DESCRIPTION)
208
- gr.ChatInterface(
209
- fn=stream_chat,
210
- multimodal=True,
211
-
212
-
213
- textbox=chat_input,
214
- chatbot=chatbot,
215
- fill_height=True,
216
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
217
- additional_inputs=[
218
- gr.Slider(
219
- minimum=0,
220
- maximum=1,
221
- step=0.1,
222
- value=0.8,
223
- label="Temperature",
224
- render=False,
225
- ),
226
- gr.Slider(
227
- minimum=1024,
228
- maximum=8192,
229
- step=1,
230
- value=4096,
231
- label="Max Length",
232
- render=False,
233
- ),
234
- gr.Slider(
235
- minimum=0.0,
236
- maximum=1.0,
237
- step=0.1,
238
- value=1.0,
239
- label="top_p",
240
- render=False,
241
- ),
242
- gr.Slider(
243
- minimum=1,
244
- maximum=20,
245
- step=1,
246
- value=10,
247
- label="top_k",
248
- render=False,
249
- ),
250
- gr.Slider(
251
- minimum=0.0,
252
- maximum=2.0,
253
- step=0.1,
254
- value=1.0,
255
- label="Repetition penalty",
256
- render=False,
257
- ),
258
- ],
259
- ),
260
- gr.Examples(EXAMPLES, [chat_input])
261
-
262
- if __name__ == "__main__":
263
-
264
- demo.queue(api_open=False).launch(show_api=False, share=False, )#server_name="0.0.0.0", )
 
1
  import torch
2
  from PIL import Image
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import os
 
 
5
  import pymupdf
6
  import docx
7
  from pptx import Presentation
 
8
  from fastapi import FastAPI, File, UploadFile, HTTPException
9
+ from typing import List, Dict
10
 
11
  app = FastAPI()
12
 
13
+ # Model and tokenizer initialization
 
 
 
 
 
 
 
 
14
  MODEL_LIST = ["nikravan/glm-4vq"]
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  MODEL_ID = MODEL_LIST[0]
17
  MODEL_NAME = "GLM-4vq"
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_ID,
22
+ torch_dtype=torch.bfloat16,
23
+ low_cpu_mem_usage=True,
24
+ trust_remote_code=True
25
+ )
26
 
27
  def extract_text(path):
28
  return open(path, 'r').read()
29
 
 
30
  def extract_pdf(path):
31
  doc = pymupdf.open(path)
32
  text = ""
 
34
  text += page.get_text()
35
  return text
36
 
 
37
  def extract_docx(path):
38
  doc = docx.Document(path)
39
+ data = [paragraph.text for paragraph in doc.paragraphs]
40
+ return '\n\n'.join(data)
 
 
 
 
41
 
42
  def extract_pptx(path):
43
  prs = Presentation(path)
 
48
  text += shape.text + "\n"
49
  return text
50
 
 
51
  def mode_load(path):
52
+ file_type = path.split(".")[-1].lower()
53
+ if file_type in ["pdf", "txt", "py", "docx", "pptx"]:
54
+ if file_type == "pdf":
 
 
55
  content = extract_pdf(path)
56
+ elif file_type == "docx":
57
  content = extract_docx(path)
58
+ elif file_type == "pptx":
59
  content = extract_pptx(path)
60
  else:
61
  content = extract_text(path)
62
+ return "doc", content[:5000]
 
 
 
 
63
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
64
  content = Image.open(path).convert('RGB')
65
+ return "image", content
 
 
66
  else:
67
+ raise HTTPException(status_code=400, detail="Unsupported file type")
 
68
 
69
+ @app.post("/test/")
70
+ async def test_endpoint(message: Dict[str, str]):
71
+ if "text" not in message:
72
+ raise HTTPException(status_code=400, detail="Missing 'text' in request body")
73
 
74
+ response = {"message": f"Received your message: {message['text']}"}
75
+ return response
76
+
77
+ @app.post("/chat/")
78
+ async def chat_endpoint(
79
+ message: Dict[str, str],
80
+ history: List[Dict[str, str]] = [],
81
+ temperature: float = 0.8,
82
+ max_length: int = 4096,
83
+ top_p: float = 1.0,
84
+ top_k: int = 10,
85
+ penalty: float = 1.0
86
+ ):
87
  conversation = []
88
+ if "files" in message and message["files"]:
 
89
  choice, contents = mode_load(message["files"][-1])
90
  if choice == "image":
91
  conversation.append({"role": "user", "image": contents, "content": message['text']})
 
94
  conversation.append({"role": "user", "content": format_msg})
95
  else:
96
  if len(history) == 0:
 
 
97
  conversation.append({"role": "user", "content": message['text']})
98
  else:
 
99
  for prompt, answer in history:
100
  if answer is None:
 
101
  conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
102
  else:
103
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
104
+ if len(history) > 0:
105
+ choice, contents = mode_load(history[-1][0])
106
+ if choice == "image":
107
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
108
+ elif choice == "doc":
109
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
110
+ conversation.append({"role": "user", "content": format_msg})
111
+ else:
112
+ conversation.append({"role": "user", "content": message['text']})
 
 
 
 
113
 
114
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
 
115
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
116
+
117
  generate_kwargs = dict(
118
  max_length=max_length,
119
  streamer=streamer,
 
121
  top_p=top_p,
122
  top_k=top_k,
123
  temperature=temperature,
124
+ repetition_penalty=penalty
 
125
  )
126
+
 
127
  with torch.no_grad():
 
 
128
  buffer = ""
129
  for new_text in streamer:
130
  buffer += new_text
131
+ return {"response": buffer}