Spaces:
Paused
Paused
Update backend/main.py
Browse files- backend/main.py +23 -9
backend/main.py
CHANGED
@@ -133,12 +133,12 @@ static_files = {
|
|
133 |
},
|
134 |
}
|
135 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
136 |
-
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
|
137 |
#cache_dir="/.cache"
|
138 |
|
139 |
# PM - hardcoding temporarily as my GPU doesnt have enough vram
|
140 |
# model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu")
|
141 |
-
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True).to(device)
|
142 |
|
143 |
|
144 |
bytes_data = bytearray()
|
@@ -148,6 +148,18 @@ vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocode
|
|
148 |
clients = {}
|
149 |
rooms = {}
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def get_collection_users():
|
153 |
return app.database["user_records"]
|
@@ -297,16 +309,18 @@ async def incoming_audio(sid, data, call_id):
|
|
297 |
tgt_sid = next(id for id in rooms[call_id] if id != sid)
|
298 |
tgt_lang = clients[tgt_sid].target_language
|
299 |
# following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
|
300 |
-
output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt")
|
301 |
-
model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
|
302 |
-
asr_text = processor.decode(model_output, skip_special_tokens=True)
|
|
|
303 |
print(f"ASR TEXT = {asr_text}")
|
304 |
# ASR TEXT => ORIGINAL TEXT
|
305 |
|
306 |
-
t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt")
|
307 |
-
print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}")
|
308 |
-
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
|
309 |
-
translated_text = processor.decode(translated_data, skip_special_tokens=True)
|
|
|
310 |
print(f"TRANSLATED TEXT = {translated_text}")
|
311 |
|
312 |
# BO -> send translated_text to mongodb as caption record update based on call_id
|
|
|
133 |
},
|
134 |
}
|
135 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
136 |
+
# processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
|
137 |
#cache_dir="/.cache"
|
138 |
|
139 |
# PM - hardcoding temporarily as my GPU doesnt have enough vram
|
140 |
# model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu")
|
141 |
+
# model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True).to(device)
|
142 |
|
143 |
|
144 |
bytes_data = bytearray()
|
|
|
148 |
clients = {}
|
149 |
rooms = {}
|
150 |
|
151 |
+
import torch
|
152 |
+
from transformers import pipeline
|
153 |
+
translator = pipeline("automatic-speech-recognition",
|
154 |
+
"facebook/seamless-m4t-v2-large",
|
155 |
+
torch_dtype=torch.float32,
|
156 |
+
device="cpu")
|
157 |
+
|
158 |
+
converter = pipeline("translation",
|
159 |
+
"facebook/seamless-m4t-v2-large",
|
160 |
+
torch_dtype=torch.float32,
|
161 |
+
device="cpu")
|
162 |
+
|
163 |
|
164 |
def get_collection_users():
|
165 |
return app.database["user_records"]
|
|
|
309 |
tgt_sid = next(id for id in rooms[call_id] if id != sid)
|
310 |
tgt_lang = clients[tgt_sid].target_language
|
311 |
# following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
|
312 |
+
# output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt")
|
313 |
+
# model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
|
314 |
+
# asr_text = processor.decode(model_output, skip_special_tokens=True)
|
315 |
+
asr_text = translator(resampled_audio, generate_kwargs={"tgt_lang": src_lang})['text']
|
316 |
print(f"ASR TEXT = {asr_text}")
|
317 |
# ASR TEXT => ORIGINAL TEXT
|
318 |
|
319 |
+
# t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt")
|
320 |
+
# print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}")
|
321 |
+
# translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
|
322 |
+
# translated_text = processor.decode(translated_data, skip_special_tokens=True)
|
323 |
+
translated_text = converter(asr_text, src_lang=src_lang, tgt_lang=tgt_lang)
|
324 |
print(f"TRANSLATED TEXT = {translated_text}")
|
325 |
|
326 |
# BO -> send translated_text to mongodb as caption record update based on call_id
|