change way to make requests to wandb
Browse files- dashboard_utils/bubbles.py +141 -35
dashboard_utils/bubbles.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import datetime
|
2 |
from concurrent.futures import as_completed
|
|
|
3 |
from urllib import parse
|
4 |
-
|
|
|
5 |
import pandas as pd
|
6 |
|
7 |
import streamlit as st
|
@@ -66,48 +68,152 @@ def get_profiles(usernames):
|
|
66 |
@st.cache(ttl=CACHE_TTL, show_spinner=False)
|
67 |
@simple_time_tracker(_log)
|
68 |
def get_serialized_data_points():
|
69 |
-
|
70 |
api = wandb.Api()
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
}
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
{
|
100 |
"batches": run_summary["_step"],
|
101 |
"runtime": run_summary["_runtime"],
|
102 |
"loss": run_summary["train/loss"],
|
|
|
|
|
103 |
"velocity": run_summary["_step"] / run_summary["_runtime"],
|
104 |
"date": datetime.datetime.utcfromtimestamp(timestamp),
|
105 |
}
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
return serialized_data_points, latest_timestamp
|
112 |
|
113 |
|
@@ -123,7 +229,7 @@ def get_serialized_data(serialized_data_points, latest_timestamp):
|
|
123 |
batches = 0
|
124 |
velocity = 0
|
125 |
for run in serialized_data_point["Runs"]:
|
126 |
-
if run["
|
127 |
run["date"] = run["date"].isoformat()
|
128 |
activeRuns.append(run)
|
129 |
loss += run["loss"]
|
|
|
1 |
import datetime
|
2 |
from concurrent.futures import as_completed
|
3 |
+
from requests.adapters import HTTPAdapter
|
4 |
from urllib import parse
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
import pandas as pd
|
8 |
|
9 |
import streamlit as st
|
|
|
68 |
@st.cache(ttl=CACHE_TTL, show_spinner=False)
|
69 |
@simple_time_tracker(_log)
|
70 |
def get_serialized_data_points():
|
71 |
+
url = "https://api.wandb.ai/graphql"
|
72 |
api = wandb.Api()
|
73 |
+
|
74 |
+
# Get the run ids
|
75 |
+
json_query_run_names = {
|
76 |
+
"operationName":"WandbConfig",
|
77 |
+
"variables":{"limit":100000,"entityName":"learning-at-home","projectName":"dalle-hivemind-trainers","filters":"{\"$and\":[{\"$or\":[{\"$and\":[]}]},{\"$and\":[]},{\"$or\":[{\"$and\":[{\"$or\":[{\"$and\":[]}]},{\"$and\":[{\"name\":{\"$ne\":null}}]}]}]}]}","order":"-state"},
|
78 |
+
"query": """ query WandbConfig($projectName: String!, $entityName: String!, $filters: JSONString, $limit: Int = 100, $order: String) {
|
79 |
+
project(name: $projectName, entityName: $entityName) {
|
80 |
+
id
|
81 |
+
runs(filters: $filters, first: $limit, order: $order) {
|
82 |
+
edges {
|
83 |
+
node {
|
84 |
+
id
|
85 |
+
name
|
86 |
+
__typename
|
87 |
+
}
|
88 |
+
__typename
|
89 |
+
}
|
90 |
+
__typename
|
91 |
+
}
|
92 |
+
__typename
|
93 |
+
}
|
94 |
+
}
|
95 |
+
"""}
|
96 |
+
|
97 |
+
s = requests.Session()
|
98 |
+
s.mount(url, HTTPAdapter(max_retries=5))
|
99 |
+
|
100 |
+
resp = s.post(
|
101 |
+
headers={"User-Agent": api.user_agent, "Use-Admin-Privileges": "true", 'content-type': 'application/json'},
|
102 |
+
auth=("api", api.api_key),
|
103 |
+
url=url,
|
104 |
+
data=json.dumps(json_query_run_names)
|
105 |
+
)
|
106 |
+
json_metrics = resp.json()
|
107 |
+
run_names = [run['node']["name"] for run in json_metrics['data']['project']["runs"]['edges']]
|
108 |
+
|
109 |
+
# Get info of each run
|
110 |
+
with FuturesSession() as session:
|
111 |
+
futures = []
|
112 |
+
for run_name in run_names:
|
113 |
+
json_query_by_run = {
|
114 |
+
"operationName":"Run",
|
115 |
+
"variables":{"entityName":"learning-at-home","projectName":"dalle-hivemind-trainers", "runName":run_name},
|
116 |
+
"query":"""query Run($projectName: String!, $entityName: String, $runName: String!) {
|
117 |
+
project(name: $projectName, entityName: $entityName) {
|
118 |
+
id
|
119 |
+
name
|
120 |
+
createdAt
|
121 |
+
run(name: $runName) {
|
122 |
+
id
|
123 |
+
name
|
124 |
+
displayName
|
125 |
+
state
|
126 |
+
summaryMetrics
|
127 |
+
runInfo {
|
128 |
+
gpu
|
129 |
+
}
|
130 |
+
__typename
|
131 |
+
}
|
132 |
+
__typename
|
133 |
+
}
|
134 |
}
|
135 |
+
"""}
|
136 |
+
|
137 |
+
future = session.post(
|
138 |
+
headers={"User-Agent": api.user_agent, "Use-Admin-Privileges": "true", 'content-type': 'application/json'},
|
139 |
+
auth=("api", api.api_key),
|
140 |
+
url=url,
|
141 |
+
data=json.dumps(json_query_by_run)
|
142 |
+
)
|
143 |
+
futures.append(future)
|
144 |
+
|
145 |
+
serialized_data_points = {}
|
146 |
+
latest_timestamp = None
|
147 |
+
for future in as_completed(futures):
|
148 |
+
resp = future.result()
|
149 |
+
json_metrics = resp.json()
|
150 |
+
|
151 |
+
data = json_metrics.get("data", None)
|
152 |
+
if data is None:
|
153 |
+
continue
|
154 |
+
|
155 |
+
project = data.get("project", None)
|
156 |
+
if project is None:
|
157 |
+
continue
|
158 |
+
|
159 |
+
run = project.get("run", None)
|
160 |
+
if run is None:
|
161 |
+
continue
|
162 |
+
|
163 |
+
runInfo = run.get("runInfo", None)
|
164 |
+
if runInfo is None:
|
165 |
+
gpu_type = None
|
166 |
+
else:
|
167 |
+
gpu_type = runInfo.get("gpu", None)
|
168 |
+
|
169 |
+
summaryMetrics = run.get("summaryMetrics", None)
|
170 |
+
if summaryMetrics is not None:
|
171 |
+
run_summary = json.loads(summaryMetrics)
|
172 |
+
|
173 |
+
state = run.get("state", None)
|
174 |
+
if state is None:
|
175 |
+
continue
|
176 |
+
|
177 |
+
displayName = run.get("displayName", None)
|
178 |
+
if displayName is None:
|
179 |
+
continue
|
180 |
+
|
181 |
+
if displayName in serialized_data_points:
|
182 |
+
if "_timestamp" in run_summary and "_step" in run_summary:
|
183 |
+
timestamp = run_summary["_timestamp"]
|
184 |
+
serialized_data_points[displayName]["Runs"].append(
|
185 |
{
|
186 |
"batches": run_summary["_step"],
|
187 |
"runtime": run_summary["_runtime"],
|
188 |
"loss": run_summary["train/loss"],
|
189 |
+
"gpu_type": gpu_type,
|
190 |
+
"state": state,
|
191 |
"velocity": run_summary["_step"] / run_summary["_runtime"],
|
192 |
"date": datetime.datetime.utcfromtimestamp(timestamp),
|
193 |
}
|
194 |
+
)
|
195 |
+
if not latest_timestamp or timestamp > latest_timestamp:
|
196 |
+
latest_timestamp = timestamp
|
197 |
+
else:
|
198 |
+
if "_timestamp" in run_summary and "_step" in run_summary:
|
199 |
+
timestamp = run_summary["_timestamp"]
|
200 |
+
serialized_data_points[displayName] = {
|
201 |
+
"profileId": displayName,
|
202 |
+
"Runs": [
|
203 |
+
{
|
204 |
+
"batches": run_summary["_step"],
|
205 |
+
"gpu_type": gpu_type,
|
206 |
+
"state": state,
|
207 |
+
"runtime": run_summary["_runtime"],
|
208 |
+
"loss": run_summary["train/loss"],
|
209 |
+
"velocity": run_summary["_step"] / run_summary["_runtime"],
|
210 |
+
"date": datetime.datetime.utcfromtimestamp(timestamp),
|
211 |
+
}
|
212 |
+
],
|
213 |
+
}
|
214 |
+
if not latest_timestamp or timestamp > latest_timestamp:
|
215 |
+
latest_timestamp = timestamp
|
216 |
+
latest_timestamp = datetime.datetime.utcfromtimestamp(latest_timestamp)
|
217 |
return serialized_data_points, latest_timestamp
|
218 |
|
219 |
|
|
|
229 |
batches = 0
|
230 |
velocity = 0
|
231 |
for run in serialized_data_point["Runs"]:
|
232 |
+
if run["state"] == "running":
|
233 |
run["date"] = run["date"].isoformat()
|
234 |
activeRuns.append(run)
|
235 |
loss += run["loss"]
|