davda54 commited on
Commit
54c1f0c
2 Parent(s): 3e448cc 4bae60c

Merge branch 'main' of https://huggingface.co/spaces/ltg/no-en-translation

Browse files
Files changed (1) hide show
  1. app.py +67 -28
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
- from transformers import AutoTokenizer, TextIteratorStreamer
3
- # from modeling_nort5 import NorT5ForConditionalGeneration
4
  from threading import Thread
 
5
 
6
 
7
  print(f"Starting to load the model to memory")
@@ -10,9 +10,10 @@ tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
10
  cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
11
  sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
12
  eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
13
- eng_index = tokenizer.convert_tokens_to_ids(">>ENG<<")
14
- nob_index = tokenizer.convert_tokens_to_ids(">>NOB<<")
15
- nno_index = tokenizer.convert_tokens_to_ids(">>NNO<<")
 
16
 
17
  model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
18
 
@@ -25,14 +26,6 @@ model.eval()
25
  print(f"Sucessfully loaded the model to the memory")
26
 
27
 
28
- INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel."
29
- TEMPERATURE = 0.7
30
- SAMPLE = True
31
- BEAMS = 1
32
- PENALTY = 1.2
33
- TOP_K = 64
34
- TOP_P = 0.95
35
-
36
  LANGUAGES = [
37
  "🇬🇧 English",
38
  "🇳🇴 Norwegian (Bokmål)",
@@ -42,10 +35,44 @@ LANGUAGES = [
42
  LANGUAGE_IDS = {
43
  "🇬🇧 English": eng_index,
44
  "🇳🇴 Norwegian (Bokmål)": nob_index,
45
- "🇳🇴 Norwegian (Nynorsk)", nno_index
46
  }
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def set_default_target():
50
  return "*Translating...*"
51
 
@@ -54,33 +81,45 @@ def translate(source, source_language, target_language):
54
  if source_language == target_language:
55
  return source
56
 
 
57
  source_subwords = tokenizer(source).input_ids
58
- source_subwords = [cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + source_subwords + [sep_index]
59
- source_subwords = torch.tensor([source_subwords[:512]])
 
 
 
 
60
 
61
- predictions = model.generate(
 
 
 
 
 
 
62
  input_ids=source_subwords,
 
63
  max_new_tokens = 512-1,
64
- do_sample=False
 
 
 
65
  )
66
- predictions = [tokenizer.decode(p, skip_special_tokens=True) for p in predictions.tolist()]
 
67
 
68
- return predictions
 
 
69
 
70
 
71
  def switch_inputs(source, target, source_language, target_language):
72
  return target, source, target_language, source_language
73
 
74
 
75
- import gradio as gr
76
-
77
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
78
 
79
  gr.Markdown("# Norwegian-English translation")
80
- # gr.HTML('<img src="https://huggingface.co/ltg/norbert3-base/resolve/main/norbert.png" width=6.75%>')
81
- # gr.Checkbox(label="I want to publish all my conversations", value=True)
82
-
83
- # chatbot = gr.Chatbot(value=[[None, "Hei, hva kan jeg gjøre for deg? 😊"]])
84
 
85
  with gr.Row():
86
  with gr.Column(scale=7, variant="panel"):
@@ -116,8 +155,8 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
116
  return {
117
  source: gr.update(interactive=True),
118
  submit: gr.update(interactive=True),
119
- source_language: gr.update(interactive=False),
120
- target_language: gr.update(interactive=False)
121
  }
122
 
123
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
 
3
  from threading import Thread
4
+ import gradio as gr
5
 
6
 
7
  print(f"Starting to load the model to memory")
 
10
  cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
11
  sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
12
  eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
13
+ pad_index = tokenizer.convert_tokens_to_ids("[PAD]")
14
+ eng_index = tokenizer.convert_tokens_to_ids(">>eng<<")
15
+ nob_index = tokenizer.convert_tokens_to_ids(">>nob<<")
16
+ nno_index = tokenizer.convert_tokens_to_ids(">>nno<<")
17
 
18
  model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
19
 
 
26
  print(f"Sucessfully loaded the model to the memory")
27
 
28
 
 
 
 
 
 
 
 
 
29
  LANGUAGES = [
30
  "🇬🇧 English",
31
  "🇳🇴 Norwegian (Bokmål)",
 
35
  LANGUAGE_IDS = {
36
  "🇬🇧 English": eng_index,
37
  "🇳🇴 Norwegian (Bokmål)": nob_index,
38
+ "🇳🇴 Norwegian (Nynorsk)": nno_index
39
  }
40
 
41
 
42
+ class BatchStreamer(TextIteratorStreamer):
43
+ def put(self, value):
44
+ print(value.shape)
45
+
46
+ #if value.size(0) == 1:
47
+ # return super().put(value)
48
+
49
+ if len(self.token_cache) == 0:
50
+ self.token_cache = [[] for _ in range(value.size(0))]
51
+
52
+ value = value.tolist()
53
+
54
+ # Add the new token to the cache and decodes the entire thing.
55
+ for c, v in zip(self.token_cache, value):
56
+ c += [v] if isinstance(v, int) else v
57
+
58
+ paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
59
+ text = '\n'.join(paragraphs)
60
+
61
+ self.on_finalized_text(text)
62
+
63
+ def end(self):
64
+ if len(self.token_cache) > 0:
65
+ paragraphs = [tokenizer.decode(c, **self.decode_kwargs).strip() for c in self.token_cache]
66
+ printable_text = '\n'.join(paragraphs)
67
+ self.token_cache = []
68
+ self.print_len = 0
69
+ else:
70
+ printable_text = ""
71
+
72
+ self.next_tokens_are_prompt = True
73
+ self.on_finalized_text(printable_text, stream_end=True)
74
+
75
+
76
  def set_default_target():
77
  return "*Translating...*"
78
 
 
81
  if source_language == target_language:
82
  return source
83
 
84
+ source = [s.strip() for s in source.split('\n')]
85
  source_subwords = tokenizer(source).input_ids
86
+ source_subwords = [[cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + s + [sep_index] for s in source_subwords]
87
+ source_subwords = [torch.tensor(s) for s in source_subwords]
88
+ source_subwords = torch.nn.utils.rnn.pad_sequence(source_subwords, batch_first=True, padding_value=pad_index)
89
+ source_subwords = source_subwords[:, :512].to(device)
90
+
91
+ streamer = BatchStreamer(tokenizer, timeout=60.0, skip_special_tokens=True)
92
 
93
+ def generate(model, **kwargs):
94
+ with torch.inference_mode():
95
+ with torch.autocast(enabled=device != "cpu", device_type=device, dtype=torch.bfloat16):
96
+ return model.generate(**kwargs)
97
+
98
+ generate_kwargs = dict(
99
+ streamer=streamer,
100
  input_ids=source_subwords,
101
+ attention_mask=(source_subwords != pad_index).long(),
102
  max_new_tokens = 512-1,
103
+ # num_beams=4,
104
+ # early_stopping=True,
105
+ do_sample=False,
106
+ use_cache=True
107
  )
108
+ t = Thread(target=generate, args=(model,), kwargs=generate_kwargs)
109
+ t.start()
110
 
111
+ for new_text in streamer:
112
+ yield new_text.strip()
113
+ return new_text.strip()
114
 
115
 
116
  def switch_inputs(source, target, source_language, target_language):
117
  return target, source, target_language, source_language
118
 
119
 
 
 
120
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
121
 
122
  gr.Markdown("# Norwegian-English translation")
 
 
 
 
123
 
124
  with gr.Row():
125
  with gr.Column(scale=7, variant="panel"):
 
155
  return {
156
  source: gr.update(interactive=True),
157
  submit: gr.update(interactive=True),
158
+ source_language: gr.update(interactive=True),
159
+ target_language: gr.update(interactive=True)
160
  }
161
 
162