Ashrafb commited on
Commit
536a5e8
1 Parent(s): 8d3ef40

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -7
main.py CHANGED
@@ -30,7 +30,8 @@ def format_prompt(message, history):
30
  prompt += f"[INST] {message} [/INST]"
31
  return prompt
32
 
33
- async def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
 
34
  temperature = float(temperature)
35
  if temperature < 1e-2:
36
  temperature = 1e-2
@@ -47,19 +48,21 @@ async def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0
47
 
48
  formatted_prompt = format_prompt(prompt, history)
49
 
50
- async for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False):
51
- yield response.token.text
52
 
 
 
 
 
53
 
54
 
55
 
56
  @app.post("/generate/")
57
  async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
58
  history = eval(history) # Convert history string back to list
59
- async def generate_response():
60
- async for response_chunk in generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty):
61
- yield {"response_chunk": response_chunk}
62
- return StreamingResponse(generate_response())
63
 
64
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
65
 
 
30
  prompt += f"[INST] {message} [/INST]"
31
  return prompt
32
 
33
+ def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
34
+
35
  temperature = float(temperature)
36
  if temperature < 1e-2:
37
  temperature = 1e-2
 
48
 
49
  formatted_prompt = format_prompt(prompt, history)
50
 
51
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
52
+ output = ""
53
 
54
+ for response in stream:
55
+ output += response.token.text
56
+ yield output
57
+ return output
58
 
59
 
60
 
61
  @app.post("/generate/")
62
  async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
63
  history = eval(history) # Convert history string back to list
64
+ response = generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
65
+ return {"response": response}
 
 
66
 
67
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
68