New-Place / main.py
oflakne26's picture
Update main.py
bf0e60d verified
from fastapi import FastAPI, HTTPException
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
import nltk
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
end_token: str
system_prompts: List[str]
user_inputs: List[str]
history: str = ""
segment: bool = False
max_sentences: Optional[int] = None
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
if data.max_sentences is not None and data.max_sentences != 0:
data.segment = True
elif data.max_sentences == 0:
for user_input in data.user_inputs:
data.history += data.prompt_template.replace("{Prompt}", user_input)
return {
"response": [],
"sentence_count": None,
"history": data.history + data.end_token
}
responses = []
if data.segment:
for user_input in data.user_inputs:
user_sentences = tokenizer.tokenize(user_input)
user_input_str = "\n".join(user_sentences)
data.history += data.prompt_template.replace("{Prompt}", user_input_str) + "\n"
else:
for user_input in data.user_inputs:
data.history += data.prompt_template.replace("{Prompt}", user_input) + "\n"
inputs = ""
for system_prompt in data.system_prompts:
inputs += data.system_prompt_template.replace("{SystemPrompt}", system_prompt) + "\n"
inputs += data.history
seed = random.randint(0, 2**32 - 1)
try:
client = InferenceClient(model=data.model, token=HF_TOKEN)
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
response_str = str(response)
if data.segment:
ai_sentences = tokenizer.tokenize(response_str)
if data.max_sentences is not None:
ai_sentences = ai_sentences[:data.max_sentences]
responses = ai_sentences
sentence_count = len(ai_sentences)
else:
responses = [response_str]
sentence_count = None
data.history += response_str + "\n"
return {
"response": responses,
"sentence_count": sentence_count,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
# Initialize found flag
found = False
# Split the input string into words
input_words = input_string.split()
# Loop through each word in the input string
for input_word in input_words:
# Strip the word to contain only alphabetic characters
input_word = ''.join(filter(str.isalpha, input_word))
# Check if the stripped word is equal to any of the forms
if input_word in all_forms:
found = True
break # Exit loop if word is found
result = {
"found": found
}
return result