Spaces:
Paused
Paused
from asyncio import sleep | |
from typing import Optional | |
from fastapi import FastAPI | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.websockets import WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from websockets import ConnectionClosed | |
from accelerator import Accelerator | |
from answerer import Answerer | |
from mapper import Mapper | |
try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1") | |
except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}") | |
answerer = Answerer( | |
model="RWKV-5-World-3B-v2-20231118-ctx16k", | |
vocab="rwkv_vocab_v20230424", | |
strategy="cpu bf16", | |
ctx_limit=16*1024, | |
) | |
accelerator = Accelerator() | |
app = FastAPI() | |
HTML = """ | |
<!DOCTYPE HTML> | |
<html> | |
<body> | |
<form action="" onsubmit="ask(event)"> | |
<textarea id="prompt"></textarea> | |
<br> | |
<input type="submit" value="SEND" /> | |
</form> | |
<p id="output"></p> | |
<script> | |
const prompt = document.getElementById("prompt"); | |
const output = document.getElementById("output"); | |
const ws = new WebSocket("wss://daniilalpha-answerer-api.hf.space/answer"); | |
ws.onmessage = (e) => answer(e.data); | |
function ask(event) { | |
if(ws.readyState != 1) { | |
answer("websocket is not connected!"); | |
return; | |
} | |
ws.send(prompt.value); | |
event.preventDefault(); | |
} | |
function answer(value) { | |
output.innerHTML = value; | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
def index(): | |
return HTMLResponse(HTML) | |
async def answer(ws: WebSocket): | |
await accelerator.connect(ws) | |
while accelerator.connected(): | |
await sleep(10) | |
def map(query: Optional[str], items: Optional[list[str]]): | |
scores = mapper(query, items) | |
return JSONResponse(jsonable_encoder(scores)) | |
async def handle_answerer_local(ws: WebSocket, input: str): | |
output = answerer(input, 128) | |
el: str | |
async for el in output: pass | |
await ws.send_text(el) | |
async def handle_answerer_accelerated(ws: WebSocket, input: str): | |
output = await accelerator.accelerate(input) | |
if output: await ws.send_text(output) | |
else: await handle_answerer_local(ws, input) | |
async def answer(ws: WebSocket): | |
await ws.accept() | |
try: | |
input = await ws.receive_text() | |
if accelerator.connected(): await handle_answerer_accelerated(ws, input) | |
else: await handle_answerer_local(ws, input) | |
except ConnectionClosed: return | |
except WebSocketDisconnect: return | |
await ws.close() | |