Samuel CHAINEAU commited on
Commit
382e94b
1 Parent(s): ce48a3d
Files changed (3) hide show
  1. .streamlit/config.toml +2 -0
  2. pages.py +3 -18
  3. tools.py +78 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [theme]
2
+ base="dark"
pages.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import plotly.express as px
3
  import pandas as pd
4
  import numpy as np
5
- from tools import generator
6
  from PIL import Image
7
 
8
  def set_app_title_and_logo():
@@ -63,25 +63,10 @@ def qb_gpt_page(ref_df, ref, tokenizer, model):
63
  step1_true = QB_gen.prepare_for_plot(decoded_true)
64
  plot_true = pd.DataFrame(step1_true)
65
 
66
- fig_gen = px.line(plot, x="input_ids_x", y="input_ids_y", animation_frame="pos_ids", color="OffDef", symbol="ids",
67
- text="position_ids", title="Generated players' trajectories Over Time", line_shape="linear",
68
- range_x=[0, 140], range_y=[0, 60], # Set X and Y axis ranges
69
- render_mode="svg") # Render mode for smoother lines
70
-
71
- # Customize the appearance of the plot
72
- fig_gen.update_traces(marker=dict(size=10), selector=dict(mode='lines'))
73
- fig_gen.update_layout(width=800, height=600)
74
  st.plotly_chart(fig_gen)
75
 
76
- fig_true = px.line(plot_true, x="input_ids_x", y="input_ids_y", animation_frame="pos_ids", color="OffDef", symbol="ids",
77
- text="position_ids", title="True players' trajectories Over Time",
78
- range_x=[0, 140], range_y=[0, 60], # Set X and Y axis ranges
79
- line_shape="linear", # Draw lines connecting points
80
- render_mode="svg") # Render mode for smoother lines
81
-
82
- # Customize the appearance of the plot
83
- fig_true.update_traces(marker=dict(size=10), selector=dict(mode='lines'))
84
- fig_true.update_layout(width=800, height=600)
85
  st.plotly_chart(fig_true)
86
 
87
 
 
2
  import plotly.express as px
3
  import pandas as pd
4
  import numpy as np
5
+ from tools import generator, get_plot
6
  from PIL import Image
7
 
8
  def set_app_title_and_logo():
 
63
  step1_true = QB_gen.prepare_for_plot(decoded_true)
64
  plot_true = pd.DataFrame(step1_true)
65
 
66
+ fig_gen = get_plot(plot, frames, "Generated")
 
 
 
 
 
 
 
67
  st.plotly_chart(fig_gen)
68
 
69
+ fig_true = get_plot(plot_true, frames, "True")
 
 
 
 
 
 
 
 
70
  st.plotly_chart(fig_true)
71
 
72
 
tools.py CHANGED
@@ -2,6 +2,8 @@ import polars as pl
2
  import numpy as np
3
  import tensorflow as tf
4
  import pandas as pd
 
 
5
 
6
  class tokenizer:
7
  def __init__(self,
@@ -373,3 +375,79 @@ class generator:
373
  cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))]
374
  merged = self.merge_cuts(cutted)
375
  return merged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import tensorflow as tf
4
  import pandas as pd
5
+ import plotly.graph_objects as go
6
+
7
 
8
  class tokenizer:
9
  def __init__(self,
 
375
  cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))]
376
  merged = self.merge_cuts(cutted)
377
  return merged
378
+
379
+ def get_plot(df, n_frames, name):
380
+ fig = go.Figure(
381
+ layout=go.Layout(
382
+ updatemenus=[dict(type="buttons", direction="right", x=0.9, y=1.16), ],
383
+ xaxis=dict(range=[0, 120],
384
+ autorange=False, tickwidth=2,
385
+ title_text="X"),
386
+ yaxis=dict(range=[0, 60],
387
+ autorange=False,
388
+ title_text="Y")
389
+ ))
390
+
391
+ # Add traces
392
+ i = 1
393
+ frames = {i: [] for i in df["pos_ids"].unique() if i !=0}
394
+
395
+ for id in df["ids"].unique():
396
+ spec = df[df["ids"] == id].reset_index(drop = True)
397
+ fig.add_trace(
398
+ go.Scatter(x=spec.input_ids_x[:i],
399
+ y=spec.input_ids_y[:i],
400
+ name= spec.position_ids.unique()[0],
401
+ text= spec.position_ids.unique()[0],
402
+ visible=True,
403
+ line=dict(color="#f47738", dash="solid")))
404
+
405
+ for k in range(i, spec.shape[0]):
406
+ current_frame = spec["pos_ids"][k]
407
+ frames[current_frame].append(go.Scatter(x=spec.input_ids_x[:k], y=spec.input_ids_y[:k]))
408
+
409
+ frames = list(frames.values())
410
+ frames = [go.Frame(data = v) for v in frames]
411
+
412
+
413
+ # Animation
414
+ fig.update(frames=frames)
415
+
416
+ fig.update_xaxes(ticks="outside", tickwidth=2, tickcolor='white', ticklen=10)
417
+ fig.update_yaxes(ticks="outside", tickwidth=2, tickcolor='white', ticklen=1)
418
+ fig.update_layout(yaxis_tickformat=',')
419
+ fig.update_layout(legend=dict(x=0, y=1.1), legend_orientation="h")
420
+
421
+ # Buttons
422
+ fig.update_layout(title=f"{name} play",
423
+ xaxis_title="X",
424
+ yaxis_title="Y",
425
+ legend_title="Legend Title",
426
+ showlegend=False,
427
+ font=dict(
428
+ family="Arial",
429
+ size=14
430
+ ),
431
+ hovermode="x",
432
+ updatemenus=[
433
+ dict(
434
+ buttons=list(
435
+ [
436
+ dict(label="Play",
437
+ method="animate",
438
+ args=[None, {"frame": {"duration": n_frames}}])
439
+ ]
440
+ ),
441
+ type = "buttons",
442
+ direction="right",
443
+ pad={"r": 50, "t": 50},
444
+ showactive=False,
445
+ x=0.5,
446
+ yanchor="top")
447
+ ])
448
+
449
+ fig.update_layout(template='plotly_dark'
450
+ )
451
+
452
+ fig.update_layout(width=1200, height=600)
453
+ return fig