davda54 commited on
Commit
398f6f3
1 Parent(s): 562a084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -11
app.py CHANGED
@@ -4,23 +4,25 @@ from transformers import AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
5
 
6
 
7
- # print(f"Starting to load the model to memory")
8
 
9
- # tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
10
- # cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
11
- # sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
12
- # user_index = tokenizer.convert_tokens_to_ids("[USER]")
13
- # assistent_index = tokenizer.convert_tokens_to_ids("[ASSISTENT]")
 
 
14
 
15
- # model = NorT5ForConditionalGeneration.from_pretrained("nort5_en-no_base", ignore_mismatched_sizes=True)
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print(f"SYSTEM: Running on {device}", flush=True)
19
 
20
- # model = model.to(device)
21
- # model.eval()
22
 
23
- # print(f"Sucessfully loaded the model to the memory")
24
 
25
 
26
  INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel."
@@ -37,13 +39,33 @@ LANGUAGES = [
37
  "🇳🇴 Norwegian (Nynorsk)"
38
  ]
39
 
 
 
 
 
 
 
40
 
41
  def set_default_target():
42
  return "*Translating...*"
43
 
44
 
45
  def translate(source, source_language, target_language):
46
- return "This is a fake translation"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def switch_inputs(source, target, source_language, target_language):
 
4
  from threading import Thread
5
 
6
 
7
+ print(f"Starting to load the model to memory")
8
 
9
+ tokenizer = AutoTokenizer.from_pretrained("nort5_en-no_base")
10
+ cls_index = tokenizer.convert_tokens_to_ids("[CLS]")
11
+ sep_index = tokenizer.convert_tokens_to_ids("[SEP]")
12
+ eos_index = tokenizer.convert_tokens_to_ids("[EOS]")
13
+ eng_index = tokenizer.convert_tokens_to_ids(">>ENG<<")
14
+ nob_index = tokenizer.convert_tokens_to_ids(">>NOB<<")
15
+ nno_index = tokenizer.convert_tokens_to_ids(">>NNO<<")
16
 
17
+ model = AutoModelForSeq2SeqLM.from_pretrained("nort5_en-no_base", trust_remote_code=True)
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(f"SYSTEM: Running on {device}", flush=True)
21
 
22
+ model = model.to(device)
23
+ model.eval()
24
 
25
+ print(f"Sucessfully loaded the model to the memory")
26
 
27
 
28
  INITIAL_PROMPT = "Du er NorT5, en språkmodell laget ved Universitetet i Oslo. Du er en hjelpsom og ufarlig assistent som er glade for å hjelpe brukeren med enhver forespørsel."
 
39
  "🇳🇴 Norwegian (Nynorsk)"
40
  ]
41
 
42
+ LANGUAGE_IDS = {
43
+ "🇬🇧 English": eng_index,
44
+ "🇳🇴 Norwegian (Bokmål)": nob_index,
45
+ "🇳🇴 Norwegian (Nynorsk)", nno_index
46
+ }
47
+
48
 
49
  def set_default_target():
50
  return "*Translating...*"
51
 
52
 
53
  def translate(source, source_language, target_language):
54
+ if source_language == target_language:
55
+ return source
56
+
57
+ source_subwords = tokenizer(source).input_ids
58
+ source_subwords = [cls_index, LANGUAGE_IDS[target_language], LANGUAGE_IDS[source_language]] + source_subwords + [sep_index]
59
+ source_subwords = torch.tensor([source_subwords[:512]])
60
+
61
+ predictions = model.generate(
62
+ input_ids=source_subwords,
63
+ max_new_tokens = 512-1,
64
+ do_sample=False
65
+ )
66
+ predictions = [tokenizer.decode(p, skip_special_tokens=True) for p in predictions.tolist()]
67
+
68
+ return predictions
69
 
70
 
71
  def switch_inputs(source, target, source_language, target_language):