|
import time |
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer, GenerationConfig |
|
import gradio as gr |
|
|
|
MODEL_NAME = "jbochi/madlad400-3b-mt" |
|
|
|
print(f"Loading {MODEL_NAME} tokenizer...") |
|
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) |
|
print(f"Loading {MODEL_NAME} model...") |
|
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto") |
|
|
|
|
|
def inference(input_text, target_language, max_length): |
|
global model, tokenizer |
|
start_time = time.time() |
|
input_ids = tokenizer( |
|
f"<2{target_language}> {input_text}", return_tensors="pt" |
|
).input_ids |
|
outputs = model.generate( |
|
input_ids=input_ids.to(model.device), |
|
generation_config=GenerationConfig(max_length=max_length), |
|
) |
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
end_time = time.time() |
|
result = { |
|
'result': result, |
|
'inference_time': end_time - start_time, |
|
'input_token_ids': input_ids[0].tolist(), |
|
'output_token_ids': outputs[0].tolist(), |
|
} |
|
return result |
|
|
|
|
|
def run(): |
|
tokens = [tokenizer.decode(i) for i in range(500)] |
|
lang_codes = [token[2:-1] for token in tokens if token.startswith("<2")] |
|
inputs = [ |
|
gr.components.Textbox(lines=5, label="Input text"), |
|
gr.components.Dropdown(lang_codes, value="en", label="Target Language"), |
|
gr.components.Slider( |
|
minimum=5, |
|
maximum=500, |
|
value=200, |
|
label="Max length", |
|
), |
|
] |
|
examples = [ |
|
["I'm a mad lad!", "es", 50], |
|
["千里之行,始於足下", "en", 50], |
|
] |
|
outputs = gr.components.JSON() |
|
title = f"{MODEL_NAME} demo" |
|
demo_status = "Demo is running on CPU" |
|
description = f"Details: https://huggingface.co/{MODEL_NAME}. {demo_status}" |
|
gr.Interface( |
|
inference, |
|
inputs, |
|
outputs, |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
).launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
run() |
|
|