from __future__ import annotations import hashlib import json import random import sys from functools import partial from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Literal, MutableMapping, Protocol, Sequence, TypedDict, TypeVar, Union, overload, runtime_checkable, ) from typing_extensions import Concatenate, ParamSpec, TypeAlias import narwhals.stable.v1 as nw from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe from narwhals.typing import IntoDataFrame from ._importers import import_pyarrow_interchange from .core import ( DataFrameLike, sanitize_geo_interface, sanitize_narwhals_dataframe, sanitize_pandas_dataframe, to_eager_narwhals_dataframe, ) from .plugin_registry import PluginRegistry if sys.version_info >= (3, 13): from typing import TypeIs else: from typing_extensions import TypeIs if TYPE_CHECKING: import pandas as pd import pyarrow as pa @runtime_checkable class SupportsGeoInterface(Protocol): __geo_interface__: MutableMapping DataType: TypeAlias = Union[ Dict[Any, Any], IntoDataFrame, SupportsGeoInterface, DataFrameLike ] TDataType = TypeVar("TDataType", bound=DataType) TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame) VegaLiteDataDict: TypeAlias = Dict[ str, Union[str, Dict[Any, Any], List[Dict[Any, Any]]] ] ToValuesReturnType: TypeAlias = Dict[str, Union[Dict[Any, Any], List[Dict[Any, Any]]]] SampleReturnType = Union[IntoDataFrame, Dict[str, Sequence], None] def is_data_type(obj: Any) -> TypeIs[DataType]: return _is_pandas_dataframe(obj) or isinstance( obj, (dict, DataFrameLike, SupportsGeoInterface, nw.DataFrame) ) # ============================================================================== # Data transformer registry # # A data transformer is a callable that takes a supported data type and returns # a transformed dictionary version of it which is compatible with the VegaLite schema. # The dict objects will be the Data portion of the VegaLite schema. # # Renderers only deal with the dict form of a # VegaLite spec, after the Data model has been put into a schema compliant # form. # ============================================================================== P = ParamSpec("P") # NOTE: `Any` required due to the complexity of existing signatures imported in `altair.vegalite.v5.data.py` R = TypeVar("R", VegaLiteDataDict, Any) DataTransformerType = Callable[Concatenate[DataType, P], R] class DataTransformerRegistry(PluginRegistry[DataTransformerType, R]): _global_settings = {"consolidate_datasets": True} @property def consolidate_datasets(self) -> bool: return self._global_settings["consolidate_datasets"] @consolidate_datasets.setter def consolidate_datasets(self, value: bool) -> None: self._global_settings["consolidate_datasets"] = value # ============================================================================== class MaxRowsError(Exception): """Raised when a data model has too many rows.""" @overload def limit_rows(data: None = ..., max_rows: int | None = ...) -> partial: ... @overload def limit_rows(data: DataType, max_rows: int | None = ...) -> DataType: ... def limit_rows( data: DataType | None = None, max_rows: int | None = 5000 ) -> partial | DataType: """ Raise MaxRowsError if the data model has more than max_rows. If max_rows is None, then do not perform any check. """ if data is None: return partial(limit_rows, max_rows=max_rows) check_data_type(data) def raise_max_rows_error(): msg = ( "The number of rows in your dataset is greater " f"than the maximum allowed ({max_rows}).\n\n" "Try enabling the VegaFusion data transformer which " "raises this limit by pre-evaluating data\n" "transformations in Python.\n" " >> import altair as alt\n" ' >> alt.data_transformers.enable("vegafusion")\n\n' "Or, see https://altair-viz.github.io/user_guide/large_datasets.html " "for additional information\n" "on how to plot large datasets." ) raise MaxRowsError(msg) if isinstance(data, SupportsGeoInterface): if data.__geo_interface__["type"] == "FeatureCollection": values = data.__geo_interface__["features"] else: values = data.__geo_interface__ elif isinstance(data, dict): if "values" in data: values = data["values"] else: return data else: data = to_eager_narwhals_dataframe(data) values = data if max_rows is not None and len(values) > max_rows: raise_max_rows_error() return data @overload def sample( data: None = ..., n: int | None = ..., frac: float | None = ... ) -> partial: ... @overload def sample( data: TIntoDataFrame, n: int | None = ..., frac: float | None = ... ) -> TIntoDataFrame: ... @overload def sample( data: DataType, n: int | None = ..., frac: float | None = ... ) -> SampleReturnType: ... def sample( data: DataType | None = None, n: int | None = None, frac: float | None = None, ) -> partial | SampleReturnType: """Reduce the size of the data model by sampling without replacement.""" if data is None: return partial(sample, n=n, frac=frac) check_data_type(data) if _is_pandas_dataframe(data): return data.sample(n=n, frac=frac) elif isinstance(data, dict): if "values" in data: values = data["values"] if not n: if frac is None: msg = "frac cannot be None if n is None and data is a dictionary" raise ValueError(msg) n = int(frac * len(values)) values = random.sample(values, n) return {"values": values} else: # Maybe this should raise an error or return something useful? return None data = nw.from_native(data, eager_only=True) if not n: if frac is None: msg = "frac cannot be None if n is None with this data input type" raise ValueError(msg) n = int(frac * len(data)) indices = random.sample(range(len(data)), n) return nw.to_native(data[indices]) _FormatType = Literal["csv", "json"] class _FormatDict(TypedDict): type: _FormatType class _ToFormatReturnUrlDict(TypedDict): url: str format: _FormatDict @overload def to_json( data: None = ..., prefix: str = ..., extension: str = ..., filename: str = ..., urlpath: str = ..., ) -> partial: ... @overload def to_json( data: DataType, prefix: str = ..., extension: str = ..., filename: str = ..., urlpath: str = ..., ) -> _ToFormatReturnUrlDict: ... def to_json( data: DataType | None = None, prefix: str = "altair-data", extension: str = "json", filename: str = "{prefix}-{hash}.{extension}", urlpath: str = "", ) -> partial | _ToFormatReturnUrlDict: """Write the data model to a .json file and return a url based data model.""" kwds = _to_text_kwds(prefix, extension, filename, urlpath) if data is None: return partial(to_json, **kwds) else: data_str = _data_to_json_string(data) return _to_text(data_str, **kwds, format=_FormatDict(type="json")) @overload def to_csv( data: None = ..., prefix: str = ..., extension: str = ..., filename: str = ..., urlpath: str = ..., ) -> partial: ... @overload def to_csv( data: dict | pd.DataFrame | DataFrameLike, prefix: str = ..., extension: str = ..., filename: str = ..., urlpath: str = ..., ) -> _ToFormatReturnUrlDict: ... def to_csv( data: dict | pd.DataFrame | DataFrameLike | None = None, prefix: str = "altair-data", extension: str = "csv", filename: str = "{prefix}-{hash}.{extension}", urlpath: str = "", ) -> partial | _ToFormatReturnUrlDict: """Write the data model to a .csv file and return a url based data model.""" kwds = _to_text_kwds(prefix, extension, filename, urlpath) if data is None: return partial(to_csv, **kwds) else: data_str = _data_to_csv_string(data) return _to_text(data_str, **kwds, format=_FormatDict(type="csv")) def _to_text( data: str, prefix: str, extension: str, filename: str, urlpath: str, format: _FormatDict, ) -> _ToFormatReturnUrlDict: data_hash = _compute_data_hash(data) filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) Path(filename).write_text(data, encoding="utf-8") url = str(Path(urlpath, filename)) return _ToFormatReturnUrlDict({"url": url, "format": format}) def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) -> dict[str, str]: # fmt: skip return {"prefix": prefix, "extension": extension, "filename": filename, "urlpath": urlpath} # fmt: skip def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) # `strict=False` passes `data` through as-is if it is not a Narwhals object. data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): return {"values": _from_geo_interface(data_native)} elif _is_pandas_dataframe(data_native): data_native = sanitize_pandas_dataframe(data_native) return {"values": data_native.to_dict(orient="records")} elif isinstance(data_native, dict): if "values" not in data_native: msg = "values expected in data dict, but not present." raise KeyError(msg) return data_native elif isinstance(data, nw.DataFrame): data = sanitize_narwhals_dataframe(data) return {"values": data.rows(named=True)} else: # Should never reach this state as tested by check_data_type msg = f"Unrecognized data type: {type(data)}" raise ValueError(msg) def check_data_type(data: DataType) -> None: if not is_data_type(data): msg = f"Expected dict, DataFrame or a __geo_interface__ attribute, got: {type(data)}" raise TypeError(msg) # ============================================================================== # Private utilities # ============================================================================== def _compute_data_hash(data_str: str) -> str: return hashlib.sha256(data_str.encode()).hexdigest()[:32] def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]: """ Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed. Notes ----- Split out to resolve typing issues related to: - Intersection types - ``typing.TypeGuard`` - ``pd.DataFrame.__getattr__`` """ if _is_pandas_dataframe(data): data = sanitize_pandas_dataframe(data) return sanitize_geo_interface(data.__geo_interface__) def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data.""" check_data_type(data) if isinstance(data, SupportsGeoInterface): return json.dumps(_from_geo_interface(data)) elif _is_pandas_dataframe(data): data = sanitize_pandas_dataframe(data) return data.to_json(orient="records", double_precision=15) elif isinstance(data, dict): if "values" not in data: msg = "values expected in data dict, but not present." raise KeyError(msg) return json.dumps(data["values"], sort_keys=True) try: data_nw = nw.from_native(data, eager_only=True) except TypeError as exc: msg = "to_json only works with data expressed as a DataFrame or as a dict" raise NotImplementedError(msg) from exc data_nw = sanitize_narwhals_dataframe(data_nw) return json.dumps(data_nw.rows(named=True)) def _data_to_csv_string(data: DataType) -> str: """Return a CSV string representation of the input data.""" check_data_type(data) if isinstance(data, SupportsGeoInterface): msg = ( f"to_csv does not yet work with data that " f"is of type {type(SupportsGeoInterface).__name__!r}.\n" f"See https://github.com/vega/altair/issues/3441" ) raise NotImplementedError(msg) elif _is_pandas_dataframe(data): data = sanitize_pandas_dataframe(data) return data.to_csv(index=False) elif isinstance(data, dict): if "values" not in data: msg = "values expected in data dict, but not present" raise KeyError(msg) try: import pandas as pd except ImportError as exc: msg = "pandas is required to convert a dict to a CSV string" raise ImportError(msg) from exc return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) try: data_nw = nw.from_native(data, eager_only=True) except TypeError as exc: msg = "to_csv only works with data expressed as a DataFrame or as a dict" raise NotImplementedError(msg) from exc return data_nw.write_csv() def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table: """Convert a DataFrame Interchange Protocol compatible object to an Arrow Table.""" import pyarrow as pa # First check if the dataframe object has a method to convert to arrow. # Give this preference over the pyarrow from_dataframe function since the object # has more control over the conversion, and may have broader compatibility. # This is the case for Polars, which supports Date32 columns in direct conversion # while pyarrow does not yet support this type in from_dataframe for convert_method_name in ("arrow", "to_arrow", "to_arrow_table", "to_pyarrow"): convert_method = getattr(dfi_df, convert_method_name, None) if callable(convert_method): result = convert_method() if isinstance(result, pa.Table): return result pi = import_pyarrow_interchange() return pi.from_dataframe(dfi_df)