from typing import Union from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan def labeled_span_to_id(span: Union[LabeledSpan, LabeledMultiSpan]) -> str: if isinstance(span, LabeledSpan): # {type indicator}-{start}-{end}-{label} return f"span-{span.start}-{span.end}-{span.label}" elif isinstance(span, LabeledMultiSpan): # {type indicator}-({start}-{end})*-{label starts_ends = "-".join(f"{start}-{end}" for start, end in span.slices) return f"multispan-{starts_ends}-{span.label}" else: raise ValueError(f"Unsupported span type: {type(span)}") def labeled_span_from_id(span_id: str) -> Union[LabeledSpan, LabeledMultiSpan]: parts = span_id.split("-") if parts[0] == "span": return LabeledSpan(int(parts[1]), int(parts[2]), parts[3]) elif parts[0] == "multispan": label = parts[-1] # this contains: start1, end1, start2, end2, ... starts_ends = parts[1:-1] slices = tuple( (int(start), int(end)) for start, end in zip(starts_ends[::2], starts_ends[1::2]) ) return LabeledMultiSpan(slices, label) else: raise ValueError(f"Unsupported span id: {span_id}")