bambadij commited on
Commit
42e0859
·
1 Parent(s): 85fc09e
Files changed (1) hide show
  1. app.py +38 -75
app.py CHANGED
@@ -1,5 +1,5 @@
1
  #load package
2
- from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
@@ -43,88 +43,51 @@ app =FastAPI(
43
  logging.basicConfig(level=logging.INFO)
44
  logger =logging.getLogger(__name__)
45
 
46
- class StopOnTokens(StoppingCriteria):
47
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
48
- stop_ids = model.config.eos_token_id
49
- for stop_id in stop_ids:
50
- if input_ids[0][-1] == stop_id:
51
- return True
52
- return False
53
-
54
- default_prompt = """Bonjour,
55
-
56
- En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
57
- 1. **Informations Client** : Indique des détails pertinents sur le client.
58
- 2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.).
59
- 3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints).
60
-
61
- Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant.
62
-
63
- Merci !
64
-
65
- """
66
- class PredictionRequest(BaseModel):
67
- history: list
68
- prompt: str
69
- max_length: int = 128000
70
- top_p: float = 0.8
71
- temperature: float = 0.6
72
-
73
- class PredictionResponse(BaseModel):
74
- history: list
75
-
76
  @app.get("/")
77
  async def home():
78
  return 'STN BIG DATA'
79
- @app.post("/predict", response_model=PredictionResponse)
80
- async def predict(request: PredictionRequest):
81
- stop = StopOnTokens()
82
- messages = []
83
- query = ""
84
-
85
- # Préparer les messages sans inclure le prompt par défaut dans l'historique renvoyé
86
- for idx, (user_msg, model_msg) in enumerate(request.history):
87
- if idx == 0 and request.prompt: # Ignorer le prompt par défaut dans l'historique
88
- continue
89
- if idx == len(request.history) - 1 and not model_msg:
90
- query = user_msg
91
- break
92
- if user_msg:
93
- messages.append({"role": "user", "content": user_msg})
94
- if model_msg:
95
- messages.append({"role": "assistant", "content": model_msg})
96
 
97
- # Inclure le prompt uniquement pour la génération, mais pas dans l'historique
98
- model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
99
- next(model.parameters()).device
100
- )
101
 
102
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
103
- tokenizer.get_command("<|observation|>")]
104
 
105
- generate_kwargs = {
106
- "input_ids": model_inputs,
107
- "max_new_tokens": request.max_length,
108
- "do_sample": True,
109
- "top_p": request.top_p,
110
- "temperature": request.temperature,
111
- "stopping_criteria": StoppingCriteriaList([stop]),
112
- "repetition_penalty": 1,
113
- "eos_token_id": eos_token_id,
114
- }
115
 
116
- # Générer le texte
117
- output = model.generate(**generate_kwargs)
118
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
119
 
120
- # Ajouter le texte généré à l'historique
121
- request.history[-1][1] += generated_text
122
 
123
- # Retourner l'historique sans le prompt
124
- return PredictionResponse(history=request.history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  if __name__ == "__main__":
127
- uvicorn.run("app:app",reload=True)
128
-
129
-
130
-
 
1
  #load package
2
+ from fastapi import FastAPI,HTTPException
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
 
43
  logging.basicConfig(level=logging.INFO)
44
  logger =logging.getLogger(__name__)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @app.get("/")
47
  async def home():
48
  return 'STN BIG DATA'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Charger le modèle et le tokenizer
51
+ model_name = "THUDM/longwriter-glm4-9b"
52
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
53
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
54
 
55
+ # Prompt par défaut
56
+ default_prompt = """Vous êtes un assistant expert en résumé de plaintes. Votre tâche est de résumer la plainte fournie de manière concise et professionnelle, en incluant les points clés suivants :
57
 
58
+ 1. Le problème principal
59
+ 2. Les détails pertinents
60
+ 3. L'impact sur le plaignant
61
+ 4. Toute action ou résolution demandée
 
 
 
 
 
 
62
 
63
+ Résumez la plainte suivante en 3-4 phrases :
 
 
64
 
65
+ """
 
66
 
67
+ class ComplaintInput(BaseModel):
68
+ text: str
69
+
70
+ @app.post("/summarize_complaint")
71
+ async def summarize_complaint(input: ComplaintInput):
72
+ try:
73
+ full_prompt = default_prompt + input.text
74
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
75
+
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ **inputs,
79
+ max_new_tokens=150,
80
+ num_return_sequences=1,
81
+ no_repeat_ngram_size=2,
82
+ temperature=0.7
83
+ )
84
+
85
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+ # Enlever le prompt initial de la sortie
87
+ summary = summary.replace(full_prompt, "").strip()
88
+ return {"summary": summary}
89
+ except Exception as e:
90
+ raise HTTPException(status_code=500, detail=str(e))
91
 
92
  if __name__ == "__main__":
93
+ uvicorn.run("app:app",reload=True)