New-Place / main.py
oflakne26's picture
Update main.py
f319ec8 verified
raw
history blame
3.49 kB
from fastapi import FastAPI, HTTPException
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
import nltk
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: List[str]
prompt_template: List[str]
end_token: str
system_prompt: List[str]
user_input: List[str]
history: str = ""
segment: bool = False
max_sentences: Optional[int] = None
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
if data.max_sentences is not None and data.max_sentences != 0:
data.segment = True
elif data.max_sentences == 0:
for prompt in data.prompt_template:
for user_input in data.user_input:
data.history += prompt.replace("{Prompt}", user_input) + "\n"
return {
"response": "",
"history": data.history + data.end_token
}
user_input_str = ""
if data.segment:
for user_input in data.user_input:
user_sentences = tokenizer.tokenize(user_input)
user_input_str += "\n".join(user_sentences) + "\n"
else:
user_input_str = "\n".join(data.user_input)
for prompt in data.prompt_template:
data.history += prompt.replace("{Prompt}", user_input_str) + "\n"
inputs = ""
for system_prompt in data.system_prompt_template:
inputs += system_prompt.replace("{SystemPrompt}", "\n".join(data.system_prompt)) + "\n"
inputs += data.history
seed = random.randint(0, 2**32 - 1)
try:
client = InferenceClient(model=data.model, token=HF_TOKEN)
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
response_str = str(response)
if data.segment:
ai_sentences = tokenizer.tokenize(response_str)
if data.max_sentences is not None:
ai_sentences = ai_sentences[:data.max_sentences]
ai_response_str = "\n".join(ai_sentences)
else:
ai_response_str = response_str
data.history += ai_response_str + "\n"
cleaned_response = {
"New response": ai_sentences if data.segment else [response_str],
"Sentence count": len(ai_sentences) if data.segment else 1
}
return {
"response": cleaned_response,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
words_in_string = re.findall(r'\b\w+\b', input_string)
found = any(word_in_string in all_forms for word_in_string in words_in_string)
result = {
"found": found
}
return result