rodrigomasini commited on
Commit
460a4a6
·
verified ·
1 Parent(s): 69e2d88

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +20 -15
helper.py CHANGED
@@ -4,21 +4,29 @@ from typing import Callable, Generator
4
  import base64
5
  from openai import OpenAI
6
 
 
 
 
 
 
 
 
 
7
  def get_fn(model_name: str, **model_kwargs) -> Callable:
8
  """Create a chat function with the specified model."""
9
 
10
  # Instantiate an OpenAI client for a custom endpoint
11
  try:
12
  client = OpenAI(
13
- base_url="http://192.222.58.60:8000/v1",
14
- api_key="tela", # Replace with your actual API key or use environment variables
15
  )
16
  except Exception as e:
17
  print(f"The API or base URL were not defined: {str(e)}")
18
- raise e # Prevent the app from running without a client
19
 
20
  def predict(
21
- messages: list, # Preprocessed messages from preprocess function
22
  temperature: float,
23
  max_tokens: int,
24
  top_p: float
@@ -32,12 +40,10 @@ def get_fn(model_name: str, **model_kwargs) -> Callable:
32
  max_tokens=max_tokens,
33
  top_p=top_p,
34
  stream=True,
35
- # Ensure response_format is set correctly; typically it's a string like 'text'
36
  response_format={"type": "text"},
37
  )
38
 
39
  response_text = ""
40
- # Iterate over the streaming response
41
  for chunk in response:
42
  if len(chunk.choices[0].delta.content) > 0:
43
  content = chunk.choices[0].delta.content
@@ -100,28 +106,27 @@ def get_interface_args(pipeline: str):
100
  postprocess = lambda x: x # No additional postprocessing needed
101
 
102
  else:
103
- # Add other pipeline types when they are needed
104
  raise ValueError(f"Unsupported pipeline type: {pipeline}")
105
  return inputs, outputs, preprocess, postprocess
106
 
107
  def registry(name: str = None, **kwargs) -> gr.ChatInterface:
108
  """Create a Gradio Interface with similar styling and parameters."""
109
 
110
- # Retrieve preprocess and postprocess functions
111
  _, _, preprocess, postprocess = get_interface_args("chat")
112
 
113
- # Get the predict function
114
  predict_fn = get_fn(model_name=name, **kwargs)
115
 
116
- # Define a wrapper function that integrates preprocessing and postprocessing
117
  def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
118
- # Preprocess the inputs
119
  preprocessed = preprocess(message, history)
120
 
121
- # Extract the preprocessed messages
122
  messages = preprocessed["messages"]
123
 
124
- # Call the predict function and generate the response
125
  response_generator = predict_fn(
126
  messages=messages,
127
  temperature=temperature,
@@ -129,13 +134,13 @@ def registry(name: str = None, **kwargs) -> gr.ChatInterface:
129
  top_p=top_p
130
  )
131
 
132
- # Collect the generated response
133
  response = ""
134
  for partial_response in response_generator:
135
  response = partial_response # Gradio will handle streaming
136
  yield response
137
 
138
- # Create the Gradio ChatInterface with the wrapper function
139
  interface = gr.ChatInterface(
140
  fn=wrapper,
141
  additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
 
4
  import base64
5
  from openai import OpenAI
6
 
7
+ from dotenv import load_dotenv, find_dotenv
8
+ _ = load_dotenv(find_dotenv())
9
+
10
+ END_POINT = os.environ.get("ENDPOINT")
11
+ SECRET_KEY = os.environ.get("SECRETKEY")
12
+ USERS = os.environ.get("USERS")
13
+ PWD = os.environ.get("PWD")
14
+
15
  def get_fn(model_name: str, **model_kwargs) -> Callable:
16
  """Create a chat function with the specified model."""
17
 
18
  # Instantiate an OpenAI client for a custom endpoint
19
  try:
20
  client = OpenAI(
21
+ base_url=END_POINT,
22
+ api_key=SECRET_KEY,
23
  )
24
  except Exception as e:
25
  print(f"The API or base URL were not defined: {str(e)}")
26
+ raise e
27
 
28
  def predict(
29
+ messages: list,
30
  temperature: float,
31
  max_tokens: int,
32
  top_p: float
 
40
  max_tokens=max_tokens,
41
  top_p=top_p,
42
  stream=True,
 
43
  response_format={"type": "text"},
44
  )
45
 
46
  response_text = ""
 
47
  for chunk in response:
48
  if len(chunk.choices[0].delta.content) > 0:
49
  content = chunk.choices[0].delta.content
 
106
  postprocess = lambda x: x # No additional postprocessing needed
107
 
108
  else:
 
109
  raise ValueError(f"Unsupported pipeline type: {pipeline}")
110
  return inputs, outputs, preprocess, postprocess
111
 
112
  def registry(name: str = None, **kwargs) -> gr.ChatInterface:
113
  """Create a Gradio Interface with similar styling and parameters."""
114
 
115
+ # Retrieving preprocess and postprocess functions
116
  _, _, preprocess, postprocess = get_interface_args("chat")
117
 
118
+ # Getting the predict function
119
  predict_fn = get_fn(model_name=name, **kwargs)
120
 
121
+ # Defining a wrapper function that integrates preprocessing and postprocessing
122
  def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
123
+ # Preprocessing the inputs
124
  preprocessed = preprocess(message, history)
125
 
126
+ # Extracting the preprocessed messages
127
  messages = preprocessed["messages"]
128
 
129
+ # Calling the predict function and generate the response
130
  response_generator = predict_fn(
131
  messages=messages,
132
  temperature=temperature,
 
134
  top_p=top_p
135
  )
136
 
137
+ # Collecting the generated response
138
  response = ""
139
  for partial_response in response_generator:
140
  response = partial_response # Gradio will handle streaming
141
  yield response
142
 
143
+ # Creating the Gradio ChatInterface with the wrapper function
144
  interface = gr.ChatInterface(
145
  fn=wrapper,
146
  additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),