pauri32 commited on
Commit
71c4861
1 Parent(s): 99b5b37

Update app/model/model.py

Browse files
Files changed (1) hide show
  1. app/model/model.py +6 -2
app/model/model.py CHANGED
@@ -28,9 +28,11 @@ class LLM:
28
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
29
  if tokenizer.pad_token_id is None:
30
  tokenizer.pad_token_id = tokenizer.eos_token_id
 
31
  return model, tokenizer
32
 
33
  def language_detection(self, input_text):
 
34
  # Prompt with one shot for each language
35
  prompt = f"""Identify the language of the following sentences. Options: 'english', 'español', 'française' .
36
  * <Identity theft is not a joke, millions of families suffer every year>(english)
@@ -39,9 +41,11 @@ class LLM:
39
  * <{input_text}>"""
40
  # Generation and extraction of the language tag
41
  answer_ids = self.model.generate(**self.tokenizer([prompt], return_tensors="pt"), max_new_tokens=10)
42
- answer = self.tokenizer.batch_decode(answer_ids, skip_special_tokens=False)[0].split(prompt)[1]
 
 
43
  pattern = r'\b(?:' + '|'.join(map(re.escape, self.lang_codes.keys())) + r')\b'
44
- lang = re.search(pattern, answer, flags=re.IGNORECASE)
45
  # Returns tag identified or 'unk' if none is detected
46
  return self.lang_codes[lang.group()] if lang else self.lang_codes["unknown"]
47
 
 
28
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
29
  if tokenizer.pad_token_id is None:
30
  tokenizer.pad_token_id = tokenizer.eos_token_id
31
+ print("Model and tokenizer loaded.")
32
  return model, tokenizer
33
 
34
  def language_detection(self, input_text):
35
+ print(f"### Input text\n{input_text}")
36
  # Prompt with one shot for each language
37
  prompt = f"""Identify the language of the following sentences. Options: 'english', 'español', 'française' .
38
  * <Identity theft is not a joke, millions of families suffer every year>(english)
 
41
  * <{input_text}>"""
42
  # Generation and extraction of the language tag
43
  answer_ids = self.model.generate(**self.tokenizer([prompt], return_tensors="pt"), max_new_tokens=10)
44
+ answer = self.tokenizer.batch_decode(answer_ids, skip_special_tokens=False)[0]
45
+ print(answer)
46
+ generation = answer.split(prompt)[1]
47
  pattern = r'\b(?:' + '|'.join(map(re.escape, self.lang_codes.keys())) + r')\b'
48
+ lang = re.search(pattern, generation, flags=re.IGNORECASE)
49
  # Returns tag identified or 'unk' if none is detected
50
  return self.lang_codes[lang.group()] if lang else self.lang_codes["unknown"]
51