Marroco93 commited on
Commit
4cc4589
·
1 Parent(s): 02ccf98
Files changed (1) hide show
  1. main.py +9 -2
main.py CHANGED
@@ -33,13 +33,20 @@ def format_prompt(current_prompt, history):
33
 
34
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
35
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
 
 
 
 
 
 
 
36
  generate_kwargs = {
37
  "temperature": item.temperature,
38
- "max_new_tokens": item.max_new_tokens,
39
  "top_p": item.top_p,
40
  "repetition_penalty": item.repetition_penalty,
41
  "do_sample": True,
42
- "seed": 42, # Adjust or omit the seed as needed
43
  }
44
 
45
  # Stream the response from the InferenceClient
 
33
 
34
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
35
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
36
+ # Estimate token count for the formatted_prompt
37
+ input_token_count = len(formatted_prompt.split()) # Simple whitespace tokenization, adjust if necessary
38
+
39
+ # Ensure total token count doesn't exceed the maximum limit
40
+ max_tokens_allowed = 32768
41
+ max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
42
+
43
  generate_kwargs = {
44
  "temperature": item.temperature,
45
+ "max_new_tokens": max_new_tokens_adjusted,
46
  "top_p": item.top_p,
47
  "repetition_penalty": item.repetition_penalty,
48
  "do_sample": True,
49
+ "seed": 42,
50
  }
51
 
52
  # Stream the response from the InferenceClient