Source code for betty.app.extension

"""Provide Betty's extension API."""
from __future__ import annotations

import functools
from collections import defaultdict
from importlib.metadata import entry_points, EntryPoint
from pathlib import Path
from typing import Any, TypeVar, Iterable, TYPE_CHECKING, Generic, \
    Iterator, Sequence, Self

from betty.app.extension.requirement import Requirement, AllRequirements
from betty.asyncio import gather
from betty.config import ConfigurationT, Configurable
from betty.dispatch import Dispatcher, TargetedDispatcher
from betty.importlib import import_any
from betty.locale import Str

if TYPE_CHECKING:
    from betty.app import App


[docs] class ExtensionError(BaseException): pass
[docs] class ExtensionTypeError(ExtensionError, ValueError): pass
[docs] class ExtensionTypeImportError(ExtensionTypeError, ImportError): """ Raised when an alleged extension type cannot be imported. """ def __init__(self, extension_type_name: str): super().__init__(f'Cannot find and import an extension with name "{extension_type_name}".')
[docs] class ExtensionTypeInvalidError(ExtensionTypeError, ImportError): """ Raised for types that are not valid extension types. """ def __init__(self, extension_type: type): super().__init__(f'{extension_type.__module__}.{extension_type.__name__} is not an extension type class. Extension types must extend {Extension.__module__}.{Extension.__name__}.')
[docs] class CyclicDependencyError(ExtensionError, RuntimeError): def __init__(self, extension_types: Iterable[type[Extension]]): extension_names = ', '.join([extension.name() for extension in extension_types]) super().__init__(f'The following extensions have cyclic dependencies: {extension_names}')
[docs] class Dependencies(AllRequirements): def __init__(self, dependent_type: type[Extension]): dependency_requirements = [] for dependency_type in dependent_type.depends_on(): try: dependency_requirement = dependency_type.enable_requirement() except RecursionError: raise CyclicDependencyError([dependency_type]) else: dependency_requirements.append(dependency_requirement) super().__init__(*dependency_requirements) self._dependent_type = dependent_type
[docs] @classmethod def for_dependent(cls, dependent_type: type[Extension]) -> Self: return cls(dependent_type)
[docs] def summary(self) -> Str: return Str._( '{dependent_label} requires {dependency_labels}.', dependent_label=format_extension_type(self._dependent_type), dependency_labels=Str.call( lambda localizer: ', '.join( map( lambda extension_type: format_extension_type(extension_type).localize(localizer), self._dependent_type.depends_on(), ), ), ), )
[docs] class Dependents(Requirement): def __init__(self, dependency: Extension, dependents: Sequence[Extension]): super().__init__() self._dependency = dependency self._dependents = dependents
[docs] def summary(self) -> Str: return Str._( '{dependency_label} is required by {dependency_labels}.', dependency_label=format_extension_type(type(self._dependency)), dependent_labels=Str.call(lambda localizer: ', '.join([ format_extension_type(type(dependent)).localize(localizer) for dependent in self._dependents ])), )
[docs] def is_met(self) -> bool: # This class is never instantiated unless there is at least one enabled dependent, which means this requirement # is always met. return True
[docs] @classmethod def for_dependency(cls, dependency: Extension) -> Self: dependents = [ dependency.app.extensions[extension_type] for extension_type in discover_extension_types() if dependency.__class__ in extension_type.depends_on() and extension_type in dependency.app.extensions ] return cls(dependency, dependents)
[docs] class Extension: """ Integrate optional functionality with the Betty app. """ def __init__(self, app: App, *args: Any, **kwargs: Any): assert type(self) is not Extension super().__init__(*args, **kwargs) self._app = app
[docs] @classmethod def name(cls) -> str: return '%s.%s' % (cls.__module__, cls.__name__)
[docs] @classmethod def depends_on(cls) -> set[type[Extension]]: return set()
[docs] @classmethod def comes_after(cls) -> set[type[Extension]]: return set()
[docs] @classmethod def comes_before(cls) -> set[type[Extension]]: return set()
[docs] @classmethod def enable_requirement(cls) -> Requirement: """ Define the requirement for this extension to be enabled. This defaults to the extension's dependencies. """ return Dependencies.for_dependent(cls)
[docs] def disable_requirement(self) -> Requirement: """ Define the requirement for this extension to be disabled. This defaults to the extension's dependents. """ return Dependents.for_dependency(self)
[docs] @classmethod def assets_directory_path(cls) -> Path | None: """ Return the path on disk where the extension's assets are located. This may be anywhere in your Python package. """ return None
@property def app(self) -> App: return self._app
ExtensionT = TypeVar('ExtensionT', bound=Extension)
[docs] class UserFacingExtension(Extension):
[docs] @classmethod def label(cls) -> Str: raise NotImplementedError(repr(cls))
[docs] @classmethod def description(cls) -> Str: raise NotImplementedError(repr(cls))
[docs] class Theme(UserFacingExtension): pass
[docs] @functools.singledispatch def get_extension_type(extension_type_definition: str | type[Extension] | Extension) -> type[Extension]: """ Get the extension type for an extension, extension type, or extension type name. """ raise ExtensionTypeError(f'Cannot get the extension type for "{extension_type_definition}".')
[docs] @get_extension_type.register(str) def get_extension_type_by_name(extension_type_name: str) -> type[Extension]: """ Get the extension type for an extension type name. """ try: extension_type = import_any(extension_type_name) except ImportError: raise ExtensionTypeImportError(extension_type_name) from None return get_extension_type(extension_type)
[docs] @get_extension_type.register(type) def get_extension_type_by_type(extension_type: type) -> type[Extension]: """ Get the extension type for an extension type. """ if issubclass(extension_type, Extension): return extension_type raise ExtensionTypeInvalidError(extension_type)
[docs] @get_extension_type.register(Extension) def get_extension_type_by_extension(extension: Extension) -> type[Extension]: """ Get the extension type for an extension. """ return get_extension_type(type(extension))
[docs] def format_extension_type(extension_type: type[Extension]) -> Str: """ Format an extension type to a human-readable label. """ if issubclass(extension_type, UserFacingExtension): return Str.call(lambda localizer: f'{extension_type.label().localize(localizer)} ({extension_type.name()})') return Str.plain(extension_type.name())
[docs] class ConfigurableExtension(Extension, Generic[ConfigurationT], Configurable[ConfigurationT]): def __init__(self, *args: Any, configuration: ConfigurationT | None = None, **kwargs: Any): assert type(self) is not ConfigurableExtension super().__init__(*args, **kwargs) self._configuration = configuration or self.default_configuration()
[docs] @classmethod def default_configuration(cls) -> ConfigurationT: raise NotImplementedError(repr(cls))
[docs] class Extensions: def __getitem__(self, extension_type: type[ExtensionT] | str) -> ExtensionT: raise NotImplementedError(repr(self)) def __iter__(self) -> Iterator[Iterator[Extension]]: raise NotImplementedError(repr(self))
[docs] def flatten(self) -> Iterator[Extension]: raise NotImplementedError(repr(self))
def __contains__(self, extension_type: type[Extension] | str | Any) -> bool: raise NotImplementedError(repr(self))
[docs] class ListExtensions(Extensions): def __init__(self, extensions: list[list[Extension]]): super().__init__() self._extensions = extensions def __getitem__(self, extension_type: type[ExtensionT] | str) -> ExtensionT: if isinstance(extension_type, str): extension_type = import_any(extension_type) for extension in self.flatten(): if type(extension) is extension_type: return extension # type: ignore[return-value] raise KeyError(f'Unknown extension of type "{extension_type}"') def __iter__(self) -> Iterator[Iterator[Extension]]: # Use a generator so we discourage calling code from storing the result. for batch in self._extensions: yield (extension for extension in batch)
[docs] def flatten(self) -> Iterator[Extension]: for batch in self: yield from batch
def __contains__(self, extension_type: type[Extension] | str) -> bool: if isinstance(extension_type, str): try: extension_type = import_any(extension_type) except ImportError: return False for extension in self.flatten(): if type(extension) is extension_type: return True return False
[docs] class ExtensionDispatcher(Dispatcher): def __init__(self, extensions: Extensions): self._extensions = extensions
[docs] def dispatch(self, target_type: type[Any]) -> TargetedDispatcher: target_method_names = [method_name for method_name in dir(target_type) if not method_name.startswith('_')] if len(target_method_names) != 1: raise ValueError(f"A dispatch's target type must have a single method to dispatch to, but {target_type} has {len(target_method_names)}.") target_method_name = target_method_names[0] async def _dispatch(*args: Any, **kwargs: Any) -> list[Any]: return [ result for target_extension_batch in self._extensions for result in await gather(*( getattr(target_extension, target_method_name)(*args, **kwargs) for target_extension in target_extension_batch if isinstance(target_extension, target_type) )) ] return _dispatch
ExtensionTypeGraph = dict[type[Extension], set[type[Extension]]]
[docs] def build_extension_type_graph(extension_types: Iterable[type[Extension]]) -> ExtensionTypeGraph: """ Build a dependency graph of the given extension types. """ extension_types_graph: ExtensionTypeGraph = defaultdict(set) # Add dependencies to the extension graph. for extension_type in extension_types: _extend_extension_type_graph(extension_types_graph, extension_type) # Now all dependencies have been collected, extend the graph with optional extension orders. for extension_type in extension_types: for before in extension_type.comes_before(): if before in extension_types_graph: extension_types_graph[before].add(extension_type) for after in extension_type.comes_after(): if after in extension_types_graph: extension_types_graph[extension_type].add(after) return extension_types_graph
def _extend_extension_type_graph(graph: ExtensionTypeGraph, extension_type: type[Extension]) -> None: dependencies = extension_type.depends_on() # Ensure each extension type appears in the graph, even if they're isolated. graph.setdefault(extension_type, set()) for dependency in dependencies: seen_dependency = dependency in graph graph[extension_type].add(dependency) if not seen_dependency: _extend_extension_type_graph(graph, dependency)
[docs] def discover_extension_types() -> set[type[Extension]]: """ Gather the available extension types. """ betty_entry_points: Sequence[EntryPoint] betty_entry_points = entry_points( # type: ignore[assignment, unused-ignore] group='betty.extensions', # type: ignore[call-arg, unused-ignore] ) return {import_any(betty_entry_point.value) for betty_entry_point in betty_entry_points}