bambadij commited on
Commit
85fc09e
·
1 Parent(s): da80063

change model to llm write

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -64,27 +64,29 @@ Merci !
64
 
65
  """
66
  class PredictionRequest(BaseModel):
67
- history: List[Tuple[str, str]] = []
68
- prompt: str = ""
69
  max_length: int = 128000
70
  top_p: float = 0.8
71
  temperature: float = 0.6
72
- @app.post("/generate/")
73
- async def predict(request: PredictionRequest):
74
- history = default_prompt
75
- prompt = request.prompt
76
- max_length = request.max_length
77
- top_p = request.top_p
78
- temperature = request.temperature
79
 
 
 
 
 
 
 
 
 
80
  stop = StopOnTokens()
81
  messages = []
82
- if prompt:
83
- messages.append({"role": "system", "content": prompt})
84
- for idx, (user_msg, model_msg) in enumerate(history):
85
- if prompt and idx == 0:
 
86
  continue
87
- if idx == len(history) - 1 and not model_msg:
88
  query = user_msg
89
  break
90
  if user_msg:
@@ -92,35 +94,35 @@ async def predict(request: PredictionRequest):
92
  if model_msg:
93
  messages.append({"role": "assistant", "content": model_msg})
94
 
 
95
  model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
96
- next(model.parameters()).device)
97
- streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
 
98
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
99
  tokenizer.get_command("<|observation|>")]
 
100
  generate_kwargs = {
101
  "input_ids": model_inputs,
102
- "streamer": streamer,
103
- "max_new_tokens": max_length,
104
  "do_sample": True,
105
- "top_p": top_p,
106
- "temperature": temperature,
107
  "stopping_criteria": StoppingCriteriaList([stop]),
108
  "repetition_penalty": 1,
109
  "eos_token_id": eos_token_id,
110
  }
111
 
112
- t = Thread(target=model.generate, kwargs=generate_kwargs)
113
- t.start()
 
 
 
 
114
 
115
- generated_text = ""
116
- for new_token in streamer:
117
- if new_token and '<|user|>' in new_token:
118
- new_token = new_token.split('<|user|>')[0]
119
- if new_token:
120
- generated_text += new_token
121
- history[-1][1] = generated_text
122
 
123
- return {"history": history}
124
  if __name__ == "__main__":
125
  uvicorn.run("app:app",reload=True)
126
 
 
64
 
65
  """
66
  class PredictionRequest(BaseModel):
67
+ history: list
68
+ prompt: str
69
  max_length: int = 128000
70
  top_p: float = 0.8
71
  temperature: float = 0.6
 
 
 
 
 
 
 
72
 
73
+ class PredictionResponse(BaseModel):
74
+ history: list
75
+
76
+ @app.get("/")
77
+ async def home():
78
+ return 'STN BIG DATA'
79
+ @app.post("/predict", response_model=PredictionResponse)
80
+ async def predict(request: PredictionRequest):
81
  stop = StopOnTokens()
82
  messages = []
83
+ query = ""
84
+
85
+ # Préparer les messages sans inclure le prompt par défaut dans l'historique renvoyé
86
+ for idx, (user_msg, model_msg) in enumerate(request.history):
87
+ if idx == 0 and request.prompt: # Ignorer le prompt par défaut dans l'historique
88
  continue
89
+ if idx == len(request.history) - 1 and not model_msg:
90
  query = user_msg
91
  break
92
  if user_msg:
 
94
  if model_msg:
95
  messages.append({"role": "assistant", "content": model_msg})
96
 
97
+ # Inclure le prompt uniquement pour la génération, mais pas dans l'historique
98
  model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
99
+ next(model.parameters()).device
100
+ )
101
+
102
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
103
  tokenizer.get_command("<|observation|>")]
104
+
105
  generate_kwargs = {
106
  "input_ids": model_inputs,
107
+ "max_new_tokens": request.max_length,
 
108
  "do_sample": True,
109
+ "top_p": request.top_p,
110
+ "temperature": request.temperature,
111
  "stopping_criteria": StoppingCriteriaList([stop]),
112
  "repetition_penalty": 1,
113
  "eos_token_id": eos_token_id,
114
  }
115
 
116
+ # Générer le texte
117
+ output = model.generate(**generate_kwargs)
118
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
119
+
120
+ # Ajouter le texte généré à l'historique
121
+ request.history[-1][1] += generated_text
122
 
123
+ # Retourner l'historique sans le prompt
124
+ return PredictionResponse(history=request.history)
 
 
 
 
 
125
 
 
126
  if __name__ == "__main__":
127
  uvicorn.run("app:app",reload=True)
128