Spaces:
Sleeping
Sleeping
Update helper.py
Browse files
helper.py
CHANGED
@@ -4,21 +4,29 @@ from typing import Callable, Generator
|
|
4 |
import base64
|
5 |
from openai import OpenAI
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def get_fn(model_name: str, **model_kwargs) -> Callable:
|
8 |
"""Create a chat function with the specified model."""
|
9 |
|
10 |
# Instantiate an OpenAI client for a custom endpoint
|
11 |
try:
|
12 |
client = OpenAI(
|
13 |
-
base_url=
|
14 |
-
api_key=
|
15 |
)
|
16 |
except Exception as e:
|
17 |
print(f"The API or base URL were not defined: {str(e)}")
|
18 |
-
raise e
|
19 |
|
20 |
def predict(
|
21 |
-
messages: list,
|
22 |
temperature: float,
|
23 |
max_tokens: int,
|
24 |
top_p: float
|
@@ -32,12 +40,10 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
|
|
32 |
max_tokens=max_tokens,
|
33 |
top_p=top_p,
|
34 |
stream=True,
|
35 |
-
# Ensure response_format is set correctly; typically it's a string like 'text'
|
36 |
response_format={"type": "text"},
|
37 |
)
|
38 |
|
39 |
response_text = ""
|
40 |
-
# Iterate over the streaming response
|
41 |
for chunk in response:
|
42 |
if len(chunk.choices[0].delta.content) > 0:
|
43 |
content = chunk.choices[0].delta.content
|
@@ -100,28 +106,27 @@ def get_interface_args(pipeline: str):
|
|
100 |
postprocess = lambda x: x # No additional postprocessing needed
|
101 |
|
102 |
else:
|
103 |
-
# Add other pipeline types when they are needed
|
104 |
raise ValueError(f"Unsupported pipeline type: {pipeline}")
|
105 |
return inputs, outputs, preprocess, postprocess
|
106 |
|
107 |
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
108 |
"""Create a Gradio Interface with similar styling and parameters."""
|
109 |
|
110 |
-
#
|
111 |
_, _, preprocess, postprocess = get_interface_args("chat")
|
112 |
|
113 |
-
#
|
114 |
predict_fn = get_fn(model_name=name, **kwargs)
|
115 |
|
116 |
-
#
|
117 |
def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
|
118 |
-
#
|
119 |
preprocessed = preprocess(message, history)
|
120 |
|
121 |
-
#
|
122 |
messages = preprocessed["messages"]
|
123 |
|
124 |
-
#
|
125 |
response_generator = predict_fn(
|
126 |
messages=messages,
|
127 |
temperature=temperature,
|
@@ -129,13 +134,13 @@ def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
|
129 |
top_p=top_p
|
130 |
)
|
131 |
|
132 |
-
#
|
133 |
response = ""
|
134 |
for partial_response in response_generator:
|
135 |
response = partial_response # Gradio will handle streaming
|
136 |
yield response
|
137 |
|
138 |
-
#
|
139 |
interface = gr.ChatInterface(
|
140 |
fn=wrapper,
|
141 |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
|
|
|
4 |
import base64
|
5 |
from openai import OpenAI
|
6 |
|
7 |
+
from dotenv import load_dotenv, find_dotenv
|
8 |
+
_ = load_dotenv(find_dotenv())
|
9 |
+
|
10 |
+
END_POINT = os.environ.get("ENDPOINT")
|
11 |
+
SECRET_KEY = os.environ.get("SECRETKEY")
|
12 |
+
USERS = os.environ.get("USERS")
|
13 |
+
PWD = os.environ.get("PWD")
|
14 |
+
|
15 |
def get_fn(model_name: str, **model_kwargs) -> Callable:
|
16 |
"""Create a chat function with the specified model."""
|
17 |
|
18 |
# Instantiate an OpenAI client for a custom endpoint
|
19 |
try:
|
20 |
client = OpenAI(
|
21 |
+
base_url=END_POINT,
|
22 |
+
api_key=SECRET_KEY,
|
23 |
)
|
24 |
except Exception as e:
|
25 |
print(f"The API or base URL were not defined: {str(e)}")
|
26 |
+
raise e
|
27 |
|
28 |
def predict(
|
29 |
+
messages: list,
|
30 |
temperature: float,
|
31 |
max_tokens: int,
|
32 |
top_p: float
|
|
|
40 |
max_tokens=max_tokens,
|
41 |
top_p=top_p,
|
42 |
stream=True,
|
|
|
43 |
response_format={"type": "text"},
|
44 |
)
|
45 |
|
46 |
response_text = ""
|
|
|
47 |
for chunk in response:
|
48 |
if len(chunk.choices[0].delta.content) > 0:
|
49 |
content = chunk.choices[0].delta.content
|
|
|
106 |
postprocess = lambda x: x # No additional postprocessing needed
|
107 |
|
108 |
else:
|
|
|
109 |
raise ValueError(f"Unsupported pipeline type: {pipeline}")
|
110 |
return inputs, outputs, preprocess, postprocess
|
111 |
|
112 |
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
113 |
"""Create a Gradio Interface with similar styling and parameters."""
|
114 |
|
115 |
+
# Retrieving preprocess and postprocess functions
|
116 |
_, _, preprocess, postprocess = get_interface_args("chat")
|
117 |
|
118 |
+
# Getting the predict function
|
119 |
predict_fn = get_fn(model_name=name, **kwargs)
|
120 |
|
121 |
+
# Defining a wrapper function that integrates preprocessing and postprocessing
|
122 |
def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
|
123 |
+
# Preprocessing the inputs
|
124 |
preprocessed = preprocess(message, history)
|
125 |
|
126 |
+
# Extracting the preprocessed messages
|
127 |
messages = preprocessed["messages"]
|
128 |
|
129 |
+
# Calling the predict function and generate the response
|
130 |
response_generator = predict_fn(
|
131 |
messages=messages,
|
132 |
temperature=temperature,
|
|
|
134 |
top_p=top_p
|
135 |
)
|
136 |
|
137 |
+
# Collecting the generated response
|
138 |
response = ""
|
139 |
for partial_response in response_generator:
|
140 |
response = partial_response # Gradio will handle streaming
|
141 |
yield response
|
142 |
|
143 |
+
# Creating the Gradio ChatInterface with the wrapper function
|
144 |
interface = gr.ChatInterface(
|
145 |
fn=wrapper,
|
146 |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
|