|
|
|
|
|
""" |
|
Non signal processing related utilities. |
|
""" |
|
|
|
import inspect |
|
import typing as tp |
|
import sys |
|
import time |
|
|
|
|
|
def simple_repr(obj, attrs: tp.Optional[tp.Sequence[str]] = None, |
|
overrides: dict = {}): |
|
""" |
|
Return a simple representation string for `obj`. |
|
If `attrs` is not None, it should be a list of attributes to include. |
|
""" |
|
params = inspect.signature(obj.__class__).parameters |
|
attrs_repr = [] |
|
if attrs is None: |
|
attrs = list(params.keys()) |
|
for attr in attrs: |
|
display = False |
|
if attr in overrides: |
|
value = overrides[attr] |
|
elif hasattr(obj, attr): |
|
value = getattr(obj, attr) |
|
else: |
|
continue |
|
if attr in params: |
|
param = params[attr] |
|
if param.default is inspect._empty or value != param.default: |
|
display = True |
|
else: |
|
display = True |
|
|
|
if display: |
|
attrs_repr.append(f"{attr}={value}") |
|
return f"{obj.__class__.__name__}({','.join(attrs_repr)})" |
|
|
|
|
|
class MarkdownTable: |
|
""" |
|
Simple MarkdownTable generator. The column titles should be large enough |
|
for the lines content. This will right align everything. |
|
|
|
>>> import io # we use io purely for test purposes, default is sys.stdout. |
|
>>> file = io.StringIO() |
|
>>> table = MarkdownTable(["Item Name", "Price"], file=file) |
|
>>> table.header(); table.line(["Honey", "5"]); table.line(["Car", "5,000"]) |
|
>>> print(file.getvalue().strip()) # Strip for test purposes |
|
| Item Name | Price | |
|
|-----------|-------| |
|
| Honey | 5 | |
|
| Car | 5,000 | |
|
""" |
|
def __init__(self, columns, file=sys.stdout): |
|
self.columns = columns |
|
self.file = file |
|
|
|
def _writeln(self, line): |
|
self.file.write("|" + "|".join(line) + "|\n") |
|
|
|
def header(self): |
|
self._writeln(f" {col} " for col in self.columns) |
|
self._writeln("-" * (len(col) + 2) for col in self.columns) |
|
|
|
def line(self, line): |
|
out = [] |
|
for val, col in zip(line, self.columns): |
|
val = format(val, '>' + str(len(col))) |
|
out.append(" " + val + " ") |
|
self._writeln(out) |
|
|
|
|
|
class Chrono: |
|
""" |
|
Measures ellapsed time, calling `torch.cuda.synchronize` if necessary. |
|
`Chrono` instances can be used as context managers (e.g. with `with`). |
|
Upon exit of the block, you can access the duration of the block in seconds |
|
with the `duration` attribute. |
|
|
|
>>> with Chrono() as chrono: |
|
... _ = sum(range(10_000)) |
|
... |
|
>>> print(chrono.duration < 10) # Should be true unless on a really slow computer. |
|
True |
|
""" |
|
def __init__(self): |
|
self.duration = None |
|
|
|
def __enter__(self): |
|
self._begin = time.time() |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, exc_tracebck): |
|
import torch |
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
self.duration = time.time() - self._begin |
|
|