from __future__ import annotations from typing import List, Type, Protocol, TypeVar, Dict, Set import os import json import uuid from domain.domain_protocol import DomainProtocol DomainT = TypeVar('DomainT', bound=DomainProtocol) MAP_BIN = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), ".bin", "maps") class DomainDAO(Protocol[DomainT]): def insert(self, domain_objs: List[DomainT]): ... def read_by_id(self, domain_id: str) -> DomainT: ... def read_all(self) -> Set[DomainT]: ... class InMemDomainDAO(DomainDAO[DomainT]): _id_to_domain_obj: Dict[str, DomainT] def __init__(self): self._id_to_domain_obj = {} def insert(self, domain_objs: List[DomainT]): new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs} if len(new_id_to_domain_obj) != len(domain_objs): raise ValueError("Duplicate IDs exist within incoming domain_objs") if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()): raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}") self._id_to_domain_obj.update(new_id_to_domain_obj) def read_by_id(self, domain_id: str) -> DomainT: if domain_obj := self._id_to_domain_obj.get(domain_id): return domain_obj raise ValueError(f"Domain obj with id {domain_id} not found") def read_all(self) -> Set[DomainT]: return set(self._id_to_domain_obj.values()) @classmethod def load_from_file(cls, file_path: str, domain_cls: Type[DomainT]) -> InMemDomainDAO[DomainT]: if not os.path.isfile(file_path): raise ValueError(f"File not found: {file_path}") with open(file_path, 'r') as f: domain_objs = [domain_cls.from_json(line) for line in f] dao = cls() dao.insert(domain_objs) return dao def save_to_file(self, file_path: str): os.makedirs(os.path.dirname(file_path), exist_ok=True) domain_jsons = [domain_obj.to_json() for domain_obj in self._id_to_domain_obj.values()] with open(file_path, 'w') as f: f.write('\n'.join(domain_jsons) + '\n') class CacheDomainDAO(DomainDAO[DomainT]): _id_to_domain_obj: Dict[str, DomainT] _save_path: str def __init__(self, save_path: str, domain_cls: Type[DomainT]): self._id_to_domain_obj = {} self._save_path = os.path.join(MAP_BIN, save_path) self._load_cache(domain_cls) def __enter__(self): return self def __call__(self, element: DomainT) -> DomainT: self.insert([element]) return element def __exit__(self, exc_type, exc_val, exc_tb): self._save_cache() def set(self, element: DomainT) -> uuid.UUID: id = uuid.uuid4() self._id_to_domain_obj[str(id)] = element self._save_cache() return id def _save_cache(self): os.makedirs(MAP_BIN, exist_ok=True) cache = {} if os.path.isfile(self._save_path): with open(self._save_path, 'r') as f: cache = json.load(f) domain_json_map = { id: domain_obj.to_json() for id, domain_obj in self._id_to_domain_obj.items() } cache.update(domain_json_map) with open(self._save_path, 'w') as f: json.dump(cache, f, indent=4) def _load_cache(self, domain_cls: Type[DomainT]): if not os.path.isfile(self._save_path): return with open(self._save_path, 'r') as f: domain_json_map = json.load(f) for id, domain_json in domain_json_map.items(): self._id_to_domain_obj[id] = domain_cls.from_json(domain_json) def read_by_id(self, domain_id: str) -> DomainT: if domain_obj := self._id_to_domain_obj.get(domain_id): return domain_obj raise ValueError(f"Domain obj with id {domain_id} not found") def read_all(self) -> Set[DomainT]: return set(self._id_to_domain_obj.values()) def insert(self, domain_objs: List[DomainT]): new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs} if len(new_id_to_domain_obj) != len(domain_objs): raise ValueError("Duplicate IDs exist within incoming domain_objs") if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()): raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}") self._id_to_domain_obj.update(new_id_to_domain_obj) self._save_cache()