Source code for betty.model

"""Provide Betty's data model API."""

from __future__ import annotations

import builtins
import functools
import weakref
from collections import defaultdict
from contextlib import contextmanager
from reprlib import recursive_repr
from typing import (
    TypeVar,
    Generic,
    Iterable,
    Any,
    overload,
    cast,
    Iterator,
    Callable,
    Self,
    TypeAlias,
    TYPE_CHECKING,
)
from uuid import uuid4

from betty.classtools import repr_instance
from betty.importlib import import_any, fully_qualified_type_name
from betty.json.linked_data import LinkedDataDumpable, add_json_ld
from betty.json.schema import ref_json_schema
from betty.locale import Str
from betty.serde.dump import DictDump, Dump
from betty.string import camel_case_to_kebab_case, upper_camel_case_to_lower_camel_case

if TYPE_CHECKING:
    from betty.app import App


T = TypeVar("T")


[docs] class GeneratedEntityId(str): """ Generate a unique entity ID. Entities must have IDs for identification. However, not all entities can be provided with an ID that exists in the original data set (such as a third-party family tree loaded into Betty), so IDs can be generated. """ def __new__(cls, entity_id: str | None = None): return super().__new__(cls, entity_id or str(uuid4()))
[docs] class Entity(LinkedDataDumpable): def __init__( self, id: str | None = None, *args: Any, **kwargs: Any, ): self._id = GeneratedEntityId() if id is None else id super().__init__(*args, **kwargs) def __hash__(self) -> int: return hash(self.ancestry_id)
[docs] @classmethod def entity_type_label(cls) -> Str: raise NotImplementedError(repr(cls))
[docs] @classmethod def entity_type_label_plural(cls) -> Str: raise NotImplementedError(repr(cls))
@recursive_repr() def __repr__(self) -> str: return repr_instance(self, id=self._id) @property def type(self) -> builtins.type[Self]: return self.__class__ @property def id(self) -> str: return self._id @property def ancestry_id(self) -> tuple[builtins.type[Self], str]: return self.type, self.id @property def label(self) -> Str: return Str._( "{entity_type} {entity_id}", entity_type=self.entity_type_label(), entity_id=self.id, )
[docs] async def dump_linked_data(self, app: App) -> DictDump[Dump]: dump = await super().dump_linked_data(app) entity_type_name = get_entity_type_name(self.type) dump["$schema"] = app.static_url_generator.generate( f"schema.json#/definitions/entity/{upper_camel_case_to_lower_camel_case(entity_type_name)}", absolute=True, ) if not isinstance(self.id, GeneratedEntityId): dump["@id"] = app.static_url_generator.generate( f"/{camel_case_to_kebab_case(entity_type_name)}/{self.id}/index.json", absolute=True, ) dump["id"] = self.id return dump
[docs] @classmethod async def linked_data_schema(cls, app: App) -> DictDump[Dump]: schema = await super().linked_data_schema(app) schema["type"] = "object" schema["properties"] = { "$schema": ref_json_schema(schema), "id": { "type": "string", }, } schema["required"] = [ "$schema", ] schema["additionalProperties"] = False add_json_ld(schema) return schema
AncestryEntityId: TypeAlias = tuple[type[Entity], str]
[docs] class UserFacingEntity: pass
[docs] class EntityTypeProvider:
[docs] async def entity_types(self) -> set[type[Entity]]: raise NotImplementedError(repr(self))
EntityT = TypeVar("EntityT", bound=Entity) EntityU = TypeVar("EntityU", bound=Entity) TargetT = TypeVar("TargetT") OwnerT = TypeVar("OwnerT") AssociateT = TypeVar("AssociateT") AssociateU = TypeVar("AssociateU") LeftAssociateT = TypeVar("LeftAssociateT") RightAssociateT = TypeVar("RightAssociateT")
[docs] def get_entity_type_name(entity_type_definition: type[Entity] | Entity) -> str: """ Get the entity type name for an entity or entity type. """ if isinstance(entity_type_definition, Entity): entity_type = entity_type_definition.type else: entity_type = entity_type_definition if entity_type.__module__.startswith("betty.model.ancestry"): return entity_type.__name__ return f"{entity_type.__module__}.{entity_type.__name__}"
[docs] def get_entity_type(entity_type_name: str) -> type[Entity]: """ Get the entity type for an entity type name. """ try: return import_any(entity_type_name) # type: ignore[no-any-return] except ImportError: try: return import_any(f"betty.model.ancestry.{entity_type_name}") # type: ignore[no-any-return] except ImportError: raise EntityTypeImportError(entity_type_name) from None
[docs] class EntityTypeError(ValueError): pass
[docs] class EntityTypeImportError(EntityTypeError, ImportError): """ Raised when an alleged entity type cannot be imported. """ def __init__(self, entity_type_name: str): super().__init__( f'Cannot find and import an entity with name "{entity_type_name}".' )
[docs] class EntityTypeInvalidError(EntityTypeError, ImportError): """ Raised for types that are not valid entity types. """ def __init__(self, entity_type: type): super().__init__( f"{entity_type.__module__}.{entity_type.__name__} is not an entity type class. Entity types must extend {Entity.__module__}.{Entity.__name__} directly." )
[docs] class EntityCollection(Generic[TargetT]): __slots__ = () def __init__(self): super().__init__() def _on_add(self, *entities: TargetT & Entity) -> None: pass def _on_remove(self, *entities: TargetT & Entity) -> None: pass @property def view(self) -> list[TargetT & Entity]: return [*self]
[docs] def add(self, *entities: TargetT & Entity) -> None: raise NotImplementedError(repr(self))
[docs] def remove(self, *entities: TargetT & Entity) -> None: raise NotImplementedError(repr(self))
[docs] def replace(self, *entities: TargetT & Entity) -> None: self.remove(*(entity for entity in self if entity not in entities)) self.add(*entities)
[docs] def clear(self) -> None: raise NotImplementedError(repr(self))
def __iter__(self) -> Iterator[TargetT & Entity]: raise NotImplementedError(repr(self)) def __len__(self) -> int: raise NotImplementedError(repr(self)) @overload def __getitem__(self, index: int) -> TargetT & Entity: pass @overload def __getitem__(self, indices: slice) -> list[TargetT & Entity]: pass def __getitem__( self, key: int | slice ) -> TargetT & Entity | list[TargetT & Entity]: raise NotImplementedError(repr(self)) def __delitem__(self, key: TargetT & Entity) -> None: raise NotImplementedError(repr(self)) def __contains__(self, value: Any) -> bool: raise NotImplementedError(repr(self)) def _known(self, *entities: TargetT & Entity) -> Iterable[TargetT & Entity]: seen = [] for entity in entities: if entity in self and entity not in seen: yield entity seen.append(entity) def _unknown(self, *entities: TargetT & Entity) -> Iterable[TargetT & Entity]: seen = [] for entity in entities: if entity not in self and entity not in seen: yield entity seen.append(entity)
EntityCollectionT = TypeVar("EntityCollectionT", bound=EntityCollection[EntityT]) class _EntityTypeAssociation(Generic[OwnerT, AssociateT]): def __init__( self, owner_type: type[OwnerT], owner_attr_name: str, associate_type_name: str, ): self._owner_type = owner_type self._owner_attr_name = owner_attr_name self._owner_private_attr_name = f"_{owner_attr_name}" self._associate_type_name = associate_type_name self._associate_type: type[AssociateT] | None = None def __hash__(self) -> int: return hash( ( self._owner_type, self._owner_attr_name, self._associate_type_name, ) ) def __repr__(self) -> str: return repr_instance( self, owner_type=self._owner_type, owner_attr_name=self._owner_attr_name, associate_type_name=self._associate_type_name, ) @property def owner_type(self) -> type[OwnerT]: return self._owner_type @property def owner_attr_name(self) -> str: return self._owner_attr_name @property def associate_type(self) -> type[AssociateT]: if self._associate_type is None: self._associate_type = import_any(self._associate_type_name) return self._associate_type def register( # type: ignore[misc] self: ToAny[OwnerT, AssociateT], ) -> None: EntityTypeAssociationRegistry._register(self) original_init = self._owner_type.__init__ @functools.wraps(original_init) def _init(owner: OwnerT & Entity, *args: Any, **kwargs: Any) -> None: self.initialize(owner) original_init(owner, *args, **kwargs) self._owner_type.__init__ = _init # type: ignore[assignment, method-assign] def initialize(self, owner: OwnerT & Entity) -> None: raise NotImplementedError(repr(self)) def finalize(self, owner: OwnerT & Entity) -> None: self.delete(owner) delattr(owner, self._owner_private_attr_name) def delete(self, owner: OwnerT & Entity) -> None: raise NotImplementedError(repr(self)) def associate(self, owner: OwnerT & Entity, associate: AssociateT & Entity) -> None: raise NotImplementedError(repr(self)) def disassociate( self, owner: OwnerT & Entity, associate: AssociateT & Entity ) -> None: raise NotImplementedError(repr(self))
[docs] class BidirectionalEntityTypeAssociation( Generic[OwnerT, AssociateT], _EntityTypeAssociation[OwnerT, AssociateT] ): def __init__( self, owner_type: type[OwnerT], owner_attr_name: str, associate_type_name: str, associate_attr_name: str, ): super().__init__( owner_type, owner_attr_name, associate_type_name, ) self._associate_attr_name = associate_attr_name def __hash__(self) -> int: return hash( ( self._owner_type, self._owner_attr_name, self._associate_type_name, self._associate_attr_name, ) ) def __repr__(self) -> str: return repr_instance( self, owner_type=self._owner_type, owner_attr_name=self._owner_attr_name, associate_type_name=self._associate_type_name, associate_attr_name=self._associate_attr_name, ) @property def associate_attr_name(self) -> str: return self._associate_attr_name
[docs] def inverse(self) -> BidirectionalEntityTypeAssociation[AssociateT, OwnerT]: association = EntityTypeAssociationRegistry.get_association( self.associate_type, self.associate_attr_name ) assert isinstance(association, BidirectionalEntityTypeAssociation) return association
[docs] class ToOneEntityTypeAssociation( Generic[OwnerT, AssociateT], _EntityTypeAssociation[OwnerT, AssociateT] ):
[docs] def register(self) -> None: super().register() setattr( self.owner_type, self.owner_attr_name, property( self.get, self.set, self.delete, ), )
[docs] def initialize(self, owner: OwnerT & Entity) -> None: setattr(owner, self._owner_private_attr_name, None)
[docs] def get(self, owner: OwnerT & Entity) -> AssociateT & Entity | None: return getattr(owner, self._owner_private_attr_name) # type: ignore[no-any-return]
[docs] def set( self, owner: OwnerT & Entity, associate: AssociateT & Entity | None ) -> None: setattr(owner, self._owner_private_attr_name, associate)
[docs] def delete(self, owner: OwnerT & Entity) -> None: self.set(owner, None)
[docs] def associate(self, owner: OwnerT & Entity, associate: AssociateT & Entity) -> None: self.set(owner, associate)
[docs] def disassociate( self, owner: OwnerT & Entity, associate: AssociateT & Entity ) -> None: if associate == self.get(owner): self.delete(owner)
[docs] class ToManyEntityTypeAssociation( Generic[OwnerT, AssociateT], _EntityTypeAssociation[OwnerT, AssociateT] ):
[docs] def register(self) -> None: super().register() setattr( self.owner_type, self.owner_attr_name, property( self.get, self.set, self.delete, ), )
[docs] def get(self, owner: OwnerT & Entity) -> EntityCollection[AssociateT & Entity]: return cast( EntityCollection["AssociateT & Entity"], getattr(owner, self._owner_private_attr_name), )
[docs] def set( self, owner: OwnerT & Entity, entities: Iterable[AssociateT & Entity] ) -> None: self.get(owner).replace(*entities)
[docs] def delete(self, owner: OwnerT & Entity) -> None: self.get(owner).clear()
[docs] def associate(self, owner: OwnerT & Entity, associate: AssociateT & Entity) -> None: self.get(owner).add(associate)
[docs] def disassociate( self, owner: OwnerT & Entity, associate: AssociateT & Entity ) -> None: self.get(owner).remove(associate)
[docs] class BidirectionalToOneEntityTypeAssociation( Generic[OwnerT, AssociateT], ToOneEntityTypeAssociation[OwnerT, AssociateT], BidirectionalEntityTypeAssociation[OwnerT, AssociateT], ):
[docs] def set( self, owner: OwnerT & Entity, associate: AssociateT & Entity | None ) -> None: previous_associate = self.get(owner) if previous_associate == associate: return super().set(owner, associate) if previous_associate is not None: self.inverse().disassociate(previous_associate, owner) if associate is not None: self.inverse().associate(associate, owner)
[docs] class BidirectionalToManyEntityTypeAssociation( Generic[OwnerT, AssociateT], ToManyEntityTypeAssociation[OwnerT, AssociateT], BidirectionalEntityTypeAssociation[OwnerT, AssociateT], ):
[docs] def initialize(self, owner: OwnerT & Entity) -> None: setattr( owner, self._owner_private_attr_name, _BidirectionalAssociateCollection( owner, self, ), )
[docs] class ToOne( Generic[OwnerT, AssociateT], ToOneEntityTypeAssociation[OwnerT, AssociateT] ): pass
[docs] class OneToOne( Generic[OwnerT, AssociateT], BidirectionalToOneEntityTypeAssociation[OwnerT, AssociateT], ): pass
[docs] class ManyToOne( Generic[OwnerT, AssociateT], BidirectionalToOneEntityTypeAssociation[OwnerT, AssociateT], ): pass
[docs] class ToMany( Generic[OwnerT, AssociateT], ToManyEntityTypeAssociation[OwnerT, AssociateT] ):
[docs] def initialize(self, owner: OwnerT & Entity) -> None: setattr( owner, self._owner_private_attr_name, SingleTypeEntityCollection[AssociateT](self.associate_type), )
[docs] class OneToMany( Generic[OwnerT, AssociateT], BidirectionalToManyEntityTypeAssociation[OwnerT, AssociateT], ): pass
[docs] class ManyToMany( Generic[OwnerT, AssociateT], BidirectionalToManyEntityTypeAssociation[OwnerT, AssociateT], ): pass
ToAny: TypeAlias = ( ToOneEntityTypeAssociation[OwnerT, AssociateT] | ToManyEntityTypeAssociation[OwnerT, AssociateT] )
[docs] def to_one( owner_attr_name: str, associate_type_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a unidirectional to-one association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: ToOne( owner_type, owner_attr_name, associate_type_name, ).register() return owner_type return _decorator
[docs] def one_to_one( owner_attr_name: str, associate_type_name: str, associate_attr_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a bidirectional one-to-one association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: OneToOne( owner_type, owner_attr_name, associate_type_name, associate_attr_name, ).register() return owner_type return _decorator
[docs] def many_to_one( owner_attr_name: str, associate_type_name: str, associate_attr_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a bidirectional many-to-one association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: ManyToOne( owner_type, owner_attr_name, associate_type_name, associate_attr_name, ).register() return owner_type return _decorator
[docs] def to_many( owner_attr_name: str, associate_type_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a unidirectional to-many association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: ToMany( owner_type, owner_attr_name, associate_type_name, ).register() return owner_type return _decorator
[docs] def one_to_many( owner_attr_name: str, associate_type_name: str, associate_attr_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a bidirectional one-to-many association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: OneToMany( owner_type, owner_attr_name, associate_type_name, associate_attr_name, ).register() return owner_type return _decorator
[docs] def many_to_many( owner_attr_name: str, associate_type_name: str, associate_attr_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a bidirectional many-to-many association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: ManyToMany( owner_type, owner_attr_name, associate_type_name, associate_attr_name, ).register() return owner_type return _decorator
[docs] def many_to_one_to_many( left_associate_type_name: str, left_associate_attr_name: str, left_owner_attr_name: str, right_owner_attr_name: str, right_associate_type_name: str, right_associate_attr_name: str, ) -> Callable[[type[OwnerT]], type[OwnerT]]: """ Add a bidirectional many-to-one-to-many association to an entity or entity mixin. """ def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]: ManyToOne( owner_type, left_owner_attr_name, left_associate_type_name, left_associate_attr_name, ).register() ManyToOne( owner_type, right_owner_attr_name, right_associate_type_name, right_associate_attr_name, ).register() return owner_type return _decorator
[docs] class EntityTypeAssociationRegistry: _associations = set[ToAny[Any, Any]]()
[docs] @classmethod def get_all_associations(cls, owner: type | object) -> set[ToAny[Any, Any]]: owner_type = owner if isinstance(owner, type) else type(owner) return { association for association in cls._associations if association.owner_type in owner_type.__mro__ }
[docs] @classmethod def get_association( cls, owner: type[OwnerT] | OwnerT & Entity, owner_attr_name: str ) -> ToAny[OwnerT, Any]: for association in cls.get_all_associations(owner): if association.owner_attr_name == owner_attr_name: return association raise ValueError( f"No association exists for {fully_qualified_type_name(owner if isinstance(owner, type) else owner.__class__)}.{owner_attr_name}." )
[docs] @classmethod def get_associates( cls, owner: EntityT, association: ToAny[EntityT, AssociateT] ) -> Iterable[AssociateT]: associates: AssociateT | None | Iterable[AssociateT] = getattr( owner, f"_{association.owner_attr_name}" ) if isinstance(association, ToOneEntityTypeAssociation): if associates is None: return yield cast(AssociateT, associates) return yield from cast(Iterable[AssociateT], associates)
@classmethod def _register(cls, association: ToAny[Any, Any]) -> None: cls._associations.add(association)
[docs] @classmethod def initialize(cls, *owners: Entity) -> None: for owner in owners: for association in cls.get_all_associations(owner): association.initialize(owner)
[docs] @classmethod def finalize(cls, *owners: Entity) -> None: for owner in owners: for association in cls.get_all_associations(owner): association.finalize(owner)
[docs] class SingleTypeEntityCollection(Generic[TargetT], EntityCollection[TargetT]): __slots__ = "_entities", "_target_type" def __init__( self, target_type: type[TargetT], ): super().__init__() self._entities: list[TargetT & Entity] = [] self._target_type = target_type @recursive_repr() def __repr__(self) -> str: return repr_instance(self, target_type=self._target_type, length=len(self))
[docs] def add(self, *entities: TargetT & Entity) -> None: added_entities = [*self._unknown(*entities)] for entity in added_entities: self._entities.append(entity) if added_entities: self._on_add(*added_entities)
[docs] def remove(self, *entities: TargetT & Entity) -> None: removed_entities = [*self._known(*entities)] for entity in removed_entities: self._entities.remove(entity) if removed_entities: self._on_remove(*removed_entities)
[docs] def clear(self) -> None: self.remove(*self)
def __iter__(self) -> Iterator[TargetT & Entity]: return self._entities.__iter__() def __len__(self) -> int: return len(self._entities) @overload def __getitem__(self, index: int) -> TargetT & Entity: pass @overload def __getitem__(self, indices: slice) -> list[TargetT & Entity]: pass @overload def __getitem__(self, entity_id: str) -> TargetT & Entity: pass def __getitem__( self, key: int | slice | str ) -> TargetT & Entity | list[TargetT & Entity]: if isinstance(key, int): return self._getitem_by_index(key) if isinstance(key, slice): return self._getitem_by_indices(key) return self._getitem_by_entity_id(key) def _getitem_by_index(self, index: int) -> TargetT & Entity: return self._entities[index] def _getitem_by_indices(self, indices: slice) -> list[TargetT & Entity]: return self.view[indices] def _getitem_by_entity_id(self, entity_id: str) -> TargetT & Entity: for entity in self._entities: if entity_id == entity.id: return entity raise KeyError( f'Cannot find a {self._target_type} entity with ID "{entity_id}".' ) def __delitem__(self, key: str | TargetT & Entity) -> None: if isinstance(key, self._target_type): return self._delitem_by_entity(cast("TargetT & Entity", key)) if isinstance(key, str): return self._delitem_by_entity_id(key) raise TypeError(f"Cannot find entities by {repr(key)}.") def _delitem_by_entity(self, entity: TargetT & Entity) -> None: self.remove(entity) def _delitem_by_entity_id(self, entity_id: str) -> None: for entity in self._entities: if entity_id == entity.id: self.remove(entity) return def __contains__(self, value: Any) -> bool: if isinstance(value, self._target_type): return self._contains_by_entity(cast("TargetT & Entity", value)) if isinstance(value, str): return self._contains_by_entity_id(value) return False def _contains_by_entity(self, other_entity: TargetT & Entity) -> bool: for entity in self._entities: if other_entity is entity: return True return False def _contains_by_entity_id(self, entity_id: str) -> bool: for entity in self._entities: if entity.id == entity_id: return True return False
SingleTypeEntityCollectionT = TypeVar( "SingleTypeEntityCollectionT", bound=SingleTypeEntityCollection[AssociateT] )
[docs] class MultipleTypesEntityCollection(Generic[TargetT], EntityCollection[TargetT]): __slots__ = "_collections" def __init__(self): super().__init__() self._collections: dict[type[Entity], SingleTypeEntityCollection[Entity]] = {} @recursive_repr() def __repr__(self) -> str: return repr_instance( self, entity_types=", ".join(map(get_entity_type_name, self._collections.keys())), length=len(self), ) def _get_collection( self, entity_type: type[EntityT] ) -> SingleTypeEntityCollection[EntityT]: assert issubclass(entity_type, Entity), f"{entity_type} is not an entity type." try: return cast( SingleTypeEntityCollection[EntityT], self._collections[entity_type] ) except KeyError: self._collections[entity_type] = SingleTypeEntityCollection(entity_type) return cast( SingleTypeEntityCollection[EntityT], self._collections[entity_type] ) @overload def __getitem__(self, index: int) -> TargetT & Entity: pass @overload def __getitem__(self, indices: slice) -> list[TargetT & Entity]: pass @overload def __getitem__(self, entity_type_name: str) -> SingleTypeEntityCollection[Entity]: pass @overload def __getitem__( self, entity_type: type[EntityT] ) -> SingleTypeEntityCollection[EntityT]: pass def __getitem__( self, key: int | slice | str | type[EntityT], ) -> ( TargetT & Entity | SingleTypeEntityCollection[Entity] | SingleTypeEntityCollection[EntityT] | list[TargetT & Entity] ): if isinstance(key, int): return self._getitem_by_index(key) if isinstance(key, slice): return self._getitem_by_indices(key) if isinstance(key, str): return self._getitem_by_entity_type_name(key) return self._getitem_by_entity_type(key) def _getitem_by_entity_type( self, entity_type: type[EntityT] ) -> SingleTypeEntityCollection[EntityT]: return self._get_collection(entity_type) def _getitem_by_entity_type_name( self, entity_type_name: str ) -> SingleTypeEntityCollection[Entity]: return self._get_collection( get_entity_type(entity_type_name), ) def _getitem_by_index(self, index: int) -> TargetT & Entity: return self.view[index] def _getitem_by_indices(self, indices: slice) -> list[TargetT & Entity]: return self.view[indices] def __delitem__(self, key: str | type[TargetT & Entity] | TargetT & Entity) -> None: if isinstance(key, type): return self._delitem_by_type( key, ) if isinstance(key, Entity): return self._delitem_by_entity( key, # type: ignore[arg-type] ) return self._delitem_by_entity_type_name(key) def _delitem_by_type(self, entity_type: type[TargetT & Entity]) -> None: removed_entities = [*self._get_collection(entity_type)] self._get_collection(entity_type).clear() if removed_entities: self._on_remove(*removed_entities) def _delitem_by_entity(self, entity: TargetT & Entity) -> None: self.remove(entity) def _delitem_by_entity_type_name(self, entity_type_name: str) -> None: self._delitem_by_type( get_entity_type(entity_type_name), # type: ignore[arg-type] ) def __iter__(self) -> Iterator[TargetT & Entity]: for collection in self._collections.values(): for entity in collection: yield cast("TargetT & Entity", entity) def __len__(self) -> int: return sum(map(len, self._collections.values())) def __contains__(self, value: Any) -> bool: if isinstance(value, Entity): return self._contains_by_entity(value) return False def _contains_by_entity(self, other_entity: Any) -> bool: for entity in self: if other_entity is entity: return True return False
[docs] def add(self, *entities: TargetT & Entity) -> None: added_entities = [*self._unknown(*entities)] for entity in added_entities: self[entity.type].add(entity) if added_entities: self._on_add(*added_entities)
[docs] def remove(self, *entities: TargetT & Entity) -> None: removed_entities = [*self._known(*entities)] for entity in removed_entities: self[entity.type].remove(entity) if removed_entities: self._on_remove(*removed_entities)
[docs] def clear(self) -> None: removed_entities = (*self,) for collection in self._collections.values(): collection.clear() if removed_entities: self._on_remove(*removed_entities)
class _BidirectionalAssociateCollection( Generic[AssociateT, OwnerT], SingleTypeEntityCollection[AssociateT] ): __slots__ = "__owner", "_association" def __init__( self, owner: OwnerT & Entity, association: BidirectionalEntityTypeAssociation[OwnerT, AssociateT], ): super().__init__(association.associate_type) self._association = association self.__owner = weakref.ref(owner) @property def _owner(self) -> OwnerT & Entity: owner = self.__owner() if owner is None: raise RuntimeError( "This associate collection's owner no longer exists in memory." ) return owner def _on_add(self, *entities: AssociateT & Entity) -> None: super()._on_add(*entities) for associate in entities: self._association.inverse().associate(associate, self._owner) def _on_remove(self, *entities: AssociateT & Entity) -> None: super()._on_remove(*entities) for associate in entities: self._association.inverse().disassociate(associate, self._owner)
[docs] class AliasedEntity(Generic[EntityT]): def __init__(self, original_entity: EntityT, aliased_entity_id: str | None = None): self._entity = original_entity self._id = ( GeneratedEntityId() if aliased_entity_id is None else aliased_entity_id ) def __repr__(self) -> str: return repr_instance(self, id=self.id) @property def type(self) -> builtins.type[Entity]: return self._entity.type @property def id(self) -> str: return self._id
[docs] def unalias(self) -> EntityT: return self._entity
AliasableEntity: TypeAlias = EntityT | AliasedEntity[EntityT]
[docs] def unalias(entity: AliasableEntity[EntityT]) -> EntityT: """ Unalias a potentially aliased entity. """ if isinstance(entity, AliasedEntity): return entity.unalias() return entity
_EntityGraphBuilderEntities: TypeAlias = dict[ type[Entity], dict[str, AliasableEntity[Entity]] ] _EntityGraphBuilderAssociations: TypeAlias = dict[ type[Entity], # The owner entity type. dict[ str, # The owner attribute name. dict[str, list[AncestryEntityId]], # The owner ID. # The associate IDs. ], ] class _EntityGraphBuilder: def __init__(self): self._entities: _EntityGraphBuilderEntities = defaultdict(dict) self._associations: _EntityGraphBuilderAssociations = defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: list())) ) self._built = False def _assert_unbuilt(self) -> None: if self._built: raise RuntimeError("This entity graph has been built already.") def _iter(self) -> Iterator[AliasableEntity[Entity]]: for entity_type in self._entities: yield from self._entities[entity_type].values() def _build_associations(self) -> None: for owner_type, owner_attrs in self._associations.items(): for owner_attr_name, owner_associations in owner_attrs.items(): association = EntityTypeAssociationRegistry.get_association( owner_type, owner_attr_name ) for owner_id, associate_ancestry_ids in owner_associations.items(): associates = [ unalias(self._entities[associate_type][associate_id]) for associate_type, associate_id in associate_ancestry_ids ] owner = unalias(self._entities[owner_type][owner_id]) if isinstance(association, ToOneEntityTypeAssociation): association.set(owner, associates[0]) else: association.set(owner, associates) def build(self) -> Iterator[Entity]: self._assert_unbuilt() self._built = True unaliased_entities = list( map( unalias, self._iter(), ) ) EntityTypeAssociationRegistry.initialize(*unaliased_entities) self._build_associations() yield from unaliased_entities
[docs] class EntityGraphBuilder(_EntityGraphBuilder):
[docs] def add_entity(self, *entities: AliasableEntity[Entity]) -> None: self._assert_unbuilt() for entity in entities: self._entities[entity.type][entity.id] = entity
[docs] def add_association( self, owner_type: type[Entity], owner_id: str, owner_attr_name: str, associate_type: type[Entity], associate_id: str, ) -> None: self._assert_unbuilt() self._associations[owner_type][owner_attr_name][owner_id].append( (associate_type, associate_id) )
[docs] @contextmanager def record_added( entities: EntityCollection[EntityT], ) -> Iterator[MultipleTypesEntityCollection[EntityT]]: """ Record all entities that are added to a collection. """ original = [*entities] added = MultipleTypesEntityCollection[EntityT]() yield added added.add(*[entity for entity in entities if entity not in original])