Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from typing import List, Type, Protocol, TypeVar, Dict, Set | |
import os | |
import json | |
import uuid | |
from domain.domain_protocol import DomainProtocol | |
DomainT = TypeVar('DomainT', bound=DomainProtocol) | |
MAP_BIN = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), ".bin", "maps") | |
class DomainDAO(Protocol[DomainT]): | |
def insert(self, domain_objs: List[DomainT]): | |
... | |
def read_by_id(self, domain_id: str) -> DomainT: | |
... | |
def read_all(self) -> Set[DomainT]: | |
... | |
class InMemDomainDAO(DomainDAO[DomainT]): | |
_id_to_domain_obj: Dict[str, DomainT] | |
def __init__(self): | |
self._id_to_domain_obj = {} | |
def insert(self, domain_objs: List[DomainT]): | |
new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs} | |
if len(new_id_to_domain_obj) != len(domain_objs): | |
raise ValueError("Duplicate IDs exist within incoming domain_objs") | |
if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()): | |
raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}") | |
self._id_to_domain_obj.update(new_id_to_domain_obj) | |
def read_by_id(self, domain_id: str) -> DomainT: | |
if domain_obj := self._id_to_domain_obj.get(domain_id): | |
return domain_obj | |
raise ValueError(f"Domain obj with id {domain_id} not found") | |
def read_all(self) -> Set[DomainT]: | |
return set(self._id_to_domain_obj.values()) | |
def load_from_file(cls, file_path: str, domain_cls: Type[DomainT]) -> InMemDomainDAO[DomainT]: | |
if not os.path.isfile(file_path): | |
raise ValueError(f"File not found: {file_path}") | |
with open(file_path, 'r') as f: | |
domain_objs = [domain_cls.from_json(line) for line in f] | |
dao = cls() | |
dao.insert(domain_objs) | |
return dao | |
def save_to_file(self, file_path: str): | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
domain_jsons = [domain_obj.to_json() for domain_obj in self._id_to_domain_obj.values()] | |
with open(file_path, 'w') as f: | |
f.write('\n'.join(domain_jsons) + '\n') | |
class CacheDomainDAO(DomainDAO[DomainT]): | |
_id_to_domain_obj: Dict[str, DomainT] | |
_save_path: str | |
def __init__(self, save_path: str, domain_cls: Type[DomainT]): | |
self._id_to_domain_obj = {} | |
self._save_path = os.path.join(MAP_BIN, save_path) | |
self._load_cache(domain_cls) | |
def __enter__(self): | |
return self | |
def __call__(self, element: DomainT) -> DomainT: | |
self.insert([element]) | |
return element | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self._save_cache() | |
def set(self, element: DomainT) -> uuid.UUID: | |
id = uuid.uuid4() | |
self._id_to_domain_obj[str(id)] = element | |
self._save_cache() | |
return id | |
def _save_cache(self): | |
os.makedirs(MAP_BIN, exist_ok=True) | |
cache = {} | |
if os.path.isfile(self._save_path): | |
with open(self._save_path, 'r') as f: | |
cache = json.load(f) | |
domain_json_map = { | |
id: domain_obj.to_json() | |
for id, domain_obj in self._id_to_domain_obj.items() | |
} | |
cache.update(domain_json_map) | |
with open(self._save_path, 'w') as f: | |
json.dump(cache, f, indent=4) | |
def _load_cache(self, domain_cls: Type[DomainT]): | |
if not os.path.isfile(self._save_path): | |
return | |
with open(self._save_path, 'r') as f: | |
domain_json_map = json.load(f) | |
for id, domain_json in domain_json_map.items(): | |
self._id_to_domain_obj[id] = domain_cls.from_json(domain_json) | |
def read_by_id(self, domain_id: str) -> DomainT: | |
if domain_obj := self._id_to_domain_obj.get(domain_id): | |
return domain_obj | |
raise ValueError(f"Domain obj with id {domain_id} not found") | |
def read_all(self) -> Set[DomainT]: | |
return set(self._id_to_domain_obj.values()) | |
def insert(self, domain_objs: List[DomainT]): | |
new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs} | |
if len(new_id_to_domain_obj) != len(domain_objs): | |
raise ValueError("Duplicate IDs exist within incoming domain_objs") | |
if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()): | |
raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}") | |
self._id_to_domain_obj.update(new_id_to_domain_obj) | |
self._save_cache() | |