Marroco93 commited on
Commit
d0435f3
1 Parent(s): 22a4b4f

no message

Browse files
Files changed (1) hide show
  1. main.py +11 -30
main.py CHANGED
@@ -20,49 +20,34 @@ nltk.data.path.append(os.getenv('NLTK_DATA'))
20
  app = FastAPI()
21
 
22
  # Initialize the InferenceClient with your model
23
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
24
-
25
-
26
-
27
- # summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
28
-
29
 
30
  class Item(BaseModel):
31
  prompt: str
32
  history: list
33
  system_prompt: str
34
  temperature: float = 0.8
35
- max_new_tokens: int = 12000
36
  top_p: float = 0.15
37
  repetition_penalty: float = 1.0
38
 
39
- def format_prompt(current_prompt, history):
40
- formatted_history = "<s>"
41
- for entry in history:
42
- if entry["role"] == "user":
43
- formatted_history += f"[USER] {entry['content']} [/USER]"
44
- elif entry["role"] == "assistant":
45
- formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
46
- formatted_history += f"[USER] {current_prompt} [/USER]</s>"
47
- return formatted_history
48
-
49
 
50
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
51
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
52
- # Estimate token count for the formatted_prompt
53
- input_token_count = len(nltk.word_tokenize(formatted_prompt)) # NLTK tokenization
54
-
55
- # Ensure total token count doesn't exceed the maximum limit
56
- max_tokens_allowed = 32768
57
- max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
58
-
59
  generate_kwargs = {
60
  "temperature": item.temperature,
61
- "max_new_tokens": max_new_tokens_adjusted,
62
  "top_p": item.top_p,
63
  "repetition_penalty": item.repetition_penalty,
64
  "do_sample": True,
65
- "seed": 42,
66
  }
67
 
68
  # Stream the response from the InferenceClient
@@ -74,10 +59,6 @@ def generate_stream(item: Item) -> Generator[bytes, None, None]:
74
  }
75
  yield json.dumps(chunk).encode("utf-8") + b"\n"
76
 
77
-
78
- class SummarizeRequest(BaseModel):
79
- text: str
80
-
81
  @app.post("/generate/")
82
  async def generate_text(item: Item):
83
  # Stream response back to the client
 
20
  app = FastAPI()
21
 
22
  # Initialize the InferenceClient with your model
23
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
 
 
24
 
25
  class Item(BaseModel):
26
  prompt: str
27
  history: list
28
  system_prompt: str
29
  temperature: float = 0.8
30
+ max_new_tokens: int = 9000
31
  top_p: float = 0.15
32
  repetition_penalty: float = 1.0
33
 
34
+ def format_prompt(message, history):
35
+ prompt = "<s>"
36
+ for user_prompt, bot_response in history:
37
+ prompt += f"[INST] {user_prompt} [/INST]"
38
+ prompt += f" {bot_response}</s> "
39
+ prompt += f"[INST] {message} [/INST]"
40
+ return prompt
 
 
 
41
 
42
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
43
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
 
 
 
 
 
 
 
44
  generate_kwargs = {
45
  "temperature": item.temperature,
46
+ "max_new_tokens": item.max_new_tokens,
47
  "top_p": item.top_p,
48
  "repetition_penalty": item.repetition_penalty,
49
  "do_sample": True,
50
+ "seed": 42, # Adjust or omit the seed as needed
51
  }
52
 
53
  # Stream the response from the InferenceClient
 
59
  }
60
  yield json.dumps(chunk).encode("utf-8") + b"\n"
61
 
 
 
 
 
62
  @app.post("/generate/")
63
  async def generate_text(item: Item):
64
  # Stream response back to the client