Start-GPT commited on
Commit
0f15686
1 Parent(s): 95167de

Create server/transformer_formatter.py

Browse files
Files changed (1) hide show
  1. server/transformer_formatter.py +138 -0
server/transformer_formatter.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable, Tuple
2
+ from functools import partial
3
+ import numpy as np
4
+ import torch
5
+ import json
6
+
7
+ from utils.token_processing import fix_byte_spaces
8
+ from utils.gen_utils import map_nlist
9
+
10
+
11
+ def round_return_value(attentions, ndigits=5):
12
+ """Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists
13
+
14
+ attentions: {
15
+ 'aa': {
16
+ left
17
+ right
18
+ att
19
+ }
20
+ }
21
+
22
+ """
23
+ rounder = partial(round, ndigits=ndigits)
24
+ nested_rounder = partial(map_nlist, rounder)
25
+ new_out = attentions # Modify values to save memory
26
+ new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"])
27
+
28
+ return new_out
29
+
30
+ def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
31
+ """Remove the batch dimension of every tensor inside the Iterable container `x`"""
32
+ return tuple([x_.squeeze(0) for x_ in x])
33
+
34
+ def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
35
+ """Combine the last two dimensions of the context."""
36
+ shape = x[0].shape
37
+ new_shape = shape[:-2] + (-1,)
38
+ return tuple([x_.view(new_shape) for x_ in x])
39
+
40
+ def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]:
41
+ """The embeddings have n_layers + 1, indicating the final output embedding."""
42
+
43
+ return (torch.zeros_like(xs[0]),) + xs
44
+
45
+ class TransformerOutputFormatter:
46
+ def __init__(
47
+ self,
48
+ sentence: str,
49
+ tokens: List[str],
50
+ special_tokens_mask: List[int],
51
+ att: Tuple[torch.Tensor],
52
+ topk_words: List[List[str]],
53
+ topk_probs: List[List[float]],
54
+ model_config
55
+ ):
56
+ assert len(tokens) > 0, "Cannot have an empty token output!"
57
+
58
+ modified_att = flatten_batch(att)
59
+
60
+ self.sentence = sentence
61
+ self.tokens = tokens
62
+ self.special_tokens_mask = special_tokens_mask
63
+ self.attentions = modified_att
64
+ self.topk_words = topk_words
65
+ self.topk_probs = topk_probs
66
+ self.model_config = model_config
67
+
68
+ try:
69
+ # GPT vals
70
+ self.n_layer = self.model_config.n_layer
71
+ self.n_head = self.model_config.n_head
72
+ self.hidden_dim = self.model_config.n_embd
73
+ except AttributeError:
74
+ try:
75
+ # BERT vals
76
+ self.n_layer = self.model_config.num_hidden_layers
77
+ self.n_head = self.model_config.num_attention_heads
78
+ self.hidden_dim = self.model_config.hidden_size
79
+ except AttributeError: raise
80
+
81
+
82
+ self.__len = len(tokens)# Get the number of tokens in the input
83
+ assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!"
84
+
85
+ def to_json(self, layer:int, ndigits=5):
86
+ """The original API expects the following response:
87
+ aa: {
88
+ att: number[][][]
89
+ left: List[str]
90
+ right: List[str]
91
+ }
92
+ """
93
+ # Convert the embeddings, attentions, and contexts into list. Perform rounding
94
+
95
+ rounder = partial(round, ndigits=ndigits)
96
+ nested_rounder = partial(map_nlist, rounder)
97
+
98
+ def tolist(tens): return [t.tolist() for t in tens]
99
+
100
+ def to_resp(tok: str, topk_words, topk_probs):
101
+ return {
102
+ "text": tok,
103
+ "topk_words": topk_words,
104
+ "topk_probs": nested_rounder(topk_probs)
105
+ }
106
+
107
+ side_info = [to_resp(t, w, p) for t,w,p in zip( self.tokens,
108
+ self.topk_words,
109
+ self.topk_probs)]
110
+
111
+ out = {"aa": {
112
+ "att": nested_rounder(tolist(self.attentions[layer])),
113
+ "left": side_info,
114
+ "right": side_info
115
+ }}
116
+
117
+ return out
118
+
119
+ def display_tokens(self, tokens):
120
+ return fix_byte_spaces(tokens)
121
+
122
+ def __repr__(self):
123
+ lim = 50
124
+ if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..."
125
+ else: s = self.sentence[:lim]
126
+
127
+ return f"TransformerOutput({s})"
128
+
129
+ def __len__(self):
130
+ return self.__len
131
+
132
+ def to_numpy(x):
133
+ """Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array
134
+ for storage in hdf5"""
135
+ return np.array([x_.detach().numpy() for x_ in x])
136
+
137
+ def to_searchable(t: Tuple[torch.Tensor]):
138
+ return t.detach().numpy().astype(np.float32)