|
|
|
|
|
|
|
|
|
|
|
|
|
import treetable as tt |
|
|
|
from .._base_explorers import BaseExplorer |
|
|
|
|
|
class CompressionExplorer(BaseExplorer): |
|
eval_metrics = ["sisnr", "visqol"] |
|
|
|
def stages(self): |
|
return ["train", "valid", "evaluate"] |
|
|
|
def get_grid_meta(self): |
|
"""Returns the list of Meta information to display for each XP/job. |
|
""" |
|
return [ |
|
tt.leaf("index", align=">"), |
|
tt.leaf("name", wrap=140), |
|
tt.leaf("state"), |
|
tt.leaf("sig", align=">"), |
|
] |
|
|
|
def get_grid_metrics(self): |
|
"""Return the metrics that should be displayed in the tracking table. |
|
""" |
|
return [ |
|
tt.group( |
|
"train", |
|
[ |
|
tt.leaf("epoch"), |
|
tt.leaf("bandwidth", ".2f"), |
|
tt.leaf("adv", ".4f"), |
|
tt.leaf("d_loss", ".4f"), |
|
], |
|
align=">", |
|
), |
|
tt.group( |
|
"valid", |
|
[ |
|
tt.leaf("bandwidth", ".2f"), |
|
tt.leaf("adv", ".4f"), |
|
tt.leaf("msspec", ".4f"), |
|
tt.leaf("sisnr", ".2f"), |
|
], |
|
align=">", |
|
), |
|
tt.group( |
|
"evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">" |
|
), |
|
] |
|
|