Woziii commited on
Commit
2deee43
1 Parent(s): cdbe4f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -21
app.py CHANGED
@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import numpy as np
9
  import time
 
10
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
@@ -28,6 +29,23 @@ models = [
28
  "croissantllm/CroissantLLMBase"
29
  ]
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Variables globales
32
  model = None
33
  tokenizer = None
@@ -38,14 +56,33 @@ def load_model(model_name, progress=gr.Progress()):
38
  progress(0, desc="Chargement du tokenizer")
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  progress(0.5, desc="Chargement du modèle")
41
- model = AutoModelForCausalLM.from_pretrained(
42
- model_name,
43
- torch_dtype=torch.float32,
44
- device_map="cpu",
45
- attn_implementation="eager"
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if tokenizer.pad_token is None:
48
  tokenizer.pad_token = tokenizer.eos_token
 
49
  progress(1.0, desc="Modèle chargé")
50
  return f"Modèle {model_name} chargé avec succès."
51
  except Exception as e:
@@ -63,18 +100,24 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
63
  if model is None or tokenizer is None:
64
  return "Veuillez d'abord charger un modèle.", None, None
65
 
66
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
67
 
68
  try:
69
  with torch.no_grad():
70
  outputs = model(**inputs)
71
 
72
  last_token_logits = outputs.logits[0, -1, :]
73
- probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
74
 
75
- top_k = 10
76
  top_probs, top_indices = torch.topk(probabilities, top_k)
77
  top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
 
78
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
79
 
80
  prob_text = "Prochains tokens les plus probables :\n\n"
@@ -94,17 +137,22 @@ def generate_text(input_text, temperature, top_p, top_k):
94
  if model is None or tokenizer is None:
95
  return "Veuillez d'abord charger un modèle."
96
 
97
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
98
 
99
  try:
100
- with torch.no_grad():
101
- outputs = model.generate(
102
- **inputs,
103
- max_new_tokens=1,
104
- temperature=temperature,
105
- top_p=top_p,
106
- top_k=top_k
107
- )
108
 
109
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
110
  return generated_text
@@ -139,7 +187,7 @@ def plot_attention(input_ids, last_token_logits):
139
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
140
 
141
  fig, ax = plt.subplots(figsize=(14, 7))
142
- sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
143
  ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
144
  ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
145
  ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
@@ -158,7 +206,7 @@ def reset():
158
  return "", 1.0, 1.0, 50, None, None, None, None
159
 
160
  with gr.Blocks() as demo:
161
- gr.Markdown("# Analyse et génération de texte")
162
 
163
  with gr.Accordion("Sélection du modèle"):
164
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
@@ -179,7 +227,7 @@ with gr.Blocks() as demo:
179
  attention_plot = gr.Plot(label="Visualisation de l'attention")
180
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
181
 
182
- generate_button = gr.Button("Générer le prochain mot")
183
  generated_text = gr.Textbox(label="Texte généré")
184
 
185
  reset_button = gr.Button("Réinitialiser")
 
7
  import seaborn as sns
8
  import numpy as np
9
  import time
10
+ from langdetect import detect
11
 
12
  # Authentification
13
  login(token=os.environ["HF_TOKEN"])
 
29
  "croissantllm/CroissantLLMBase"
30
  ]
31
 
32
+ # Dictionnaire des langues supportées par modèle
33
+ model_languages = {
34
+ "meta-llama/Llama-2-13b-hf": ["en"],
35
+ "meta-llama/Llama-2-7b-hf": ["en"],
36
+ "meta-llama/Llama-2-70b-hf": ["en"],
37
+ "meta-llama/Meta-Llama-3-8B": ["en"],
38
+ "meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
39
+ "meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
40
+ "mistralai/Mistral-7B-v0.1": ["en"],
41
+ "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
42
+ "mistralai/Mistral-7B-v0.3": ["en"],
43
+ "google/gemma-2-2b": ["en"],
44
+ "google/gemma-2-9b": ["en"],
45
+ "google/gemma-2-27b": ["en"],
46
+ "croissantllm/CroissantLLMBase": ["en", "fr"]
47
+ }
48
+
49
  # Variables globales
50
  model = None
51
  tokenizer = None
 
56
  progress(0, desc="Chargement du tokenizer")
57
  tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  progress(0.5, desc="Chargement du modèle")
59
+
60
+ # Configurations spécifiques par modèle
61
+ if "mixtral" in model_name.lower():
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ model_name,
64
+ torch_dtype=torch.float16,
65
+ device_map="auto",
66
+ attn_implementation="flash_attention_2",
67
+ load_in_8bit=True
68
+ )
69
+ elif "llama" in model_name.lower() or "mistral" in model_name.lower():
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ model_name,
72
+ torch_dtype=torch.float16,
73
+ device_map="auto",
74
+ attn_implementation="flash_attention_2"
75
+ )
76
+ else:
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ model_name,
79
+ torch_dtype=torch.float16,
80
+ device_map="auto"
81
+ )
82
+
83
  if tokenizer.pad_token is None:
84
  tokenizer.pad_token = tokenizer.eos_token
85
+
86
  progress(1.0, desc="Modèle chargé")
87
  return f"Modèle {model_name} chargé avec succès."
88
  except Exception as e:
 
100
  if model is None or tokenizer is None:
101
  return "Veuillez d'abord charger un modèle.", None, None
102
 
103
+ # Détection de la langue
104
+ detected_lang = detect(input_text)
105
+ if detected_lang not in model_languages.get(model.config._name_or_path, []):
106
+ return f"Langue détectée ({detected_lang}) non supportée par ce modèle.", None, None
107
+
108
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
109
 
110
  try:
111
  with torch.no_grad():
112
  outputs = model(**inputs)
113
 
114
  last_token_logits = outputs.logits[0, -1, :]
115
+ probabilities = torch.nn.functional.softmax(last_token_logits / temperature, dim=-1)
116
 
117
+ top_k = min(top_k, probabilities.size(-1))
118
  top_probs, top_indices = torch.topk(probabilities, top_k)
119
  top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
120
+
121
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
122
 
123
  prob_text = "Prochains tokens les plus probables :\n\n"
 
137
  if model is None or tokenizer is None:
138
  return "Veuillez d'abord charger un modèle."
139
 
140
+ # Détection de la langue
141
+ detected_lang = detect(input_text)
142
+ if detected_lang not in model_languages.get(model.config._name_or_path, []):
143
+ return f"Langue détectée ({detected_lang}) non supportée par ce modèle."
144
+
145
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
146
 
147
  try:
148
+ outputs = model.generate(
149
+ **inputs,
150
+ max_new_tokens=50,
151
+ do_sample=True,
152
+ temperature=temperature,
153
+ top_p=top_p,
154
+ top_k=top_k
155
+ )
156
 
157
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
158
  return generated_text
 
187
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
188
 
189
  fig, ax = plt.subplots(figsize=(14, 7))
190
+ sns.heatmap(top_attention_scores.unsqueeze(0).cpu().numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
191
  ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
192
  ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
193
  ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
 
206
  return "", 1.0, 1.0, 50, None, None, None, None
207
 
208
  with gr.Blocks() as demo:
209
+ gr.Markdown("# Analyse et génération de texte avec LLM")
210
 
211
  with gr.Accordion("Sélection du modèle"):
212
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
 
227
  attention_plot = gr.Plot(label="Visualisation de l'attention")
228
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
229
 
230
+ generate_button = gr.Button("Générer la suite du texte")
231
  generated_text = gr.Textbox(label="Texte généré")
232
 
233
  reset_button = gr.Button("Réinitialiser")