Coverage for fastapi_restly / views / _base.py: 89%
682 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-24 11:13 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-24 11:13 +0000
1"""
2This module provides a framework for class-based views on SQLAlchemy models.
4View class:
5This class is used to create a collection of endpoints that share an
6APIRouter (created when calling `include_view()`) and dependencies
7as class attributes. It uses the same mechanics as the class based
8view decorator from fastapi-utils.
9(https://fastapi-utils.davidmontague.xyz/user-guide/class-based-views/)
11AsyncRestView:
12Provides default reading and writing functions on the database using
13SQLAlchemy models.
14"""
16import dataclasses
17import functools
18import inspect
19import types
20import warnings
21from enum import Enum
22from math import ceil
23from typing import (
24 Annotated,
25 Any,
26 Callable,
27 ClassVar,
28 Generic,
29 Iterable,
30 Iterator,
31 Protocol,
32 Sequence,
33 Union,
34 cast,
35 get_args,
36 get_origin,
37 get_type_hints,
38 overload,
39)
41import fastapi
42import pydantic
43from fastapi import BackgroundTasks, Request, Response, WebSocket
44from fastapi.params import Depends as _DependsMarker
45from pydantic import create_model
46from sqlalchemy import inspect as sa_inspect
47from sqlalchemy.orm import DeclarativeBase, selectinload
48from starlette.datastructures import QueryParams
49from typing_extensions import TypeVar
51from .._exception_handlers import register_default_exception_handlers
52from ..db._globals import _fr_globals
53from ..exc import RestlyMisuseWarning
54from ..objects import snapshot as _object_snapshot
55from ..query import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, create_list_params_schema
56from ..schemas import BaseSchema, IDRef, IDSchema
57from ..schemas._base import (
58 _reject_buried_markers,
59 create_model_with_optional_fields,
60 create_model_without_read_only_fields,
61 get_writable_inputs,
62 is_readonly_field,
63 is_writeonly_field,
64)
65from ..schemas._generator import auto_generate_schema_for_view
66from ._openapi import _register_for_resource_ref
68ModelT = TypeVar("ModelT", bound=DeclarativeBase, default=DeclarativeBase)
69SchemaT = TypeVar("SchemaT", bound=pydantic.BaseModel, default=BaseSchema)
70CreateSchemaT = TypeVar(
71 "CreateSchemaT", bound=pydantic.BaseModel, default=pydantic.BaseModel
72)
73UpdateSchemaT = TypeVar(
74 "UpdateSchemaT", bound=pydantic.BaseModel, default=pydantic.BaseModel
75)
76IdT = TypeVar("IdT", default=int)
79@dataclasses.dataclass(frozen=True)
80class ListingResult(Generic[ModelT]):
81 """Result returned by ``get_many`` before HTTP response formatting."""
83 objects: Sequence[ModelT]
84 total_count: int
85 query_params: Any = None
88class ViewRoute(str, Enum):
89 """Generated CRUD routes that can be referenced by view options.
91 Values are the route-shell method names so ``exclude_routes`` can drop them.
92 """
94 GET_MANY = "get_many_endpoint"
95 GET_ONE = "get_one_endpoint"
96 CREATE = "create_endpoint"
97 UPDATE = "update_endpoint"
98 DELETE = "delete_endpoint"
101class ResponseShape(str, Enum):
102 """The wire shape a route shell asks :meth:`BaseRestView.to_response` to
103 produce.
105 This is separate from write-action names such as ``"publish"``. Route
106 shells choose one of these three response shapes; custom actions remain an
107 open string namespace.
108 """
110 SINGLE = "single" # one serialized object
111 LISTING = "listing" # a ListingResult -> array / paginated envelope
112 EMPTY = "empty" # 204 No Content
115class Action:
116 """Canonical CRUD action names passed to ``authorize`` / ``before_commit``
117 / ``after_commit``.
119 This is a constants class, not an ``Enum``: custom actions and mixins add
120 their own names. Use constants for typo checking at import time.
121 """
123 GET_MANY = "get_many"
124 GET_ONE = "get_one"
125 CREATE = "create"
126 UPDATE = "update"
127 DELETE = "delete"
130def _accepts_init_kwarg(model_cls: type, attr_name: str) -> bool:
131 """Return True if attr_name can be passed as a keyword argument to model_cls.__init__.
133 Non-dataclass models (DeclarativeBase subclasses using mapped_column) accept all
134 kwargs. Dataclass-based models may have fields with init=False, in which case
135 passing the attribute to __init__ raises TypeError.
136 """
137 if not dataclasses.is_dataclass(model_cls):
138 return True
139 dc_fields = {f.name: f for f in dataclasses.fields(model_cls)}
140 return attr_name not in dc_fields or dc_fields[attr_name].init
143def _requires_init_kwarg(model_cls: type, attr_name: str) -> bool:
144 if not dataclasses.is_dataclass(model_cls):
145 return False
146 dc_fields = {f.name: f for f in dataclasses.fields(model_cls)}
147 field = dc_fields.get(attr_name)
148 if field is None or not field.init:
149 return False
150 return (
151 field.default is dataclasses.MISSING
152 and field.default_factory is dataclasses.MISSING
153 )
156@dataclasses.dataclass
157class _CreatePlan:
158 kwargs: dict[str, Any]
159 post_assignments: dict[str, Any]
162class _HasID(Protocol):
163 """Anything with an ``id`` attribute. By framework convention, primary
164 keys are named ``id``; ``IDBase`` formalizes this but isn't required."""
166 id: Any
169def _has_model_attr(model_cls: type[DeclarativeBase], attr_name: str) -> bool:
170 return hasattr(model_cls, attr_name)
173def _get_relationship_property(
174 model_cls: type[DeclarativeBase], relation_name: str
175) -> Any | None:
176 try:
177 mapper = sa_inspect(model_cls)
178 except Exception:
179 return None
180 return mapper.relationships.get(relation_name)
183def _get_unambiguous_local_fk_name(
184 model_cls: type[DeclarativeBase], relation_name: str
185) -> str | None:
186 relationship_property = _get_relationship_property(model_cls, relation_name)
187 if relationship_property is None: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 return None
190 if getattr(relationship_property.direction, "name", None) != "MANYTOONE": 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 return None
193 local_columns = list(relationship_property.local_columns)
194 if len(local_columns) != 1:
195 column_names = ", ".join(column.key for column in local_columns) or "<none>"
196 raise ValueError(
197 f"Cannot infer a single local FK for relationship "
198 f"{model_cls.__name__}.{relation_name}; found {column_names}. "
199 "Use an explicit custom handler for this relationship."
200 )
201 return local_columns[0].key
204def _is_reference_schema_field(
205 schema_cls: type[pydantic.BaseModel], field_name: str
206) -> bool:
207 field_info = schema_cls.model_fields.get(field_name)
208 if field_info is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 return False
210 return _is_idschema_reference_annotation(field_info.annotation)
213def _add_assignment(target: dict[str, Any], field_name: str | None, value: Any) -> None:
214 if field_name: 214 ↛ exitline 214 didn't return from function '_add_assignment' because the condition on line 214 was always true
215 target[field_name] = value
218_EXPLICIT_NULL_REF = object()
221def _reference_identity(value: Any) -> tuple[type[Any] | None, Any] | object | None:
222 if value is None:
223 return _EXPLICIT_NULL_REF
224 if isinstance(value, DeclarativeBase): 224 ↛ 226line 224 didn't jump to line 226 because the condition on line 224 was always true
225 return type(value), getattr(value, "id", None)
226 if isinstance(value, IDSchema):
227 sql_model = value.get_sql_model_annotation()
228 return sql_model, value.id
229 return None
232def _reference_identity_detail(identity: object) -> Any:
233 if identity is _EXPLICIT_NULL_REF:
234 return None
235 if isinstance(identity, tuple) and len(identity) == 2: 235 ↛ 237line 235 didn't jump to line 237 because the condition on line 235 was always true
236 return identity[1]
237 return identity
240def validate_resolved_reference_consistency(
241 model_cls: type[DeclarativeBase],
242 schema_obj: pydantic.BaseModel,
243 schema_cls: type[pydantic.BaseModel] | None = None,
244 resolved: dict[str, Any] | None = None,
245) -> None:
246 """Validate explicitly supplied FK and relationship fields agree.
248 IDRef/IDSchema resolution looks model-aware references up as ORM objects
249 (in ``resolved``) before object construction/update. If the client supplied
250 both ``author_id`` and ``author`` independently, they must refer to the same
251 row. ``resolved`` is the ``{field: object}`` mapping from the resolver; a
252 field absent from it keeps its (unresolved) value on ``schema_obj``.
253 """
254 if schema_cls is None:
255 schema_cls = schema_obj.__class__
256 resolved = resolved or {}
258 for fk_field in schema_obj.model_fields_set:
259 if not fk_field.endswith("_id") or not _is_reference_schema_field(
260 schema_cls, fk_field
261 ):
262 continue
264 relation_field = fk_field[:-3]
265 if (
266 relation_field not in schema_obj.model_fields_set
267 or not _is_reference_schema_field(schema_cls, relation_field)
268 or not _has_model_attr(model_cls, relation_field)
269 ):
270 continue
272 fk_identity = _reference_identity(
273 resolved.get(fk_field, getattr(schema_obj, fk_field, None))
274 )
275 relation_identity = _reference_identity(
276 resolved.get(relation_field, getattr(schema_obj, relation_field, None))
277 )
278 if fk_identity is None or relation_identity is None: 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true
279 continue
281 if fk_identity == relation_identity:
282 continue
284 raise fastapi.HTTPException(
285 status_code=422,
286 detail=(
287 f"Conflicting references for {fk_field} and {relation_field}: "
288 f"{_reference_identity_detail(fk_identity)!r} != "
289 f"{_reference_identity_detail(relation_identity)!r}"
290 ),
291 )
294def iter_creatable_fields(
295 schema_obj: pydantic.BaseModel, schema_cls: type[pydantic.BaseModel] | None = None
296) -> Iterator[tuple[str, Any]]:
297 """Iterate over (field_name, value) pairs that should be used to construct a new
298 ORM object from ``schema_obj``.
300 Fields marked as ``ReadOnly`` are skipped. Unlike :func:`get_writable_inputs`,
301 this also includes fields that were not explicitly provided, so that
302 schema-level defaults end up on the new object.
303 """
304 if schema_cls is None: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 schema_cls = schema_obj.__class__
306 for field_name, value in schema_obj:
307 if is_readonly_field(schema_cls, field_name):
308 continue
309 yield field_name, value
312def _add_resolved_reference_to_create_plan(
313 plan: _CreatePlan,
314 model_cls: type[DeclarativeBase],
315 field_name: str,
316 value: DeclarativeBase,
317) -> None:
318 ref = cast(_HasID, value)
319 if field_name.endswith("_id"):
320 fk_name = field_name
321 relation_name = field_name[:-3]
322 accepts_relation = _has_model_attr(
323 model_cls, relation_name
324 ) and _accepts_init_kwarg(model_cls, relation_name)
326 if (
327 _requires_init_kwarg(model_cls, fk_name)
328 and accepts_relation
329 and _requires_init_kwarg(model_cls, relation_name)
330 ):
331 plan.kwargs[fk_name] = ref.id
332 plan.kwargs[relation_name] = value
333 return
335 if accepts_relation and _requires_init_kwarg(model_cls, relation_name): 335 ↛ 336line 335 didn't jump to line 336 because the condition on line 335 was never true
336 plan.kwargs[relation_name] = value
337 if _has_model_attr(model_cls, fk_name):
338 plan.post_assignments[fk_name] = ref.id
339 return
341 if _accepts_init_kwarg(model_cls, fk_name):
342 plan.kwargs[fk_name] = ref.id
343 if _has_model_attr(model_cls, relation_name):
344 plan.post_assignments[relation_name] = value
345 return
347 if accepts_relation:
348 plan.kwargs[relation_name] = value
349 plan.post_assignments[fk_name] = ref.id
350 return
352 if _has_model_attr(model_cls, fk_name): 352 ↛ 354line 352 didn't jump to line 354 because the condition on line 352 was always true
353 plan.post_assignments[fk_name] = ref.id
354 if _has_model_attr(model_cls, relation_name): 354 ↛ 356line 354 didn't jump to line 356 because the condition on line 354 was always true
355 plan.post_assignments[relation_name] = value
356 return
358 relation_name = field_name
359 fk_name = _get_unambiguous_local_fk_name(model_cls, relation_name)
361 if _has_model_attr(model_cls, relation_name) and _accepts_init_kwarg(
362 model_cls, relation_name
363 ):
364 plan.kwargs[relation_name] = value
365 _add_assignment(plan.post_assignments, fk_name, ref.id)
366 return
368 if fk_name and _accepts_init_kwarg(model_cls, fk_name): 368 ↛ 374line 368 didn't jump to line 374 because the condition on line 368 was always true
369 plan.kwargs[fk_name] = ref.id
370 if _has_model_attr(model_cls, relation_name): 370 ↛ 372line 370 didn't jump to line 372 because the condition on line 370 was always true
371 plan.post_assignments[relation_name] = value
372 return
374 if _has_model_attr(model_cls, relation_name):
375 plan.post_assignments[relation_name] = value
376 _add_assignment(plan.post_assignments, fk_name, ref.id)
379def build_create_plan(
380 model_cls: type[DeclarativeBase],
381 schema_obj: pydantic.BaseModel,
382 schema_cls: type[pydantic.BaseModel] | None = None,
383 resolved: dict[str, Any] | None = None,
384) -> _CreatePlan:
385 """Translate ``schema_obj`` fields into kwargs for ``model_cls(**kwargs)``.
387 Shared by sync and async ``make_new_object``. ``resolved`` is the
388 ``{field: object_or_list}`` mapping returned by the IDSchema resolver (sync
389 vs async); a resolved reference field uses that ORM value instead of the
390 wire-shaped ``IDRef`` still on ``schema_obj``.
391 """
392 if schema_cls is None:
393 schema_cls = schema_obj.__class__
394 resolved = resolved or {}
396 plan = _CreatePlan(kwargs={}, post_assignments={})
397 for field_name, value in iter_creatable_fields(schema_obj, schema_cls):
398 if field_name in resolved:
399 value = resolved[field_name]
400 if isinstance(value, IDSchema) and field_name.endswith("_id"):
401 if _accepts_init_kwarg(model_cls, field_name): 401 ↛ 403line 401 didn't jump to line 403 because the condition on line 401 was always true
402 plan.kwargs[field_name] = value.id
403 elif _has_model_attr(model_cls, field_name):
404 plan.post_assignments[field_name] = value.id
405 continue
406 if isinstance(value, DeclarativeBase) and _is_reference_schema_field(
407 schema_cls, field_name
408 ):
409 _add_resolved_reference_to_create_plan(plan, model_cls, field_name, value)
410 continue
412 if _accepts_init_kwarg(model_cls, field_name): 412 ↛ 414line 412 didn't jump to line 414 because the condition on line 412 was always true
413 plan.kwargs[field_name] = value
414 elif _has_model_attr(model_cls, field_name):
415 plan.post_assignments[field_name] = value
416 return plan
419def build_create_kwargs(
420 model_cls: type[DeclarativeBase],
421 schema_obj: pydantic.BaseModel,
422 schema_cls: type[pydantic.BaseModel] | None = None,
423 resolved: dict[str, Any] | None = None,
424) -> dict[str, Any]:
425 return build_create_plan(model_cls, schema_obj, schema_cls, resolved).kwargs
428def apply_create_assignments(obj: DeclarativeBase, assignments: dict[str, Any]) -> None:
429 for field_name, value in assignments.items():
430 setattr(obj, field_name, value)
433def _apply_resolved_reference_update(
434 obj: DeclarativeBase, field_name: str, value: DeclarativeBase
435) -> None:
436 ref = cast(_HasID, value)
437 model_cls = type(obj)
438 if field_name.endswith("_id"): 438 ↛ 445line 438 didn't jump to line 445 because the condition on line 438 was always true
439 setattr(obj, field_name, ref.id)
440 relation_name = field_name[:-3]
441 if hasattr(obj, relation_name): 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true
442 setattr(obj, relation_name, value)
443 return
445 if hasattr(obj, field_name):
446 setattr(obj, field_name, value)
448 fk_name = _get_unambiguous_local_fk_name(model_cls, field_name)
449 if fk_name:
450 setattr(obj, fk_name, ref.id)
453def apply_update_to_object(
454 obj: DeclarativeBase,
455 schema_obj: pydantic.BaseModel,
456 schema_cls: type[pydantic.BaseModel] | None = None,
457 resolved: dict[str, Any] | None = None,
458) -> None:
459 """Apply writable inputs from ``schema_obj`` onto ``obj`` in place.
461 Shared by sync and async ``update_object``. ``resolved`` is the
462 ``{field: object_or_list}`` mapping returned by the IDSchema resolver (sync
463 vs async); a resolved reference field uses that ORM value instead of the
464 wire-shaped ``IDRef`` still on ``schema_obj``.
465 """
466 resolved = resolved or {}
467 for field_name, value in get_writable_inputs(schema_obj, schema_cls).items():
468 if field_name in resolved:
469 value = resolved[field_name]
470 if isinstance(value, IDSchema) and field_name.endswith("_id"):
471 setattr(obj, field_name, value.id)
472 continue
473 if isinstance(value, DeclarativeBase) and _is_reference_schema_field(
474 schema_cls or schema_obj.__class__, field_name
475 ):
476 _apply_resolved_reference_update(obj, field_name, value)
477 continue
478 setattr(obj, field_name, value)
481def _unwrap_optional_annotation(annotation: Any) -> Any:
482 origin = get_origin(annotation)
483 if origin not in (types.UnionType, Union, None):
484 return annotation
486 if origin is None:
487 return annotation
489 non_none_args = [arg for arg in get_args(annotation) if arg is not type(None)]
490 if len(non_none_args) == 1: 490 ↛ 492line 490 didn't jump to line 492 because the condition on line 490 was always true
491 return non_none_args[0]
492 return annotation
495def _is_idschema_reference_annotation(annotation: Any) -> bool:
496 annotation = _unwrap_optional_annotation(annotation)
497 if annotation in (IDSchema, IDRef):
498 return True
499 if not inspect.isclass(annotation):
500 return False
501 try:
502 if not issubclass(annotation, IDSchema):
503 return False
504 except TypeError:
505 return False
506 metadata = getattr(annotation, "__pydantic_generic_metadata__", {})
507 return metadata.get("origin") in (IDSchema, IDRef)
510def _serialize_idschema_value(annotation: Any, value: Any) -> Any:
511 if value is None: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true
512 return None
513 id_value = value.id if hasattr(value, "id") else value
514 if inspect.isclass(annotation) and issubclass(annotation, IDRef):
515 return id_value
516 if inspect.isclass(annotation) and issubclass(annotation, IDSchema): 516 ↛ 518line 516 didn't jump to line 518 because the condition on line 516 was always true
517 return annotation.model_construct(id=id_value)
518 return {"id": id_value}
521def _serialize_response_value(annotation: Any, value: Any) -> Any:
522 annotation = _unwrap_optional_annotation(annotation)
524 if _is_idschema_reference_annotation(annotation):
525 return _serialize_idschema_value(annotation, value)
527 origin = get_origin(annotation)
528 if origin is list:
529 item_annotation = get_args(annotation)[0] if get_args(annotation) else Any
530 if _is_idschema_reference_annotation(item_annotation) and isinstance( 530 ↛ 533line 530 didn't jump to line 533 because the condition on line 530 was never true
531 value, Sequence
532 ):
533 return [_serialize_idschema_value(item_annotation, item) for item in value]
535 return value
538def _get_nested_schema_annotation(annotation: Any) -> type[pydantic.BaseModel] | None:
539 annotation = _unwrap_optional_annotation(annotation)
541 try:
542 if inspect.isclass(annotation) and issubclass(annotation, pydantic.BaseModel):
543 return annotation
544 except TypeError:
545 pass
547 origin = get_origin(annotation)
548 if origin is list: 548 ↛ 553line 548 didn't jump to line 553 because the condition on line 548 was always true
549 args = get_args(annotation)
550 if args: 550 ↛ 553line 550 didn't jump to line 553 because the condition on line 550 was always true
551 return _get_nested_schema_annotation(args[0])
553 return None
556class _OmitWriteOnlyMixin(pydantic.BaseModel):
557 @classmethod
558 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
559 super().__pydantic_init_subclass__(**kwargs)
561 writeonly_fields = [
562 name for name in cls.model_fields if is_writeonly_field(cls, name)
563 ]
564 for name in writeonly_fields:
565 del cls.model_fields[name]
567 cls.model_rebuild(force=True)
570@functools.cache
571def _create_response_validation_schema(
572 schema_cls: type[pydantic.BaseModel],
573) -> type[pydantic.BaseModel]:
574 if not any(
575 is_writeonly_field(schema_cls, name) for name in schema_cls.model_fields
576 ):
577 return schema_cls
579 return type(
580 f"Response{schema_cls.__name__}",
581 (_OmitWriteOnlyMixin, schema_cls),
582 {
583 "__module__": schema_cls.__module__,
584 "__doc__": (schema_cls.__doc__ or "")
585 + "\nWrite-only fields have been removed for response validation.",
586 },
587 )
590def _build_relationship_loader_options(
591 model_cls: type[DeclarativeBase],
592 schema_cls: type[pydantic.BaseModel],
593 seen: set[tuple[type[DeclarativeBase], type[pydantic.BaseModel]]] | None = None,
594) -> list[Any]:
595 if seen is None:
596 seen = set()
598 visit_key = (model_cls, schema_cls)
599 if visit_key in seen: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true
600 return []
601 seen = seen | {visit_key}
603 mapper = sa_inspect(model_cls)
604 options: list[Any] = []
605 for field_name, field_info in schema_cls.model_fields.items():
606 if field_name not in mapper.relationships:
607 continue
609 relationship_prop = mapper.relationships[field_name]
610 loader = selectinload(getattr(model_cls, field_name))
611 nested_schema = _get_nested_schema_annotation(field_info.annotation)
613 if nested_schema is not None: 613 ↛ 620line 613 didn't jump to line 620 because the condition on line 613 was always true
614 child_options = _build_relationship_loader_options(
615 relationship_prop.mapper.class_, nested_schema, seen
616 )
617 if child_options:
618 loader = loader.options(*child_options)
620 options.append(loader)
622 return options
625class View:
626 """
627 Class-based view primitive for FastAPI.
629 Group related endpoints on a class, share dependencies and metadata via
630 class attributes, and let subclasses override individual handlers. Routes
631 are bound at :func:`include_view` time, not at class-definition time, so
632 subclassing works the way Python developers expect: override a method on
633 a subclass and the override is what runs.
635 Most users will subclass :class:`RestView` or :class:`AsyncRestView`,
636 which extend ``View`` with CRUD scaffolding. Use ``View`` directly for
637 grouped non-CRUD endpoints (auth flows, custom RPC routes, etc.).
638 """
640 prefix: ClassVar[str]
641 tags: ClassVar[Any] = None
642 dependencies: ClassVar[Any] = None
643 responses: ClassVar[dict[int | str, dict[str, Any]]] = {}
645 @classmethod
646 def before_include_view(cls):
647 pass
650V = TypeVar("V", bound=type[View])
653@overload
654def include_view(
655 parent_router: fastapi.APIRouter | fastapi.FastAPI, view_cls: V
656) -> V: ...
657@overload
658def include_view(
659 parent_router: fastapi.APIRouter | fastapi.FastAPI,
660) -> Callable[[V], V]: ...
663def include_view(
664 parent_router: fastapi.APIRouter | fastapi.FastAPI, view_cls: V | None = None
665) -> V | Callable[[V], V]:
666 """
667 Add a View class's routes to a FastAPI app or APIRouter.
669 Prefer the direct call form from your app/router composition layer::
671 include_view(app, MyView)
673 For small apps, it can also be used as a decorator::
675 @include_view(app)
676 class MyView(AsyncRestView):
677 ...
678 """
679 if view_cls is not None:
680 _init_view_cls_and_add_to_router(view_cls, parent_router)
681 return view_cls
683 def class_decorator(view_cls: V) -> V:
684 _init_view_cls_and_add_to_router(view_cls, parent_router)
685 return view_cls
687 return class_decorator
690def route(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
691 """Decorator to mark a View method as an endpoint.
692 The path and api_route_kwargs are passed into APIRouter.add_api_route(), see for example:
693 https://fastapi.tiangolo.com/reference/apirouter/#fastapi.APIRouter.get
695 Endpoints methods are later added as routes to the FastAPI app using `include_view()`
696 """
698 def store_args_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
699 # Create a new attribute: '_api_route_args'
700 func._api_route_args = (path, api_route_kwargs) # type: ignore[attr-defined]
701 return func
703 return store_args_decorator
706def get(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
707 """Decorator to mark a View method as a GET endpoint.
709 Equivalent to::
711 @route(path, methods=["GET"], status_code=200, ... )
712 """
713 api_route_kwargs.setdefault("methods", ["GET"])
714 api_route_kwargs.setdefault("status_code", 200)
715 return route(path, **api_route_kwargs)
718def post(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
719 """Decorator to mark a View method as a POST endpoint.
721 Equivalent to::
723 @route(path, methods=["POST"], status_code=201, ... )
724 """
725 api_route_kwargs.setdefault("methods", ["POST"])
726 api_route_kwargs.setdefault("status_code", 201)
727 return route(path, **api_route_kwargs)
730def put(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
731 """Decorator to mark a View method as a PUT endpoint.
733 Equivalent to::
735 @route(path, methods=["PUT"], status_code=200, ... )
736 """
737 api_route_kwargs.setdefault("methods", ["PUT"])
738 api_route_kwargs.setdefault("status_code", 200)
739 return route(path, **api_route_kwargs)
742def patch(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
743 """Decorator to mark a View method as a PATCH endpoint.
745 Equivalent to::
747 @route(path, methods=["PATCH"], status_code=200, ... )
748 """
749 api_route_kwargs.setdefault("methods", ["PATCH"])
750 api_route_kwargs.setdefault("status_code", 200)
751 return route(path, **api_route_kwargs)
754def delete(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
755 """Decorator to mark a View method as a DELETE endpoint.
757 Equivalent to::
759 @route(path, methods=["DELETE"], status_code=204, ... )
760 """
761 api_route_kwargs.setdefault("methods", ["DELETE"])
762 api_route_kwargs.setdefault("status_code", 204)
763 return route(path, **api_route_kwargs)
766class BaseRestView(View, Generic[ModelT, SchemaT, CreateSchemaT, UpdateSchemaT, IdT]):
767 """
768 Base class for RestView implementations.
770 This class contains the common functionality shared between AsyncRestView
771 and RestView, including schema definitions, model configuration, and
772 common CRUD operation logic.
773 """
775 responses: ClassVar[dict[int | str, dict[str, Any]]] = {
776 404: {"description": "Not found"}
777 }
779 schema: ClassVar[type[pydantic.BaseModel]]
780 # If 'schema_create' is not defined it will be created from 'schema'
781 # using `create_model_without_read_only_fields()`.
782 schema_create: ClassVar[type[pydantic.BaseModel]]
783 schema_update: ClassVar[type[pydantic.BaseModel]]
784 model: ClassVar[type[DeclarativeBase]]
785 id_type: ClassVar[type[Any]] = int
786 include_pagination_metadata: ClassVar[bool] = (
787 False # Set True to include count/total in list responses
788 )
789 exclude_routes: ClassVar[Iterable[str | ViewRoute]] = ()
790 #: Extra query-parameter keys to allow on the listing endpoint beyond those
791 #: derived from the response schema. Use this when a view consumes a custom
792 #: parameter (e.g. ``?include_deleted=true`` on a soft-delete mixin). Without
793 #: this, the strict unknown-key guard rejects the request with 422.
794 extra_query_params: ClassVar[Iterable[str]] = ()
795 #: Default ``page_size`` for list endpoints. ``None`` means "no implicit
796 #: cap" (the framework default). Override per-view.
797 default_page_size: ClassVar[int | None] = DEFAULT_PAGE_SIZE
798 #: Maximum ``page_size`` accepted on list endpoints. Above this returns 422.
799 max_page_size: ClassVar[int] = MAX_PAGE_SIZE
800 listing_param_schema: ClassVar[type[pydantic.BaseModel]]
801 pagination_response_schema: ClassVar[type[pydantic.BaseModel]]
803 request: fastapi.Request
805 def get_relationship_loader_options(self) -> list[Any]:
806 return _build_relationship_loader_options(self.model, self.schema)
808 def _reject_unknown_query_params(self) -> None:
809 """Reject any query-string key that isn't part of ``listing_param_schema``.
811 FastAPI flattens ``Annotated[listing_param_schema, Query()]`` into named
812 query parameters; unknown keys are silently ignored at that layer,
813 which would let typoed filters or unsupported operators (e.g.
814 ``active__gte=true`` on a boolean column where the schema does not
815 emit a range operator) widen the result set without telling the
816 caller. We treat unknown keys as a validation error instead, mirroring
817 FastAPI's 422 envelope shape so the response is consistent with
818 bound-violation errors.
820 No-op when there's no live request (programmatic ``view.listing(...)``
821 calls outside an HTTP request) — there's no URL surface to validate
822 and the in-process caller is responsible for what they pass.
823 """
824 request = getattr(self, "request", None)
825 if request is None:
826 return
827 listing_schema = getattr(self, "listing_param_schema", None)
828 if listing_schema is None: 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true
829 return
830 allowed = set(listing_schema.model_fields) | set(self.extra_query_params)
831 sent = set(request.query_params.keys())
832 unknown = sent - allowed
833 if not unknown:
834 return
835 detail = [
836 {
837 "type": "extra_forbidden",
838 "loc": ["query", key],
839 "msg": f"Unknown query parameter {key!r}",
840 "input": request.query_params.get(key),
841 }
842 for key in sorted(unknown)
843 ]
844 raise fastapi.HTTPException(status_code=422, detail=detail)
846 def to_response_schema(self, obj: ModelT | SchemaT) -> SchemaT:
847 """Serialize an ORM object to the configured response schema.
849 WriteOnly fields are stripped from responses by ``exclude=True`` on the
850 marker itself (recursively, at serialization time), so a pre-built schema
851 instance is safe to return as-is. The ORM path below still validates
852 through the WriteOnly-omitting response schema, so a read schema that
853 declares a WriteOnly field the ORM object doesn't carry (e.g. ``password``
854 backed by a ``password_hash`` column) doesn't fail response validation.
855 """
856 if isinstance(obj, self.schema):
857 return cast(SchemaT, obj)
859 # Build a payload using schema field names. Alias rendering happens
860 # when FastAPI serializes the response model.
861 payload: dict[str, Any] = {}
862 for field_name, field_info in self.schema.model_fields.items():
863 if is_writeonly_field(self.schema, field_name):
864 continue
865 if hasattr(obj, field_name): 865 ↛ 870line 865 didn't jump to line 870 because the condition on line 865 was always true
866 value = getattr(obj, field_name)
867 payload[field_name] = _serialize_response_value(
868 field_info.annotation, value
869 )
870 elif field_info.alias and hasattr(obj, field_info.alias):
871 payload[field_name] = getattr(obj, field_info.alias)
873 response_schema = _create_response_validation_schema(self.schema)
874 return cast(
875 SchemaT,
876 response_schema.model_validate(payload, by_alias=False, by_name=True),
877 )
879 @staticmethod
880 def _to_query_params(query_params: Any) -> QueryParams:
881 if isinstance(query_params, QueryParams): 881 ↛ 882line 881 didn't jump to line 882 because the condition on line 881 was never true
882 return query_params
883 if isinstance(query_params, pydantic.BaseModel):
884 dumped = query_params.model_dump(
885 exclude_none=True, by_alias=True, mode="json"
886 )
887 return QueryParams({k: str(v) for k, v in dumped.items()})
888 if isinstance(query_params, dict): 888 ↛ 890line 888 didn't jump to line 890 because the condition on line 888 was always true
889 return QueryParams({k: str(v) for k, v in query_params.items()})
890 return QueryParams(query_params)
892 @classmethod
893 def _create_pagination_response_schema(
894 cls, response_schema: type[pydantic.BaseModel]
895 ) -> type[pydantic.BaseModel]:
896 return create_model(
897 f"{cls.__name__}PaginatedResponse",
898 items=(Sequence[response_schema], ...),
899 total=(int, ...),
900 page=(int | None, None),
901 page_size=(int | None, None),
902 total_pages=(int | None, None),
903 )
905 def to_paginated_listing_response(
906 self, query_params: Any, listing_result: ListingResult[Any]
907 ) -> dict[str, Any]:
908 params = self._to_query_params(query_params)
909 payload: dict[str, Any] = {
910 "items": [self.to_response_schema(obj) for obj in listing_result.objects],
911 "total": listing_result.total_count,
912 "page": None,
913 "page_size": None,
914 "total_pages": None,
915 }
916 page_size_raw = params.get("page_size")
917 if page_size_raw is None and self.default_page_size is None:
918 # No implicit cap and the client did not ask for one. Leave
919 # page/page_size/total_pages as None.
920 return payload
921 page = int(params.get("page", "1"))
922 if page_size_raw is not None: 922 ↛ 926line 922 didn't jump to line 926 because the condition on line 922 was always true
923 page_size = int(page_size_raw)
924 else:
925 # The early return above guarantees default_page_size is non-None here.
926 page_size = cast(int, self.default_page_size)
927 payload["page"] = page
928 payload["page_size"] = page_size
929 payload["total_pages"] = (
930 ceil(listing_result.total_count / page_size) if page_size > 0 else 0
931 )
932 return payload
934 def to_listing_response(
935 self, query_params: Any, listing_result: ListingResult[ModelT]
936 ) -> Any:
937 if not self.include_pagination_metadata:
938 return [self.to_response_schema(obj) for obj in listing_result.objects]
940 return self.to_paginated_listing_response(query_params, listing_result)
942 def to_response(
943 self, obj_or_list: Any, shape: ResponseShape = ResponseShape.SINGLE
944 ) -> Any:
945 """Route-shell response boundary.
947 ``shape`` selects the wire form: single object, listing, or empty. It is
948 not the write-action name. Override for envelopes or shape-wide status
949 behavior; per-endpoint projections belong in the route shell.
950 """
951 if shape is ResponseShape.EMPTY:
952 return fastapi.Response(status_code=204)
953 if shape is ResponseShape.LISTING:
954 return self.to_listing_response(obj_or_list.query_params, obj_or_list)
955 return self.to_response_schema(obj_or_list)
957 def snapshot(self, obj: Any) -> dict[str, Any]:
958 """Frozen capture of an object's already-loaded column values, passed as
959 ``old`` to ``before_commit`` / ``after_commit`` for dirty detection.
960 Override to change what ``old`` captures (e.g. include a relationship's
961 prior state); the default delegates to :func:`fastapi_restly.snapshot`.
962 """
963 return _object_snapshot(obj)
965 @classmethod
966 def before_include_view(cls):
967 """
968 Apply type annotations needed for FastAPI, before creating an APIRouter from
969 this view and registering it.
971 This function can be overridden to further tweak the endpoints before they
972 are added to FastAPI.
973 """
974 # Auto-generate schema if none is provided. Each of these guards
975 # checks ``cls.__dict__`` — not ``hasattr`` — so a subclass that
976 # changes ``schema``/``default_page_size``/``max_page_size`` regenerates
977 # the derived schemas instead of silently inheriting the parent's.
978 if "schema" not in cls.__dict__:
979 if not hasattr(cls, "model"): 979 ↛ 980line 979 didn't jump to line 980 because the condition on line 979 was never true
980 raise ValueError(
981 f"'{cls.__name__}.model' must be specified to auto-generate schema"
982 )
983 cls.schema = cast(
984 type[SchemaT], auto_generate_schema_for_view(cls, cls.model)
985 )
987 if "listing_param_schema" not in cls.__dict__: 987 ↛ 999line 987 didn't jump to line 999 because the condition on line 987 was always true
988 if not hasattr(cls, "model"):
989 raise ValueError(
990 f"'{cls.__name__}.model' must be specified: it is needed to "
991 "generate list query parameters."
992 )
993 cls.listing_param_schema = create_list_params_schema(
994 cls.schema,
995 cls.model,
996 default_page_size=cls.default_page_size,
997 max_page_size=cls.max_page_size,
998 )
999 if "schema_create" not in cls.__dict__: 999 ↛ 1003line 999 didn't jump to line 1003 because the condition on line 999 was always true
1000 cls.schema_create = cast(
1001 type[CreateSchemaT], create_model_without_read_only_fields(cls.schema)
1002 )
1003 if "schema_update" not in cls.__dict__: 1003 ↛ 1012line 1003 didn't jump to line 1012 because the condition on line 1003 was always true
1004 cls.schema_update = cast(
1005 type[UpdateSchemaT], create_model_with_optional_fields(cls.schema)
1006 )
1008 # WriteOnly fields are excluded from responses by ``exclude=True`` on the
1009 # marker (recursively, and from the OpenAPI response schema -- FastAPI's
1010 # serialization-mode schema drops them), so the response_model can be the
1011 # full schema.
1012 response_schema = cls.schema
1014 # Only annotate if the methods exist (they will be overridden in subclasses)
1015 listing_response_annotation: Any = Sequence[response_schema]
1016 if cls.include_pagination_metadata:
1017 cls.pagination_response_schema = cls._create_pagination_response_schema(
1018 response_schema
1019 )
1020 listing_response_annotation = cls.pagination_response_schema
1022 # The ``*_endpoint`` route shells are defined on AsyncRestView/RestView
1023 # subclasses and may be excluded by ``exclude_routes``, so they aren't
1024 # visible on BaseRestView. ``getattr`` keeps pyright happy without
1025 # falsely advertising them on the base class.
1026 if (ep := getattr(cls, "get_many_endpoint", None)) is not None: 1026 ↛ 1032line 1026 didn't jump to line 1032 because the condition on line 1026 was always true
1027 _annotate(
1028 ep,
1029 return_annotation=listing_response_annotation,
1030 query_params=Annotated[cls.listing_param_schema, fastapi.Query()],
1031 )
1032 if (ep := getattr(cls, "get_one_endpoint", None)) is not None: 1032 ↛ 1034line 1032 didn't jump to line 1034 because the condition on line 1032 was always true
1033 _annotate(ep, return_annotation=response_schema, id=cls.id_type)
1034 if (ep := getattr(cls, "create_endpoint", None)) is not None: 1034 ↛ 1038line 1034 didn't jump to line 1038 because the condition on line 1034 was always true
1035 _annotate(
1036 ep, return_annotation=response_schema, schema_obj=cls.schema_create
1037 )
1038 if (ep := getattr(cls, "update_endpoint", None)) is not None: 1038 ↛ 1045line 1038 didn't jump to line 1045 because the condition on line 1038 was always true
1039 _annotate(
1040 ep,
1041 return_annotation=response_schema,
1042 schema_obj=cls.schema_update,
1043 id=cls.id_type,
1044 )
1045 if (ep := getattr(cls, "delete_endpoint", None)) is not None: 1045 ↛ 1047line 1045 didn't jump to line 1047 because the condition on line 1045 was always true
1046 _annotate(ep, return_annotation=fastapi.Response, id=cls.id_type)
1047 _exclude_routes(cls)
1050def _exclude_routes(cls: type[BaseRestView[Any, Any, Any, Any, Any]]):
1051 for route_name in cls.exclude_routes:
1052 method_name = (
1053 route_name.value if isinstance(route_name, ViewRoute) else route_name
1054 )
1055 # @route decorator adds `_api_route_args` to a method to create the route later.
1056 # By removing it from the method, the method will no longer be added as a route.
1057 view_func = getattr(cls, method_name, None)
1058 if view_func is not None and hasattr(view_func, "_api_route_args"):
1059 del view_func._api_route_args
1060 continue
1061 # Not a live route on this class. Tolerate an exclusion that is *already*
1062 # satisfied: a subclass that inherits ``exclude_routes`` from a parent
1063 # which already excluded the route never receives a routable copy, so
1064 # there is nothing to strip. The name is still a genuine route elsewhere
1065 # in the lineage -- only raise when it is no route at all (a typo, or the
1066 # business verb name instead of the ``*_endpoint`` route name).
1067 if not _is_route_name_in_lineage(cls, method_name):
1068 raise AttributeError(f"{method_name!r} is not a route on {cls.__name__}")
1071def _is_route_name_in_lineage(
1072 cls: type[BaseRestView[Any, Any, Any, Any, Any]], method_name: str
1073) -> bool:
1074 """True if any class in ``cls``'s MRO defines a routable endpoint of this
1075 name -- so the name is a real route that may merely be already-excluded here.
1076 """
1077 return any(
1078 hasattr(klass.__dict__.get(method_name), "_api_route_args")
1079 for klass in cls.mro()
1080 )
1083def _init_view_cls_and_add_to_router(
1084 view_cls: type[View], parent_router: fastapi.APIRouter | fastapi.FastAPI
1085):
1086 """
1087 To make View classes work in FastAPI some hacks are needed. Those hacks are
1088 applied here.
1090 FastAPI does a lot with annotations. For example, accepted or returned JSON is
1091 often described with Pydantic classes like this:
1093 def my_endpoint(foo: FooRead) -> FooRead:
1095 Most of the hacks here are to set the correct annotations on (inherited) class
1096 methods.
1098 The class-level preparation (copying parent endpoints, renaming, annotating,
1099 schema generation, dataclass-style __init__) only runs once per View class —
1100 subsequent calls to ``include_view()`` reuse the prepared class and only
1101 construct a fresh APIRouter to mount on the new parent. This makes
1102 registering the same view on *different* routers safe (e.g. a public app and
1103 an admin app, or ``/v1`` and ``/v2`` sub-apps).
1105 Re-mounting the same view on the *same* router, however, duplicates its
1106 routes (each call still runs ``include_router``); don't register a view more
1107 than once on a given parent. (Tracked: bug for a same-router idempotency
1108 guard.)
1109 """
1110 _prepare_view_class(view_cls)
1111 api_router = _init_api_router(view_cls)
1112 _register_for_resource_ref(parent_router, view_cls)
1113 parent_router.include_router(api_router)
1114 # Fallback registration for users who skip ``fr.configure(app=...)``.
1115 # ``register_default_exception_handlers`` is idempotent and only acts on
1116 # FastAPI apps (it ignores nested APIRouter parents).
1117 if isinstance(parent_router, fastapi.FastAPI):
1118 register_default_exception_handlers(parent_router)
1121#: Bare business-verb method names. A ``@route``-decorated method must not be
1122#: named like one of these: it would shadow the verb (which the ``handle_<verb>``
1123#: handlers call) and collide with the ``<verb>_endpoint`` route shell at the
1124#: same path. Override the bare verb *without* a decorator for domain logic; use
1125#: ``<verb>_endpoint`` or a distinct name for a custom route.
1126_BARE_VERB_NAMES = frozenset({"get_many", "get_one", "create", "update", "delete"})
1129def _reject_bare_verb_route_names(view_cls: type[View]) -> None:
1130 for name, value in view_cls.__dict__.items():
1131 if name in _BARE_VERB_NAMES and hasattr(value, "_api_route_args"):
1132 raise TypeError(
1133 f"{view_cls.__name__}.{name}() is a route method named like the "
1134 f"business verb '{name}'. A route by that name shadows the verb "
1135 f"and collides with the '{name}_endpoint' route shell. Rename it "
1136 f"to '{name}_endpoint' (to replace the shell) or give the custom "
1137 f"action its own name."
1138 )
1141#: The five wire-tier route shells generated by RestView / AsyncRestView.
1142_SHELL_NAMES = frozenset(
1143 {
1144 "get_many_endpoint",
1145 "get_one_endpoint",
1146 "create_endpoint",
1147 "update_endpoint",
1148 "delete_endpoint",
1149 }
1150)
1153def _warn_on_misuse(view_cls: type[View]) -> None:
1154 """Opt-in registration-time lint (``fr.configure(warn_on_misuse=True)``).
1156 Flags the three dominant misuse patterns with the idiomatic fix named in
1157 each message. Heuristic, best-effort, and advisory: every pattern it flags
1158 has a legitimate use, so it warns (:class:`RestlyMisuseWarning`) rather
1159 than rejects. Must run *before* parent endpoints are copied into the
1160 subclass, while ``__dict__`` still holds only what the user wrote; only
1161 the registered class is linted, not user-defined intermediate bases.
1162 """
1163 own = view_cls.__dict__
1164 name = view_cls.__name__
1165 is_crud_view = issubclass(view_cls, BaseRestView)
1167 # 1. Route-shell override where a business-verb override was likely meant.
1168 if is_crud_view:
1169 for shell in sorted(_SHELL_NAMES & own.keys()):
1170 verb = shell.removesuffix("_endpoint")
1171 warnings.warn(
1172 f"{name} overrides the route shell '{shell}' (the wire tier). "
1173 f"Override a shell only to change the HTTP contract. For "
1174 f"domain logic override the bare verb '{verb}'; for "
1175 f"orchestration 'handle_{verb}'; for the response shape "
1176 f"'to_response'.",
1177 RestlyMisuseWarning,
1178 stacklevel=5,
1179 )
1181 # 2. Manual session.commit() in a view method. The framework owns the
1182 # commit; methods that go through write_action / handle_<verb> are exempt.
1183 for attr, value in own.items():
1184 func = getattr(value, "__func__", value)
1185 if not isinstance(func, types.FunctionType):
1186 continue
1187 try:
1188 source = inspect.getsource(func)
1189 except (OSError, TypeError):
1190 continue
1191 if (
1192 ".commit(" in source
1193 and "write_action" not in source
1194 and "handle_" not in source
1195 ):
1196 warnings.warn(
1197 f"{name}.{attr} calls session.commit() directly. The framework "
1198 f"owns the commit: reuse handle_<verb>(), or bracket the "
1199 f"mutation with write_action('<action>', ...) so authorize / "
1200 f"before_commit / after_commit run.",
1201 RestlyMisuseWarning,
1202 stacklevel=5,
1203 )
1205 # 3. A CRUD route set hand-rolled on a bare View.
1206 if not is_crud_view:
1207 http_methods: set[str] = set()
1208 n_routes = 0
1209 for value in own.values():
1210 route_args = getattr(value, "_api_route_args", None)
1211 if route_args is None:
1212 continue
1213 _path, route_kwargs = route_args
1214 n_routes += 1
1215 http_methods.update(
1216 method.upper() for method in route_kwargs.get("methods", ["GET"])
1217 )
1218 if (
1219 n_routes >= 3
1220 and {"GET", "POST"} <= http_methods
1221 and http_methods & {"PATCH", "PUT", "DELETE"}
1222 ):
1223 warnings.warn(
1224 f"{name} hand-rolls a CRUD route set on a bare View. RestView / "
1225 f"AsyncRestView generate list/create/get/update/delete from "
1226 f"`model` + `schema`; subclass one and override the bare verbs "
1227 f"(create/update/delete), build_query, or authorize for custom "
1228 f"behavior.",
1229 RestlyMisuseWarning,
1230 stacklevel=5,
1231 )
1234def _reject_buried_markers_in_view_schemas(view_cls: type[View]) -> None:
1235 """Backstop the ``BaseSchema`` import-time check at view registration.
1237 ``BaseSchema.__pydantic_init_subclass__`` already rejects a buried
1238 ReadOnly/WriteOnly marker as the schema class is defined, but a view may use
1239 a schema (and derived create/update schemas) that does not subclass
1240 ``BaseSchema``. Re-check the schemas this view actually uses so those are
1241 covered too.
1242 """
1243 checked: set[type] = set()
1244 for attr in ("schema", "schema_create", "schema_update"):
1245 schema = getattr(view_cls, attr, None)
1246 if schema is None or schema in checked:
1247 continue
1248 checked.add(schema)
1249 _reject_buried_markers(schema)
1252def _prepare_view_class(view_cls: type[View]) -> None:
1253 """Run the one-time class-level setup for a View.
1255 Guarded by the ``_fr_initialised`` marker (stored in ``__dict__`` so it is
1256 not inherited from a parent class that was registered separately). Calling
1257 this multiple times is a no-op after the first run.
1258 """
1259 if view_cls.__dict__.get("_fr_initialised", False):
1260 return
1261 if _fr_globals.warn_on_misuse:
1262 _warn_on_misuse(view_cls)
1263 _copy_all_parent_class_endpoints_into_this_subclass(view_cls)
1264 _reject_bare_verb_route_names(view_cls)
1265 _init_all_endpoints(view_cls)
1266 view_cls.before_include_view()
1267 _reject_buried_markers_in_view_schemas(view_cls)
1268 _init_class_based_view(view_cls)
1269 view_cls._fr_initialised = True # type: ignore[attr-defined]
1272def _copy_all_parent_class_endpoints_into_this_subclass(view_cls: type[View]):
1273 """
1274 Override all methods with a @route decorator of the parent classes of view_cls
1275 with a new copy directly on view_cls . This allows us to change the
1276 annotations on these endpoints without affecting the parent endpoints.
1278 For example, FooView.get() delegates to AsyncRestView.get() if it is not
1279 overridden (this is called implicit delegation through method resolution). And if
1280 we add the annotation that FooView.get() returns FooRead but do not make a copy
1281 then AsyncRestView.get() and all other subclasses will get the FooRead
1282 annotation as well.
1283 """
1284 for name, endpoint in _get_all_parent_endpoints(view_cls).items():
1285 # `name` is the attribute key (e.g. "get_many_endpoint"), which is stable
1286 # across copies; `endpoint.__name__` may have been mangled by a parent's
1287 # own registration (e.g. "parentview_get_many_endpoint"), so key off the
1288 # attribute name throughout.
1289 if name in view_cls.__dict__:
1290 # This endpoint is already overridden!
1291 continue
1293 # The original endpoint might be shared between subclasses.
1294 # So make a copy and put that on the view_cls.
1295 endpoint_wrapper = _make_copy(endpoint, view_cls)
1296 if getattr(endpoint, "__module__", "").startswith("fastapi_restly."):
1297 # The shells carry override-redirect docstrings for help()/source
1298 # readers. FastAPI reads endpoint.__doc__ as the OpenAPI operation
1299 # description, so strip the copy: framework guidance must not leak
1300 # into the user's API docs. User-defined endpoints keep theirs.
1301 endpoint_wrapper.__doc__ = None
1302 # Reset the copy's name to the endpoint attribute so downstream renaming
1303 # produces "<view>_<name>" even when the source was a parent's renamed
1304 # copy.
1305 endpoint_wrapper.__name__ = name
1306 # Set explicit __qualname__ for debugging purposes.
1307 endpoint_wrapper.__qualname__ = f"{view_cls.__name__}_{name}_wrapper"
1308 setattr(view_cls, name, endpoint_wrapper)
1311def _make_copy(endpoint: Callable, view_cls: type[View]) -> Callable:
1312 """
1313 Wrap the endpoint in a new function as kind of copy.
1315 Fun fact: You cannot do this inside a for loop, because the closure of 'endpoint'
1316 inside the wrapper works on the variable, not on the value. And for-loops in Python
1317 do not have their own variable scope.
1319 https://eev.ee/blog/2011/04/24/gotcha-python-scoping-closures/
1320 """
1321 if inspect.iscoroutinefunction(endpoint):
1323 @functools.wraps(endpoint)
1324 async def _async_wrapper(self, *args, **kwargs):
1325 return await endpoint(self, *args, **kwargs)
1327 endpoint_wrapper: Callable = _async_wrapper
1328 else:
1330 @functools.wraps(endpoint)
1331 def _sync_wrapper(self, *args, **kwargs):
1332 return endpoint(self, *args, **kwargs)
1334 endpoint_wrapper = _sync_wrapper
1336 endpoint_wrapper.__annotations__ = endpoint.__annotations__.copy()
1337 return endpoint_wrapper
1340def _init_all_endpoints(view_cls: type[View]):
1341 """
1342 Ensure every endpoint has a unique name and update the 'self' annotation.
1343 """
1344 for attr in view_cls.__dict__.values():
1345 if not hasattr(attr, "_api_route_args"):
1346 continue
1347 endpoint = attr
1348 # Give every endpoint a unique name
1349 # This will give the FooView.create() endpoint the name "fooview_create"
1350 endpoint.__name__ = view_cls.__name__.lower() + "_" + endpoint.__name__
1351 _annotate_self(view_cls, endpoint)
1354def _annotate(func: Callable, return_annotation: Any = None, **param_annotations):
1355 """
1356 Annotate a function by setting func.__signature__ explicitly.
1357 """
1358 sig = inspect.signature(func)
1359 new_params = []
1360 for param in sig.parameters.values():
1361 if param.name in param_annotations:
1362 annotation = param_annotations[param.name]
1363 new_param = param.replace(annotation=annotation)
1364 new_params.append(new_param)
1365 else:
1366 new_params.append(param)
1367 func.__signature__ = sig.replace( # type: ignore[attr-defined]
1368 parameters=new_params, return_annotation=return_annotation
1369 )
1372def _get_all_parent_endpoints(view_cls: type[View]) -> dict[str, Callable]:
1373 """Map parent route endpoints by attribute name.
1375 Registered intermediate parents copy base endpoints into their ``__dict__``.
1376 Keying by most-derived attribute name prevents duplicate logical endpoints
1377 while still respecting overrides.
1378 """
1379 endpoints: dict[str, Callable] = {}
1380 seen: set[str] = set()
1381 for cls in view_cls.mro():
1382 if cls is view_cls:
1383 continue
1384 for name, value in cls.__dict__.items():
1385 if name in seen:
1386 # A more-derived class already defined this name; it shadows the
1387 # base regardless of whether that override is itself a route.
1388 continue
1389 seen.add(name)
1390 if hasattr(value, "_api_route_args"):
1391 endpoints[name] = value
1392 return endpoints
1395def _init_api_router(view_cls: type[View]) -> fastapi.APIRouter:
1396 # Concatenate prefixes defined at each level of the class hierarchy (base → derived).
1397 prefix = "".join(
1398 c.__dict__["prefix"] for c in reversed(view_cls.mro()) if "prefix" in c.__dict__
1399 )
1400 tags = _get_router_tags(view_cls, prefix)
1401 api_router = fastapi.APIRouter(
1402 prefix=prefix,
1403 tags=tags,
1404 responses=view_cls.responses,
1405 dependencies=view_cls.dependencies,
1406 )
1408 # Find all endpoint functions in this class and add them to the router
1409 for attr in view_cls.__dict__.values():
1410 if not hasattr(attr, "_api_route_args"):
1411 continue
1412 endpoint = attr
1413 path, route_kwargs = endpoint._api_route_args
1414 _add_api_route(api_router, view_cls, path, endpoint, route_kwargs)
1416 return api_router
1419def _get_router_tags(view_cls: type[View], prefix: str) -> list[str | Enum]:
1420 if view_cls.tags is not None:
1421 return list(view_cls.tags)
1422 return [_derive_tag_from_prefix(prefix) or view_cls.__name__]
1425def _derive_tag_from_prefix(prefix: str) -> str | None:
1426 segments = [segment for segment in prefix.strip("/").split("/") if segment]
1427 if not segments:
1428 return None
1429 return segments[-1].replace("-", " ").replace("_", " ").title()
1432def _add_api_route(
1433 api_router: fastapi.APIRouter,
1434 view_cls: type[View],
1435 path: str,
1436 endpoint: Callable,
1437 route_kwargs: dict[str, Any],
1438) -> None:
1439 if _should_add_collection_route_alias(view_cls, path, endpoint):
1440 api_router.add_api_route("", endpoint, **route_kwargs)
1441 hidden_alias_kwargs = {**route_kwargs, "include_in_schema": False}
1442 api_router.add_api_route("/", endpoint, **hidden_alias_kwargs)
1443 return
1445 api_router.add_api_route(path, endpoint, **route_kwargs)
1448def _should_add_collection_route_alias(
1449 view_cls: type[View], path: str, endpoint: Callable
1450) -> bool:
1451 if not issubclass(view_cls, BaseRestView):
1452 return False
1453 if path != "/":
1454 return False
1455 return endpoint.__name__.endswith(("get_many_endpoint", "create_endpoint"))
1458def _annotate_self(view_cls: type[View], endpoint: Callable) -> None:
1459 """
1460 Annotate the 'self' argument as 'self=Depends(view_cls)'. That way FastAPI instantiates the
1461 view_cls before calling the endpoint function and passes it as 'self'.
1462 Note that it sets endpoint.__signature__ which overrides any other inspection.
1464 Note: Copied (MIT license) and adjusted from: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
1466 Fixes the endpoint signature to ensure FastAPI performs dependency injection properly.
1467 """
1468 sig = inspect.signature(endpoint)
1469 params: list[inspect.Parameter] = list(sig.parameters.values())
1470 self_param = params[0]
1471 new_self_param = self_param.replace(default=fastapi.Depends(view_cls))
1473 new_params = [new_self_param] + [
1474 param.replace(kind=inspect.Parameter.KEYWORD_ONLY) for param in params[1:]
1475 ]
1476 endpoint.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]
1479# Bare-typed annotations FastAPI special-cases for parameter injection
1480# (no ``Depends(...)`` marker required). Treated alongside ``Depends``-
1481# marked annotations as DI-wired class attributes; everything else is
1482# left as plain typing.
1483_FASTAPI_SPECIAL_INJECTABLE: tuple[type, ...] = (
1484 Request,
1485 Response,
1486 BackgroundTasks,
1487 WebSocket,
1488)
1491def _init_class_based_view(view_cls: type[View]) -> None:
1492 """
1493 Note: Copied (MIT license) and adjusted from: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
1495 Idempotently modifies the provided `cls`, performing the following modifications:
1496 * The `__init__` function is updated to set any class-annotated dependencies as instance attributes
1497 * The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer
1498 """
1499 if getattr(view_cls, "__class_based_view", False):
1500 return # Already initialized
1501 old_init: Callable[..., Any] = view_cls.__init__
1502 old_signature = inspect.signature(old_init)
1503 old_parameters = list(old_signature.parameters.values())[1:] # drop `self`
1504 new_parameters = [
1505 x
1506 for x in old_parameters
1507 if x.kind
1508 not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
1509 ]
1510 # Marker-based DI with MRO-aware shadowing: walk the MRO from the
1511 # base classes upward and pick, for each name, an annotation that
1512 # either carries a ``Depends(...)`` marker or names one of FastAPI's
1513 # bare-injectable special types (``Request`` / ``Response`` etc.).
1514 # A *plain* annotation on a more-derived class (e.g. a mixin
1515 # declaring ``session: AsyncSession`` for static-typing purposes)
1516 # does NOT shadow a marker-bearing annotation from a base — the
1517 # framework prefers wiring fidelity over the most-derived hint.
1518 # Without this rule, any plain annotation a mixin adds would
1519 # silently break dependency injection.
1520 di_annotations: dict[str, Any] = {}
1521 for cls in reversed(view_cls.__mro__):
1522 try:
1523 cls_hints = get_type_hints(cls, include_extras=True)
1524 except Exception:
1525 continue
1526 for name, annotation in cls_hints.items():
1527 if get_origin(annotation) is ClassVar:
1528 continue
1529 metadata = getattr(annotation, "__metadata__", ())
1530 has_depends_marker = any(isinstance(m, _DependsMarker) for m in metadata)
1531 underlying = (
1532 annotation
1533 if get_origin(annotation) is not Annotated
1534 else (get_args(annotation)[0] if get_args(annotation) else annotation)
1535 )
1536 is_special_type = inspect.isclass(underlying) and issubclass(
1537 underlying, _FASTAPI_SPECIAL_INJECTABLE
1538 )
1539 if has_depends_marker or is_special_type: 1539 ↛ 1526line 1539 didn't jump to line 1526 because the condition on line 1539 was always true
1540 # Marker-bearing annotation wins, regardless of MRO position.
1541 di_annotations[name] = annotation
1542 # Plain annotations are silently ignored — they neither set
1543 # nor clear an entry in di_annotations.
1545 dependency_names: list[str] = []
1546 for name, annotation in di_annotations.items():
1547 dependency_names.append(name)
1548 default_value = getattr(view_cls, name, inspect.Parameter.empty)
1549 new_parameters.append(
1550 inspect.Parameter(
1551 name=name,
1552 kind=inspect.Parameter.KEYWORD_ONLY,
1553 default=default_value,
1554 annotation=annotation,
1555 )
1556 )
1557 new_signature = old_signature.replace(parameters=new_parameters)
1559 def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
1560 for dep_name in dependency_names:
1561 dep_value = kwargs.pop(dep_name)
1562 setattr(self, dep_name, dep_value)
1563 old_init(self, *args, **kwargs)
1565 setattr(view_cls, "__signature__", new_signature)
1566 setattr(view_cls, "__init__", new_init)
1567 setattr(view_cls, "__class_based_view", True)