omkar56 commited on
Commit
3d03ce4
1 Parent(s): a438652

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -6
main.py CHANGED
@@ -4,17 +4,19 @@ from typing import Optional
4
  from huggingface_hub import InferenceClient
5
  import random
6
 
7
- API_URL = "https://api-inference.huggingface.co/models/"
8
- API_KEY = "abcd12345" # Replace with your actual API key
9
 
10
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
 
 
 
 
11
  app = FastAPI()
12
 
13
  security = APIKeyHeader(name="api_key", auto_error=False)
14
 
15
  def get_api_key(api_key: Optional[str] = Depends(security)):
16
  if api_key is None or api_key != API_KEY:
17
- raise HTTPException(status_code=401, detail="Unauthorized access")
18
  return api_key
19
 
20
  def format_prompt(message, history):
@@ -35,15 +37,16 @@ def generate_text(
35
  sys_prompt = body.get("sysPrompt", "")
36
  temperature = body.get("temperature", 0.5)
37
  top_p = body.get("top_p", 0.95)
 
38
  repetition_penalty = body.get("repetition_penalty", 1.0)
39
-
40
  history = [] # You might need to handle this based on your actual usage
41
  formatted_prompt = format_prompt(prompt, history)
42
 
43
  stream = client.text_generation(
44
  formatted_prompt,
45
  temperature=temperature,
46
- max_new_tokens=512,
47
  top_p=top_p,
48
  repetition_penalty=repetition_penalty,
49
  do_sample=True,
 
4
  from huggingface_hub import InferenceClient
5
  import random
6
 
 
 
7
 
8
+ API_URL = os.environ.get("API_URL")
9
+ API_KEY = os.environ.get("API_KEY")
10
+ MODEL_NAME = os.environ.get("MODEL_NAME")
11
+
12
+ client = InferenceClient(MODEL_NAME)
13
  app = FastAPI()
14
 
15
  security = APIKeyHeader(name="api_key", auto_error=False)
16
 
17
  def get_api_key(api_key: Optional[str] = Depends(security)):
18
  if api_key is None or api_key != API_KEY:
19
+ raise HTTPException(status_code=401, error="Unauthorized access")
20
  return api_key
21
 
22
  def format_prompt(message, history):
 
37
  sys_prompt = body.get("sysPrompt", "")
38
  temperature = body.get("temperature", 0.5)
39
  top_p = body.get("top_p", 0.95)
40
+ max_new_tokens = body.get("max_new_tokens",512)
41
  repetition_penalty = body.get("repetition_penalty", 1.0)
42
+ print(f"temperature + {temperature}")
43
  history = [] # You might need to handle this based on your actual usage
44
  formatted_prompt = format_prompt(prompt, history)
45
 
46
  stream = client.text_generation(
47
  formatted_prompt,
48
  temperature=temperature,
49
+ max_new_tokens=max_new_tokens,
50
  top_p=top_p,
51
  repetition_penalty=repetition_penalty,
52
  do_sample=True,