rodrigomasini commited on
Commit
612a10c
·
verified ·
1 Parent(s): 6bf705e

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +18 -49
helper.py CHANGED
@@ -11,16 +11,14 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
11
  try:
12
  client = OpenAI(
13
  base_url="http://192.222.58.60:8000/v1",
14
- api_key="tela",
15
  )
16
  except Exception as e:
17
  print(f"The API or base URL were not defined: {str(e)}")
18
- raise e # It's better to raise the exception to prevent the app from running without a client
19
 
20
  def predict(
21
- message: str,
22
- history: list,
23
- system_prompt: str,
24
  temperature: float,
25
  max_tokens: int,
26
  top_k: int,
@@ -28,20 +26,6 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
28
  top_p: float
29
  ) -> Generator[str, None, None]:
30
  try:
31
- # Initialize the messages list with the system prompt
32
- messages = [
33
- {"role": "system", "content": system_prompt}
34
- ]
35
-
36
- # Append the conversation history
37
- for user_msg, assistant_msg in history:
38
- messages.append({"role": "user", "content": user_msg})
39
- if assistant_msg:
40
- messages.append({"role": "assistant", "content": assistant_msg})
41
-
42
- # Append the latest user message
43
- messages.append({"role": "user", "content": message})
44
-
45
  # Call the OpenAI API with the formatted messages
46
  response = client.chat.completions.create(
47
  model=model_name,
@@ -55,7 +39,7 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
55
  # Ensure response_format is set correctly; typically it's a string like 'text'
56
  response_format="text",
57
  )
58
-
59
  response_text = ""
60
  # Iterate over the streaming response
61
  for chunk in response:
@@ -68,46 +52,33 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
68
 
69
  if not response_text.strip():
70
  yield "I apologize, but I was unable to generate a response. Please try again."
71
-
72
  except Exception as e:
73
  print(f"Error during generation: {str(e)}")
74
  yield f"An error occurred: {str(e)}"
75
-
76
- return predict
77
-
78
 
 
79
 
80
- def get_image_base64(url: str, ext: str):
81
  with open(url, "rb") as image_file:
82
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
83
- return "data:image/" + ext + ";base64," + encoded_string
84
-
85
 
86
- def handle_user_msg(message: str):
87
- if type(message) is str:
88
  return message
89
- elif type(message) is dict:
90
- if message["files"] is not None and len(message["files"]) > 0:
91
- ext = os.path.splitext(message["files"][-1])[1].strip(".")
92
- if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
93
  encoded_str = get_image_base64(message["files"][-1], ext)
 
94
  else:
95
- raise NotImplementedError(f"Not supported file type {ext}")
96
- content = [
97
- {"type": "text", "text": message["text"]},
98
- {
99
- "type": "image_url",
100
- "image_url": {
101
- "url": encoded_str,
102
- }
103
- },
104
- ]
105
  else:
106
- content = message["text"]
107
- return content
108
  else:
109
- raise NotImplementedError
110
-
111
 
112
  def get_interface_args(pipeline: str):
113
  if pipeline == "chat":
@@ -138,7 +109,6 @@ def get_interface_args(pipeline: str):
138
  raise ValueError(f"Unsupported pipeline type: {pipeline}")
139
  return inputs, outputs, preprocess, postprocess
140
 
141
-
142
  def registry(name: str = None, **kwargs) -> gr.ChatInterface:
143
  """Create a Gradio Interface with similar styling and parameters."""
144
 
@@ -187,7 +157,6 @@ def registry(name: str = None, **kwargs) -> gr.ChatInterface:
187
  gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
188
  gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
189
  ],
190
- # Optionally, you can customize other ChatInterface parameters here
191
  )
192
 
193
  return interface
 
11
  try:
12
  client = OpenAI(
13
  base_url="http://192.222.58.60:8000/v1",
14
+ api_key="tela", # Replace with your actual API key or use environment variables
15
  )
16
  except Exception as e:
17
  print(f"The API or base URL were not defined: {str(e)}")
18
+ raise e # Prevent the app from running without a client
19
 
20
  def predict(
21
+ messages: list, # Preprocessed messages from preprocess function
 
 
22
  temperature: float,
23
  max_tokens: int,
24
  top_k: int,
 
26
  top_p: float
27
  ) -> Generator[str, None, None]:
28
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Call the OpenAI API with the formatted messages
30
  response = client.chat.completions.create(
31
  model=model_name,
 
39
  # Ensure response_format is set correctly; typically it's a string like 'text'
40
  response_format="text",
41
  )
42
+
43
  response_text = ""
44
  # Iterate over the streaming response
45
  for chunk in response:
 
52
 
53
  if not response_text.strip():
54
  yield "I apologize, but I was unable to generate a response. Please try again."
55
+
56
  except Exception as e:
57
  print(f"Error during generation: {str(e)}")
58
  yield f"An error occurred: {str(e)}"
 
 
 
59
 
60
+ return predict
61
 
62
+ def get_image_base64(url: str, ext: str) -> str:
63
  with open(url, "rb") as image_file:
64
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
65
+ return f"data:image/{ext};base64,{encoded_string}"
 
66
 
67
+ def handle_user_msg(message: str) -> str:
68
+ if isinstance(message, str):
69
  return message
70
+ elif isinstance(message, dict):
71
+ if message.get("files"):
72
+ ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
73
+ if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
74
  encoded_str = get_image_base64(message["files"][-1], ext)
75
+ return f"{message.get('text', '')}\n![Image]({encoded_str})"
76
  else:
77
+ raise NotImplementedError(f"Unsupported file type: {ext}")
 
 
 
 
 
 
 
 
 
78
  else:
79
+ return message.get("text", "")
 
80
  else:
81
+ raise NotImplementedError("Unsupported message type")
 
82
 
83
  def get_interface_args(pipeline: str):
84
  if pipeline == "chat":
 
109
  raise ValueError(f"Unsupported pipeline type: {pipeline}")
110
  return inputs, outputs, preprocess, postprocess
111
 
 
112
  def registry(name: str = None, **kwargs) -> gr.ChatInterface:
113
  """Create a Gradio Interface with similar styling and parameters."""
114
 
 
157
  gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
158
  gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
159
  ],
 
160
  )
161
 
162
  return interface