Clemspace's picture
Initial model upload
cb9e677
raw
history blame
No virus
2.01 kB
import contextlib
import dataclasses
import datetime
import logging
import time
from typing import Optional, Protocol
import torch
logger = logging.getLogger("utils")
@dataclasses.dataclass
class TrainState:
max_steps: int
step: int = 0
elapsed_time: float = 0.0
n_seen_tokens: int = 0
this_step_time: float = 0.0
begin_step_time: float = 0.0
this_eval_perplexity: Optional[float] = None
this_eval_loss: Optional[float] = None
def start_step(self):
self.step += 1
self.begin_step_time = time.time()
def end_step(self, n_batch_tokens: int):
self.this_step_time = time.time() - self.begin_step_time
self.this_step_tokens = n_batch_tokens
self.elapsed_time += self.this_step_time
self.n_seen_tokens += self.this_step_tokens
self.begin_step_time = time.time()
@property
def wps(self):
return self.this_step_tokens / self.this_step_time
@property
def avg_wps(self):
return self.n_seen_tokens / self.elapsed_time
@property
def eta(self):
steps_left = self.max_steps - self.step
avg_time_per_step = self.elapsed_time / self.step
return steps_left * avg_time_per_step
def set_random_seed(seed: int) -> None:
"""Set random seed for reproducibility."""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class Closable(Protocol):
def close(self):
pass
@contextlib.contextmanager
def logged_closing(thing: Closable, name: str):
"""
Logging the closing to be sure something is not hanging at exit time
"""
try:
setattr(thing, "wrapped_by_closing", True)
yield
finally:
logger.info(f"Closing: {name}")
try:
thing.close()
except Exception:
logger.error(f"Error while closing {name}!")
raise
logger.info(f"Closed: {name}")
def now_as_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")