davda54 commited on
Commit
f8494a3
·
1 Parent(s): fe94d2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -73,6 +73,18 @@ class BatchStreamer(TextIteratorStreamer):
73
  self.on_finalized_text(printable_text, stream_end=True)
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def translate(source, source_language, target_language):
77
  if source_language == target_language:
78
  yield source.strip()
@@ -97,10 +109,17 @@ def translate(source, source_language, target_language):
97
  input_ids=source_subwords,
98
  attention_mask=(source_subwords != pad_index).long(),
99
  max_new_tokens = 512-1,
 
 
 
 
 
 
 
100
  # num_beams=4,
101
  # early_stopping=True,
102
- do_sample=False,
103
- use_cache=True
104
  )
105
  t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
106
  t.start()
 
73
  self.on_finalized_text(printable_text, stream_end=True)
74
 
75
 
76
+ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
77
+ def __init__(self, penalty: float, model):
78
+ last_bias = model.classifier.nonlinearity[-1].bias.data
79
+ last_bias = torch.nn.functional.log_softmax(last_bias)
80
+ self.penalty = penalty * (last_bias - last_bias.max())
81
+
82
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
83
+ penalized_score = torch.gather(scores + self.penalty.unsqueeze(0).to(input_ids.device), 1, input_ids)
84
+ scores.scatter_(1, input_ids, penalized_score)
85
+ return scores
86
+
87
+
88
  def translate(source, source_language, target_language):
89
  if source_language == target_language:
90
  yield source.strip()
 
109
  input_ids=source_subwords,
110
  attention_mask=(source_subwords != pad_index).long(),
111
  max_new_tokens = 512-1,
112
+ top_k=64,
113
+ top_p=0.95,
114
+ do_sample=True,
115
+ temperature=0.3,
116
+ num_beams=1,
117
+ use_cache=True,
118
+ logits_processor=[RepetitionPenaltyLogitsProcessor(1.0, model)]
119
  # num_beams=4,
120
  # early_stopping=True,
121
+ #do_sample=False,
122
+ #use_cache=True
123
  )
124
  t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
125
  t.start()