Jingni (Janet) Cai commited on
Commit
4c37e0b
2 Parent(s): 489081f cfb18ac

Merge pull request #4 from jcai0o0/api-only-product

Browse files
Files changed (2) hide show
  1. app.py +30 -57
  2. requirements.txt +0 -4
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- import torch
4
- from transformers import pipeline
5
  from prometheus_client import start_http_server, Counter, Summary
6
 
7
  from typing import Iterable
@@ -24,7 +24,7 @@ REQUEST_DURATION = Summary('app_request_duration_seconds', 'Time spent processin
24
  client = InferenceClient(model="mistralai/Mistral-Small-Instruct-2409",
25
  # token=HF_ACCESS
26
  )
27
- pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
28
 
29
  # Global flag to handle cancellation
30
  stop_inference = False
@@ -49,60 +49,33 @@ def respond(
49
  if history is None:
50
  history = []
51
 
52
- if use_local_model:
53
- # local inference
54
- messages = [{"role": "system", "content": system_message}]
55
- for val in history:
56
- if val[0]:
57
- messages.append({"role": "user", "content": val[0]})
58
- if val[1]:
59
- messages.append({"role": "assistant", "content": val[1]})
60
- messages.append({"role": "user", "content": message})
61
-
62
- response = ""
63
- for output in pipe(
64
- messages,
65
- max_new_tokens=max_tokens,
66
- temperature=temperature,
67
- do_sample=True,
68
- top_p=top_p,
69
- ):
70
- if stop_inference:
71
- response = "Inference cancelled."
72
- yield history + [(message, response)]
73
- return
74
- token = output['generated_text'][-1]['content']
75
- response += token
76
- yield history + [(message, response)] # Yield history + new response
77
-
78
- else:
79
- # API-based inference
80
- messages = [{"role": "system", "content": system_message}]
81
- for val in history:
82
- if val[0]:
83
- messages.append({"role": "user", "content": val[0]})
84
- if val[1]:
85
- messages.append({"role": "assistant", "content": val[1]})
86
- messages.append({"role": "user", "content": message})
87
-
88
- response = ""
89
- for message_chunk in client.chat_completion(
90
- messages,
91
- max_tokens=max_tokens,
92
- stream=False,
93
- temperature=temperature,
94
- top_p=top_p,
95
- ):
96
- if stop_inference:
97
- response = "Inference cancelled."
98
- yield history + [(message, response)]
99
- return
100
- if stop_inference:
101
- response = "Inference cancelled."
102
- break
103
- token = message_chunk.choices[0].delta.content
104
- response += token
105
- yield history + [(message, response)] # Yield history + new response
106
  SUCCESSFUL_REQUESTS.inc() # Increment successful request counter
107
  except Exception as e:
108
  FAILED_REQUESTS.inc() # Increment failed request counter
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ # import torch
4
+ # from transformers import pipeline
5
  from prometheus_client import start_http_server, Counter, Summary
6
 
7
  from typing import Iterable
 
24
  client = InferenceClient(model="mistralai/Mistral-Small-Instruct-2409",
25
  # token=HF_ACCESS
26
  )
27
+ # pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
28
 
29
  # Global flag to handle cancellation
30
  stop_inference = False
 
49
  if history is None:
50
  history = []
51
 
52
+ # API-based inference
53
+ messages = [{"role": "system", "content": system_message}]
54
+ for val in history:
55
+ if val[0]:
56
+ messages.append({"role": "user", "content": val[0]})
57
+ if val[1]:
58
+ messages.append({"role": "assistant", "content": val[1]})
59
+ messages.append({"role": "user", "content": message})
60
+
61
+ response = ""
62
+ for message_chunk in client.chat_completion(
63
+ messages,
64
+ max_tokens=max_tokens,
65
+ stream=False,
66
+ temperature=temperature,
67
+ top_p=top_p,
68
+ ):
69
+ if stop_inference:
70
+ response = "Inference cancelled."
71
+ yield history + [(message, response)]
72
+ return
73
+ if stop_inference:
74
+ response = "Inference cancelled."
75
+ break
76
+ token = message_chunk.choices[0].delta.content
77
+ response += token
78
+ yield history + [(message, response)] # Yield history + new response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  SUCCESSFUL_REQUESTS.inc() # Increment successful request counter
80
  except Exception as e:
81
  FAILED_REQUESTS.inc() # Increment failed request counter
requirements.txt CHANGED
@@ -1,8 +1,4 @@
1
- --extra-index-url https://download.pytorch.org/whl/cpu
2
  huggingface_hub==0.23.*
3
  gradio==4.43.0
4
- torch==2.4.*
5
- transformers==4.43.*
6
- accelerate==0.33.*
7
  python-dotenv==1.0.1
8
  prometheus_client==0.21.*
 
 
1
  huggingface_hub==0.23.*
2
  gradio==4.43.0
 
 
 
3
  python-dotenv==1.0.1
4
  prometheus_client==0.21.*