Spaces:
Running
Running
from __future__ import annotations | |
import hashlib | |
import json | |
import random | |
import sys | |
from functools import partial | |
from pathlib import Path | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
Dict, | |
List, | |
Literal, | |
MutableMapping, | |
Protocol, | |
Sequence, | |
TypedDict, | |
TypeVar, | |
Union, | |
overload, | |
runtime_checkable, | |
) | |
from typing_extensions import Concatenate, ParamSpec, TypeAlias | |
import narwhals.stable.v1 as nw | |
from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe | |
from narwhals.typing import IntoDataFrame | |
from ._importers import import_pyarrow_interchange | |
from .core import ( | |
DataFrameLike, | |
sanitize_geo_interface, | |
sanitize_narwhals_dataframe, | |
sanitize_pandas_dataframe, | |
to_eager_narwhals_dataframe, | |
) | |
from .plugin_registry import PluginRegistry | |
if sys.version_info >= (3, 13): | |
from typing import TypeIs | |
else: | |
from typing_extensions import TypeIs | |
if TYPE_CHECKING: | |
import pandas as pd | |
import pyarrow as pa | |
class SupportsGeoInterface(Protocol): | |
__geo_interface__: MutableMapping | |
DataType: TypeAlias = Union[ | |
Dict[Any, Any], IntoDataFrame, SupportsGeoInterface, DataFrameLike | |
] | |
TDataType = TypeVar("TDataType", bound=DataType) | |
TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame) | |
VegaLiteDataDict: TypeAlias = Dict[ | |
str, Union[str, Dict[Any, Any], List[Dict[Any, Any]]] | |
] | |
ToValuesReturnType: TypeAlias = Dict[str, Union[Dict[Any, Any], List[Dict[Any, Any]]]] | |
SampleReturnType = Union[IntoDataFrame, Dict[str, Sequence], None] | |
def is_data_type(obj: Any) -> TypeIs[DataType]: | |
return _is_pandas_dataframe(obj) or isinstance( | |
obj, (dict, DataFrameLike, SupportsGeoInterface, nw.DataFrame) | |
) | |
# ============================================================================== | |
# Data transformer registry | |
# | |
# A data transformer is a callable that takes a supported data type and returns | |
# a transformed dictionary version of it which is compatible with the VegaLite schema. | |
# The dict objects will be the Data portion of the VegaLite schema. | |
# | |
# Renderers only deal with the dict form of a | |
# VegaLite spec, after the Data model has been put into a schema compliant | |
# form. | |
# ============================================================================== | |
P = ParamSpec("P") | |
# NOTE: `Any` required due to the complexity of existing signatures imported in `altair.vegalite.v5.data.py` | |
R = TypeVar("R", VegaLiteDataDict, Any) | |
DataTransformerType = Callable[Concatenate[DataType, P], R] | |
class DataTransformerRegistry(PluginRegistry[DataTransformerType, R]): | |
_global_settings = {"consolidate_datasets": True} | |
def consolidate_datasets(self) -> bool: | |
return self._global_settings["consolidate_datasets"] | |
def consolidate_datasets(self, value: bool) -> None: | |
self._global_settings["consolidate_datasets"] = value | |
# ============================================================================== | |
class MaxRowsError(Exception): | |
"""Raised when a data model has too many rows.""" | |
def limit_rows(data: None = ..., max_rows: int | None = ...) -> partial: ... | |
def limit_rows(data: DataType, max_rows: int | None = ...) -> DataType: ... | |
def limit_rows( | |
data: DataType | None = None, max_rows: int | None = 5000 | |
) -> partial | DataType: | |
""" | |
Raise MaxRowsError if the data model has more than max_rows. | |
If max_rows is None, then do not perform any check. | |
""" | |
if data is None: | |
return partial(limit_rows, max_rows=max_rows) | |
check_data_type(data) | |
def raise_max_rows_error(): | |
msg = ( | |
"The number of rows in your dataset is greater " | |
f"than the maximum allowed ({max_rows}).\n\n" | |
"Try enabling the VegaFusion data transformer which " | |
"raises this limit by pre-evaluating data\n" | |
"transformations in Python.\n" | |
" >> import altair as alt\n" | |
' >> alt.data_transformers.enable("vegafusion")\n\n' | |
"Or, see https://altair-viz.github.io/user_guide/large_datasets.html " | |
"for additional information\n" | |
"on how to plot large datasets." | |
) | |
raise MaxRowsError(msg) | |
if isinstance(data, SupportsGeoInterface): | |
if data.__geo_interface__["type"] == "FeatureCollection": | |
values = data.__geo_interface__["features"] | |
else: | |
values = data.__geo_interface__ | |
elif isinstance(data, dict): | |
if "values" in data: | |
values = data["values"] | |
else: | |
return data | |
else: | |
data = to_eager_narwhals_dataframe(data) | |
values = data | |
if max_rows is not None and len(values) > max_rows: | |
raise_max_rows_error() | |
return data | |
def sample( | |
data: None = ..., n: int | None = ..., frac: float | None = ... | |
) -> partial: ... | |
def sample( | |
data: TIntoDataFrame, n: int | None = ..., frac: float | None = ... | |
) -> TIntoDataFrame: ... | |
def sample( | |
data: DataType, n: int | None = ..., frac: float | None = ... | |
) -> SampleReturnType: ... | |
def sample( | |
data: DataType | None = None, | |
n: int | None = None, | |
frac: float | None = None, | |
) -> partial | SampleReturnType: | |
"""Reduce the size of the data model by sampling without replacement.""" | |
if data is None: | |
return partial(sample, n=n, frac=frac) | |
check_data_type(data) | |
if _is_pandas_dataframe(data): | |
return data.sample(n=n, frac=frac) | |
elif isinstance(data, dict): | |
if "values" in data: | |
values = data["values"] | |
if not n: | |
if frac is None: | |
msg = "frac cannot be None if n is None and data is a dictionary" | |
raise ValueError(msg) | |
n = int(frac * len(values)) | |
values = random.sample(values, n) | |
return {"values": values} | |
else: | |
# Maybe this should raise an error or return something useful? | |
return None | |
data = nw.from_native(data, eager_only=True) | |
if not n: | |
if frac is None: | |
msg = "frac cannot be None if n is None with this data input type" | |
raise ValueError(msg) | |
n = int(frac * len(data)) | |
indices = random.sample(range(len(data)), n) | |
return nw.to_native(data[indices]) | |
_FormatType = Literal["csv", "json"] | |
class _FormatDict(TypedDict): | |
type: _FormatType | |
class _ToFormatReturnUrlDict(TypedDict): | |
url: str | |
format: _FormatDict | |
def to_json( | |
data: None = ..., | |
prefix: str = ..., | |
extension: str = ..., | |
filename: str = ..., | |
urlpath: str = ..., | |
) -> partial: ... | |
def to_json( | |
data: DataType, | |
prefix: str = ..., | |
extension: str = ..., | |
filename: str = ..., | |
urlpath: str = ..., | |
) -> _ToFormatReturnUrlDict: ... | |
def to_json( | |
data: DataType | None = None, | |
prefix: str = "altair-data", | |
extension: str = "json", | |
filename: str = "{prefix}-{hash}.{extension}", | |
urlpath: str = "", | |
) -> partial | _ToFormatReturnUrlDict: | |
"""Write the data model to a .json file and return a url based data model.""" | |
kwds = _to_text_kwds(prefix, extension, filename, urlpath) | |
if data is None: | |
return partial(to_json, **kwds) | |
else: | |
data_str = _data_to_json_string(data) | |
return _to_text(data_str, **kwds, format=_FormatDict(type="json")) | |
def to_csv( | |
data: None = ..., | |
prefix: str = ..., | |
extension: str = ..., | |
filename: str = ..., | |
urlpath: str = ..., | |
) -> partial: ... | |
def to_csv( | |
data: dict | pd.DataFrame | DataFrameLike, | |
prefix: str = ..., | |
extension: str = ..., | |
filename: str = ..., | |
urlpath: str = ..., | |
) -> _ToFormatReturnUrlDict: ... | |
def to_csv( | |
data: dict | pd.DataFrame | DataFrameLike | None = None, | |
prefix: str = "altair-data", | |
extension: str = "csv", | |
filename: str = "{prefix}-{hash}.{extension}", | |
urlpath: str = "", | |
) -> partial | _ToFormatReturnUrlDict: | |
"""Write the data model to a .csv file and return a url based data model.""" | |
kwds = _to_text_kwds(prefix, extension, filename, urlpath) | |
if data is None: | |
return partial(to_csv, **kwds) | |
else: | |
data_str = _data_to_csv_string(data) | |
return _to_text(data_str, **kwds, format=_FormatDict(type="csv")) | |
def _to_text( | |
data: str, | |
prefix: str, | |
extension: str, | |
filename: str, | |
urlpath: str, | |
format: _FormatDict, | |
) -> _ToFormatReturnUrlDict: | |
data_hash = _compute_data_hash(data) | |
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) | |
Path(filename).write_text(data, encoding="utf-8") | |
url = str(Path(urlpath, filename)) | |
return _ToFormatReturnUrlDict({"url": url, "format": format}) | |
def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) -> dict[str, str]: # fmt: skip | |
return {"prefix": prefix, "extension": extension, "filename": filename, "urlpath": urlpath} # fmt: skip | |
def to_values(data: DataType) -> ToValuesReturnType: | |
"""Replace a DataFrame by a data model with values.""" | |
check_data_type(data) | |
# `strict=False` passes `data` through as-is if it is not a Narwhals object. | |
data_native = nw.to_native(data, strict=False) | |
if isinstance(data_native, SupportsGeoInterface): | |
return {"values": _from_geo_interface(data_native)} | |
elif _is_pandas_dataframe(data_native): | |
data_native = sanitize_pandas_dataframe(data_native) | |
return {"values": data_native.to_dict(orient="records")} | |
elif isinstance(data_native, dict): | |
if "values" not in data_native: | |
msg = "values expected in data dict, but not present." | |
raise KeyError(msg) | |
return data_native | |
elif isinstance(data, nw.DataFrame): | |
data = sanitize_narwhals_dataframe(data) | |
return {"values": data.rows(named=True)} | |
else: | |
# Should never reach this state as tested by check_data_type | |
msg = f"Unrecognized data type: {type(data)}" | |
raise ValueError(msg) | |
def check_data_type(data: DataType) -> None: | |
if not is_data_type(data): | |
msg = f"Expected dict, DataFrame or a __geo_interface__ attribute, got: {type(data)}" | |
raise TypeError(msg) | |
# ============================================================================== | |
# Private utilities | |
# ============================================================================== | |
def _compute_data_hash(data_str: str) -> str: | |
return hashlib.sha256(data_str.encode()).hexdigest()[:32] | |
def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]: | |
""" | |
Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed. | |
Notes | |
----- | |
Split out to resolve typing issues related to: | |
- Intersection types | |
- ``typing.TypeGuard`` | |
- ``pd.DataFrame.__getattr__`` | |
""" | |
if _is_pandas_dataframe(data): | |
data = sanitize_pandas_dataframe(data) | |
return sanitize_geo_interface(data.__geo_interface__) | |
def _data_to_json_string(data: DataType) -> str: | |
"""Return a JSON string representation of the input data.""" | |
check_data_type(data) | |
if isinstance(data, SupportsGeoInterface): | |
return json.dumps(_from_geo_interface(data)) | |
elif _is_pandas_dataframe(data): | |
data = sanitize_pandas_dataframe(data) | |
return data.to_json(orient="records", double_precision=15) | |
elif isinstance(data, dict): | |
if "values" not in data: | |
msg = "values expected in data dict, but not present." | |
raise KeyError(msg) | |
return json.dumps(data["values"], sort_keys=True) | |
try: | |
data_nw = nw.from_native(data, eager_only=True) | |
except TypeError as exc: | |
msg = "to_json only works with data expressed as a DataFrame or as a dict" | |
raise NotImplementedError(msg) from exc | |
data_nw = sanitize_narwhals_dataframe(data_nw) | |
return json.dumps(data_nw.rows(named=True)) | |
def _data_to_csv_string(data: DataType) -> str: | |
"""Return a CSV string representation of the input data.""" | |
check_data_type(data) | |
if isinstance(data, SupportsGeoInterface): | |
msg = ( | |
f"to_csv does not yet work with data that " | |
f"is of type {type(SupportsGeoInterface).__name__!r}.\n" | |
f"See https://github.com/vega/altair/issues/3441" | |
) | |
raise NotImplementedError(msg) | |
elif _is_pandas_dataframe(data): | |
data = sanitize_pandas_dataframe(data) | |
return data.to_csv(index=False) | |
elif isinstance(data, dict): | |
if "values" not in data: | |
msg = "values expected in data dict, but not present" | |
raise KeyError(msg) | |
try: | |
import pandas as pd | |
except ImportError as exc: | |
msg = "pandas is required to convert a dict to a CSV string" | |
raise ImportError(msg) from exc | |
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) | |
try: | |
data_nw = nw.from_native(data, eager_only=True) | |
except TypeError as exc: | |
msg = "to_csv only works with data expressed as a DataFrame or as a dict" | |
raise NotImplementedError(msg) from exc | |
return data_nw.write_csv() | |
def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table: | |
"""Convert a DataFrame Interchange Protocol compatible object to an Arrow Table.""" | |
import pyarrow as pa | |
# First check if the dataframe object has a method to convert to arrow. | |
# Give this preference over the pyarrow from_dataframe function since the object | |
# has more control over the conversion, and may have broader compatibility. | |
# This is the case for Polars, which supports Date32 columns in direct conversion | |
# while pyarrow does not yet support this type in from_dataframe | |
for convert_method_name in ("arrow", "to_arrow", "to_arrow_table", "to_pyarrow"): | |
convert_method = getattr(dfi_df, convert_method_name, None) | |
if callable(convert_method): | |
result = convert_method() | |
if isinstance(result, pa.Table): | |
return result | |
pi = import_pyarrow_interchange() | |
return pi.from_dataframe(dfi_df) | |