Spaces:
Sleeping
Sleeping
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/utils.ipynb (unless otherwise specified). | |
__all__ = ['generate_TS_df', 'normalize_columns', 'remove_constant_columns', 'ReferenceArtifact', 'PrintLayer', | |
'get_wandb_artifacts', 'get_pickle_artifact'] | |
# Cell | |
from .imports import * | |
from fastcore.all import * | |
import wandb | |
import pickle | |
import pandas as pd | |
import numpy as np | |
#import tensorflow as tf | |
import torch.nn as nn | |
from fastai.basics import * | |
# Cell | |
def generate_TS_df(rows, cols): | |
"Generates a dataframe containing a multivariate time series, where each column \ | |
represents a variable and each row a time point (sample). The timestamp is in the \ | |
index of the dataframe, and it is created with a even space of 1 second between samples" | |
index = np.arange(pd.Timestamp.now(), | |
pd.Timestamp.now() + pd.Timedelta(rows-1, 'seconds'), | |
pd.Timedelta(1, 'seconds')) | |
data = np.random.randn(len(index), cols) | |
return pd.DataFrame(data, index=index) | |
# Cell | |
def normalize_columns(df:pd.DataFrame): | |
"Normalize columns from `df` to have 0 mean and 1 standard deviation" | |
mean = df.mean() | |
std = df.std() + 1e-7 | |
return (df-mean)/std | |
# Cell | |
def remove_constant_columns(df:pd.DataFrame): | |
return df.loc[:, (df != df.iloc[0]).any()] | |
# Cell | |
class ReferenceArtifact(wandb.Artifact): | |
default_storage_path = Path('data/wandb_artifacts/') # * this path is relative to Path.home() | |
"This class is meant to create an artifact with a single reference to an object \ | |
passed as argument in the contructor. The object will be pickled, hashed and stored \ | |
in a specified folder." | |
def __init__(self, obj, name, type='object', folder=None, **kwargs): | |
super().__init__(type=type, name=name, **kwargs) | |
# pickle dumps the object and then hash it | |
hash_code = str(hash(pickle.dumps(obj))) | |
folder = Path(ifnone(folder, Path.home()/self.default_storage_path)) | |
with open(f'{folder}/{hash_code}', 'wb') as f: | |
pickle.dump(obj, f) | |
self.add_reference(f'file://{folder}/{hash_code}') | |
if self.metadata is None: | |
self.metadata = dict() | |
self.metadata['ref'] = dict() | |
self.metadata['ref']['hash'] = hash_code | |
self.metadata['ref']['type'] = str(obj.__class__) | |
# Cell | |
def to_obj(self:wandb.apis.public.Artifact): | |
"""Download the files of a saved ReferenceArtifact and get the referenced object. The artifact must \ | |
come from a call to `run.use_artifact` with a proper wandb run.""" | |
if self.metadata.get('ref') is None: | |
print(f'ERROR:{self} does not come from a saved ReferenceArtifact') | |
return None | |
original_path = ReferenceArtifact.default_storage_path/self.metadata['ref']['hash'] | |
path = original_path if original_path.exists() else Path(self.download()).ls()[0] | |
with open(path, 'rb') as f: | |
obj = pickle.load(f) | |
return obj | |
# Cell | |
import torch.nn as nn | |
class PrintLayer(nn.Module): | |
def __init__(self): | |
super(PrintLayer, self).__init__() | |
def forward(self, x): | |
# Do your print / debug stuff here | |
print(x.shape) | |
return x | |
# Cell | |
def export_and_get(self:Learner, keep_exported_file=False): | |
""" | |
Export the learner into an auxiliary file, load it and return it back. | |
""" | |
aux_path = Path('aux.pkl') | |
self.export(fname='aux.pkl') | |
aux_learn = load_learner('aux.pkl') | |
if not keep_exported_file: aux_path.unlink() | |
return aux_learn | |
# Cell | |
def get_wandb_artifacts(project_path, type=None, name=None, last_version=True): | |
""" | |
Get the artifacts logged in a wandb project. | |
Input: | |
- `project_path` (str): entity/project_name | |
- `type` (str): whether to return only one type of artifacts | |
- `name` (str): Leave none to have all artifact names | |
- `last_version`: whether to return only the last version of each artifact or not | |
Output: List of artifacts | |
""" | |
public_api = wandb.Api() | |
if type is not None: | |
types = [public_api.artifact_type(type, project_path)] | |
else: | |
types = public_api.artifact_types(project_path) | |
res = L() | |
for kind in types: | |
for collection in kind.collections(): | |
if name is None or name == collection.name: | |
versions = public_api.artifact_versions( | |
kind.type, | |
"/".join([kind.entity, kind.project, collection.name]), | |
per_page=1, | |
) | |
if last_version: res += next(versions) | |
else: res += L(versions) | |
return list(res) | |
# Cell | |
def get_pickle_artifact(filename): | |
with open(filename, "rb") as f: | |
df = pickle.load(f) | |
return df |