Spaces:
Sleeping
Sleeping
Update helper.py
Browse files
helper.py
CHANGED
@@ -11,16 +11,14 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
|
|
11 |
try:
|
12 |
client = OpenAI(
|
13 |
base_url="http://192.222.58.60:8000/v1",
|
14 |
-
api_key="tela",
|
15 |
)
|
16 |
except Exception as e:
|
17 |
print(f"The API or base URL were not defined: {str(e)}")
|
18 |
-
raise e #
|
19 |
|
20 |
def predict(
|
21 |
-
|
22 |
-
history: list,
|
23 |
-
system_prompt: str,
|
24 |
temperature: float,
|
25 |
max_tokens: int,
|
26 |
top_k: int,
|
@@ -28,20 +26,6 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
|
|
28 |
top_p: float
|
29 |
) -> Generator[str, None, None]:
|
30 |
try:
|
31 |
-
# Initialize the messages list with the system prompt
|
32 |
-
messages = [
|
33 |
-
{"role": "system", "content": system_prompt}
|
34 |
-
]
|
35 |
-
|
36 |
-
# Append the conversation history
|
37 |
-
for user_msg, assistant_msg in history:
|
38 |
-
messages.append({"role": "user", "content": user_msg})
|
39 |
-
if assistant_msg:
|
40 |
-
messages.append({"role": "assistant", "content": assistant_msg})
|
41 |
-
|
42 |
-
# Append the latest user message
|
43 |
-
messages.append({"role": "user", "content": message})
|
44 |
-
|
45 |
# Call the OpenAI API with the formatted messages
|
46 |
response = client.chat.completions.create(
|
47 |
model=model_name,
|
@@ -55,7 +39,7 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
|
|
55 |
# Ensure response_format is set correctly; typically it's a string like 'text'
|
56 |
response_format="text",
|
57 |
)
|
58 |
-
|
59 |
response_text = ""
|
60 |
# Iterate over the streaming response
|
61 |
for chunk in response:
|
@@ -68,46 +52,33 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
|
|
68 |
|
69 |
if not response_text.strip():
|
70 |
yield "I apologize, but I was unable to generate a response. Please try again."
|
71 |
-
|
72 |
except Exception as e:
|
73 |
print(f"Error during generation: {str(e)}")
|
74 |
yield f"An error occurred: {str(e)}"
|
75 |
-
|
76 |
-
return predict
|
77 |
-
|
78 |
|
|
|
79 |
|
80 |
-
def get_image_base64(url: str, ext: str):
|
81 |
with open(url, "rb") as image_file:
|
82 |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
83 |
-
return "data:image/
|
84 |
-
|
85 |
|
86 |
-
def handle_user_msg(message: str):
|
87 |
-
if
|
88 |
return message
|
89 |
-
elif
|
90 |
-
if message
|
91 |
-
ext = os.path.splitext(message["files"][-1])[1].strip(".")
|
92 |
-
if ext
|
93 |
encoded_str = get_image_base64(message["files"][-1], ext)
|
|
|
94 |
else:
|
95 |
-
raise NotImplementedError(f"
|
96 |
-
content = [
|
97 |
-
{"type": "text", "text": message["text"]},
|
98 |
-
{
|
99 |
-
"type": "image_url",
|
100 |
-
"image_url": {
|
101 |
-
"url": encoded_str,
|
102 |
-
}
|
103 |
-
},
|
104 |
-
]
|
105 |
else:
|
106 |
-
|
107 |
-
return content
|
108 |
else:
|
109 |
-
raise NotImplementedError
|
110 |
-
|
111 |
|
112 |
def get_interface_args(pipeline: str):
|
113 |
if pipeline == "chat":
|
@@ -138,7 +109,6 @@ def get_interface_args(pipeline: str):
|
|
138 |
raise ValueError(f"Unsupported pipeline type: {pipeline}")
|
139 |
return inputs, outputs, preprocess, postprocess
|
140 |
|
141 |
-
|
142 |
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
143 |
"""Create a Gradio Interface with similar styling and parameters."""
|
144 |
|
@@ -187,7 +157,6 @@ def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
|
187 |
gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
|
188 |
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
|
189 |
],
|
190 |
-
# Optionally, you can customize other ChatInterface parameters here
|
191 |
)
|
192 |
|
193 |
return interface
|
|
|
11 |
try:
|
12 |
client = OpenAI(
|
13 |
base_url="http://192.222.58.60:8000/v1",
|
14 |
+
api_key="tela", # Replace with your actual API key or use environment variables
|
15 |
)
|
16 |
except Exception as e:
|
17 |
print(f"The API or base URL were not defined: {str(e)}")
|
18 |
+
raise e # Prevent the app from running without a client
|
19 |
|
20 |
def predict(
|
21 |
+
messages: list, # Preprocessed messages from preprocess function
|
|
|
|
|
22 |
temperature: float,
|
23 |
max_tokens: int,
|
24 |
top_k: int,
|
|
|
26 |
top_p: float
|
27 |
) -> Generator[str, None, None]:
|
28 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# Call the OpenAI API with the formatted messages
|
30 |
response = client.chat.completions.create(
|
31 |
model=model_name,
|
|
|
39 |
# Ensure response_format is set correctly; typically it's a string like 'text'
|
40 |
response_format="text",
|
41 |
)
|
42 |
+
|
43 |
response_text = ""
|
44 |
# Iterate over the streaming response
|
45 |
for chunk in response:
|
|
|
52 |
|
53 |
if not response_text.strip():
|
54 |
yield "I apologize, but I was unable to generate a response. Please try again."
|
55 |
+
|
56 |
except Exception as e:
|
57 |
print(f"Error during generation: {str(e)}")
|
58 |
yield f"An error occurred: {str(e)}"
|
|
|
|
|
|
|
59 |
|
60 |
+
return predict
|
61 |
|
62 |
+
def get_image_base64(url: str, ext: str) -> str:
|
63 |
with open(url, "rb") as image_file:
|
64 |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
65 |
+
return f"data:image/{ext};base64,{encoded_string}"
|
|
|
66 |
|
67 |
+
def handle_user_msg(message: str) -> str:
|
68 |
+
if isinstance(message, str):
|
69 |
return message
|
70 |
+
elif isinstance(message, dict):
|
71 |
+
if message.get("files"):
|
72 |
+
ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
|
73 |
+
if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
|
74 |
encoded_str = get_image_base64(message["files"][-1], ext)
|
75 |
+
return f"{message.get('text', '')}\n"
|
76 |
else:
|
77 |
+
raise NotImplementedError(f"Unsupported file type: {ext}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
else:
|
79 |
+
return message.get("text", "")
|
|
|
80 |
else:
|
81 |
+
raise NotImplementedError("Unsupported message type")
|
|
|
82 |
|
83 |
def get_interface_args(pipeline: str):
|
84 |
if pipeline == "chat":
|
|
|
109 |
raise ValueError(f"Unsupported pipeline type: {pipeline}")
|
110 |
return inputs, outputs, preprocess, postprocess
|
111 |
|
|
|
112 |
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
113 |
"""Create a Gradio Interface with similar styling and parameters."""
|
114 |
|
|
|
157 |
gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
|
158 |
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
|
159 |
],
|
|
|
160 |
)
|
161 |
|
162 |
return interface
|