Coverage for fastapi_restly / query / _impl.py: 94%
260 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-24 11:13 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-24 11:13 +0000
1import datetime as _dt
2import decimal as _decimal
3import functools
4import uuid as _uuid
5from collections import defaultdict
6from typing import Annotated, Any, Callable, Iterator, Optional, cast
8import pydantic
9import sqlalchemy
10from pydantic import Field
11from pydantic.fields import FieldInfo
12from sqlalchemy import ColumnElement, Select
13from sqlalchemy.orm import DeclarativeBase
14from sqlalchemy.orm.attributes import InstrumentedAttribute
15from sqlalchemy.orm.properties import ColumnProperty
16from starlette.datastructures import QueryParams
18from ..exc import BadQueryParam
19from ..schemas._base import IDRef, IDSchema
20from ._shared import _escape_like_value, _unwrap_optional_annotation
22SchemaType = type[pydantic.BaseModel]
24#: Default ``page_size`` applied to list endpoints when the client does not
25#: send one. ``None`` disables the implicit cap (lists return every matching
26#: row and ``page`` is ignored). Override per-view via
27#: :attr:`BaseRestView.default_page_size`.
28DEFAULT_PAGE_SIZE: int | None = None
30#: Maximum ``page_size`` accepted by list endpoints. Values above this are
31#: rejected with a 422 by the FastAPI Pydantic-Query validation layer.
32#: Override per-view via :attr:`BaseRestView.max_page_size`.
33MAX_PAGE_SIZE = 1000
35#: Reserved query-parameter names produced by the schema. Filter columns
36#: literally named one of these would shadow pagination/sort, which would
37#: silently break the endpoint contract. Treated as a hard error.
38_RESERVED_NAMES = frozenset({"page", "page_size", "sort"})
40# Types that support SQL ``<``/``<=``/``>``/``>=`` comparisons. Booleans are
41# excluded: ordering booleans is rarely meaningful, and ``WHERE active >= true``
42# raises ``sqlalchemy.exc.ArgumentError`` at query time, which would otherwise
43# surface to the client as a 500.
44_ORDERABLE_TYPES: tuple[type, ...] = (
45 int,
46 float,
47 _decimal.Decimal,
48 _dt.date,
49 _dt.datetime,
50 _dt.time,
51 _dt.timedelta,
52 str,
53)
56def _is_string_field(field: FieldInfo) -> bool:
57 annotation = _unwrap_optional_annotation(field.annotation)
58 return annotation is str
61def _is_idref_field(field: FieldInfo) -> bool:
62 """True for a scalar ``IDRef[T]`` FK field (e.g. ``post_id: IDRef[Post]``).
64 ``IDRef`` is a ``BaseModel`` subclass, so without this it would be recursed
65 into as a nested schema and yield only ``post_id.id`` -- which never resolves
66 (``post_id`` is a scalar column, not a relationship), leaving the FK with no
67 filter param at all. Treated as a leaf, it filters on its own public name.
68 Targets ``IDRef`` specifically, not ``IDSchema``: a nested *resource* schema
69 that embeds its ``id`` also subclasses ``IDSchema`` and must keep its dotted
70 traversal. ``list[IDRef[T]]`` (to-many) is unaffected -- its annotation is a
71 list, so it is never a nested schema here.
72 """
73 annotation = _unwrap_optional_annotation(field.annotation)
74 return isinstance(annotation, type) and issubclass(annotation, IDRef)
77def _supports_range_operators(field: FieldInfo) -> bool:
78 annotation = _unwrap_optional_annotation(field.annotation)
79 if annotation is bool:
80 return False
81 if not isinstance(annotation, type): 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 return True
83 if issubclass(annotation, bool): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 return False
85 if issubclass(annotation, _ORDERABLE_TYPES):
86 return True
87 if issubclass(annotation, _uuid.UUID):
88 return False
89 return False
92def create_list_params_schema(
93 schema_cls: SchemaType,
94 model: type[DeclarativeBase],
95 *,
96 default_page_size: int | None = DEFAULT_PAGE_SIZE,
97 max_page_size: int = MAX_PAGE_SIZE,
98) -> SchemaType:
99 """
100 Create a Pydantic model that describes and validates URL query parameters
101 for list endpoints.
103 The generated model accepts pagination (``page``, ``page_size``), sorting
104 (``sort``), and one filter parameter per response-schema field that maps to
105 a filterable column on ``model`` -- with optional ``__in``/``__ne``/``__gte``/
106 ``__lte``/``__gt``/``__lt``/``__isnull``/``__contains``/``__icontains``
107 suffixes. Fields that do not resolve to a column (relationship/collection
108 fields, or reference traversals the request path would reject) get no filter
109 params, so the generated schema -- and the OpenAPI it produces -- no longer
110 advertises filters for fields that are not filterable at all. (This validates
111 column existence, the same check the request path makes; it does not promise
112 every operator executes for exotic column types such as ``ARRAY``/``JSON``.)
114 ``page`` and ``page_size`` are validated by Pydantic with bounds
115 (``page >= 1``, ``1 <= page_size <= max_page_size``); out-of-range values
116 produce a standard 422 response from FastAPI.
118 Args:
119 schema_cls: The response schema whose fields drive the available
120 filter parameters.
121 model: The SQLAlchemy model the list endpoint queries. Used to verify
122 each field resolves to a filterable column; non-column fields are
123 omitted from the generated params.
124 default_page_size: Default value for the ``page_size`` parameter.
125 ``None`` (the default) means "no implicit page size" — omitting
126 ``page_size`` returns every matching row and ``page`` is ignored.
127 max_page_size: Upper bound (inclusive) for the ``page_size``
128 parameter. Defaults to :data:`MAX_PAGE_SIZE`.
129 """
130 fields: dict[str, Any] = {
131 "page": (
132 Annotated[
133 int,
134 Field(
135 ge=1,
136 description=(
137 "1-based page number. Only takes effect when "
138 "``page_size`` is also set."
139 ),
140 ),
141 ],
142 1,
143 ),
144 "page_size": (
145 Annotated[
146 Optional[int],
147 Field(
148 ge=1,
149 le=max_page_size,
150 description=(
151 f"Number of items per page (1–{max_page_size}). "
152 "Omit to return every matching row (no implicit cap)."
153 ),
154 ),
155 ],
156 default_page_size,
157 ),
158 "sort": (
159 Annotated[
160 Optional[str],
161 Field(
162 description=(
163 "Comma-separated list of fields to sort by. Prefix a "
164 "field with ``-`` for descending order. Example: "
165 "``-created_at,name``."
166 )
167 ),
168 ],
169 None,
170 ),
171 }
172 for name, field in _iter_fields_including_nested(schema_cls):
173 if name in _RESERVED_NAMES:
174 raise ValueError(
175 f"List-params schema for {schema_cls.__name__!r} cannot expose "
176 f"field {name!r}: it collides with a reserved pagination/sort "
177 "parameter. Add a Pydantic alias to expose it as a filter."
178 )
180 # Only emit params for fields that resolve to a filterable column on the
181 # model -- using the very predicate the request path applies. A
182 # relationship/collection field (e.g. ``books: list[BookRef]``) or a
183 # reference traversal that does not resolve would otherwise advertise
184 # filters in OpenAPI that always 400 at request time.
185 try:
186 _resolve_column(model, name, schema_cls)
187 except BadQueryParam:
188 continue
190 # Type filter parameters as ``Optional[list[str]]`` instead of the
191 # column's true type) so FastAPI/Starlette preserve repeated query
192 # parameters as a list and downstream ``_parse_value`` can perform
193 # field-type coercion. ``__isnull`` stays a scalar bool because
194 # repeating it makes no sense.
195 eq_desc = (
196 f"Filter by ``{name}``. Comma-separated values are OR-combined "
197 "(SQL ``IN``). Repeat the parameter to AND multiple predicates."
198 )
199 ne_desc = (
200 f"Exclude rows where ``{name}`` matches. Comma-separated values "
201 "are AND-combined (SQL ``NOT IN``)."
202 )
203 in_desc = (
204 f"Filter by ``{name}`` with explicit SQL ``IN`` semantics. "
205 "Provide comma-separated values."
206 )
207 fields[name] = (
208 Annotated[Optional[list[str]], Field(description=eq_desc)],
209 None,
210 )
211 fields[f"{name}__in"] = (
212 Annotated[Optional[list[str]], Field(description=in_desc)],
213 None,
214 )
215 fields[f"{name}__ne"] = (
216 Annotated[Optional[list[str]], Field(description=ne_desc)],
217 None,
218 )
219 fields[f"{name}__isnull"] = (
220 Annotated[
221 Optional[bool],
222 Field(
223 description=(
224 f"``true`` matches rows where ``{name}`` IS NULL; "
225 f"``false`` matches IS NOT NULL."
226 )
227 ),
228 ],
229 None,
230 )
232 if _supports_range_operators(field):
233 for suffix, sql in (
234 ("__gte", ">="),
235 ("__lte", "<="),
236 ("__gt", ">"),
237 ("__lt", "<"),
238 ):
239 fields[f"{name}{suffix}"] = (
240 Annotated[
241 Optional[list[str]],
242 Field(description=f"``{name} {sql} value``."),
243 ],
244 None,
245 )
247 if _is_string_field(field):
248 fields[f"{name}__contains"] = (
249 Annotated[
250 Optional[list[str]],
251 Field(
252 description=(
253 f"Case-sensitive substring search on "
254 f"``{name}``. Repeat the parameter to AND "
255 "multiple terms; whitespace inside one value is "
256 "also AND-split as a convenience."
257 )
258 ),
259 ],
260 None,
261 )
262 fields[f"{name}__icontains"] = (
263 Annotated[
264 Optional[list[str]],
265 Field(
266 description=(
267 f"Case-insensitive substring search on "
268 f"``{name}``. Repeat the parameter to AND "
269 "multiple terms; whitespace inside one value is "
270 "also AND-split as a convenience."
271 )
272 ),
273 ],
274 None,
275 )
277 schema_name = "ListParams" + schema_cls.__name__
278 return pydantic.create_model(schema_name, **fields) # type: ignore[call-overload]
281def apply_list_params(
282 params: pydantic.BaseModel | QueryParams,
283 select_query: Select[Any],
284 model: type[DeclarativeBase],
285 schema_cls: SchemaType,
286) -> Select[Any]:
287 """
288 Apply pagination, sorting, and filtering on a SQL query using validated
289 list-endpoint query parameters.
291 ``params`` is normally an instance of the schema returned by
292 :func:`create_list_params_schema`. The generated FastAPI endpoints
293 always pass a validated instance, so pagination/filter bounds have
294 already been checked.
296 A raw :class:`~starlette.datastructures.QueryParams` is also accepted
297 for callers that build the query parameters programmatically.
298 **Raw inputs bypass schema validation** — the caller is responsible
299 for verifying ``page``/``page_size`` ranges and any per-view bounds
300 (``max_page_size``); this function only performs the minimum coercion
301 needed to apply the SQL clauses.
303 Examples::
305 # Pagination
306 page=2&page_size=50
308 # Sorting
309 sort=name,-created_at
311 # Filtering
312 name=Bob&status=active&created_at__gte=2024-01-01
314 # Contains (string fields)
315 name__contains=John&email__icontains=example
316 """
317 query_params = _coerce_to_query_params(params)
318 select_query = _apply_filtering(query_params, select_query, model, schema_cls)
319 select_query = _apply_sorting(query_params, select_query, model, schema_cls)
320 select_query = _apply_pagination(query_params, select_query)
321 return select_query
324def _coerce_to_query_params(params: pydantic.BaseModel | QueryParams) -> QueryParams:
325 """Normalise a validated Pydantic model or raw QueryParams to QueryParams.
327 When a dumped field is a list (e.g. a repeated ``name__contains``), each
328 element is expanded to its own ``(key, value)`` tuple so that
329 ``QueryParams.multi_items()`` later returns the original repeated values.
330 """
331 if isinstance(params, QueryParams):
332 return params
333 if isinstance(params, pydantic.BaseModel):
334 dumped = params.model_dump(exclude_none=True, by_alias=True, mode="json")
335 items: list[tuple[str, str]] = []
336 for key, value in dumped.items():
337 if isinstance(value, list):
338 items.extend((key, str(item)) for item in value)
339 else:
340 items.append((key, str(value)))
341 return QueryParams(items)
342 return QueryParams(params)
345def _apply_pagination(
346 query_params: QueryParams, select_query: Select[Any]
347) -> Select[Any]:
348 page_size = _get_int(query_params, "page_size")
349 if page_size is None:
350 return select_query
351 page = _get_int(query_params, "page") or 1
352 offset = (page - 1) * page_size
353 return select_query.limit(page_size).offset(offset)
356def _get_int(query_params: QueryParams, param_name: str) -> Optional[int]:
357 value = query_params.get(param_name)
358 if not value:
359 return None
360 try:
361 return int(value)
362 except ValueError:
363 raise BadQueryParam(
364 f"Invalid value for URL query parameter {param_name}: "
365 f"{value} is not an integer"
366 )
369def _apply_sorting(
370 query_params: QueryParams,
371 select_query: Select[Any],
372 model: type[DeclarativeBase],
373 schema_cls: SchemaType,
374) -> Select[Any]:
375 id_column = getattr(model, "id", None)
376 sort_string = query_params.get("sort")
377 if not sort_string:
378 if id_column is not None: 378 ↛ 380line 378 didn't jump to line 380 because the condition on line 378 was always true
379 return select_query.order_by(id_column)
380 return select_query
382 sorted_on_pk = False
383 for column_name in sort_string.split(","):
384 order = sqlalchemy.asc
385 if column_name.startswith("-"):
386 order = sqlalchemy.desc
387 column_name = column_name[1:]
388 joins, column = _resolve_column(model, column_name, schema_cls)
389 for join in joins:
390 select_query = select_query.join(join)
391 select_query = select_query.order_by(order(column))
392 if column is id_column:
393 sorted_on_pk = True
394 # Append the primary key (the conventional ``id``) as a final tiebreaker so
395 # pagination stays deterministic when the user sorts on a non-unique column
396 # -- without it, equal-valued rows can be skipped or repeated across pages.
397 # Skipped when the user already sorts by the PK. Models without a single
398 # ``id`` primary key (composite/custom) get no tiebreaker, matching the
399 # no-sort path and the framework's wider single-``id`` assumption.
400 if id_column is not None and not sorted_on_pk:
401 select_query = select_query.order_by(id_column)
402 return select_query
405def _iter_fields_including_nested(
406 schema_cls: SchemaType, prefix: str = ""
407) -> Iterator[tuple[str, FieldInfo]]:
408 for name, field in schema_cls.model_fields.items():
409 public_name = field.alias or name
410 # Each segment of the public dotted path becomes part of the URL
411 # grammar. ``__`` is reserved for operator suffixes (``__gte``,
412 # ``__contains``, ...) and ``.`` is reserved for relation traversal,
413 # so a segment containing either character would create an
414 # ambiguous URL key. Reject at schema-generation time so the
415 # collision surfaces during view registration, not at request time.
416 if "__" in public_name: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true
417 raise ValueError(
418 f"List-params schema for {schema_cls.__name__!r} cannot "
419 f"expose field {public_name!r}: ``__`` is reserved for "
420 "operator suffixes. Choose a different Pydantic alias."
421 )
422 if "." in public_name: 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true
423 raise ValueError(
424 f"List-params schema for {schema_cls.__name__!r} cannot "
425 f"expose field {public_name!r}: ``.`` is reserved for "
426 "relation traversal. Choose a different Pydantic alias."
427 )
428 full_name = f"{prefix}.{public_name}" if prefix else public_name
429 nested = _get_nested_schema(field)
430 if nested and not _is_idref_field(field):
431 yield from _iter_fields_including_nested(nested, full_name)
432 else:
433 yield full_name, field
436def _resolve_field_name(schema_cls: SchemaType, public_name: str) -> str | None:
437 """Return the Python field name for a public URL field name.
439 The public name is the field's alias when one is declared, otherwise the
440 field name itself. Aliased fields are *only* reachable by their alias —
441 Python field names are never part of the public URL contract, even when
442 the schema has ``populate_by_name=True`` (which only affects how Pydantic
443 parses input bodies, not the generated list-params query schema).
444 """
445 for field_name, field in schema_cls.model_fields.items():
446 if field.alias == public_name:
447 return field_name
449 if public_name in schema_cls.model_fields:
450 field = schema_cls.model_fields[public_name]
451 if field.alias is None:
452 return public_name
453 return None
456def _resolve_column(
457 model: type[DeclarativeBase], column_path: str, schema_cls: SchemaType
458) -> tuple[list[InstrumentedAttribute[Any]], InstrumentedAttribute[Any]]:
459 """Resolve a (possibly dotted) public column path to its SQLAlchemy column,
460 plus the relationship attributes that need to be joined.
462 Strict: every path segment must resolve through the schema's public name
463 (alias when set, Python field name otherwise). Falling back to a raw
464 model attribute lookup would let URLs reach columns the schema didn't
465 expose — for example, a Python field name on an aliased schema field —
466 and silently bypass the public-name contract.
467 """
468 joins: list[InstrumentedAttribute[Any]] = []
469 current_model = model
470 current_schema: SchemaType | None = schema_cls
471 name = column_path
472 while "." in name:
473 relation, _, name = name.partition(".")
474 if current_schema is None: 474 ↛ 475line 474 didn't jump to line 475 because the condition on line 474 was never true
475 raise BadQueryParam(f"Invalid attribute in URL query: {column_path}")
476 field_name = _resolve_field_name(current_schema, relation)
477 if field_name is None: 477 ↛ 478line 477 didn't jump to line 478 because the condition on line 477 was never true
478 raise BadQueryParam(f"Invalid attribute in URL query: {column_path}")
479 rel = getattr(current_model, field_name, None)
480 if not isinstance(rel, InstrumentedAttribute) or not hasattr(
481 rel.property, "mapper"
482 ):
483 raise BadQueryParam(f"Invalid attribute in URL query: {column_path}")
484 joins.append(rel)
485 current_model = rel.property.mapper.class_
486 current_schema = _get_nested_schema(current_schema.model_fields[field_name])
488 if current_schema is None: 488 ↛ 489line 488 didn't jump to line 489 because the condition on line 488 was never true
489 raise BadQueryParam(f"Invalid attribute in URL query: {column_path}")
490 field_name = _resolve_field_name(current_schema, name)
491 if field_name is None:
492 raise BadQueryParam(f"Invalid attribute in URL query: {column_path}")
493 column = getattr(current_model, field_name, None)
494 if (
495 column is None
496 or not isinstance(column, InstrumentedAttribute)
497 or not isinstance(column.property, ColumnProperty)
498 ):
499 raise BadQueryParam(f"Invalid attribute in URL query: {column_path}")
500 return joins, cast(InstrumentedAttribute[Any], column)
503def _apply_filtering(
504 query_params: QueryParams,
505 select_query: Select[Any],
506 model: type[DeclarativeBase],
507 schema_cls: SchemaType,
508) -> Select[Any]:
509 """Apply ``key=value`` and ``key__op=value`` filters to ``select_query``.
511 Multiple filters on the same column are AND-combined. Comma-separated
512 values within one parameter are OR-combined for ``eq`` (the default),
513 mapped to SQL ``IN`` for ``in``, and AND-combined for ``ne`` (so
514 ``status__ne=a,b`` means NOT IN (a, b)). For ``contains``/``icontains``
515 values are split on whitespace and AND-combined.
516 """
517 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict(
518 list
519 )
520 joins: set[InstrumentedAttribute[Any]] = set()
522 for key, raw_value in query_params.multi_items():
523 if key in _RESERVED_NAMES:
524 continue
526 if "__" in key:
527 column_name, op = key.split("__", 1)
528 else:
529 column_name, op = key, "eq"
531 column_joins, column = _resolve_column(model, column_name, schema_cls)
532 joins.update(column_joins)
533 parser = functools.partial(_parse_value, schema_cls, column_name)
535 if op == "isnull":
536 try:
537 value = pydantic.TypeAdapter(bool).validate_python(raw_value)
538 except pydantic.ValidationError as exc:
539 raise BadQueryParam(
540 f"Invalid value for URL query parameter {key}"
541 ) from exc
542 filters[column].append(column.is_(None) if value else column.isnot(None))
543 continue
545 clause = _build_clause(column, raw_value, op, parser)
546 if clause is not None:
547 filters[column].append(clause)
549 for join in joins:
550 select_query = select_query.join(join)
552 for column, clauses in filters.items():
553 and_clause = clauses[0] if len(clauses) == 1 else sqlalchemy.and_(*clauses)
554 select_query = select_query.where(and_clause)
555 return select_query
558def _build_clause(
559 column: InstrumentedAttribute[Any],
560 raw_value: str,
561 op: str,
562 parser: Callable[[str], Any],
563) -> ColumnElement[Any] | None:
564 """Combine multiple values within one parameter according to ``op`` semantics."""
565 if op in {"contains", "icontains"}:
566 values = [v for v in raw_value.split() if v]
567 if not values:
568 return None
569 clauses = [_make_where_clause(column, v, op, parser) for v in values]
570 return clauses[0] if len(clauses) == 1 else sqlalchemy.and_(*clauses)
572 values = raw_value.split(",")
573 if not values: 573 ↛ 574line 573 didn't jump to line 574 because the condition on line 573 was never true
574 return None
575 if op == "in":
576 return column.in_([parser(v) for v in values])
577 clauses = [_make_where_clause(column, v, op, parser) for v in values]
578 if len(clauses) == 1:
579 return clauses[0]
580 # ``ne`` with multiple values means NOT IN (...) — AND-combine, not OR.
581 if op == "ne":
582 return sqlalchemy.and_(*clauses)
583 return sqlalchemy.or_(*clauses)
586def _parse_value(schema_cls: SchemaType, column_name: str, value: str) -> Any:
587 if "." in column_name:
588 relation, _, column_part = column_name.partition(".")
589 relation_field_name = _resolve_field_name(schema_cls, relation) or relation
590 field = schema_cls.model_fields.get(relation_field_name)
591 nested = _get_nested_schema(field)
592 if nested is None: 592 ↛ 593line 592 didn't jump to line 593 because the condition on line 592 was never true
593 raise BadQueryParam(f"Invalid attribute in URL query: {column_name}")
594 return _parse_value(nested, column_part, value)
596 field_name = _resolve_field_name(schema_cls, column_name)
597 if field_name is None:
598 raise BadQueryParam(f"Invalid attribute in URL query: {column_name}")
600 try:
601 obj = schema_cls.__pydantic_validator__.validate_assignment(
602 schema_cls.model_construct(), field_name, value
603 )
604 result = getattr(obj, field_name)
605 # An IDRef[T] FK field validates to an IDRef object; the SQL bind value
606 # is its scalar id, not the reference wrapper (which cannot bind).
607 if isinstance(result, IDSchema):
608 return result.id
609 return result
610 except Exception:
611 raise BadQueryParam(f"Invalid attribute in URL query: {column_name}")
614def _get_nested_schema(field: FieldInfo | None) -> SchemaType | None:
615 if field is None: 615 ↛ 616line 615 didn't jump to line 616 because the condition on line 615 was never true
616 return None
617 annotation = _unwrap_optional_annotation(field.annotation)
618 if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel):
619 return annotation
620 return None
623def _make_where_clause(
624 column: InstrumentedAttribute[Any],
625 filter_value: str,
626 op: str,
627 parser: Callable[[str], Any],
628) -> ColumnElement[Any]:
629 if op == "gte":
630 return column >= parser(filter_value)
631 if op == "lte":
632 return column <= parser(filter_value)
633 if op == "gt":
634 return column > parser(filter_value)
635 if op == "lt":
636 return column < parser(filter_value)
637 if op == "ne":
638 return column != parser(filter_value)
639 if op == "contains":
640 return column.like(f"%{_escape_like_value(filter_value)}%", escape="\\")
641 if op == "icontains":
642 return column.ilike(f"%{_escape_like_value(filter_value)}%", escape="\\")
643 if op == "eq": 643 ↛ 645line 643 didn't jump to line 645 because the condition on line 643 was always true
644 return column == parser(filter_value)
645 raise BadQueryParam(f"Unsupported filter operator: {op!r}")