File size: 3,470 Bytes
3fe712e
003aab4
3fe712e
 
57da3c7
 
bbf7982
57da3c7
 
 
3fe712e
57da3c7
 
 
 
3fe712e
 
 
 
 
57da3c7
 
 
003aab4
3fe712e
 
 
 
 
 
 
 
 
 
 
 
 
003aab4
85fc09e
003aab4
 
 
 
85fc09e
003aab4
57da3c7
003aab4
85fc09e
42e0859
bbf7982
 
 
 
003aab4
bbf7982
d74f4a3
bbf7982
003aab4
 
bbf7982
 
 
 
 
7949d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf7982
5bf98dd
a19aaf0
5bf98dd
 
a19aaf0
5bf98dd
a19aaf0
7949d6d
a19aaf0
 
7949d6d
bbf7982
 
 
003aab4
 
 
57da3c7
 
003aab4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#load package
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from typing import List, Tuple
from threading import Thread
import os
from pydantic import BaseModel
import logging
import uvicorn


# Configurer les répertoires de cache
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
os.environ['HF_HOME'] = '/app/.cache'
# Charger le modèle et le tokenizer
model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)

 
#Additional information
 
Informations = """ 
-text : Texte à resumé 

output:
- Text summary : texte resumé
"""

app =FastAPI(
    title='Text Summary',
    description =Informations
)  
default_prompt = """Bonjour,

En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
1. **Informations Client** : Indique des détails pertinents sur le client.
2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.).
3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints).

Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant.

Merci !

"""
#class to define the input text 
logging.basicConfig(level=logging.INFO)
logger =logging.getLogger(__name__)
# Définir le modèle de requête
class PredictionRequest(BaseModel):
    text: str = None  # Texte personnalisé ajouté par l'utilisateur
    # max_length: int = 2000  # Limite la longueur maximale du texte généré

@app.post("/predict/")
async def predict(request: PredictionRequest):
    # Construire le prompt final
    if request.text:
        prompt = default_prompt + "\n\n" + request.text
    else:
        prompt = default_prompt
    # Assurez-vous que le pad_token est défini
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Définir une longueur maximale arbitraire pour la tokenization
    max_length = 1024  # Vous pouvez ajuster cette valeur selon vos besoins

    # Tokenize l'entrée sans troncation automatique
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=False,
        max_length=None  # Pas de longueur maximale pour la tokenization
    )
    
    # Tronquer manuellement si nécessaire
    if inputs.input_ids.shape[1] > max_length:
        inputs.input_ids = inputs.input_ids[:, :max_length]
        inputs.attention_mask = inputs.attention_mask[:, :max_length]

    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)

    # Générez le texte en passant l'attention mask
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=3000,  # Longueur maximale pour la génération
        do_sample=True
    )
   
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return {"generated_text": generated_text}

if __name__ == "__main__":
    uvicorn.run("app:app",reload=True)