Marroco93 commited on
Commit
9441c54
1 Parent(s): d0c61b6
Files changed (1) hide show
  1. main.py +21 -31
main.py CHANGED
@@ -3,11 +3,11 @@ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
6
- import json # Make sure to import json
7
-
8
 
9
  app = FastAPI()
10
 
 
11
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
12
 
13
  class Item(BaseModel):
@@ -15,7 +15,7 @@ class Item(BaseModel):
15
  history: list
16
  system_prompt: str
17
  temperature: float = 0.0
18
- max_new_tokens: int = 1048
19
  top_p: float = 0.15
20
  repetition_penalty: float = 1.0
21
 
@@ -27,40 +27,30 @@ def format_prompt(message, history):
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
- import json # Import the JSON module
31
-
32
- def generate(item: Item):
33
- temperature = float(item.temperature)
34
- if temperature < 1e-2:
35
- temperature = 1e-2
36
- top_p = float(item.top_p)
37
-
38
- generate_kwargs = dict(
39
- temperature=temperature,
40
- max_new_tokens=item.max_new_tokens,
41
- top_p=top_p,
42
- repetition_penalty=item.repetition_penalty,
43
- do_sample=True,
44
- seed=42,
45
- )
46
-
47
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
48
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
49
-
50
- # Convert stream to a list to check if it's the last element
51
- responses = list(stream)
52
- for i, response in enumerate(responses):
53
- # Prepare the chunk as a JSON object
 
 
 
 
 
 
54
  chunk = {
55
  "text": response.token.text,
56
- "complete": i == len(responses) - 1 # True if this is the last chunk
57
  }
58
- # Yield the JSON-encoded string with a newline to separate chunks
59
  yield json.dumps(chunk).encode("utf-8") + b"\n"
60
 
61
  @app.post("/generate/")
62
  async def generate_text(item: Item):
63
- return StreamingResponse(generate(item), media_type="application/x-ndjson")
64
-
65
-
66
 
 
 
 
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
6
+ from typing import Generator
 
7
 
8
  app = FastAPI()
9
 
10
+ # Initialize the InferenceClient with your model
11
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
12
 
13
  class Item(BaseModel):
 
15
  history: list
16
  system_prompt: str
17
  temperature: float = 0.0
18
+ max_new_tokens: int = 9000
19
  top_p: float = 0.15
20
  repetition_penalty: float = 1.0
21
 
 
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
+ def generate_stream(item: Item) -> Generator[bytes, None, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
32
+ generate_kwargs = {
33
+ "temperature": item.temperature,
34
+ "max_new_tokens": item.max_new_tokens,
35
+ "top_p": item.top_p,
36
+ "repetition_penalty": item.repetition_penalty,
37
+ "do_sample": True,
38
+ "seed": 42, # Adjust or omit the seed as needed
39
+ }
40
+
41
+ # Stream the response from the InferenceClient
42
+ for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
43
+ # This assumes 'details=True' gives you a structure where you can access the text like this
44
  chunk = {
45
  "text": response.token.text,
46
+ "complete": response.generated_text is not None # Adjust based on how you detect completion
47
  }
 
48
  yield json.dumps(chunk).encode("utf-8") + b"\n"
49
 
50
  @app.post("/generate/")
51
  async def generate_text(item: Item):
52
+ # Stream response back to the client
53
+ return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
 
54
 
55
+ if __name__ == "__main__":
56
+ uvicorn.run(app, host="0.0.0.0", port=8000)