Woziii commited on
Commit
0db8079
1 Parent(s): 117e81a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -32,10 +32,11 @@ def generate_text(input_text, temperature, top_p, top_k):
32
  global model, tokenizer
33
 
34
  inputs = tokenizer(input_text, return_tensors="pt")
 
35
 
36
  with torch.no_grad():
37
  outputs = model.generate(
38
- **inputs,
39
  max_new_tokens=50,
40
  temperature=temperature,
41
  top_p=top_p,
@@ -46,9 +47,11 @@ def generate_text(input_text, temperature, top_p, top_k):
46
 
47
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
48
 
49
- # Extraire les attentions et les logits
50
- attentions = outputs.attentions[-1][0][-1].numpy()
51
- logits = outputs.scores[-1][0]
 
 
52
 
53
  # Visualiser l'attention
54
  plt.figure(figsize=(10, 10))
@@ -58,12 +61,13 @@ def generate_text(input_text, temperature, top_p, top_k):
58
  plt.close()
59
 
60
  # Obtenir les mots les plus probables
61
- probs = torch.nn.functional.softmax(logits, dim=-1)
62
- top_probs, top_indices = torch.topk(probs, k=5)
63
- top_words = [tokenizer.decode([idx]) for idx in top_indices]
64
 
65
  return generated_text, attention_plot, top_words
66
 
 
67
  def reset():
68
  return "", 1.0, 1.0, 50, None, None, None
69
 
 
32
  global model, tokenizer
33
 
34
  inputs = tokenizer(input_text, return_tensors="pt")
35
+ input_ids = inputs["input_ids"]
36
 
37
  with torch.no_grad():
38
  outputs = model.generate(
39
+ input_ids,
40
  max_new_tokens=50,
41
  temperature=temperature,
42
  top_p=top_p,
 
47
 
48
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
49
 
50
+ # Obtenir les logits pour le dernier token généré
51
+ last_token_logits = model(outputs.sequences[:, -1:]).logits[:, -1, :]
52
+
53
+ # Extraire les attentions
54
+ attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
55
 
56
  # Visualiser l'attention
57
  plt.figure(figsize=(10, 10))
 
61
  plt.close()
62
 
63
  # Obtenir les mots les plus probables
64
+ probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
65
+ top_probs, top_indices = torch.topk(probs[0], k=5)
66
+ top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
67
 
68
  return generated_text, attention_plot, top_words
69
 
70
+
71
  def reset():
72
  return "", 1.0, 1.0, 50, None, None, None
73