Cahya Wirawan
add indochat
4964cc6
raw
history blame
5.42 kB
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import os
app = FastAPI()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
// var ws = new WebSocket("ws://localhost:8000/api/ws");
var ws = new WebSocket("wss://cahya-indonesian-whisperer.hf.space/api/ws");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage(event) {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
@app.get("/")
async def get():
return HTMLResponse(html)
@app.get("/env")
async def env():
environment_variables = "<h3>Environment Variables</h3>"
for name, value in os.environ.items():
environment_variables += f"{name}: {value}<br>"
return HTMLResponse(environment_variables)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
@app.post("/api/indochat/v1")
async def indochat(
text: str = Form(default="", description="The Prompt"),
max_length: int = Form(default=250, description="Maximal length of the generated text"),
do_sample: bool = Form(default=True, description="Whether to use sampling; use greedy decoding otherwise"),
top_k: int = Form(default=50, description="The number of highest probability vocabulary tokens to keep "
"for top-k-filtering"),
top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
"probabilities that add up to top_p or higher are kept "
"for generation"),
temperature: float = Form(default=1.0, description="The Temperature of the softmax distribution"),
penalty_alpha: float = Form(default=0.6, description="Penalty alpha"),
repetition_penalty: float = Form(default=1.0, description="Repetition penalty"),
seed: int = Form(default=42, description="Random Seed"),
max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
):
set_seed(seed)
if repetition_penalty == 0.0:
min_penalty = 1.05
max_penalty = 1.5
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
prompt = f"User: {text}\nAssistant: "
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
model.eval()
print("Generating text...")
print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
time_start = time.time()
sample_outputs = model.generate(input_ids,
penalty_alpha=penalty_alpha,
do_sample=do_sample,
min_length=200,
max_length=max_length,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
max_time=max_time
)
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
# result = result[len(prompt) + 1:]
time_end = time.time()
time_diff = time_end - time_start
print(f"result:\n{result}")
generated_text = result
return {"generated_text": generated_text, "processing_time": time_diff}
def get_text_generator(model_name: str, device: str = "cpu"):
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
print(f"hf_auth_token: {hf_auth_token}")
print(f"Loading model with device: {device}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
use_auth_token=hf_auth_token)
model.to(device)
print("Model loaded")
return model, tokenizer
def get_config():
return json.load(open("config.json", "r"))
config = get_config()
device = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer = get_text_generator(model_name=config["model_name"], device=device)