|
from fastapi import FastAPI, HTTPException |
|
from typing import Any, Dict, List, Optional |
|
from pydantic import BaseModel |
|
from os import getenv |
|
from huggingface_hub import InferenceClient |
|
import random |
|
import nltk |
|
import re |
|
from word_forms.word_forms import get_word_forms |
|
|
|
app = FastAPI() |
|
|
|
nltk.download('punkt') |
|
|
|
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') |
|
|
|
HF_TOKEN = getenv("HF_TOKEN") |
|
|
|
class InputData(BaseModel): |
|
model: str |
|
system_prompt_template: List[str] |
|
prompt_template: List[str] |
|
end_token: str |
|
system_prompt: List[str] |
|
user_input: List[str] |
|
history: str = "" |
|
segment: bool = False |
|
max_sentences: Optional[int] = None |
|
|
|
class WordCheckData(BaseModel): |
|
string: str |
|
word: str |
|
|
|
@app.post("/generate-response/") |
|
async def generate_response(data: InputData) -> Dict[str, Any]: |
|
if data.max_sentences is not None and data.max_sentences != 0: |
|
data.segment = True |
|
elif data.max_sentences == 0: |
|
for prompt in data.prompt_template: |
|
for user_input in data.user_input: |
|
data.history += prompt.replace("{Prompt}", user_input) + "\n" |
|
return { |
|
"response": "", |
|
"history": data.history + data.end_token |
|
} |
|
|
|
user_input_str = "" |
|
if data.segment: |
|
for user_input in data.user_input: |
|
user_sentences = tokenizer.tokenize(user_input) |
|
user_input_str += "\n".join(user_sentences) + "\n" |
|
else: |
|
user_input_str = "\n".join(data.user_input) |
|
|
|
for prompt in data.prompt_template: |
|
data.history += prompt.replace("{Prompt}", user_input_str) + "\n" |
|
|
|
inputs = "" |
|
for system_prompt in data.system_prompt_template: |
|
inputs += system_prompt.replace("{SystemPrompt}", "\n".join(data.system_prompt)) + "\n" |
|
inputs += data.history |
|
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
try: |
|
client = InferenceClient(model=data.model, token=HF_TOKEN) |
|
response = client.text_generation( |
|
inputs, |
|
temperature=1.0, |
|
max_new_tokens=1000, |
|
seed=seed |
|
) |
|
|
|
response_str = str(response) |
|
|
|
if data.segment: |
|
ai_sentences = tokenizer.tokenize(response_str) |
|
if data.max_sentences is not None: |
|
ai_sentences = ai_sentences[:data.max_sentences] |
|
ai_response_str = "\n".join(ai_sentences) |
|
else: |
|
ai_response_str = response_str |
|
|
|
data.history += ai_response_str + "\n" |
|
|
|
cleaned_response = { |
|
"New response": ai_sentences if data.segment else [response_str], |
|
"Sentence count": len(ai_sentences) if data.segment else 1 |
|
} |
|
|
|
return { |
|
"response": cleaned_response, |
|
"history": data.history + data.end_token |
|
} |
|
|
|
except Exception as e: |
|
print(f"Model {data.model} failed with error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response") |
|
|
|
@app.post("/check-word/") |
|
async def check_word(data: WordCheckData) -> Dict[str, Any]: |
|
input_string = data.string.lower() |
|
word = data.word.lower() |
|
|
|
forms = get_word_forms(word) |
|
|
|
all_forms = set() |
|
for words in forms.values(): |
|
all_forms.update(words) |
|
|
|
words_in_string = re.findall(r'\b\w+\b', input_string) |
|
|
|
found = any(word_in_string in all_forms for word_in_string in words_in_string) |
|
|
|
result = { |
|
"found": found |
|
} |
|
|
|
return result |
|
|