File size: 2,545 Bytes
81a216e
96dc6e1
ef18a5d
 
a8aa145
c7f3620
84f3ea8
9a6620f
71f5570
1e28d04
26790f6
9a6620f
9fe8e1f
589d913
6f601b3
1e28d04
2bc4710
1e28d04
 
 
 
aeb5107
71f5570
1e28d04
9a6620f
 
a3ccca5
 
 
 
 
 
6942628
8b0801b
05dc8f9
 
a3ccca5
 
a05f54e
a3ccca5
 
 
 
7cab91c
9e7b790
a3ccca5
 
9e7b790
 
 
 
 
a3ccca5
 
 
9e7b790
 
 
 
a3ccca5
 
 
 
 
 
bc2d927
 
 
26790f6
aeb5107
 
71f5570
589d913
 
aeb5107
d7562ab
96dc6e1
ef18a5d
 
a3ccca5
96dc6e1
84f3ea8
324ca9f
 
ca588ec
96dc6e1
 
 
 
7be15a8
96dc6e1
1e28d04
 
 
3518fdf
a8aa145
 
7be15a8
 
70982cd
a8aa145
aeb5107
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from asyncio import sleep
from typing import Optional
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.websockets import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, JSONResponse
from websockets import ConnectionClosed

from accelerator import Accelerator
from answerer import Answerer
from mapper import Mapper

try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1")
except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}")
    
answerer = Answerer(
  model="RWKV-5-World-3B-v2-20231118-ctx16k",
  vocab="rwkv_vocab_v20230424",
  strategy="cpu bf16",
  ctx_limit=16*1024,
)

accelerator = Accelerator()
  
app = FastAPI()

HTML = """
<!DOCTYPE HTML>

<html>

<body>
  <form action="" onsubmit="ask(event)">
    <textarea id="prompt"></textarea>
    <br>
    <input type="submit" value="SEND" />
  </form>

  <p id="output"></p>
  <script>
    const prompt = document.getElementById("prompt");
    const output = document.getElementById("output");

    const ws = new WebSocket("wss://daniilalpha-answerer-api.hf.space/answer"); 
    ws.onmessage = (e) => answer(e.data);

    function ask(event) {
      if(ws.readyState != 1) {
        answer("websocket is not connected!");
        return;
      }

      ws.send(prompt.value);
      event.preventDefault();
    }

    function answer(value) {
      output.innerHTML = value;
    }
  </script>
</body>

</html>
"""

@app.get("/")
def index():
  return HTMLResponse(HTML)

@app.websocket("/accelerate")
async def answer(ws: WebSocket):
  await accelerator.connect(ws)
  while accelerator.connected():
    await sleep(10)

@app.post("/map")
def map(query: Optional[str], items: Optional[list[str]]):
  scores = mapper(query, items)
  return JSONResponse(jsonable_encoder(scores))

async def handle_answerer_local(ws: WebSocket, input: str):
  output = answerer(input, 128)
  el: str
  async for el in output: pass
  await ws.send_text(el)

async def handle_answerer_accelerated(ws: WebSocket, input: str):
  output = await accelerator.accelerate(input)
  if output: await ws.send_text(output)
  else: await handle_answerer_local(ws, input)

@app.websocket("/answer")
async def answer(ws: WebSocket):
  await ws.accept()
      
  try: 
    input = await ws.receive_text()
    if accelerator.connected(): await handle_answerer_accelerated(ws, input)
    else: await handle_answerer_local(ws, input)
  except ConnectionClosed: return
  except WebSocketDisconnect: return
      
  await ws.close()