rodrigomasini commited on
Commit
ee40bdf
·
verified ·
1 Parent(s): d5a3d59

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +46 -152
helper.py CHANGED
@@ -1,200 +1,96 @@
1
  import os
2
- import subprocess
3
- from huggingface_hub import hf_hub_download, list_repo_files
4
  import gradio as gr
5
  from typing import Callable
6
  import base64
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
- from threading import Thread
10
- from transformers import TextIteratorStreamer
11
 
 
 
12
 
13
- def get_fn(model_path: str, **model_kwargs):
14
- """Create a chat function with the specified model."""
15
-
16
- # Initialize tokenizer and model
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
19
-
20
- tokenizer = AutoTokenizer.from_pretrained(model_path)
21
-
22
- # Simple flash-attention installation attempt
23
- try:
24
- subprocess.run(
25
- 'pip install flash-attn --no-build-isolation',
26
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
27
- shell=True,
28
- check=True
29
- )
30
- # Try loading model with flash attention
31
- model = AutoModelForCausalLM.from_pretrained(
32
- model_path,
33
- device_map="auto",
34
- quantization_config=quantization_config,
35
- attn_implementation="flash_attention_2",
36
- )
37
- except Exception as e:
38
- print(f"Flash Attention failed, falling back to default attention: {str(e)}")
39
- # Fallback to default attention implementation
40
- model = AutoModelForCausalLM.from_pretrained(
41
- model_path,
42
- torch_dtype= torch.bfloat16,
43
- device_map="auto",
44
- quantization_config=quantization_config,
45
- )
46
 
47
  def predict(
48
  message: str,
49
  history,
50
  system_prompt: str,
51
  temperature: float,
52
- max_new_tokens: int,
53
- top_k: int,
54
- repetition_penalty: float,
55
- top_p: float
56
  ):
57
  try:
58
- # Format conversation with ChatML format
59
- instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
 
60
  for user_msg, assistant_msg in history:
61
- instruction += f'<|im_start|>user\n{user_msg}\n<|im_end|>\n<|im_start|>assistant\n{assistant_msg}\n<|im_end|>\n'
62
- instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
63
-
64
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
65
- enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
66
- input_ids, attention_mask = enc.input_ids, enc.attention_mask
67
 
68
- # Truncate if needed
69
- if input_ids.shape[1] > 8192: # Using n_ctx from original
70
- input_ids = input_ids[:, -8192:]
71
- attention_mask = attention_mask[:, -8192:]
72
-
73
- generate_kwargs = dict(
74
- input_ids=input_ids.to(device),
75
- attention_mask=attention_mask.to(device),
76
- streamer=streamer,
77
- do_sample=True,
78
  temperature=temperature,
79
- max_new_tokens=max_new_tokens,
80
- top_k=top_k,
81
- repetition_penalty=repetition_penalty,
82
- top_p=top_p
 
83
  )
84
-
85
- t = Thread(target=model.generate, kwargs=generate_kwargs)
86
- t.start()
87
-
88
- response_text = ""
89
- for new_token in streamer:
90
- if new_token in ["<|endoftext|>", "<|im_end|>"]:
91
- break
92
- response_text += new_token
93
- yield response_text.strip()
94
-
95
- if not response_text.strip():
96
- yield "I apologize, but I was unable to generate a response. Please try again."
97
-
98
  except Exception as e:
99
  print(f"Error during generation: {str(e)}")
100
  yield f"An error occurred: {str(e)}"
101
 
102
  return predict
103
 
104
-
105
  def get_image_base64(url: str, ext: str):
106
  with open(url, "rb") as image_file:
107
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
108
  return "data:image/" + ext + ";base64," + encoded_string
109
 
110
-
111
  def handle_user_msg(message: str):
112
- if type(message) is str:
113
  return message
114
- elif type(message) is dict:
115
- if message["files"] is not None and len(message["files"]) > 0:
116
  ext = os.path.splitext(message["files"][-1])[1].strip(".")
117
  if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
118
  encoded_str = get_image_base64(message["files"][-1], ext)
119
  else:
120
  raise NotImplementedError(f"Not supported file type {ext}")
121
  content = [
122
- {"type": "text", "text": message["text"]},
123
- {
124
- "type": "image_url",
125
- "image_url": {
126
- "url": encoded_str,
127
- }
128
- },
129
- ]
130
  else:
131
- content = message["text"]
132
  return content
133
  else:
134
  raise NotImplementedError
135
 
136
-
137
- def get_interface_args(pipeline):
138
- if pipeline == "chat":
139
- inputs = None
140
- outputs = None
141
-
142
- def preprocess(message, history):
143
- messages = []
144
- files = None
145
- for user_msg, assistant_msg in history:
146
- if assistant_msg is not None:
147
- messages.append({"role": "user", "content": handle_user_msg(user_msg)})
148
- messages.append({"role": "assistant", "content": assistant_msg})
149
- else:
150
- files = user_msg
151
- if type(message) is str and files is not None:
152
- message = {"text":message, "files":files}
153
- elif type(message) is dict and files is not None:
154
- if message["files"] is None or len(message["files"]) == 0:
155
- message["files"] = files
156
- messages.append({"role": "user", "content": handle_user_msg(message)})
157
- return {"messages": messages}
158
-
159
- postprocess = lambda x: x
160
- else:
161
- # Add other pipeline types when they will be needed
162
- raise ValueError(f"Unsupported pipeline type: {pipeline}")
163
- return inputs, outputs, preprocess, postprocess
164
-
165
-
166
- def get_pipeline(model_name):
167
- # Determine the pipeline type based on the model name
168
- # For simplicity, assuming all models are chat models at the moment
169
- return "chat"
170
-
171
-
172
  def get_model_path(name: str = None, model_path: str = None) -> str:
