radinhas commited on
Commit
746ab1f
1 Parent(s): bcc577e

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +7 -4
apis/chat_api.py CHANGED
@@ -4,7 +4,7 @@ import sys
4
  import os
5
  import io
6
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
7
- from transformers import pipeline
8
  import time
9
  import json
10
  from typing import List
@@ -170,11 +170,14 @@ class ChatAPIApp:
170
  "default": "t5-base",
171
  }
172
  if item.model in MODEL_MAP.keys():
173
- target_model = item.model
174
  else:
175
  target_model = "default"
176
-
177
- translator = pipeline("translation_"+item.from_language+"_to_"+item.to_language, model=target_model)
 
 
 
178
  result = translator(item.input_text)
179
 
180
  item_response = {
 
4
  import os
5
  import io
6
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
  import time
9
  import json
10
  from typing import List
 
170
  "default": "t5-base",
171
  }
172
  if item.model in MODEL_MAP.keys():
173
+ target_model = MODEL_MAP[item.model]
174
  else:
175
  target_model = "default"
176
+
177
+ real_name = MODEL_MAP[target_model]
178
+ read_model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
179
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
180
+ translator = pipeline("translation", model=read_model, tokenizer=tokenizer, src_lang=item.from_language, tgt_lang=item.to_language)
181
  result = translator(item.input_text)
182
 
183
  item_response = {