H2OTest / llm_studio /src /loggers.py
elineve's picture
Upload 301 files
07423df
import dataclasses
import logging
import os
from typing import Any, Dict, List, Optional
import numpy as np
from sqlitedict import SqliteDict
__all__ = ["Loggers"]
from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS
logger = logging.getLogger(__name__)
def get_cfg(cfg: Any) -> Dict:
"""Returns simplified config elements
Args:
cfg: configuration
Returns:
Dict of config elements
"""
items: Dict = {}
type_annotations = cfg.get_annotations()
cfg_dict = cfg.__dict__
cfg_dict = {key: cfg_dict[key] for key in cfg._get_order(warn_if_unset=False)}
for k, v in cfg_dict.items():
if k.startswith("_") or cfg._get_visibility(k) < 0:
continue
if any([x in k for x in ["api"]]):
continue
if dataclasses.is_dataclass(v):
elements_group = get_cfg(cfg=v)
t = elements_group
items = {**items, **t}
else:
type_annotation = type_annotations[k]
if type_annotation == float:
items[k] = float(v)
else:
items[k] = v
return items
class NeptuneLogger:
def __init__(self, cfg: Any):
import neptune as neptune
from neptune.utils import stringify_unsupported
if cfg.logging._neptune_debug:
mode = "debug"
else:
mode = "async"
self.logger = neptune.init_run(
project=cfg.logging.neptune_project,
api_token=os.getenv("NEPTUNE_API_TOKEN", ""),
name=cfg.experiment_name,
mode=mode,
capture_stdout=False,
capture_stderr=False,
source_files=[],
)
self.logger["cfg"] = stringify_unsupported(get_cfg(cfg))
def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):
name = f"{subset}/{name}"
self.logger[name].append(value, step=step)
class LocalLogger:
def __init__(self, cfg: Any):
logging.getLogger("sqlitedict").setLevel(logging.ERROR)
self.logs = f"{cfg.output_directory}/charts.db"
params = get_cfg(cfg)
with SqliteDict(self.logs) as logs:
logs["cfg"] = params
logs.commit()
def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):
if subset in PLOT_ENCODINGS:
with SqliteDict(self.logs) as logs:
if subset not in logs:
subset_dict = dict()
else:
subset_dict = logs[subset]
subset_dict[name] = value
logs[subset] = subset_dict
logs.commit()
return
# https://github.com/h2oai/wave/issues/447
if np.isnan(value):
value = None
else:
value = float(value)
with SqliteDict(self.logs) as logs:
if subset not in logs:
subset_dict = dict()
else:
subset_dict = logs[subset]
if name not in subset_dict:
subset_dict[name] = {"steps": [], "values": []}
subset_dict[name]["steps"].append(step)
subset_dict[name]["values"].append(value)
logs[subset] = subset_dict
logs.commit()
class DummyLogger:
def __init__(self, cfg: Optional[Any] = None):
return
def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):
return
class MainLogger:
"""Main logger"""
def __init__(self, cfg: Any):
self.loggers = {
"local": LocalLogger(cfg),
"external": Loggers.get(cfg.logging.logger),
}
try:
self.loggers["external"] = self.loggers["external"](cfg)
except Exception as e:
logger.warning(
f"Error when initializing logger. "
f"Disabling custom logging functionality. "
f"Please ensure logger configuration is correct and "
f"you have a stable Internet connection: {e}"
)
self.loggers["external"] = DummyLogger(cfg)
def reset_external(self):
self.loggers["external"] = DummyLogger()
def log(self, subset: str, name: str, value: str | float, step: float = None):
for k, logger in self.loggers.items():
if "validation_predictions" in name and k == "external":
continue
if subset == "internal" and not isinstance(logger, LocalLogger):
continue
logger.log(subset=subset, name=name, value=value, step=step)
class Loggers:
"""Loggers factory."""
_loggers = {"None": DummyLogger, "Neptune": NeptuneLogger}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._loggers.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Loggers.
Args:
name: loggers name
Returns:
A class to build the Loggers
"""
return cls._loggers.get(name, DummyLogger)