omkar56 commited on
Commit
e9448a1
1 Parent(s): 58ef054

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -5
main.py CHANGED
@@ -22,13 +22,17 @@ def format_prompt(message, history):
22
  @app.post("/api/v1/generate_text")
23
  async def generate_text(request: Request, prompt: str = Body()):
24
  history = [] # You might need to handle this based on your actual usage
25
-
26
  temperature = request.headers.get("temperature", 0.9)
 
27
  top_p = request.headers.get("top_p", 0.95)
 
28
  repetition_penalty = request.headers.get("repetition_penalty", 1.0)
 
29
 
30
  formatted_prompt = format_prompt(prompt, history)
31
- response = client.text_generation(
 
32
  formatted_prompt,
33
  temperature=temperature,
34
  max_new_tokens=512,
@@ -36,6 +40,16 @@ async def generate_text(request: Request, prompt: str = Body()):
36
  repetition_penalty=repetition_penalty,
37
  do_sample=True,
38
  seed=random.randint(0, 10**7),
39
- )[0]
40
-
41
- return response
 
 
 
 
 
 
 
 
 
 
 
22
  @app.post("/api/v1/generate_text")
23
  async def generate_text(request: Request, prompt: str = Body()):
24
  history = [] # You might need to handle this based on your actual usage
25
+ print(f"prompt + {prompt}")
26
  temperature = request.headers.get("temperature", 0.9)
27
+ print(f"temperature + {temperature}")
28
  top_p = request.headers.get("top_p", 0.95)
29
+ print(f"top_p + {top_p}")
30
  repetition_penalty = request.headers.get("repetition_penalty", 1.0)
31
+ print(f"repetition_penalty + {repetition_penalty}")
32
 
33
  formatted_prompt = format_prompt(prompt, history)
34
+ print(f"formatted_prompt + {formatted_prompt}")
35
+ stream = client.text_generation(
36
  formatted_prompt,
37
  temperature=temperature,
38
  max_new_tokens=512,
 
40
  repetition_penalty=repetition_penalty,
41
  do_sample=True,
42
  seed=random.randint(0, 10**7),
43
+ stream=True,
44
+ details=True,
45
+ return_full_text=False
46
+ )
47
+ output = ""
48
+
49
+ for response in stream:
50
+ output += response.token.text
51
+ yield output
52
+ print(f"output + {output}")
53
+ return output
54
+
55
+ # return response