|
|
|
|
|
|
|
|
|
|
|
|
|
import typing as tp |
|
|
|
import treetable as tt |
|
|
|
from .._base_explorers import BaseExplorer |
|
|
|
|
|
class LMExplorer(BaseExplorer): |
|
eval_metrics: tp.List[str] = [] |
|
|
|
def stages(self) -> tp.List[str]: |
|
return ['train', 'valid'] |
|
|
|
def get_grid_metrics(self): |
|
"""Return the metrics that should be displayed in the tracking table.""" |
|
return [ |
|
tt.group( |
|
'train', |
|
[ |
|
tt.leaf('epoch'), |
|
tt.leaf('duration', '.1f'), |
|
tt.leaf('ping'), |
|
tt.leaf('ce', '.4f'), |
|
tt.leaf("ppl", '.3f'), |
|
], |
|
align='>', |
|
), |
|
tt.group( |
|
'valid', |
|
[ |
|
tt.leaf('ce', '.4f'), |
|
tt.leaf('ppl', '.3f'), |
|
tt.leaf('best_ppl', '.3f'), |
|
], |
|
align='>', |
|
), |
|
] |
|
|
|
def process_sheep(self, sheep, history): |
|
parts = super().process_sheep(sheep, history) |
|
|
|
track_by = {'ppl': 'lower'} |
|
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} |
|
|
|
def comparator(mode, a, b): |
|
return a < b if mode == 'lower' else a > b |
|
|
|
for metrics in history: |
|
for key, sub in metrics.items(): |
|
for metric in track_by: |
|
|
|
|
|
if key == 'valid' and metric in sub and comparator( |
|
track_by[metric], sub[metric], best_metrics[metric] |
|
): |
|
best_metrics[metric] = sub[metric] |
|
|
|
if 'valid' in parts: |
|
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) |
|
return parts |
|
|
|
|
|
class GenerationEvalExplorer(BaseExplorer): |
|
eval_metrics: tp.List[str] = [] |
|
|
|
def stages(self) -> tp.List[str]: |
|
return ['evaluate'] |
|
|
|
def get_grid_metrics(self): |
|
"""Return the metrics that should be displayed in the tracking table.""" |
|
return [ |
|
tt.group( |
|
'evaluate', |
|
[ |
|
tt.leaf('epoch', '.3f'), |
|
tt.leaf('duration', '.1f'), |
|
tt.leaf('ping'), |
|
tt.leaf('ce', '.4f'), |
|
tt.leaf('ppl', '.3f'), |
|
tt.leaf('fad', '.3f'), |
|
tt.leaf('kld', '.3f'), |
|
tt.leaf('text_consistency', '.3f'), |
|
tt.leaf('chroma_cosine', '.3f'), |
|
], |
|
align='>', |
|
), |
|
] |
|
|