File size: 3,489 Bytes
e53fb7b f319ec8 5dab16f e53fb7b d83f45c 5dab16f 4e06fbd e53fb7b f319ec8 db9c88b f319ec8 e53fb7b d17ac0e e53fb7b d83f45c 407b03e 344f4fe d17ac0e f319ec8 d17ac0e e53fb7b f319ec8 d17ac0e f319ec8 d17ac0e f319ec8 53b6766 f319ec8 04dbc8e f319ec8 04dbc8e e53fb7b 04dbc8e 9db93ec d17ac0e 344f4fe 0abcce5 cdca6a5 d17ac0e cdca6a5 d17ac0e cdca6a5 9db93ec c81e2b1 d83f45c cdca6a5 d83f45c f319ec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from fastapi import FastAPI, HTTPException
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
import nltk
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: List[str]
prompt_template: List[str]
end_token: str
system_prompt: List[str]
user_input: List[str]
history: str = ""
segment: bool = False
max_sentences: Optional[int] = None
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
if data.max_sentences is not None and data.max_sentences != 0:
data.segment = True
elif data.max_sentences == 0:
for prompt in data.prompt_template:
for user_input in data.user_input:
data.history += prompt.replace("{Prompt}", user_input) + "\n"
return {
"response": "",
"history": data.history + data.end_token
}
user_input_str = ""
if data.segment:
for user_input in data.user_input:
user_sentences = tokenizer.tokenize(user_input)
user_input_str += "\n".join(user_sentences) + "\n"
else:
user_input_str = "\n".join(data.user_input)
for prompt in data.prompt_template:
data.history += prompt.replace("{Prompt}", user_input_str) + "\n"
inputs = ""
for system_prompt in data.system_prompt_template:
inputs += system_prompt.replace("{SystemPrompt}", "\n".join(data.system_prompt)) + "\n"
inputs += data.history
seed = random.randint(0, 2**32 - 1)
try:
client = InferenceClient(model=data.model, token=HF_TOKEN)
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
response_str = str(response)
if data.segment:
ai_sentences = tokenizer.tokenize(response_str)
if data.max_sentences is not None:
ai_sentences = ai_sentences[:data.max_sentences]
ai_response_str = "\n".join(ai_sentences)
else:
ai_response_str = response_str
data.history += ai_response_str + "\n"
cleaned_response = {
"New response": ai_sentences if data.segment else [response_str],
"Sentence count": len(ai_sentences) if data.segment else 1
}
return {
"response": cleaned_response,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
words_in_string = re.findall(r'\b\w+\b', input_string)
found = any(word_in_string in all_forms for word_in_string in words_in_string)
result = {
"found": found
}
return result
|