Spaces:
Running
Running
Zekun Wu
commited on
Commit
·
b607f53
1
Parent(s):
35b059b
update
Browse files
app.py
CHANGED
@@ -16,20 +16,19 @@ class ContentFormatter:
|
|
16 |
return json.dumps(data)
|
17 |
|
18 |
class AzureAgent:
|
19 |
-
def __init__(self, api_key, azure_uri, deployment_name
|
20 |
self.azure_uri = azure_uri
|
21 |
self.headers = {
|
22 |
'Authorization': f"Bearer {api_key}",
|
23 |
'Content-Type': 'application/json'
|
24 |
}
|
25 |
self.deployment_name = deployment_name
|
26 |
-
self.api_version = api_version
|
27 |
self.chat_formatter = ContentFormatter
|
28 |
|
29 |
def invoke(self, text, **kwargs):
|
30 |
body = self.chat_formatter.chat_completions(text, {**kwargs})
|
31 |
conn = http.client.HTTPSConnection(self.azure_uri)
|
32 |
-
conn.request("POST", f'/
|
33 |
response = conn.getresponse()
|
34 |
data = response.read()
|
35 |
conn.close()
|
@@ -68,7 +67,7 @@ model_type = st.sidebar.radio("Select the type of agent", ('AzureAgent', 'GPTAge
|
|
68 |
api_key = st.sidebar.text_input("API Key", type="password")
|
69 |
endpoint_url = st.sidebar.text_input("Endpoint URL")
|
70 |
deployment_name = st.sidebar.text_input("Model Name")
|
71 |
-
api_version = st.sidebar.text_input("API Version", '
|
72 |
|
73 |
# Model invocation parameters
|
74 |
temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
|
@@ -86,7 +85,7 @@ if uploaded_file is not None:
|
|
86 |
# Process data button
|
87 |
if st.button('Process Data'):
|
88 |
if model_type == 'AzureAgent':
|
89 |
-
agent = AzureAgent(api_key, endpoint_url, deployment_name
|
90 |
else:
|
91 |
agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
|
92 |
|
|
|
16 |
return json.dumps(data)
|
17 |
|
18 |
class AzureAgent:
|
19 |
+
def __init__(self, api_key, azure_uri, deployment_name):
|
20 |
self.azure_uri = azure_uri
|
21 |
self.headers = {
|
22 |
'Authorization': f"Bearer {api_key}",
|
23 |
'Content-Type': 'application/json'
|
24 |
}
|
25 |
self.deployment_name = deployment_name
|
|
|
26 |
self.chat_formatter = ContentFormatter
|
27 |
|
28 |
def invoke(self, text, **kwargs):
|
29 |
body = self.chat_formatter.chat_completions(text, {**kwargs})
|
30 |
conn = http.client.HTTPSConnection(self.azure_uri)
|
31 |
+
conn.request("POST", f'/v1/chat/completions', body=body, headers=self.headers)
|
32 |
response = conn.getresponse()
|
33 |
data = response.read()
|
34 |
conn.close()
|
|
|
67 |
api_key = st.sidebar.text_input("API Key", type="password")
|
68 |
endpoint_url = st.sidebar.text_input("Endpoint URL")
|
69 |
deployment_name = st.sidebar.text_input("Model Name")
|
70 |
+
api_version = st.sidebar.text_input("API Version", '2024-02-15-preview') # Default API version
|
71 |
|
72 |
# Model invocation parameters
|
73 |
temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
|
|
|
85 |
# Process data button
|
86 |
if st.button('Process Data'):
|
87 |
if model_type == 'AzureAgent':
|
88 |
+
agent = AzureAgent(api_key, endpoint_url, deployment_name)
|
89 |
else:
|
90 |
agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
|
91 |
|