Ashrafb commited on
Commit
b1accca
1 Parent(s): ad20f87

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -49
main.py CHANGED
@@ -1,20 +1,16 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, Request
2
- from fastapi.responses import HTMLResponse, FileResponse
3
- from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
  from huggingface_hub import InferenceClient
6
- import random
7
 
8
- API_URL = "https://api-inference.huggingface.co/models/"
 
9
 
10
  client = InferenceClient(
11
  "mistralai/Mistral-7B-Instruct-v0.1"
12
  )
13
 
14
- app = FastAPI()
15
-
16
 
17
- def format_prompt(message, history):
18
  prompt = "<s>"
19
  for user_prompt, bot_response in history:
20
  prompt += f"[INST] {user_prompt} [/INST]"
@@ -23,56 +19,39 @@ def format_prompt(message, history):
23
  return prompt
24
 
25
 
26
- def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
 
 
 
 
 
 
27
  temperature = float(temperature)
28
  if temperature < 1e-2:
29
  temperature = 1e-2
30
  top_p = float(top_p)
31
 
32
- generate_kwargs = dict(
33
- temperature=temperature,
34
- max_new_tokens=max_new_tokens,
35
- top_p=top_p,
36
- repetition_penalty=repetition_penalty,
37
- do_sample=True,
38
- seed=random.randint(0, 10**7),
39
- )
40
-
41
- formatted_prompt = format_prompt(prompt, history)
42
-
43
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
44
- output = ""
45
- word = ""
46
 
47
- for response in stream:
48
- token_text = response.token.text.strip()
49
- if token_text != "":
50
- # Decode the token text to handle encoded characters
51
- decoded_text = token_text.encode("utf-8", "backslashreplace").decode("utf-8")
52
 
53
- # Add the decoded letter to the current word
54
- word += decoded_text
55
 
56
- # If the token is a space or the end of the stream, add the word to the output and reset the word
57
- if token_text == " " or response.is_end_of_stream:
58
- output += word + " "
59
- word = ""
60
 
61
- return output
 
 
62
 
63
 
64
  @app.post("/generate/")
65
- async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
66
- history = eval(history) # Convert history string back to list
67
- response = generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
68
-
69
- # Concatenate the generated response strings into a single coherent response
70
- coherent_response = " ".join(response)
71
-
72
- return {"response": coherent_response}
73
-
74
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
75
-
76
- @app.get("/")
77
- def index() -> FileResponse:
78
- return FileResponse(path="/app/static/index.html", media_type="text/html")
 
1
+ from fastapi import FastAPI, Request
 
 
2
  from fastapi.templating import Jinja2Templates
3
  from huggingface_hub import InferenceClient
 
4
 
5
+ app = FastAPI()
6
+ templates = Jinja2Templates(directory="templates")
7
 
8
  client = InferenceClient(
9
  "mistralai/Mistral-7B-Instruct-v0.1"
10
  )
11
 
 
 
12
 
13
+ async def format_prompt(message, history):
14
  prompt = "<s>"
15
  for user_prompt, bot_response in history:
16
  prompt += f"[INST] {user_prompt} [/INST]"
 
19
  return prompt
20
 
21
 
22
+ async def generate(
23
+ prompt: str,
24
+ temperature: float = 0.9,
25
+ max_new_tokens: int = 256,
26
+ top_p: float = 0.95,
27
+ repetition_penalty: float = 1.0,
28
+ ):
29
  temperature = float(temperature)
30
  if temperature < 1e-2:
31
  temperature = 1e-2
32
  top_p = float(top_p)
33
 
34
+ generate_kwargs = {
35
+ "temperature": temperature,
36
+ "max_new_tokens": max_new_tokens,
37
+ "top_p": top_p,
38
+ "repetition_penalty": repetition_penalty,
39
+ "do_sample": True,
40
+ "seed": 42,
41
+ }
 
 
 
 
 
 
42
 
43
+ formatted_prompt = await format_prompt(prompt, [])
 
 
 
 
44
 
45
+ response = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=True)
46
+ return response
47
 
 
 
 
 
48
 
49
+ @app.get("/")
50
+ async def index(request: Request):
51
+ return templates.TemplateResponse("index.html", {"request": request})
52
 
53
 
54
  @app.post("/generate/")
55
+ async def chatbot_response(prompt: str):
56
+ response = await generate(prompt)
57
+ return {"response": response}