Spaces:
Running
Running
File size: 4,618 Bytes
04ffec9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import BaseModel, PrivateAttr
class BaseSerialized(TypedDict):
"""Base class for serialized objects."""
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
"""Serialized constructor."""
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
"""Serialized secret."""
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
"""Serialized not implemented."""
type: Literal["not_implemented"]
class Serializable(BaseModel, ABC):
"""Serializable base class."""
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
return False
@property
def lc_namespace(self) -> List[str]:
"""
Return the namespace of the langchain object.
eg. ["langchain", "llms", "openai"]
"""
return self.__class__.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
"""
Return a map of constructor argument names to secret ids.
eg. {"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {}
class Config:
extra = "ignore"
_lc_kwargs = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.lc_serializable:
return self.to_json_not_implemented()
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self._lc_kwargs.items()
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
# Once we get to Serializable, we're done
if cls is Serializable:
break
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead
for key in secrets.keys():
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None:
lc_kwargs.update({key: secret_value})
return {
"lc": 1,
"type": "constructor",
"id": [*self.lc_namespace, self.__class__.__name__],
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),
}
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
current = result
for part in parts:
if part not in current:
break
current[part] = current[part].copy()
current = current[part]
if last in current:
current[last] = {
"lc": 1,
"type": "secret",
"id": [secret_id],
}
return result
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"""Serialize a "not implemented" object.
Args:
obj: object to serialize
Returns:
SerializedNotImplemented
"""
_id: List[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]
elif hasattr(obj, "__class__"):
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
except Exception:
pass
return {
"lc": 1,
"type": "not_implemented",
"id": _id,
}
|