|
from fastapi import FastAPI |
|
from typing import List |
|
|
|
from pydantic import BaseModel |
|
|
|
import torch |
|
import transformers |
|
|
|
app = FastAPI() |
|
|
|
class HebrewText(BaseModel): |
|
text: List[str] |
|
|
|
@app.post("/diacritize/") |
|
async def diacritize_hebrew(hebrew_text: HebrewText): |
|
model_name = "sadafwalliyani/D_Nikud_model" |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
|
model = transformers.AutoModel.from_pretrained(model_name) |
|
|
|
input_ids = torch.tensor(tokenizer.encode(hebrew_text.text, return_tensors="pt")).to(model.device) |
|
|
|
|
|
response = model.generate( |
|
input_ids, |
|
max_length=100, |
|
num_beams=5, |
|
early_stopping=True, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
output_hidden_states=False, |
|
return_attention_mask=True, |
|
use_cache=True, |
|
) |
|
|
|
|
|
output_text = tokenizer.decode(response.sequences[0]) |
|
|
|
return {"text": output_text} |