Marroco93 commited on
Commit
bc5e3f5
1 Parent(s): 0f34bf3
Files changed (2) hide show
  1. main.py +15 -10
  2. requirements.txt +2 -1
main.py CHANGED
@@ -5,11 +5,13 @@ from huggingface_hub import InferenceClient
5
  import uvicorn
6
  from typing import Generator
7
  import json # Asegúrate de que esta línea esté al principio del archivo
 
 
8
 
9
  app = FastAPI()
10
 
11
  # Initialize the InferenceClient with your model
12
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
 
14
  class Item(BaseModel):
15
  prompt: str
@@ -21,12 +23,16 @@ class Item(BaseModel):
21
  repetition_penalty: float = 1.0
22
 
23
  def format_prompt(message, history):
24
- prompt = "<s>"
 
25
  for user_prompt, bot_response in history:
26
- prompt += f"[INST] {user_prompt} [/INST]"
27
- prompt += f" {bot_response}</s> "
28
- prompt += f"[INST] {message} [/INST]"
29
- return prompt
 
 
 
30
 
31
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
32
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
@@ -41,17 +47,16 @@ def generate_stream(item: Item) -> Generator[bytes, None, None]:
41
 
42
  # Stream the response from the InferenceClient
43
  for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
44
- # This assumes 'details=True' gives you a structure where you can access the text like this
45
  chunk = {
46
  "text": response.token.text,
47
- "complete": response.generated_text is not None # Adjust based on how you detect completion
48
  }
49
  yield json.dumps(chunk).encode("utf-8") + b"\n"
50
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
53
- # Stream response back to the client
54
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
55
 
56
  if __name__ == "__main__":
57
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  import uvicorn
6
  from typing import Generator
7
  import json # Asegúrate de que esta línea esté al principio del archivo
8
+ import torch
9
+
10
 
11
  app = FastAPI()
12
 
13
  # Initialize the InferenceClient with your model
14
+ client = InferenceClient("meta-llama/Llama-2-7b-chat")
15
 
16
  class Item(BaseModel):
17
  prompt: str
 
23
  repetition_penalty: float = 1.0
24
 
25
  def format_prompt(message, history):
26
+ # Simple structure: alternating lines of dialogue, no special tokens unless specified by the model documentation
27
+ conversation = ""
28
  for user_prompt, bot_response in history:
29
+ conversation += f"User: {user_prompt}\nBot: {bot_response}\n"
30
+ conversation += f"User: {message}"
31
+ return conversation
32
+
33
+
34
+
35
+ # No changes needed in the format_prompt function unless the new model requires different prompt formatting
36
 
37
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
38
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
 
47
 
48
  # Stream the response from the InferenceClient
49
  for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
50
+ # Check if the 'details' flag and response structure are the same for the new model
51
  chunk = {
52
  "text": response.token.text,
53
+ "complete": response.generated_text is not None
54
  }
55
  yield json.dumps(chunk).encode("utf-8") + b"\n"
56
 
57
  @app.post("/generate/")
58
  async def generate_text(item: Item):
 
59
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
60
 
61
  if __name__ == "__main__":
62
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  fastapi
2
  uvicorn
3
  huggingface_hub
4
- pydantic
 
 
1
  fastapi
2
  uvicorn
3
  huggingface_hub
4
+ pydantic
5
+ torch==2.0.0