nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
raw
history blame
4.03 kB
"""Interface for implementing a source."""
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Type, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from pydantic import BaseModel
if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
from ..schema import (
Field,
ImageInfo,
Item,
Schema,
arrow_dtype_to_dtype,
arrow_schema_to_schema,
field,
)
class SourceSchema(BaseModel):
"""The schema of a source."""
fields: dict[str, Field]
num_items: Optional[int] = None
class SourceProcessResult(BaseModel):
"""The result after processing all the shards of a source dataset."""
filepaths: list[str]
data_schema: Schema
num_items: int
images: Optional[list[ImageInfo]] = None
class Source(BaseModel):
"""Interface for sources to implement. A source processes a set of shards and writes files."""
# ClassVars do not get serialized with pydantic.
name: ClassVar[str]
def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> dict[str, Any]:
"""Override the default dict method to add `source_name`."""
res = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
res['source_name'] = self.name
return res
class Config:
underscore_attrs_are_private = True
@staticmethod
def schema_extra(schema: dict[str, Any], source: Type['Source']) -> None:
"""Add the title to the schema from the display name and name.
Pydantic defaults this to the class name.
"""
signal_prop: dict[str, Any]
if hasattr(source, 'name'):
signal_prop = {'enum': [source.name]}
else:
signal_prop = {'type': 'string'}
schema['properties'] = {'source_name': signal_prop, **schema['properties']}
if 'required' not in schema:
schema['required'] = []
schema['required'].append('source_name')
def source_schema(self) -> SourceSchema:
"""Return the source schema for this source.
Returns
A SourceSchema with
fields: mapping top-level columns to fields that describes the schema of the source.
num_items: the number of items in the source, used for progress.
"""
raise NotImplementedError
def setup(self) -> None:
"""Prepare the source for processing.
This allows the source to do setup outside the constructor, but before its processed. This
avoids potentially expensive computation the pydantic model is deserialized.
"""
pass
def teardown(self) -> None:
"""Tears down the source after processing."""
pass
def process(self) -> Iterable[Item]:
"""Process the source upload request.
Args:
task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the
progress of the task.
"""
raise NotImplementedError
def schema_from_df(df: pd.DataFrame, index_colname: str) -> SourceSchema:
"""Create a source schema from a dataframe."""
index_np_dtype = df.index.dtype
# String index dtypes are stored as objects.
if index_np_dtype == np.dtype(object):
index_np_dtype = np.dtype(str)
index_dtype = arrow_dtype_to_dtype(pa.from_numpy_dtype(index_np_dtype))
schema = arrow_schema_to_schema(pa.Schema.from_pandas(df, preserve_index=False))
return SourceSchema(
fields={
**schema.fields, index_colname: field(dtype=index_dtype)
}, num_items=len(df))
def normalize_column_name(name: str) -> str:
"""Normalize a column name."""
return name
#return name.replace(' ', '_').replace(':', '_').replace('.', '_')