"""Provide Betty's data model API."""
from __future__ import annotations
import builtins
import functools
import weakref
from collections import defaultdict
from contextlib import contextmanager
from reprlib import recursive_repr
from typing import (
TypeVar,
Generic,
Iterable,
Any,
overload,
cast,
Iterator,
Callable,
Self,
TypeAlias,
TYPE_CHECKING,
)
from uuid import uuid4
from betty.classtools import repr_instance
from betty.importlib import import_any, fully_qualified_type_name
from betty.json.linked_data import LinkedDataDumpable, add_json_ld
from betty.json.schema import ref_json_schema
from betty.locale import Str
from betty.serde.dump import DictDump, Dump
from betty.string import camel_case_to_kebab_case, upper_camel_case_to_lower_camel_case
if TYPE_CHECKING:
from betty.app import App
T = TypeVar("T")
[docs]
class GeneratedEntityId(str):
"""
Generate a unique entity ID.
Entities must have IDs for identification. However, not all entities can be provided with an ID that exists in the
original data set (such as a third-party family tree loaded into Betty), so IDs can be generated.
"""
def __new__(cls, entity_id: str | None = None):
return super().__new__(cls, entity_id or str(uuid4()))
[docs]
class Entity(LinkedDataDumpable):
def __init__(
self,
id: str | None = None,
*args: Any,
**kwargs: Any,
):
self._id = GeneratedEntityId() if id is None else id
super().__init__(*args, **kwargs)
def __hash__(self) -> int:
return hash(self.ancestry_id)
[docs]
@classmethod
def entity_type_label(cls) -> Str:
raise NotImplementedError(repr(cls))
[docs]
@classmethod
def entity_type_label_plural(cls) -> Str:
raise NotImplementedError(repr(cls))
@recursive_repr()
def __repr__(self) -> str:
return repr_instance(self, id=self._id)
@property
def type(self) -> builtins.type[Self]:
return self.__class__
@property
def id(self) -> str:
return self._id
@property
def ancestry_id(self) -> tuple[builtins.type[Self], str]:
return self.type, self.id
@property
def label(self) -> Str:
return Str._(
"{entity_type} {entity_id}",
entity_type=self.entity_type_label(),
entity_id=self.id,
)
[docs]
async def dump_linked_data(self, app: App) -> DictDump[Dump]:
dump = await super().dump_linked_data(app)
entity_type_name = get_entity_type_name(self.type)
dump["$schema"] = app.static_url_generator.generate(
f"schema.json#/definitions/entity/{upper_camel_case_to_lower_camel_case(entity_type_name)}",
absolute=True,
)
if not isinstance(self.id, GeneratedEntityId):
dump["@id"] = app.static_url_generator.generate(
f"/{camel_case_to_kebab_case(entity_type_name)}/{self.id}/index.json",
absolute=True,
)
dump["id"] = self.id
return dump
[docs]
@classmethod
async def linked_data_schema(cls, app: App) -> DictDump[Dump]:
schema = await super().linked_data_schema(app)
schema["type"] = "object"
schema["properties"] = {
"$schema": ref_json_schema(schema),
"id": {
"type": "string",
},
}
schema["required"] = [
"$schema",
]
schema["additionalProperties"] = False
add_json_ld(schema)
return schema
AncestryEntityId: TypeAlias = tuple[type[Entity], str]
[docs]
class UserFacingEntity:
pass
[docs]
class EntityTypeProvider:
[docs]
async def entity_types(self) -> set[type[Entity]]:
raise NotImplementedError(repr(self))
EntityT = TypeVar("EntityT", bound=Entity)
EntityU = TypeVar("EntityU", bound=Entity)
TargetT = TypeVar("TargetT")
OwnerT = TypeVar("OwnerT")
AssociateT = TypeVar("AssociateT")
AssociateU = TypeVar("AssociateU")
LeftAssociateT = TypeVar("LeftAssociateT")
RightAssociateT = TypeVar("RightAssociateT")
[docs]
def get_entity_type_name(entity_type_definition: type[Entity] | Entity) -> str:
"""
Get the entity type name for an entity or entity type.
"""
if isinstance(entity_type_definition, Entity):
entity_type = entity_type_definition.type
else:
entity_type = entity_type_definition
if entity_type.__module__.startswith("betty.model.ancestry"):
return entity_type.__name__
return f"{entity_type.__module__}.{entity_type.__name__}"
[docs]
def get_entity_type(entity_type_name: str) -> type[Entity]:
"""
Get the entity type for an entity type name.
"""
try:
return import_any(entity_type_name) # type: ignore[no-any-return]
except ImportError:
try:
return import_any(f"betty.model.ancestry.{entity_type_name}") # type: ignore[no-any-return]
except ImportError:
raise EntityTypeImportError(entity_type_name) from None
[docs]
class EntityTypeError(ValueError):
pass
[docs]
class EntityTypeImportError(EntityTypeError, ImportError):
"""
Raised when an alleged entity type cannot be imported.
"""
def __init__(self, entity_type_name: str):
super().__init__(
f'Cannot find and import an entity with name "{entity_type_name}".'
)
[docs]
class EntityTypeInvalidError(EntityTypeError, ImportError):
"""
Raised for types that are not valid entity types.
"""
def __init__(self, entity_type: type):
super().__init__(
f"{entity_type.__module__}.{entity_type.__name__} is not an entity type class. Entity types must extend {Entity.__module__}.{Entity.__name__} directly."
)
[docs]
class EntityCollection(Generic[TargetT]):
__slots__ = ()
def __init__(self):
super().__init__()
def _on_add(self, *entities: TargetT & Entity) -> None:
pass
def _on_remove(self, *entities: TargetT & Entity) -> None:
pass
@property
def view(self) -> list[TargetT & Entity]:
return [*self]
[docs]
def add(self, *entities: TargetT & Entity) -> None:
raise NotImplementedError(repr(self))
[docs]
def remove(self, *entities: TargetT & Entity) -> None:
raise NotImplementedError(repr(self))
[docs]
def replace(self, *entities: TargetT & Entity) -> None:
self.remove(*(entity for entity in self if entity not in entities))
self.add(*entities)
[docs]
def clear(self) -> None:
raise NotImplementedError(repr(self))
def __iter__(self) -> Iterator[TargetT & Entity]:
raise NotImplementedError(repr(self))
def __len__(self) -> int:
raise NotImplementedError(repr(self))
@overload
def __getitem__(self, index: int) -> TargetT & Entity:
pass
@overload
def __getitem__(self, indices: slice) -> list[TargetT & Entity]:
pass
def __getitem__(
self, key: int | slice
) -> TargetT & Entity | list[TargetT & Entity]:
raise NotImplementedError(repr(self))
def __delitem__(self, key: TargetT & Entity) -> None:
raise NotImplementedError(repr(self))
def __contains__(self, value: Any) -> bool:
raise NotImplementedError(repr(self))
def _known(self, *entities: TargetT & Entity) -> Iterable[TargetT & Entity]:
seen = []
for entity in entities:
if entity in self and entity not in seen:
yield entity
seen.append(entity)
def _unknown(self, *entities: TargetT & Entity) -> Iterable[TargetT & Entity]:
seen = []
for entity in entities:
if entity not in self and entity not in seen:
yield entity
seen.append(entity)
EntityCollectionT = TypeVar("EntityCollectionT", bound=EntityCollection[EntityT])
class _EntityTypeAssociation(Generic[OwnerT, AssociateT]):
def __init__(
self,
owner_type: type[OwnerT],
owner_attr_name: str,
associate_type_name: str,
):
self._owner_type = owner_type
self._owner_attr_name = owner_attr_name
self._owner_private_attr_name = f"_{owner_attr_name}"
self._associate_type_name = associate_type_name
self._associate_type: type[AssociateT] | None = None
def __hash__(self) -> int:
return hash(
(
self._owner_type,
self._owner_attr_name,
self._associate_type_name,
)
)
def __repr__(self) -> str:
return repr_instance(
self,
owner_type=self._owner_type,
owner_attr_name=self._owner_attr_name,
associate_type_name=self._associate_type_name,
)
@property
def owner_type(self) -> type[OwnerT]:
return self._owner_type
@property
def owner_attr_name(self) -> str:
return self._owner_attr_name
@property
def associate_type(self) -> type[AssociateT]:
if self._associate_type is None:
self._associate_type = import_any(self._associate_type_name)
return self._associate_type
def register( # type: ignore[misc]
self: ToAny[OwnerT, AssociateT],
) -> None:
EntityTypeAssociationRegistry._register(self)
original_init = self._owner_type.__init__
@functools.wraps(original_init)
def _init(owner: OwnerT & Entity, *args: Any, **kwargs: Any) -> None:
self.initialize(owner)
original_init(owner, *args, **kwargs)
self._owner_type.__init__ = _init # type: ignore[assignment, method-assign]
def initialize(self, owner: OwnerT & Entity) -> None:
raise NotImplementedError(repr(self))
def finalize(self, owner: OwnerT & Entity) -> None:
self.delete(owner)
delattr(owner, self._owner_private_attr_name)
def delete(self, owner: OwnerT & Entity) -> None:
raise NotImplementedError(repr(self))
def associate(self, owner: OwnerT & Entity, associate: AssociateT & Entity) -> None:
raise NotImplementedError(repr(self))
def disassociate(
self, owner: OwnerT & Entity, associate: AssociateT & Entity
) -> None:
raise NotImplementedError(repr(self))
[docs]
class BidirectionalEntityTypeAssociation(
Generic[OwnerT, AssociateT], _EntityTypeAssociation[OwnerT, AssociateT]
):
def __init__(
self,
owner_type: type[OwnerT],
owner_attr_name: str,
associate_type_name: str,
associate_attr_name: str,
):
super().__init__(
owner_type,
owner_attr_name,
associate_type_name,
)
self._associate_attr_name = associate_attr_name
def __hash__(self) -> int:
return hash(
(
self._owner_type,
self._owner_attr_name,
self._associate_type_name,
self._associate_attr_name,
)
)
def __repr__(self) -> str:
return repr_instance(
self,
owner_type=self._owner_type,
owner_attr_name=self._owner_attr_name,
associate_type_name=self._associate_type_name,
associate_attr_name=self._associate_attr_name,
)
@property
def associate_attr_name(self) -> str:
return self._associate_attr_name
[docs]
def inverse(self) -> BidirectionalEntityTypeAssociation[AssociateT, OwnerT]:
association = EntityTypeAssociationRegistry.get_association(
self.associate_type, self.associate_attr_name
)
assert isinstance(association, BidirectionalEntityTypeAssociation)
return association
[docs]
class ToOneEntityTypeAssociation(
Generic[OwnerT, AssociateT], _EntityTypeAssociation[OwnerT, AssociateT]
):
[docs]
def register(self) -> None:
super().register()
setattr(
self.owner_type,
self.owner_attr_name,
property(
self.get,
self.set,
self.delete,
),
)
[docs]
def initialize(self, owner: OwnerT & Entity) -> None:
setattr(owner, self._owner_private_attr_name, None)
[docs]
def get(self, owner: OwnerT & Entity) -> AssociateT & Entity | None:
return getattr(owner, self._owner_private_attr_name) # type: ignore[no-any-return]
[docs]
def set(
self, owner: OwnerT & Entity, associate: AssociateT & Entity | None
) -> None:
setattr(owner, self._owner_private_attr_name, associate)
[docs]
def delete(self, owner: OwnerT & Entity) -> None:
self.set(owner, None)
[docs]
def associate(self, owner: OwnerT & Entity, associate: AssociateT & Entity) -> None:
self.set(owner, associate)
[docs]
def disassociate(
self, owner: OwnerT & Entity, associate: AssociateT & Entity
) -> None:
if associate == self.get(owner):
self.delete(owner)
[docs]
class ToManyEntityTypeAssociation(
Generic[OwnerT, AssociateT], _EntityTypeAssociation[OwnerT, AssociateT]
):
[docs]
def register(self) -> None:
super().register()
setattr(
self.owner_type,
self.owner_attr_name,
property(
self.get,
self.set,
self.delete,
),
)
[docs]
def get(self, owner: OwnerT & Entity) -> EntityCollection[AssociateT & Entity]:
return cast(
EntityCollection["AssociateT & Entity"],
getattr(owner, self._owner_private_attr_name),
)
[docs]
def set(
self, owner: OwnerT & Entity, entities: Iterable[AssociateT & Entity]
) -> None:
self.get(owner).replace(*entities)
[docs]
def delete(self, owner: OwnerT & Entity) -> None:
self.get(owner).clear()
[docs]
def associate(self, owner: OwnerT & Entity, associate: AssociateT & Entity) -> None:
self.get(owner).add(associate)
[docs]
def disassociate(
self, owner: OwnerT & Entity, associate: AssociateT & Entity
) -> None:
self.get(owner).remove(associate)
[docs]
class BidirectionalToOneEntityTypeAssociation(
Generic[OwnerT, AssociateT],
ToOneEntityTypeAssociation[OwnerT, AssociateT],
BidirectionalEntityTypeAssociation[OwnerT, AssociateT],
):
[docs]
def set(
self, owner: OwnerT & Entity, associate: AssociateT & Entity | None
) -> None:
previous_associate = self.get(owner)
if previous_associate == associate:
return
super().set(owner, associate)
if previous_associate is not None:
self.inverse().disassociate(previous_associate, owner)
if associate is not None:
self.inverse().associate(associate, owner)
[docs]
class BidirectionalToManyEntityTypeAssociation(
Generic[OwnerT, AssociateT],
ToManyEntityTypeAssociation[OwnerT, AssociateT],
BidirectionalEntityTypeAssociation[OwnerT, AssociateT],
):
[docs]
def initialize(self, owner: OwnerT & Entity) -> None:
setattr(
owner,
self._owner_private_attr_name,
_BidirectionalAssociateCollection(
owner,
self,
),
)
[docs]
class ToOne(
Generic[OwnerT, AssociateT], ToOneEntityTypeAssociation[OwnerT, AssociateT]
):
pass
[docs]
class OneToOne(
Generic[OwnerT, AssociateT],
BidirectionalToOneEntityTypeAssociation[OwnerT, AssociateT],
):
pass
[docs]
class ManyToOne(
Generic[OwnerT, AssociateT],
BidirectionalToOneEntityTypeAssociation[OwnerT, AssociateT],
):
pass
[docs]
class ToMany(
Generic[OwnerT, AssociateT], ToManyEntityTypeAssociation[OwnerT, AssociateT]
):
[docs]
def initialize(self, owner: OwnerT & Entity) -> None:
setattr(
owner,
self._owner_private_attr_name,
SingleTypeEntityCollection[AssociateT](self.associate_type),
)
[docs]
class OneToMany(
Generic[OwnerT, AssociateT],
BidirectionalToManyEntityTypeAssociation[OwnerT, AssociateT],
):
pass
[docs]
class ManyToMany(
Generic[OwnerT, AssociateT],
BidirectionalToManyEntityTypeAssociation[OwnerT, AssociateT],
):
pass
ToAny: TypeAlias = (
ToOneEntityTypeAssociation[OwnerT, AssociateT]
| ToManyEntityTypeAssociation[OwnerT, AssociateT]
)
[docs]
def to_one(
owner_attr_name: str,
associate_type_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a unidirectional to-one association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
ToOne(
owner_type,
owner_attr_name,
associate_type_name,
).register()
return owner_type
return _decorator
[docs]
def one_to_one(
owner_attr_name: str,
associate_type_name: str,
associate_attr_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a bidirectional one-to-one association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
OneToOne(
owner_type,
owner_attr_name,
associate_type_name,
associate_attr_name,
).register()
return owner_type
return _decorator
[docs]
def many_to_one(
owner_attr_name: str,
associate_type_name: str,
associate_attr_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a bidirectional many-to-one association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
ManyToOne(
owner_type,
owner_attr_name,
associate_type_name,
associate_attr_name,
).register()
return owner_type
return _decorator
[docs]
def to_many(
owner_attr_name: str,
associate_type_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a unidirectional to-many association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
ToMany(
owner_type,
owner_attr_name,
associate_type_name,
).register()
return owner_type
return _decorator
[docs]
def one_to_many(
owner_attr_name: str,
associate_type_name: str,
associate_attr_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a bidirectional one-to-many association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
OneToMany(
owner_type,
owner_attr_name,
associate_type_name,
associate_attr_name,
).register()
return owner_type
return _decorator
[docs]
def many_to_many(
owner_attr_name: str,
associate_type_name: str,
associate_attr_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a bidirectional many-to-many association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
ManyToMany(
owner_type,
owner_attr_name,
associate_type_name,
associate_attr_name,
).register()
return owner_type
return _decorator
[docs]
def many_to_one_to_many(
left_associate_type_name: str,
left_associate_attr_name: str,
left_owner_attr_name: str,
right_owner_attr_name: str,
right_associate_type_name: str,
right_associate_attr_name: str,
) -> Callable[[type[OwnerT]], type[OwnerT]]:
"""
Add a bidirectional many-to-one-to-many association to an entity or entity mixin.
"""
def _decorator(owner_type: type[OwnerT]) -> type[OwnerT]:
ManyToOne(
owner_type,
left_owner_attr_name,
left_associate_type_name,
left_associate_attr_name,
).register()
ManyToOne(
owner_type,
right_owner_attr_name,
right_associate_type_name,
right_associate_attr_name,
).register()
return owner_type
return _decorator
[docs]
class EntityTypeAssociationRegistry:
_associations = set[ToAny[Any, Any]]()
[docs]
@classmethod
def get_all_associations(cls, owner: type | object) -> set[ToAny[Any, Any]]:
owner_type = owner if isinstance(owner, type) else type(owner)
return {
association
for association in cls._associations
if association.owner_type in owner_type.__mro__
}
[docs]
@classmethod
def get_association(
cls, owner: type[OwnerT] | OwnerT & Entity, owner_attr_name: str
) -> ToAny[OwnerT, Any]:
for association in cls.get_all_associations(owner):
if association.owner_attr_name == owner_attr_name:
return association
raise ValueError(
f"No association exists for {fully_qualified_type_name(owner if isinstance(owner, type) else owner.__class__)}.{owner_attr_name}."
)
[docs]
@classmethod
def get_associates(
cls, owner: EntityT, association: ToAny[EntityT, AssociateT]
) -> Iterable[AssociateT]:
associates: AssociateT | None | Iterable[AssociateT] = getattr(
owner, f"_{association.owner_attr_name}"
)
if isinstance(association, ToOneEntityTypeAssociation):
if associates is None:
return
yield cast(AssociateT, associates)
return
yield from cast(Iterable[AssociateT], associates)
@classmethod
def _register(cls, association: ToAny[Any, Any]) -> None:
cls._associations.add(association)
[docs]
@classmethod
def initialize(cls, *owners: Entity) -> None:
for owner in owners:
for association in cls.get_all_associations(owner):
association.initialize(owner)
[docs]
@classmethod
def finalize(cls, *owners: Entity) -> None:
for owner in owners:
for association in cls.get_all_associations(owner):
association.finalize(owner)
[docs]
class SingleTypeEntityCollection(Generic[TargetT], EntityCollection[TargetT]):
__slots__ = "_entities", "_target_type"
def __init__(
self,
target_type: type[TargetT],
):
super().__init__()
self._entities: list[TargetT & Entity] = []
self._target_type = target_type
@recursive_repr()
def __repr__(self) -> str:
return repr_instance(self, target_type=self._target_type, length=len(self))
[docs]
def add(self, *entities: TargetT & Entity) -> None:
added_entities = [*self._unknown(*entities)]
for entity in added_entities:
self._entities.append(entity)
if added_entities:
self._on_add(*added_entities)
[docs]
def remove(self, *entities: TargetT & Entity) -> None:
removed_entities = [*self._known(*entities)]
for entity in removed_entities:
self._entities.remove(entity)
if removed_entities:
self._on_remove(*removed_entities)
[docs]
def clear(self) -> None:
self.remove(*self)
def __iter__(self) -> Iterator[TargetT & Entity]:
return self._entities.__iter__()
def __len__(self) -> int:
return len(self._entities)
@overload
def __getitem__(self, index: int) -> TargetT & Entity:
pass
@overload
def __getitem__(self, indices: slice) -> list[TargetT & Entity]:
pass
@overload
def __getitem__(self, entity_id: str) -> TargetT & Entity:
pass
def __getitem__(
self, key: int | slice | str
) -> TargetT & Entity | list[TargetT & Entity]:
if isinstance(key, int):
return self._getitem_by_index(key)
if isinstance(key, slice):
return self._getitem_by_indices(key)
return self._getitem_by_entity_id(key)
def _getitem_by_index(self, index: int) -> TargetT & Entity:
return self._entities[index]
def _getitem_by_indices(self, indices: slice) -> list[TargetT & Entity]:
return self.view[indices]
def _getitem_by_entity_id(self, entity_id: str) -> TargetT & Entity:
for entity in self._entities:
if entity_id == entity.id:
return entity
raise KeyError(
f'Cannot find a {self._target_type} entity with ID "{entity_id}".'
)
def __delitem__(self, key: str | TargetT & Entity) -> None:
if isinstance(key, self._target_type):
return self._delitem_by_entity(cast("TargetT & Entity", key))
if isinstance(key, str):
return self._delitem_by_entity_id(key)
raise TypeError(f"Cannot find entities by {repr(key)}.")
def _delitem_by_entity(self, entity: TargetT & Entity) -> None:
self.remove(entity)
def _delitem_by_entity_id(self, entity_id: str) -> None:
for entity in self._entities:
if entity_id == entity.id:
self.remove(entity)
return
def __contains__(self, value: Any) -> bool:
if isinstance(value, self._target_type):
return self._contains_by_entity(cast("TargetT & Entity", value))
if isinstance(value, str):
return self._contains_by_entity_id(value)
return False
def _contains_by_entity(self, other_entity: TargetT & Entity) -> bool:
for entity in self._entities:
if other_entity is entity:
return True
return False
def _contains_by_entity_id(self, entity_id: str) -> bool:
for entity in self._entities:
if entity.id == entity_id:
return True
return False
SingleTypeEntityCollectionT = TypeVar(
"SingleTypeEntityCollectionT", bound=SingleTypeEntityCollection[AssociateT]
)
[docs]
class MultipleTypesEntityCollection(Generic[TargetT], EntityCollection[TargetT]):
__slots__ = "_collections"
def __init__(self):
super().__init__()
self._collections: dict[type[Entity], SingleTypeEntityCollection[Entity]] = {}
@recursive_repr()
def __repr__(self) -> str:
return repr_instance(
self,
entity_types=", ".join(map(get_entity_type_name, self._collections.keys())),
length=len(self),
)
def _get_collection(
self, entity_type: type[EntityT]
) -> SingleTypeEntityCollection[EntityT]:
assert issubclass(entity_type, Entity), f"{entity_type} is not an entity type."
try:
return cast(
SingleTypeEntityCollection[EntityT], self._collections[entity_type]
)
except KeyError:
self._collections[entity_type] = SingleTypeEntityCollection(entity_type)
return cast(
SingleTypeEntityCollection[EntityT], self._collections[entity_type]
)
@overload
def __getitem__(self, index: int) -> TargetT & Entity:
pass
@overload
def __getitem__(self, indices: slice) -> list[TargetT & Entity]:
pass
@overload
def __getitem__(self, entity_type_name: str) -> SingleTypeEntityCollection[Entity]:
pass
@overload
def __getitem__(
self, entity_type: type[EntityT]
) -> SingleTypeEntityCollection[EntityT]:
pass
def __getitem__(
self,
key: int | slice | str | type[EntityT],
) -> (
TargetT & Entity
| SingleTypeEntityCollection[Entity]
| SingleTypeEntityCollection[EntityT]
| list[TargetT & Entity]
):
if isinstance(key, int):
return self._getitem_by_index(key)
if isinstance(key, slice):
return self._getitem_by_indices(key)
if isinstance(key, str):
return self._getitem_by_entity_type_name(key)
return self._getitem_by_entity_type(key)
def _getitem_by_entity_type(
self, entity_type: type[EntityT]
) -> SingleTypeEntityCollection[EntityT]:
return self._get_collection(entity_type)
def _getitem_by_entity_type_name(
self, entity_type_name: str
) -> SingleTypeEntityCollection[Entity]:
return self._get_collection(
get_entity_type(entity_type_name),
)
def _getitem_by_index(self, index: int) -> TargetT & Entity:
return self.view[index]
def _getitem_by_indices(self, indices: slice) -> list[TargetT & Entity]:
return self.view[indices]
def __delitem__(self, key: str | type[TargetT & Entity] | TargetT & Entity) -> None:
if isinstance(key, type):
return self._delitem_by_type(
key,
)
if isinstance(key, Entity):
return self._delitem_by_entity(
key, # type: ignore[arg-type]
)
return self._delitem_by_entity_type_name(key)
def _delitem_by_type(self, entity_type: type[TargetT & Entity]) -> None:
removed_entities = [*self._get_collection(entity_type)]
self._get_collection(entity_type).clear()
if removed_entities:
self._on_remove(*removed_entities)
def _delitem_by_entity(self, entity: TargetT & Entity) -> None:
self.remove(entity)
def _delitem_by_entity_type_name(self, entity_type_name: str) -> None:
self._delitem_by_type(
get_entity_type(entity_type_name), # type: ignore[arg-type]
)
def __iter__(self) -> Iterator[TargetT & Entity]:
for collection in self._collections.values():
for entity in collection:
yield cast("TargetT & Entity", entity)
def __len__(self) -> int:
return sum(map(len, self._collections.values()))
def __contains__(self, value: Any) -> bool:
if isinstance(value, Entity):
return self._contains_by_entity(value)
return False
def _contains_by_entity(self, other_entity: Any) -> bool:
for entity in self:
if other_entity is entity:
return True
return False
[docs]
def add(self, *entities: TargetT & Entity) -> None:
added_entities = [*self._unknown(*entities)]
for entity in added_entities:
self[entity.type].add(entity)
if added_entities:
self._on_add(*added_entities)
[docs]
def remove(self, *entities: TargetT & Entity) -> None:
removed_entities = [*self._known(*entities)]
for entity in removed_entities:
self[entity.type].remove(entity)
if removed_entities:
self._on_remove(*removed_entities)
[docs]
def clear(self) -> None:
removed_entities = (*self,)
for collection in self._collections.values():
collection.clear()
if removed_entities:
self._on_remove(*removed_entities)
class _BidirectionalAssociateCollection(
Generic[AssociateT, OwnerT], SingleTypeEntityCollection[AssociateT]
):
__slots__ = "__owner", "_association"
def __init__(
self,
owner: OwnerT & Entity,
association: BidirectionalEntityTypeAssociation[OwnerT, AssociateT],
):
super().__init__(association.associate_type)
self._association = association
self.__owner = weakref.ref(owner)
@property
def _owner(self) -> OwnerT & Entity:
owner = self.__owner()
if owner is None:
raise RuntimeError(
"This associate collection's owner no longer exists in memory."
)
return owner
def _on_add(self, *entities: AssociateT & Entity) -> None:
super()._on_add(*entities)
for associate in entities:
self._association.inverse().associate(associate, self._owner)
def _on_remove(self, *entities: AssociateT & Entity) -> None:
super()._on_remove(*entities)
for associate in entities:
self._association.inverse().disassociate(associate, self._owner)
[docs]
class AliasedEntity(Generic[EntityT]):
def __init__(self, original_entity: EntityT, aliased_entity_id: str | None = None):
self._entity = original_entity
self._id = (
GeneratedEntityId() if aliased_entity_id is None else aliased_entity_id
)
def __repr__(self) -> str:
return repr_instance(self, id=self.id)
@property
def type(self) -> builtins.type[Entity]:
return self._entity.type
@property
def id(self) -> str:
return self._id
[docs]
def unalias(self) -> EntityT:
return self._entity
AliasableEntity: TypeAlias = EntityT | AliasedEntity[EntityT]
[docs]
def unalias(entity: AliasableEntity[EntityT]) -> EntityT:
"""
Unalias a potentially aliased entity.
"""
if isinstance(entity, AliasedEntity):
return entity.unalias()
return entity
_EntityGraphBuilderEntities: TypeAlias = dict[
type[Entity], dict[str, AliasableEntity[Entity]]
]
_EntityGraphBuilderAssociations: TypeAlias = dict[
type[Entity], # The owner entity type.
dict[
str, # The owner attribute name.
dict[str, list[AncestryEntityId]], # The owner ID. # The associate IDs.
],
]
class _EntityGraphBuilder:
def __init__(self):
self._entities: _EntityGraphBuilderEntities = defaultdict(dict)
self._associations: _EntityGraphBuilderAssociations = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: list()))
)
self._built = False
def _assert_unbuilt(self) -> None:
if self._built:
raise RuntimeError("This entity graph has been built already.")
def _iter(self) -> Iterator[AliasableEntity[Entity]]:
for entity_type in self._entities:
yield from self._entities[entity_type].values()
def _build_associations(self) -> None:
for owner_type, owner_attrs in self._associations.items():
for owner_attr_name, owner_associations in owner_attrs.items():
association = EntityTypeAssociationRegistry.get_association(
owner_type, owner_attr_name
)
for owner_id, associate_ancestry_ids in owner_associations.items():
associates = [
unalias(self._entities[associate_type][associate_id])
for associate_type, associate_id in associate_ancestry_ids
]
owner = unalias(self._entities[owner_type][owner_id])
if isinstance(association, ToOneEntityTypeAssociation):
association.set(owner, associates[0])
else:
association.set(owner, associates)
def build(self) -> Iterator[Entity]:
self._assert_unbuilt()
self._built = True
unaliased_entities = list(
map(
unalias,
self._iter(),
)
)
EntityTypeAssociationRegistry.initialize(*unaliased_entities)
self._build_associations()
yield from unaliased_entities
[docs]
class EntityGraphBuilder(_EntityGraphBuilder):
[docs]
def add_entity(self, *entities: AliasableEntity[Entity]) -> None:
self._assert_unbuilt()
for entity in entities:
self._entities[entity.type][entity.id] = entity
[docs]
def add_association(
self,
owner_type: type[Entity],
owner_id: str,
owner_attr_name: str,
associate_type: type[Entity],
associate_id: str,
) -> None:
self._assert_unbuilt()
self._associations[owner_type][owner_attr_name][owner_id].append(
(associate_type, associate_id)
)
[docs]
@contextmanager
def record_added(
entities: EntityCollection[EntityT],
) -> Iterator[MultipleTypesEntityCollection[EntityT]]:
"""
Record all entities that are added to a collection.
"""
original = [*entities]
added = MultipleTypesEntityCollection[EntityT]()
yield added
added.add(*[entity for entity in entities if entity not in original])