"""
Provide rendering utilities using `Jinja2 <https://jinja.palletsprojects.com>`_.
"""
from __future__ import annotations
import datetime
from collections import defaultdict
from collections.abc import MutableMapping, Iterator, Sequence
from pathlib import Path
from threading import Lock
from typing import Callable, Any, cast, TypeVar
import aiofiles
from aiofiles import os as aiofiles_os
from jinja2 import (
Environment as Jinja2Environment,
select_autoescape,
FileSystemLoader,
pass_context,
)
from jinja2.runtime import StrictUndefined, Context, DebugUndefined
from betty.app import App
from betty.app.extension import Extension
from betty.html import CssProvider, JsProvider
from betty.jinja2.filter import FILTERS
from betty.jinja2.test import TESTS
from betty.job import Context as JobContext
from betty.locale import Date, Localizer, DEFAULT_LOCALIZER
from betty.model import Entity, get_entity_type
from betty.model.ancestry import Citation
from betty.project import ProjectConfiguration
from betty.render import Renderer
from betty.serde.dump import Dumpable, DictDump, VoidableDump, Void, Dump
T = TypeVar("T")
[docs]
def context_app(context: Context) -> App:
"""
Get the current app from the Jinja2 context.
"""
return cast(Environment, context.environment).app
[docs]
def context_job_context(context: Context) -> JobContext | None:
"""
Get the current job context from the Jinja2 context.
"""
job_context = context.resolve_or_missing("job_context")
return job_context if isinstance(job_context, JobContext) else None
[docs]
def context_localizer(context: Context) -> Localizer:
"""
Get the current localizer from the Jinja2 context.
"""
localizer = context.resolve_or_missing("localizer")
if isinstance(localizer, Localizer):
return localizer
raise RuntimeError(
"No `localizer` context variable exists in this Jinja2 template."
)
class _Citer:
__slots__ = "_lock", "_cited"
def __init__(self):
self._lock = Lock()
self._cited: list[Citation] = []
def __iter__(self) -> enumerate[Citation]:
return enumerate(self._cited, 1)
def __len__(self) -> int:
return len(self._cited)
def cite(self, citation: Citation) -> int:
with self._lock:
if citation not in self._cited:
self._cited.append(citation)
return self._cited.index(citation) + 1
class _Breadcrumb(Dumpable):
def __init__(self, label: str, url: str):
self._label = label
self._url = url
def dump(self) -> DictDump[Dump]:
return {
"@type": "ListItem",
"name": self._label,
"item": self._url,
}
class _Breadcrumbs(Dumpable):
def __init__(self):
self._breadcrumbs: list[_Breadcrumb] = []
def append(self, label: str, url: str) -> None:
self._breadcrumbs.append(_Breadcrumb(label, url))
def dump(self) -> VoidableDump:
if not self._breadcrumbs:
return Void
return {
"@context": "https://schema.org",
"@type": "BreadcrumbList",
"itemListElement": [
{
"position": position,
**breadcrumb.dump(),
}
for position, breadcrumb in enumerate(self._breadcrumbs, 1)
],
}
[docs]
class EntityContexts:
def __init__(self, *entities: Entity) -> None:
self._contexts: dict[type[Entity], Entity | None] = defaultdict(lambda: None)
for entity in entities:
self._contexts[entity.type] = entity
def __getitem__(
self, entity_type_or_type_name: type[Entity] | str
) -> Entity | None:
if isinstance(entity_type_or_type_name, str):
entity_type = get_entity_type(entity_type_or_type_name)
else:
entity_type = entity_type_or_type_name
return self._contexts[entity_type]
def __call__(self, *entities: Entity) -> EntityContexts:
updated_contexts = EntityContexts()
for entity in entities:
updated_contexts._contexts[entity.type] = entity
return updated_contexts
[docs]
class Jinja2Provider:
@property
def globals(self) -> dict[str, Any]:
return {}
@property
def filters(self) -> dict[str, Callable[..., Any]]:
return {}
@property
def tests(self) -> dict[str, Callable[..., bool]]:
return {}
[docs]
def new_context_vars(self) -> dict[str, Any]:
return {}
[docs]
class Environment(Jinja2Environment):
globals: dict[str, Any]
filters: dict[str, Callable[..., Any]]
tests: dict[str, Callable[..., bool]]
def __init__(self, app: App):
template_directory_paths = [
str(path / "templates") for path, _ in app.assets.paths
]
super().__init__(
loader=FileSystemLoader(template_directory_paths),
auto_reload=app.project.configuration.debug,
enable_async=True,
undefined=(
DebugUndefined if app.project.configuration.debug else StrictUndefined
),
autoescape=select_autoescape(["html.j2"]),
trim_blocks=True,
lstrip_blocks=True,
extensions=[
"jinja2.ext.do",
"jinja2.ext.i18n",
],
)
self._context_class: type[Context] | None = None
self.app = app
if app.project.configuration.debug:
self.add_extension("jinja2.ext.debug")
self._init_i18n()
self._init_globals()
self.filters.update(FILTERS)
self.tests.update(TESTS)
self._init_extensions()
def _init_i18n(self) -> None:
self.install_gettext_callables( # type: ignore[attr-defined]
gettext=self._gettext,
ngettext=self._ngettext,
pgettext=self._pgettext,
npgettext=self._npgettext,
)
self.policies["ext.i18n.trimmed"] = True
@property
def context_class(self) -> type[Context]: # type: ignore[override]
if self._context_class is None:
jinja2_providers: Sequence[Jinja2Provider & Extension] = [
extension
for extension in self.app.extensions.flatten()
if isinstance(extension, Jinja2Provider)
]
class _Context(Context):
def __init__(
self,
environment: Environment,
parent: dict[str, Any],
name: str | None,
blocks: dict[str, Callable[[Context], Iterator[str]]],
globals: MutableMapping[str, Any] | None = None,
):
if "citer" not in parent:
parent["citer"] = _Citer()
if "breadcrumbs" not in parent:
parent["breadcrumbs"] = _Breadcrumbs()
for jinja2_provider in jinja2_providers:
for key, value in jinja2_provider.new_context_vars().items():
if key not in parent:
parent[key] = value
super().__init__(
environment,
parent,
name,
blocks,
globals,
)
self._context_class = _Context
return self._context_class
@pass_context
def _gettext(self, context: Context, message: str) -> str:
return context_localizer(context).gettext(message)
@pass_context
def _ngettext(
self, context: Context, message_singular: str, message_plural: str, n: int
) -> str:
return context_localizer(context).ngettext(message_singular, message_plural, n)
@pass_context
def _pgettext(self, context: Context, gettext_context: str, message: str) -> str:
return context_localizer(context).pgettext(gettext_context, message)
@pass_context
def _npgettext(
self,
context: Context,
gettext_context: str,
message_singular: str,
message_plural: str,
n: int,
) -> str:
return context_localizer(context).npgettext(
gettext_context, message_singular, message_plural, n
)
def _init_globals(self) -> None:
self.globals["app"] = self.app
today = datetime.date.today()
self.globals["today"] = Date(today.year, today.month, today.day)
# Ideally we would use the Dispatcher for this. However, it is asynchronous only.
self.globals["public_css_paths"] = [
path
for extension in self.app.extensions.flatten()
if isinstance(extension, CssProvider)
for path in extension.public_css_paths
]
self.globals["public_js_paths"] = [
path
for extension in self.app.extensions.flatten()
if isinstance(extension, JsProvider)
for path in extension.public_js_paths
]
self.globals["entity_contexts"] = EntityContexts()
self.globals["localizer"] = DEFAULT_LOCALIZER
def _init_extensions(self) -> None:
for extension in self.app.extensions.flatten():
if isinstance(extension, Jinja2Provider):
self.globals.update(extension.globals)
self.filters.update(extension.filters)
self.tests.update(extension.tests)
[docs]
class Jinja2Renderer(Renderer):
def __init__(self, environment: Environment, configuration: ProjectConfiguration):
self._environment = environment
self._configuration = configuration
@property
def file_extensions(self) -> set[str]:
return {".j2"}
[docs]
async def render_file(
self,
file_path: Path,
*,
job_context: JobContext | None = None,
localizer: Localizer | None = None,
) -> Path:
destination_file_path = file_path.parent / file_path.stem
data: dict[str, Any] = {}
if job_context is not None:
data["job_context"] = job_context
if localizer is not None:
data["localizer"] = localizer
try:
relative_file_destination_path = destination_file_path.relative_to(
self._configuration.www_directory_path
)
except ValueError:
pass
else:
resource = "/".join(relative_file_destination_path.parts)
if self._configuration.locales.multilingual:
resource_parts = resource.lstrip("/").split("/")
if resource_parts[0] in map(
lambda x: x.alias, self._configuration.locales.values()
):
resource = "/".join(resource_parts[1:])
data["page_resource"] = resource
async with aiofiles.open(file_path) as f:
template_source = await f.read()
rendered = await self._environment.from_string(
template_source, self._environment.globals
).render_async(data)
async with aiofiles.open(destination_file_path, "w", encoding="utf-8") as f:
await f.write(rendered)
await aiofiles_os.remove(file_path)
return destination_file_path