change model to llm write
Browse files
app.py
CHANGED
@@ -64,27 +64,29 @@ Merci !
|
|
64 |
|
65 |
"""
|
66 |
class PredictionRequest(BaseModel):
|
67 |
-
history:
|
68 |
-
prompt: str
|
69 |
max_length: int = 128000
|
70 |
top_p: float = 0.8
|
71 |
temperature: float = 0.6
|
72 |
-
@app.post("/generate/")
|
73 |
-
async def predict(request: PredictionRequest):
|
74 |
-
history = default_prompt
|
75 |
-
prompt = request.prompt
|
76 |
-
max_length = request.max_length
|
77 |
-
top_p = request.top_p
|
78 |
-
temperature = request.temperature
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
stop = StopOnTokens()
|
81 |
messages = []
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
continue
|
87 |
-
if idx == len(history) - 1 and not model_msg:
|
88 |
query = user_msg
|
89 |
break
|
90 |
if user_msg:
|
@@ -92,35 +94,35 @@ async def predict(request: PredictionRequest):
|
|
92 |
if model_msg:
|
93 |
messages.append({"role": "assistant", "content": model_msg})
|
94 |
|
|
|
95 |
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
|
96 |
-
next(model.parameters()).device
|
97 |
-
|
|
|
98 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
99 |
tokenizer.get_command("<|observation|>")]
|
|
|
100 |
generate_kwargs = {
|
101 |
"input_ids": model_inputs,
|
102 |
-
"
|
103 |
-
"max_new_tokens": max_length,
|
104 |
"do_sample": True,
|
105 |
-
"top_p": top_p,
|
106 |
-
"temperature": temperature,
|
107 |
"stopping_criteria": StoppingCriteriaList([stop]),
|
108 |
"repetition_penalty": 1,
|
109 |
"eos_token_id": eos_token_id,
|
110 |
}
|
111 |
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
if new_token and '<|user|>' in new_token:
|
118 |
-
new_token = new_token.split('<|user|>')[0]
|
119 |
-
if new_token:
|
120 |
-
generated_text += new_token
|
121 |
-
history[-1][1] = generated_text
|
122 |
|
123 |
-
return {"history": history}
|
124 |
if __name__ == "__main__":
|
125 |
uvicorn.run("app:app",reload=True)
|
126 |
|
|
|
64 |
|
65 |
"""
|
66 |
class PredictionRequest(BaseModel):
|
67 |
+
history: list
|
68 |
+
prompt: str
|
69 |
max_length: int = 128000
|
70 |
top_p: float = 0.8
|
71 |
temperature: float = 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
class PredictionResponse(BaseModel):
|
74 |
+
history: list
|
75 |
+
|
76 |
+
@app.get("/")
|
77 |
+
async def home():
|
78 |
+
return 'STN BIG DATA'
|
79 |
+
@app.post("/predict", response_model=PredictionResponse)
|
80 |
+
async def predict(request: PredictionRequest):
|
81 |
stop = StopOnTokens()
|
82 |
messages = []
|
83 |
+
query = ""
|
84 |
+
|
85 |
+
# Préparer les messages sans inclure le prompt par défaut dans l'historique renvoyé
|
86 |
+
for idx, (user_msg, model_msg) in enumerate(request.history):
|
87 |
+
if idx == 0 and request.prompt: # Ignorer le prompt par défaut dans l'historique
|
88 |
continue
|
89 |
+
if idx == len(request.history) - 1 and not model_msg:
|
90 |
query = user_msg
|
91 |
break
|
92 |
if user_msg:
|
|
|
94 |
if model_msg:
|
95 |
messages.append({"role": "assistant", "content": model_msg})
|
96 |
|
97 |
+
# Inclure le prompt uniquement pour la génération, mais pas dans l'historique
|
98 |
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
|
99 |
+
next(model.parameters()).device
|
100 |
+
)
|
101 |
+
|
102 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
103 |
tokenizer.get_command("<|observation|>")]
|
104 |
+
|
105 |
generate_kwargs = {
|
106 |
"input_ids": model_inputs,
|
107 |
+
"max_new_tokens": request.max_length,
|
|
|
108 |
"do_sample": True,
|
109 |
+
"top_p": request.top_p,
|
110 |
+
"temperature": request.temperature,
|
111 |
"stopping_criteria": StoppingCriteriaList([stop]),
|
112 |
"repetition_penalty": 1,
|
113 |
"eos_token_id": eos_token_id,
|
114 |
}
|
115 |
|
116 |
+
# Générer le texte
|
117 |
+
output = model.generate(**generate_kwargs)
|
118 |
+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
119 |
+
|
120 |
+
# Ajouter le texte généré à l'historique
|
121 |
+
request.history[-1][1] += generated_text
|
122 |
|
123 |
+
# Retourner l'historique sans le prompt
|
124 |
+
return PredictionResponse(history=request.history)
|
|
|
|
|
|
|
|
|
|
|
125 |
|
|
|
126 |
if __name__ == "__main__":
|
127 |
uvicorn.run("app:app",reload=True)
|
128 |
|