173
- """Get the local path to the model."""
174
  if model_path:
175
  return model_path
176
-
177
  if name:
178
- if "/" in name:
179
- return name # Return HF model ID directly
180
- else:
181
- # You could maintain a mapping of friendly names to HF model IDs
182
- model_mapping = {
183
- # Add any default model mappings here
184
- "example-model": "organization/model-name"
185
- }
186
- if name not in model_mapping:
187
- raise ValueError(f"Unknown model name: {name}")
188
- return model_mapping[name]
189
-
190
  raise ValueError("Either name or model_path must be provided")
191
 
192
-
193
  def registry(name: str = None, model_path: str = None, **kwargs):
194
- """Create a Gradio Interface with similar styling and parameters."""
195
-
196
- model_path = get_model_path(name, model_path)
197
- fn = get_fn(model_path, **kwargs)
198
 
199
  interface = gr.ChatInterface(
200
  fn=fn,
@@ -206,10 +102,8 @@ def registry(name: str = None, model_path: str = None, **kwargs):
206
  ),
207
  gr.Slider(0, 1, 0.7, label="Temperature"),
208
  gr.Slider(128, 4096, 1024, label="Max new tokens"),
209
- gr.Slider(1, 80, 40, label="Top K sampling"),
210
- gr.Slider(0, 2, 1.1, label="Repetition penalty"),
211
  gr.Slider(0, 1, 0.95, label="Top P sampling"),
212
  ],
213
  )
214
-
215
- return interface
 
1
  import os
 
 
2
  import gradio as gr
3
  from typing import Callable
4
  import base64
5
+ from openai import OpenAI
 
 
 
6
 
7
+ def get_fn(model_name: str, **model_kwargs):
8
+ """Create a chat function that uses the OpenAI-compatible endpoint."""
9
 
10
+ OPENAI_API_KEY = "-"
11
+ client = OpenAI(
12
+ base_url=" http://192.222.58.60:8000/v1",
13
+ api_key="tela",
14
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def predict(
17
  message: str,
18
  history,
19
  system_prompt: str,
20
  temperature: float,
21
+ top_p: float,
22
+ max_tokens: int,
23
+
 
24
  ):
25
  try:
26
+ messages = []
27
+ if system_prompt:
28
+ messages.append({"role": "system", "content": system_prompt})
29
  for user_msg, assistant_msg in history:
30
+ messages.append({"role": "user", "content": user_msg})
31
+ messages.append({"role": "assistant", "content": assistant_msg})
32
+ messages.append({"role": "user", "content": message})
 
 
 
33
 
34
+ response = openai.chat.completions.create(
35
+ model=model_name,
36
+ messages=messages,
 
 
 
 
 
 
 
37
  temperature=temperature,
38
+ top_p=top_p,
39
+ max_tokens=max_new_tokens,
40
+ n=1,
41
+ stream=True,
42
+ response_format={"type": "text"},
43
  )
44
+ assistant_message = response.choices[0].message.content
45
+ yield assistant_message.strip()
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
  print(f"Error during generation: {str(e)}")
48
  yield f"An error occurred: {str(e)}"
49
 
50
  return predict
51
 
 
52
  def get_image_base64(url: str, ext: str):
53
  with open(url, "rb") as image_file:
54
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
55
  return "data:image/" + ext + ";base64," + encoded_string
56
 
 
57
  def handle_user_msg(message: str):
58
+ if isinstance(message, str):
59
  return message
60
+ elif isinstance(message, dict):
61
+ if message.get("files"):
62
  ext = os.path.splitext(message["files"][-1])[1].strip(".")
63
  if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
64
  encoded_str = get_image_base64(message["files"][-1], ext)
65
  else:
66
  raise NotImplementedError(f"Not supported file type {ext}")
67
  content = [
68
+ {"type": "text", "text": message.get("text", "")},
69
+ {
70
+ "type": "image_url",
71
+ "image_url": {
72
+ "url": encoded_str,
73
+ }
74
+ },
75
+ ]
76
  else:
77
+ content = message.get("text", "")
78
  return content
79
  else:
80
  raise NotImplementedError
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def get_model_path(name: str = None, model_path: str = None) -> str:
83
+ """Get the model name to use with the endpoint."""
84
  if model_path:
85
  return model_path
 
86
  if name:
87
+ return name
 
 
 
 
 
 
 
 
 
 
 
88
  raise ValueError("Either name or model_path must be provided")
89
 
 
90
  def registry(name: str = None, model_path: str = None, **kwargs):
91
+ """Create a Gradio ChatInterface."""
92
+ model_name = get_model_path(name, model_path)
93
+ fn = get_fn(model_name, **kwargs)
 
94
 
95
  interface = gr.ChatInterface(
96
  fn=fn,
 
102
  ),
103
  gr.Slider(0, 1, 0.7, label="Temperature"),
104
  gr.Slider(128, 4096, 1024, label="Max new tokens"),
 
 
105
  gr.Slider(0, 1, 0.95, label="Top P sampling"),
106
  ],
107
  )
108
+
109
+ return interface