|
import polars as pl |
|
import numpy as np |
|
import tensorflow as tf |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
|
|
|
|
class tokenizer: |
|
def __init__(self, |
|
moves_index : str, |
|
play_index : str, |
|
positions_index : str, |
|
scrimmage_index : str, |
|
starts_index : str, |
|
time_index : str, |
|
window_size : int): |
|
self.window = window_size |
|
|
|
moves_index = pl.read_parquet(moves_index) |
|
self.moves_index = self.convert_index_to_dict(moves_index) |
|
|
|
play_index = pl.read_parquet(play_index) |
|
self.play_index= self.convert_index_to_dict(play_index) |
|
|
|
positions_index = pl.read_parquet(positions_index) |
|
self.positions_index = self.convert_index_to_dict(positions_index) |
|
|
|
scrimmage_index = pl.read_parquet(scrimmage_index) |
|
self.scrimmage_index = self.convert_index_to_dict(scrimmage_index) |
|
|
|
starts_index = pl.read_parquet(starts_index) |
|
self.starts_index = self.convert_index_to_dict(starts_index) |
|
|
|
time_index = pl.read_parquet(time_index) |
|
self.time_index = self.convert_index_to_dict(time_index) |
|
|
|
self.offdef_index = {0 : ["Def"], |
|
1 : ["Off"]} |
|
|
|
self.index = {"input_ids" : self.moves_index, |
|
"PlayType" : self.play_index, |
|
"position_ids" : self.positions_index, |
|
"scrim_ids" : self.scrimmage_index, |
|
"start_ids" : self.starts_index, |
|
"pos_ids" : self.time_index, |
|
"OffDef" : self.offdef_index} |
|
|
|
def convert_index_to_dict(self, df : pl.DataFrame): |
|
|
|
ID_col = [v for v in df.columns if "ID" in v] |
|
assert len(ID_col) == 1 |
|
new_id_name = ["ID"] |
|
|
|
val_cols = [v for v in df.columns if v not in ID_col+["Cat"]] |
|
new_val_name = ["Val_"+str(i) for i in range(1, len(val_cols)+1)] |
|
|
|
past_names = ID_col + val_cols |
|
new_names = new_id_name+new_val_name |
|
|
|
renaming = {past_names[i]: new_names[i] for i in range(len(new_names))} |
|
|
|
d = (df. |
|
drop("Cat"). |
|
rename(renaming). |
|
select(new_names). |
|
to_dict(as_series=False)) |
|
|
|
final_d = {d["ID"][i] : [d[k][i] for k in new_val_name] for i in range(len(d["ID"]))} |
|
|
|
return final_d |
|
|
|
def base_decode(self, |
|
pad_element, |
|
inputs : list, |
|
index : dict, |
|
first : bool): |
|
if first == True: |
|
return [index[v][0] if v in index.keys() else pad_element for v in inputs] |
|
else: |
|
return [index[v] if v in index.keys() else pad_element for v in inputs] |
|
|
|
def decode(self, |
|
inputs : list, |
|
type : str): |
|
if type in ["input_ids", "start_ids"]: |
|
padding = [-1000, -1000] |
|
elif type in ["scrim_ids", "pos_ids"]: |
|
padding = -1000 |
|
else: |
|
padding = "[PAD]" |
|
|
|
if type in ["input_ids", "start_ids"]: |
|
return self.base_decode(padding, inputs, index = self.index[type], first=False) |
|
else: |
|
return self.base_decode(padding, inputs, index = self.index[type], first=True) |
|
|
|
def find_id_by_values(self, |
|
input_dict : dict, |
|
target_list : list): |
|
|
|
for key, values in input_dict.items(): |
|
if set(target_list) == set(values): |
|
return key |
|
|
|
def base_encode(self, |
|
inputs : list, |
|
index : dict): |
|
return [self.find_id_by_values(index, [v]) for v in inputs] |
|
|
|
def encode(self, |
|
inputs : list, |
|
type : str): |
|
return self.base_encode(inputs, index = self.index[type]) |
|
|
|
def decode_sequence(self, |
|
input : dict): |
|
return {k : self.decode(v, k) if k not in ["side_ids", "token_type_ids", "labels", "attention_mask", "ids"] else v for k,v in input.items()} |
|
|
|
def encode_sequence(self, |
|
input : dict): |
|
return {k : self.encode(v, k) if k not in ["side_ids", "token_type_ids", "labels", "attention_mask", "ids"] else v for k,v in input.items()} |
|
|
|
def truncate_to_time_t(self, |
|
input : dict, |
|
t : int): |
|
to_keep = [i < t for i in input["pos_ids"]] |
|
return {k: [v[i] for i in range(len(v)) if to_keep[i] == True] for k,v in input.items()} |
|
|
|
def resize_window(self, |
|
input : dict, |
|
pos_id): |
|
out = input.copy() |
|
out["attention_mask"] = [0 if out["pos_ids"][p] <pos_id else 1 for p in range(len(out["pos_ids"]))] |
|
return out |
|
|
|
def prepare_for_call(self, |
|
input : dict): |
|
resize_limit = max([v for v in np.array(input["pos_ids"]).flatten() if v != 51]) - self.window |
|
if resize_limit > 0: |
|
input = self.resize_window(input, resize_limit) |
|
|
|
done = {k : tf.constant(v) for k,v in input.items()} |
|
if len(done["pos_ids"].shape) == 1: |
|
done = {k : tf.expand_dims(v, axis=0) for k,v in input.items()} |
|
return done |
|
|
|
class generator: |
|
def __init__(self, |
|
model, |
|
tokenizer, |
|
temp, |
|
n_select): |
|
|
|
self.QBGPT = model |
|
self.tokenizer = tokenizer |
|
|
|
self.temperature = temp |
|
self.n_select = n_select |
|
|
|
def get_unique_lists(self, |
|
l_of_ls : list): |
|
list_of_tuples = [tuple(inner_list) for inner_list in l_of_ls] |
|
|
|
|
|
unique_tuples = set(list_of_tuples) |
|
|
|
|
|
unique_lists = [list(unique_tuple) for unique_tuple in unique_tuples] |
|
|
|
return unique_lists |
|
|
|
def cut(self, l, ref): |
|
splitted = [] |
|
cutted = [] |
|
for i in range(len(l)): |
|
if ref[i] == True: |
|
cutted.append(l[i]) |
|
else: |
|
splitted.append(cutted) |
|
cutted = [] |
|
cutted.append(l[i]) |
|
if i == len(l)-1: |
|
splitted.append(cutted) |
|
return splitted |
|
|
|
def get_last_preds(self, |
|
logits, |
|
input : dict): |
|
|
|
to_keep = [i == max(input["pos_ids"]) for i in input["pos_ids"]] |
|
return np.array([logits[i] for i in range(len(logits)) if to_keep[i] == True]) |
|
|
|
def get_logits(self, |
|
input : dict): |
|
x = self.tokenizer.prepare_for_call(input) |
|
return self.QBGPT(x) |
|
|
|
def convert_to_preds(self, |
|
logits): |
|
preds = tf.squeeze(logits, axis=0) |
|
return preds |
|
|
|
def set_temperature(self, |
|
x): |
|
if x < 5: |
|
return self.temperature |
|
elif x < 10 and x >= 5: |
|
return self.temperature/2 |
|
elif x <20 and x >= 10: |
|
return self.temperature/5 |
|
else: |
|
return 1.0 |
|
|
|
def select_and_temp(self, |
|
tensor, |
|
n, |
|
temp): |
|
probas = tf.nn.softmax(tf.sort(tensor/temp, axis = -1)[:,:,-n:], axis = 2) |
|
indices = tf.argsort(tensor, axis = -1)[:,:,-n:] |
|
return probas, indices |
|
|
|
def draw_random(self, |
|
probas): |
|
drawn = np.vstack([np.random.multinomial(1, p.numpy(), 1) for p in probas[0]]) |
|
drawn = tf.expand_dims(drawn, axis = 0) |
|
return tf.cast(drawn, dtype="int32") |
|
|
|
def get_indices(self, |
|
drawn, |
|
ind): |
|
return tf.reduce_sum(drawn*ind, axis = 2) |
|
|
|
def process_logits(self, |
|
logits, |
|
temp, |
|
n): |
|
probas, indices = self.select_and_temp(logits, n, temp) |
|
drawn = self.draw_random(probas) |
|
results = self.get_indices(drawn, indices) |
|
return results |
|
|
|
def generate(self, |
|
input : dict): |
|
logits = self.get_logits(input) |
|
temperature_parameter = self.set_temperature(max(input["pos_ids"])) |
|
processed_logits = self.process_logits(logits, n=self.n_select, temp=temperature_parameter) |
|
preds = self.convert_to_preds(processed_logits) |
|
return self.get_last_preds(preds, input) |
|
|
|
def slice_inputs(self, |
|
input : dict): |
|
flags = [True] + [input["pos_ids"][i+1] > input["pos_ids"][i] for i in range(len(input["pos_ids"])-1)] |
|
cutted_inputs = {k : self.cut(v, flags) for k,v in input.items()} |
|
return cutted_inputs |
|
|
|
def continue_by_token(self, |
|
arr, |
|
token :str): |
|
if token == "input_ids": |
|
return arr |
|
if token == "pos_ids": |
|
insert = max(arr)+1 |
|
return np.concatenate([arr, np.array([insert])]) |
|
elif token == "token_type_ids": |
|
return np.concatenate([arr, np.array([1])]) |
|
else: |
|
return np.concatenate([arr, [arr[-1]]]) |
|
|
|
|
|
def append_prediction(self, |
|
arr, |
|
pred): |
|
return np.concatenate([arr, [pred]]) |
|
|
|
def append_predictions(self, |
|
d : dict, |
|
preds): |
|
new = d.copy() |
|
new["input_ids"] = [self.append_prediction(new["input_ids"][i], preds[i]) for i in range(len(preds))] |
|
return new |
|
|
|
def merge_cuts(self, |
|
input : dict): |
|
return {k : np.concatenate(v) for k,v in input.items()} |
|
|
|
def update_inputs(self, |
|
input, |
|
preds): |
|
sliced = self.slice_inputs(input) |
|
appended = self.append_predictions(sliced, preds) |
|
continued = {k : [self.continue_by_token(e, k) for e in v] for k,v in appended.items()} |
|
merged = self.merge_cuts(continued) |
|
return merged |
|
|
|
def generate_sequence(self, |
|
input, |
|
t): |
|
new_input = input.copy() |
|
for i in range(t): |
|
generated = self.generate(new_input) |
|
new_input = self.update_inputs(new_input, generated) |
|
return new_input |
|
|
|
def convert_list(self, |
|
d, |
|
keep_original): |
|
new_df = d.copy() |
|
new_df["start_ids_x"] = [v[0] for v in new_df["start_ids"]] |
|
new_df["start_ids_y"] = [v[1] for v in new_df["start_ids"]] |
|
new_df["input_ids_x"] = [v[0] for v in new_df["input_ids"]] |
|
new_df["input_ids_y"] = [v[1] for v in new_df["input_ids"]] |
|
if keep_original == True: |
|
return new_df |
|
else: |
|
return {k : v for k,v in new_df.items() if k not in ["start_ids", "input_ids"]} |
|
|
|
def remove_pad(self, |
|
seq): |
|
df = pd.DataFrame(seq) |
|
filtered = df[df["start_ids_x"] != -1000].reset_index(drop=True) |
|
filtered = df[df["input_ids_x"] != -1000].reset_index(drop=True) |
|
return filtered.to_dict(orient = "list") |
|
|
|
def _compute_true_sequence(self, |
|
scrimmage_line, |
|
start : list, |
|
moves : list): |
|
scrimmage = np.array([scrimmage_line, 26.5]) |
|
|
|
updated_moves = np.array([np.array(start) + np.array(v) for v in moves]) |
|
|
|
appended = np.concatenate([np.expand_dims(start, axis = 0), updated_moves]) |
|
|
|
final = appended + scrimmage |
|
return final |
|
|
|
def compute_true_sequence(self, |
|
scrims, |
|
starts, |
|
moves): |
|
return self._compute_true_sequence(np.unique(scrims)[0], self.get_unique_lists(starts)[0], moves) |
|
|
|
def _resize_variable(self, |
|
x, |
|
ref: str): |
|
|
|
if ref in ["pos_ids", "token_type_ids"]: |
|
return np.concatenate([[0], x]) |
|
|
|
elif ref in ["input_ids", "start_ids"]: |
|
return np.vstack([self.get_unique_lists(x)[0], x]) |
|
|
|
else: |
|
return np.concatenate([np.unique(x), x]) |
|
|
|
def prepare_for_plot(self, |
|
seq): |
|
sequence = seq.copy() |
|
sequence = self.convert_list(sequence, keep_original = True) |
|
sequence = self.remove_pad(sequence) |
|
cutted = self.slice_inputs(sequence) |
|
moves_updated = [self.compute_true_sequence(cutted["scrim_ids"][i], cutted["start_ids"][i], cutted["input_ids"][i]) for i in range(len(cutted["input_ids"]))] |
|
cutted["input_ids"] = moves_updated |
|
cutted = {k : [self._resize_variable(e, k) if k != "input_ids" else e for e in v] for k,v in cutted.items()} |
|
cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))] |
|
merged = self.merge_cuts(cutted) |
|
converted = self.convert_list(merged, keep_original = False) |
|
structured = {k:v for k,v in converted.items() if k != "labels"} |
|
return structured |
|
|
|
def insert_ids(self, |
|
input): |
|
|
|
cutted = self.slice_inputs(input) |
|
cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))] |
|
merged = self.merge_cuts(cutted) |
|
return merged |
|
|
|
def get_plot(df, n_frames, name): |
|
fig = go.Figure( |
|
layout=go.Layout( |
|
updatemenus=[dict(type="buttons", direction="right", x=0.9, y=1.16), ], |
|
xaxis=dict(range=[0, 120], |
|
autorange=False, tickwidth=2, |
|
title_text="X"), |
|
yaxis=dict(range=[0, 60], |
|
autorange=False, |
|
title_text="Y") |
|
)) |
|
|
|
|
|
i = 1 |
|
frames = {i: [] for i in df["pos_ids"].unique() if i !=0} |
|
|
|
for id in df["ids"].unique(): |
|
spec = df[df["ids"] == id].reset_index(drop = True) |
|
fig.add_trace( |
|
go.Scatter(x=spec.input_ids_x[:i], |
|
y=spec.input_ids_y[:i], |
|
name= spec.position_ids.unique()[0], |
|
text= spec.position_ids.unique()[0], |
|
visible=True, |
|
line=dict(color="#f47738", dash="solid"))) |
|
|
|
for k in range(i, spec.shape[0]): |
|
current_frame = spec["pos_ids"][k] |
|
frames[current_frame].append(go.Scatter(x=spec.input_ids_x[:k], y=spec.input_ids_y[:k])) |
|
|
|
frames = list(frames.values()) |
|
frames = [go.Frame(data = v) for v in frames] |
|
|
|
|
|
|
|
fig.update(frames=frames) |
|
|
|
fig.update_xaxes(ticks="outside", tickwidth=2, tickcolor='white', ticklen=10) |
|
fig.update_yaxes(ticks="outside", tickwidth=2, tickcolor='white', ticklen=1) |
|
fig.update_layout(yaxis_tickformat=',') |
|
fig.update_layout(legend=dict(x=0, y=1.1), legend_orientation="h") |
|
|
|
|
|
fig.update_layout(title=f"{name} play", |
|
xaxis_title="X", |
|
yaxis_title="Y", |
|
legend_title="Legend Title", |
|
showlegend=False, |
|
font=dict( |
|
family="Arial", |
|
size=14 |
|
), |
|
hovermode="x", |
|
updatemenus=[ |
|
dict( |
|
buttons=list( |
|
[ |
|
dict(label="Play", |
|
method="animate", |
|
args=[None, {"frame": {"duration": n_frames}}]) |
|
] |
|
), |
|
type = "buttons", |
|
direction="right", |
|
pad={"r": 50, "t": 50}, |
|
showactive=False, |
|
x=0.5, |
|
yanchor="top") |
|
]) |
|
|
|
fig.update_layout(template='plotly_dark' |
|
) |
|
|
|
fig.update_layout(width=1200, height=600) |
|
return fig |