Start-GPT commited on
Commit
df92393
·
verified ·
1 Parent(s): b71e56d

Create main.py

Browse files
Files changed (1) hide show
  1. server/main.py +174 -0
server/main.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import connexion
4
+ from flask_cors import CORS
5
+ from flask import render_template, redirect, send_from_directory
6
+
7
+ import utils.path_fixes as pf
8
+ from utils.f import ifnone
9
+
10
+ from model_api import get_details
11
+
12
+ app = connexion.FlaskApp(__name__, static_folder="client/dist", specification_dir=".")
13
+ flask_app = app.app
14
+ CORS(flask_app)
15
+
16
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
+ parser.add_argument("--debug", action="store_true", help=" Debug mode")
18
+ parser.add_argument("--port", default=5051, help="Port to run the app. ")
19
+
20
+ # Flask main routes
21
+ @app.route("/")
22
+ def hello_world():
23
+ return redirect("client/exBERT.html")
24
+
25
+ # send everything from client as static content
26
+ @app.route("/client/<path:path>")
27
+ def send_static_client(path):
28
+ """ serves all files from ./client/ to ``/client/<path:path>``
29
+ :param path: path from api call
30
+ """
31
+ return send_from_directory(str(pf.CLIENT_DIST), path)
32
+
33
+ # ======================================================================
34
+ ## CONNEXION API ##
35
+ # ======================================================================
36
+ def get_model_details(**request):
37
+ """Get important information about a model, like the number of layers and heads
38
+
39
+ Args:
40
+ request['model']: The model name
41
+ Returns:
42
+ {
43
+ status: 200,
44
+ payload: {
45
+ nlayers (int)
46
+ nheads (int)
47
+ }
48
+ }
49
+ """
50
+ mname = request['model']
51
+ deets = get_details(mname)
52
+
53
+ info = deets.config
54
+ nlayers = info.num_hidden_layers
55
+ nheads = info.num_attention_heads
56
+
57
+ payload_out = {
58
+ "nlayers": nlayers,
59
+ "nheads": nheads,
60
+ }
61
+
62
+ return {
63
+ "status": 200,
64
+ "payload": payload_out,
65
+ }
66
+
67
+ def get_attentions_and_preds(**request):
68
+ """For a sentence, at a layer, get the attentions and predictions
69
+
70
+ Args:
71
+ request['model']: Model name
72
+ request['sentence']: Sentence to get the attentions for
73
+ request['layer']: Which layer to extract from
74
+ Returns:
75
+ {
76
+ status: 200
77
+ payload: {
78
+ aa: {
79
+ att: Array((nheads, ntoks, ntoks))
80
+ left: [{
81
+ text (str),
82
+ topk_words (List[str]),
83
+ topk_probs (List[float])
84
+ }, ...]
85
+ right: [{
86
+ text (str),
87
+ topk_words (List[str]),
88
+ topk_probs (List[float])
89
+ }, ...]
90
+ }
91
+ }
92
+ }
93
+ """
94
+ model = request["model"]
95
+ details = get_details(model)
96
+
97
+ sentence = request["sentence"]
98
+ layer = int(request["layer"])
99
+
100
+ deets = details.from_sentence(sentence)
101
+
102
+ payload_out = deets.to_json(layer)
103
+
104
+ return {
105
+ "status": 200,
106
+ "payload": payload_out
107
+ }
108
+
109
+ def update_masked_attention(**request):
110
+ """From tokens and indices of what should be masked, get the attentions and predictions
111
+
112
+ payload = request['payload']
113
+ Args:
114
+ payload['model'] (str): Model name
115
+ payload['tokens'] (List[str]): Tokens to pass through the model
116
+ payload['sentence'] (str): Original sentence the tokens came from
117
+ payload['mask'] (List[int]): Which indices to mask
118
+ payload['layer'] (int): Which layer to extract information from
119
+ Returns:
120
+ {
121
+ status: 200
122
+ payload: {
123
+ aa: {
124
+ att: Array((nheads, ntoks, ntoks))
125
+ left: [{
126
+ text (str),
127
+ topk_words (List[str]),
128
+ topk_probs (List[float])
129
+ }, ...]
130
+ right: [{
131
+ text (str),
132
+ topk_words (List[str]),
133
+ topk_probs (List[float])
134
+ }, ...]
135
+ }
136
+ }
137
+ }
138
+ """
139
+ payload = request["payload"]
140
+
141
+ model = payload['model']
142
+ details = get_details(model)
143
+
144
+ tokens = payload["tokens"]
145
+ sentence = payload["sentence"]
146
+ mask = payload["mask"]
147
+ layer = int(payload["layer"])
148
+
149
+ MASK = details.tok.mask_token
150
+ mask_tokens = lambda toks, maskinds: [
151
+ t if i not in maskinds else ifnone(MASK, t) for (i, t) in enumerate(toks)
152
+ ]
153
+
154
+ token_inputs = mask_tokens(tokens, mask)
155
+
156
+ deets = details.from_tokens(token_inputs, sentence)
157
+ payload_out = deets.to_json(layer)
158
+
159
+ return {
160
+ "status": 200,
161
+ "payload": payload_out,
162
+ }
163
+
164
+ app.add_api("swagger.yaml")
165
+
166
+ # Setup code
167
+ if __name__ != "__main__":
168
+ print("SETTING UP ENDPOINTS")
169
+
170
+ # Then deploy app
171
+ else:
172
+ args, _ = parser.parse_known_args()
173
+ print("Initiating app")
174
+ app.run(port=args.port, use_reloader=False, debug=args.debug)