omkar56 commited on
Commit
944b573
1 Parent(s): 75e758a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -0
main.py CHANGED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from huggingface_hub import InferenceClient
3
+ import random
4
+
5
+ API_URL = "https://api-inference.huggingface.co/models/"
6
+
7
+ client = InferenceClient(
8
+ "mistralai/Mistral-7B-Instruct-v0.1"
9
+ )
10
+
11
+ app = FastAPI()
12
+
13
+ def format_prompt(message, history):
14
+ prompt = "<s>"
15
+ for user_prompt, bot_response in history:
16
+ prompt += f"[INST] {user_prompt} [/INST]"
17
+ prompt += f" {bot_response}</s> "
18
+ prompt += f"[INST] {message} [/INST]"
19
+ return prompt
20
+
21
+
22
+ @app.post("api/v1/generate_text")
23
+ async def generate_text(request: Request, prompt: str = Body()):
24
+ history = [] # You might need to handle this based on your actual usage
25
+
26
+ temperature = request.headers.get("temperature", 0.9)
27
+ top_p = request.headers.get("top_p", 0.95)
28
+ repetition_penalty = request.headers.get("repetition_penalty", 1.0)
29
+
30
+ formatted_prompt = format_prompt(prompt, history)
31
+ response = client.text_generation(
32
+ formatted_prompt,
33
+ temperature=temperature,
34
+ max_new_tokens=512,
35
+ top_p=top_p,
36
+ repetition_penalty=repetition_penalty,
37
+ do_sample=True,
38
+ seed=random.randint(0, 10**7),
39
+ )[0]
40
+
41
+ return response.token.text