jxtan commited on
Commit
b805057
1 Parent(s): 66b08d7

Added Translation Endpoint

Browse files
Dockerfile CHANGED
@@ -1,6 +1,36 @@
1
  FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
2
  ENV DEBIAN_FRONTEND=noninteractive
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  RUN useradd -m -u 1000 user
5
  USER user
6
  ENV HOME=/home/user \
@@ -15,5 +45,10 @@ RUN pip install -r ${HOME}/app/requirements.txt
15
  # RUN mkdir content
16
  # ADD --chown=user https://<SOME_ASSET_URL> content/<SOME_ASSET_NAME>
17
 
 
 
 
 
 
18
  # Start the FastAPI app on port 7860, the default port expected by Spaces
19
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
2
  ENV DEBIAN_FRONTEND=noninteractive
3
 
4
+ RUN apt-get update && \
5
+ apt-get upgrade -y && \
6
+ apt-get install -y --no-install-recommends \
7
+ git \
8
+ git-lfs \
9
+ wget \
10
+ curl \
11
+ # python build dependencies \
12
+ build-essential \
13
+ libssl-dev \
14
+ zlib1g-dev \
15
+ libbz2-dev \
16
+ libreadline-dev \
17
+ libsqlite3-dev \
18
+ libncursesw5-dev \
19
+ xz-utils \
20
+ tk-dev \
21
+ libxml2-dev \
22
+ libxmlsec1-dev \
23
+ libffi-dev \
24
+ liblzma-dev \
25
+ # gradio dependencies \
26
+ ffmpeg
27
+
28
+ # fairseq2 dependencies
29
+ RUN apt-get install -y --no-install-recommends \
30
+ libsndfile-dev
31
+
32
+ RUN apt-get clean && rm -rf /var/lib/apt/lists/*
33
+
34
  RUN useradd -m -u 1000 user
35
  USER user
36
  ENV HOME=/home/user \
 
45
  # RUN mkdir content
46
  # ADD --chown=user https://<SOME_ASSET_URL> content/<SOME_ASSET_NAME>
47
 
48
+ # SeamlessCommunication requirements
49
+ RUN pip install -r ${HOME}/app/seamless_requirements.txt && \
50
+ pip install fairseq2 --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/pt2.1.0/cu121 && \
51
+ pip install ${HOME}/app/whl/seamless_communication-1.0.0-py3-none-any.whl
52
+
53
  # Start the FastAPI app on port 7860, the default port expected by Spaces
54
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -22,4 +22,6 @@ Users should be able to call the task and get back in the standard format
22
  "model": "BAAI/bge-base-en-v1.5",
23
  "inputs: ["This is one text", "This is second text"],
24
  "parameters": {}
25
- }
 
 
 
22
  "model": "BAAI/bge-base-en-v1.5",
23
  "inputs: ["This is one text", "This is second text"],
24
  "parameters": {}
25
+ }
26
+
27
+ TODO: Models are cached in volume directory
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi.middleware.cors import CORSMiddleware
2
  from fastapi import FastAPI
3
- import sentence_embeddings
4
 
5
  app = FastAPI(docs_url="/", redoc_url=None)
6
 
 
1
  from fastapi.middleware.cors import CORSMiddleware
2
  from fastapi import FastAPI
3
+ from tasks import sentence_embeddings
4
 
5
  app = FastAPI(docs_url="/", redoc_url=None)
6
 
config.py CHANGED
@@ -1,6 +1,19 @@
 
1
  import os
2
  import dotenv
3
 
4
  dotenv.load_dotenv()
5
 
6
- TEST_MODE = (os.getenv('TEST_MODE', 'False') == "True")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import os
3
  import dotenv
4
 
5
  dotenv.load_dotenv()
6
 
7
+ TEST_MODE = (os.getenv('TEST_MODE', 'False') == "True")
8
+
9
+ if torch.cuda.is_available():
10
+ device = torch.device("cuda:0")
11
+ dtype = torch.float16
12
+ else:
13
+ device = torch.device("cpu")
14
+ dtype = torch.float32
15
+
16
+ from datetime import datetime
17
+
18
+ def log(data: dict):
19
+ print(f"{datetime.now().isoformat()}: {data}")
logger.py DELETED
@@ -1,4 +0,0 @@
1
- from datetime import datetime
2
-
3
- def log(data: dict):
4
- print(f"{datetime.now().isoformat()}: {data}")
 
 
 
 
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ transformers
2
  torch
3
  fastapi
4
  uvicorn
 
5
  python-dotenv
 
2
  torch
3
  fastapi
4
  uvicorn
5
+ pydantic
6
  python-dotenv
seamless_requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ omegaconf==2.3.0
2
+ fasttext==0.9.2
tasks/pose_estimation.py ADDED
File without changes
tasks/sentence_embeddings.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from fastapi import APIRouter
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+ from datetime import datetime
7
+ from config import TEST_MODE, device, log
8
+
9
+ router = APIRouter()
10
+
11
+ class SentenceEmbeddingsInput(BaseModel):
12
+ inputs: list[str]
13
+ model: str
14
+ parameters: dict
15
+
16
+ class SentenceEmbeddingsOutput(BaseModel):
17
+ embeddings: Optional[list[list[float]]] = None
18
+ error: Optional[str] = None
19
+
20
+ @router.post('/sentence-embeddings')
21
+ def sentence_embeddings(inputs: SentenceEmbeddingsInput):
22
+ start_time = datetime.now()
23
+ fn = sentence_embeddings_mapping.get(inputs.model)
24
+ if not fn:
25
+ return SentenceEmbeddingsOutput(
26
+ error=f'No sentence embeddings model found for {inputs.model}'
27
+ )
28
+
29
+ try:
30
+ embeddings = fn(inputs.inputs, inputs.parameters)
31
+
32
+ log({
33
+ "task": "sentence_embeddings",
34
+ "model": inputs.model,
35
+ "start_time": start_time.isoformat(),
36
+ "time_taken": (datetime.now() - start_time).total_seconds(),
37
+ "inputs": inputs.inputs,
38
+ "outputs": embeddings,
39
+ "parameters": inputs.parameters,
40
+ })
41
+ loaded_models_last_updated[inputs.model] = datetime.now()
42
+ return SentenceEmbeddingsOutput(
43
+ embeddings=embeddings
44
+ )
45
+ except Exception as e:
46
+ return SentenceEmbeddingsOutput(
47
+ error=str(e)
48
+ )
49
+
50
+ def generic_sentence_embeddings(model_name: str):
51
+ global loaded_models
52
+
53
+ def process_texts(texts: list[str], parameters: dict):
54
+ if TEST_MODE:
55
+ return [[0.1,0.2]] * len(texts)
56
+
57
+ if model_name in loaded_models:
58
+ tokenizer, model = loaded_models[model_name]
59
+ else:
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ model = AutoModel.from_pretrained(model_name).to(device)
62
+ loaded_models[model] = (tokenizer, model)
63
+
64
+ # Tokenize sentences
65
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
66
+ with torch.no_grad():
67
+ model_output = model(**encoded_input)
68
+ sentence_embeddings = model_output[0][:, 0]
69
+
70
+ # normalize embeddings
71
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
72
+ return sentence_embeddings.tolist()
73
+
74
+ return process_texts
75
+
76
+ # Polling every X minutes to
77
+ loaded_models = {}
78
+ loaded_models_last_updated = {}
79
+
80
+ sentence_embeddings_mapping = {
81
+ 'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
82
+ 'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
83
+ }
tasks/translation.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ from config import TEST_MODE, device, dtype, log
5
+ from fairseq2.data.text.text_tokenizer import TextTokenEncoder
6
+ from seamless_communication.inference import Translator
7
+ import spacy
8
+ import re
9
+ from datetime import datetime
10
+
11
+ router = APIRouter()
12
+
13
+ class TranslateInput(BaseModel):
14
+ inputs: list[str]
15
+ model: str
16
+ src_lang: str
17
+ dst_lang: str
18
+
19
+
20
+ class TranslateOutput(BaseModel):
21
+ src_lang: str
22
+ dst_lang: str
23
+ translations: Optional[list[str]] = None
24
+ error: Optional[str] = None
25
+
26
+
27
+ @router.post('/t2tt')
28
+ def t2tt(inputs: TranslateInput) -> TranslateOutput:
29
+ start_time = datetime.now()
30
+ fn = t2tt_mapping.get(inputs.model)
31
+ if not fn:
32
+ return TranslateOutput(
33
+ src_lang=inputs.src_lang,
34
+ dst_lang=inputs.dst_lang,
35
+ error=f'No sentence embeddings model found for {inputs.model}'
36
+ )
37
+
38
+ try:
39
+ translations = fn(**inputs.dict())
40
+ log({
41
+ "task": "sentence_embeddings",
42
+ "model": inputs.model,
43
+ "start_time": start_time.isoformat(),
44
+ "time_taken": (datetime.now() - start_time).total_seconds(),
45
+ "inputs": inputs.inputs,
46
+ "outputs": translations,
47
+ "parameters": {
48
+ "src_lang": inputs.src_lang,
49
+ "dst_lang": inputs.dst_lang,
50
+ },
51
+ })
52
+ loaded_models_last_updated[inputs.model] = datetime.now()
53
+ return TranslateOutput(**translations)
54
+ except Exception as e:
55
+ return TranslateOutput(
56
+ src_lang=inputs.src_lang,
57
+ dst_lang=inputs.dst_lang,
58
+ error=str(e)
59
+ )
60
+
61
+ cmn_nlp = spacy.load("zh_core_web_sm")
62
+ xx_nlp = spacy.load("xx_sent_ud_sm")
63
+ unk_re = re.compile(r"\s?<unk>|\s?⁇")
64
+
65
+ def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'):
66
+ if TEST_MODE:
67
+ return {
68
+ "src_lang": src_lang,
69
+ "dst_lang": dst_lang,
70
+ "translations": None,
71
+ "error": None
72
+ }
73
+
74
+ # Load model
75
+ if 'facebook/seamless-m4t-v2-large' in loaded_models:
76
+ translator = loaded_models['facebook/seamless-m4t-v2-large']
77
+ else:
78
+ translator = Translator(
79
+ model_name_or_card="seamlessM4T_v2_large",
80
+ vocoder_name_or_card="vocoder_v2",
81
+ device=device,
82
+ dtype=dtype,
83
+ apply_mintox=False,
84
+ )
85
+ loaded_models['facebook/seamless-m4t-v2-large'] = translator
86
+
87
+
88
+ def sent_tokenize(text, lang) -> list[str]:
89
+ if lang == 'cmn':
90
+ return [str(t) for t in cmn_nlp(text).sents]
91
+ return [str(t) for t in xx_nlp(text).sents]
92
+
93
+
94
+ def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str:
95
+ # Convert text into paragraphs and replace new lines with spaces
96
+ lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line]
97
+ lines = [item for sublist in lines for item in sublist if item]
98
+
99
+ # Tokenize and translate
100
+ input_tokens = translator.collate([token_encoder(line) for line in lines])
101
+ translations = [
102
+ unk_re.sub("", str(t))
103
+ for t in translator.predict(
104
+ input=input_tokens,
105
+ task_str="T2TT",
106
+ src_lang=src_lang,
107
+ tgt_lang=dst_lang,
108
+ )[0]
109
+ ]
110
+ return " ".join(translations)
111
+
112
+ translations = None
113
+ token_encoder = translator.text_tokenizer.create_encoder(
114
+ task="translation", lang=src_lang, mode="source", device=translator.device
115
+ )
116
+ try:
117
+ translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs]
118
+ except Exception as e:
119
+ print(f"Error translating text: {e}")
120
+
121
+ return {
122
+ "src_lang": src_lang,
123
+ "dst_lang": dst_lang,
124
+ "translations": translations,
125
+ "error": None if translations else "Failed to translate text"
126
+ }
127
+
128
+
129
+ # Polling every X minutes to
130
+ loaded_models = {}
131
+ loaded_models_last_updated = {}
132
+
133
+ t2tt_mapping = {
134
+ 'facebook/seamless-m4t-v2-large': seamless_t2tt,
135
+ }