from fastapi import FastAPI, HTTPException from typing import Any, Dict, List, Optional from pydantic import BaseModel from os import getenv from huggingface_hub import InferenceClient import random import nltk import re from word_forms.word_forms import get_word_forms app = FastAPI() nltk.download('punkt') tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') HF_TOKEN = getenv("HF_TOKEN") class InputData(BaseModel): model: str system_prompt_template: List[str] prompt_template: List[str] end_token: str system_prompt: List[str] user_input: List[str] history: str = "" segment: bool = False max_sentences: Optional[int] = None class WordCheckData(BaseModel): string: str word: str @app.post("/generate-response/") async def generate_response(data: InputData) -> Dict[str, Any]: if data.max_sentences is not None and data.max_sentences != 0: data.segment = True elif data.max_sentences == 0: for prompt in data.prompt_template: for user_input in data.user_input: data.history += prompt.replace("{Prompt}", user_input) + "\n" return { "response": "", "history": data.history + data.end_token } user_input_str = "" if data.segment: for user_input in data.user_input: user_sentences = tokenizer.tokenize(user_input) user_input_str += "\n".join(user_sentences) + "\n" else: user_input_str = "\n".join(data.user_input) for prompt in data.prompt_template: data.history += prompt.replace("{Prompt}", user_input_str) + "\n" inputs = "" for system_prompt in data.system_prompt_template: inputs += system_prompt.replace("{SystemPrompt}", "\n".join(data.system_prompt)) + "\n" inputs += data.history seed = random.randint(0, 2**32 - 1) try: client = InferenceClient(model=data.model, token=HF_TOKEN) response = client.text_generation( inputs, temperature=1.0, max_new_tokens=1000, seed=seed ) response_str = str(response) if data.segment: ai_sentences = tokenizer.tokenize(response_str) if data.max_sentences is not None: ai_sentences = ai_sentences[:data.max_sentences] ai_response_str = "\n".join(ai_sentences) else: ai_response_str = response_str data.history += ai_response_str + "\n" cleaned_response = { "New response": ai_sentences if data.segment else [response_str], "Sentence count": len(ai_sentences) if data.segment else 1 } return { "response": cleaned_response, "history": data.history + data.end_token } except Exception as e: print(f"Model {data.model} failed with error: {e}") raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response") @app.post("/check-word/") async def check_word(data: WordCheckData) -> Dict[str, Any]: input_string = data.string.lower() word = data.word.lower() forms = get_word_forms(word) all_forms = set() for words in forms.values(): all_forms.update(words) words_in_string = re.findall(r'\b\w+\b', input_string) found = any(word_in_string in all_forms for word_in_string in words_in_string) result = { "found": found } return result