Coverage for fastapi_restly / _pytest_fixtures.py: 82%
147 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-24 11:13 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-24 11:13 +0000
1from __future__ import annotations
3import traceback
4from contextlib import asynccontextmanager
5from pathlib import Path
6from typing import TYPE_CHECKING, AsyncIterator, Iterator
7from unittest.mock import MagicMock, patch
9import alembic
10import alembic.command
11import alembic.config
12import pytest
13from fastapi import FastAPI
14from sqlalchemy.ext.asyncio import AsyncConnection
15from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession
16from sqlalchemy.orm import Session as SA_Session
18from .db import activate_savepoint_only_mode
19from .db._globals import _fr_globals, _get_restly_context
20from .db._session import _clear_uncommitted
22if TYPE_CHECKING:
23 from .testing._client import RestlyTestClient
25try:
26 import pytest_asyncio
27except ModuleNotFoundError as exc:
28 if exc.name != "pytest_asyncio":
29 raise
30 pytest_asyncio = None
32_TESTING_EXTRA_MESSAGE = (
33 "fastapi_restly.pytest_fixtures requires optional testing dependencies. "
34 'Install them with: pip install "fastapi-restly[testing]"'
35)
38@pytest.fixture(scope="session")
39def restly_project_root() -> Path:
40 """Return the project root directory."""
41 # Try to find the project root by looking for pyproject.toml
42 current = Path.cwd()
43 while current != current.parent:
44 if (current / "pyproject.toml").exists():
45 return current
46 current = current.parent
47 raise Exception("Could not find a pyproject.toml to establish project root")
50def _run_alembic_upgrade(project_root: Path) -> None:
51 # Only run alembic migrations if the alembic directory exists
52 alembic_dir = project_root / "alembic"
53 if not alembic_dir.exists():
54 return # Skip if no alembic directory
56 # restly_project_root owns discovery; this helper only builds Alembic config.
57 alembic_cfg = alembic.config.Config(project_root / "alembic.ini")
58 alembic_cfg.set_main_option("script_location", str(alembic_dir))
59 try:
60 alembic.command.upgrade(alembic_cfg, "head")
61 except Exception as exc:
62 tb = traceback.format_exc()
63 pytest.exit(
64 f"Alembic migrations failed: {exc}\n\nTraceback:\n{tb}", returncode=1
65 )
68def _activate_savepoint_only_mode_sessions() -> None:
69 # Only run if database connections are set up
70 if not _fr_globals.async_make_session and not _fr_globals.make_session:
71 return # Skip if no database connections
73 if _fr_globals.async_make_session: 73 ↛ 75line 73 didn't jump to line 75 because the condition on line 73 was always true
74 activate_savepoint_only_mode(_fr_globals.async_make_session)
75 if _fr_globals.make_session: 75 ↛ exitline 75 didn't return from function '_activate_savepoint_only_mode_sessions' because the condition on line 75 was always true
76 activate_savepoint_only_mode(_fr_globals.make_session)
79@pytest.fixture
80def _shared_connection():
81 # Sync tests need a sync sessionmaker, but async-only projects should still
82 # be able to use the restly_async_session fixture without one.
83 if not _fr_globals.make_session:
84 yield None
85 return
87 engine = _fr_globals.make_session.kw["bind"]
88 with engine.connect() as conn:
89 yield conn
92if pytest_asyncio is None: 92 ↛ 94line 92 didn't jump to line 94 because the condition on line 92 was never true
94 @pytest.fixture
95 def restly_async_session(_shared_connection) -> None: # pyright: ignore[reportRedeclaration]
96 # The else-branch defines the real async fixture; this stub only
97 # runs when the optional ``pytest_asyncio`` extra isn't installed.
98 # Pyright cannot model mutually exclusive module-level branches.
99 raise ModuleNotFoundError(_TESTING_EXTRA_MESSAGE, name="pytest_asyncio")
101else:
103 @pytest_asyncio.fixture
104 async def restly_async_session(
105 _shared_connection,
106 ) -> AsyncIterator[SA_AsyncSession]:
107 """
108 Pytest fixture providing a database session with savepoint-based isolation.
110 Each test runs inside a savepoint. At the end of the test, the savepoint is
111 rolled back, leaving the database clean for the next test.
113 NOTE: Calling session.rollback() inside a test rolls back to the last savepoint
114 (created by each patched commit()), NOT to the start of the test. This differs
115 from production behavior. To undo all changes in a test, use session.rollback()
116 after each commit(), but be aware that data added before the last commit() is
117 still visible.
118 """
119 # Only run if database connections are set up
120 if not _fr_globals.async_make_session: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true
121 pytest.skip("Database connection not set up")
123 async_engine = _fr_globals.async_make_session.kw["bind"]
125 @asynccontextmanager
126 async def get_bound_async_connection():
127 if _shared_connection is None: 127 ↛ 132line 127 didn't jump to line 132 because the condition on line 127 was always true
128 async with async_engine.connect() as async_conn:
129 yield async_conn
130 return
132 async_conn = AsyncConnection(
133 async_engine, sync_connection=_shared_connection
134 )
135 async with async_conn:
136 yield async_conn
138 async with get_bound_async_connection() as async_conn:
139 async with _fr_globals.async_make_session(bind=async_conn) as sess:
141 class AsyncSessionContext:
142 def __init__(self, *, flush_on_success: bool) -> None:
143 self.flush_on_success = flush_on_success
145 async def __aenter__(self):
146 await sess.begin_nested()
147 return sess
149 async def __aexit__(self, exc_type, exc_value, tb):
150 if self.flush_on_success and exc_type is None:
151 await sess.flush()
152 return False # re-raise any exception
154 mock_sessionmaker = MagicMock()
155 mock_sessionmaker.side_effect = lambda *args, **kwargs: (
156 AsyncSessionContext(flush_on_success=False)
157 )
158 # session.begin() is used as a context manager (async with
159 # session.begin():). Return the same isolated session and flush
160 # pending changes after successful explicit transaction blocks.
161 mock_sessionmaker.begin.side_effect = lambda *args, **kwargs: (
162 AsyncSessionContext(flush_on_success=True)
163 )
165 async def passthrough_exit(self, exc_type, exc_value, traceback):
166 await sess.flush()
167 return False # re-raise any exception
169 async def patched_commit(self):
170 await sess.flush()
171 await sess.begin_nested()
172 # Treat the savepoint as this fixture's commit boundary.
173 # Clear the pending-change flag set by flush; a write that
174 # never calls commit() still leaves the flag set and warns.
175 _clear_uncommitted(getattr(sess, "sync_session", sess))
177 globals_obj = _get_restly_context()
178 original_async_make_session = globals_obj.async_make_session
179 globals_obj.async_make_session = mock_sessionmaker
180 try:
181 with (
182 patch.object(SA_AsyncSession, "__aexit__", passthrough_exit),
183 patch.object(SA_AsyncSession, "commit", patched_commit),
184 ):
185 yield sess
186 finally:
187 globals_obj.async_make_session = original_async_make_session
190@pytest.fixture
191def restly_session(_shared_connection) -> Iterator[SA_Session]:
192 """
193 Pytest fixture providing a database session with savepoint-based isolation.
195 Each test runs inside a savepoint. At the end of the test, the savepoint is
196 rolled back, leaving the database clean for the next test.
198 NOTE: Calling session.rollback() inside a test rolls back to the last savepoint
199 (created by each patched commit()), NOT to the start of the test. This differs
200 from production behavior. To undo all changes in a test, use session.rollback()
201 after each commit(), but be aware that data added before the last commit() is
202 still visible.
203 """
204 # Only run if database connections are set up
205 if not _fr_globals.make_session: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true
206 pytest.skip("Database connection not set up")
208 with _fr_globals.make_session(bind=_shared_connection) as sess:
210 def begin_nested():
211 sess.begin_nested()
212 return sess
214 mock_sessionmaker = MagicMock()
215 mock_sessionmaker.side_effect = begin_nested
216 # session.begin() is used as a context manager (with session.begin():)
217 # We need it to also return our savepoint session so explicit transaction
218 # blocks work correctly with our isolation mechanism
219 mock_sessionmaker.begin.return_value.__enter__.side_effect = begin_nested
221 def exit_nested(exc_type, exc_value, tb):
222 if exc_type is None: 222 ↛ 224line 222 didn't jump to line 224 because the condition on line 222 was always true
223 sess.flush()
224 return False # re-raise any exception
226 def passthrough_exit(self, exc_type, exc_value, traceback):
227 sess.flush()
228 return False # re-raise any exception
230 def patched_commit(self):
231 sess.flush()
232 sess.begin_nested()
233 # Mimic after_commit (see the async fixture for the full rationale):
234 # clear the uncommitted-changes flag so the request-end check does not
235 # false-warn under savepoint mode.
236 _clear_uncommitted(getattr(sess, "sync_session", sess))
238 globals_obj = _get_restly_context()
239 original_make_session = globals_obj.make_session
240 globals_obj.make_session = mock_sessionmaker
241 try:
242 with (
243 patch.object(SA_Session, "__exit__", passthrough_exit),
244 patch.object(SA_Session, "commit", patched_commit),
245 ):
246 mock_sessionmaker.begin.return_value.__exit__.side_effect = exit_nested
247 yield sess
248 finally:
249 globals_obj.make_session = original_make_session
252@pytest.fixture
253def restly_app() -> FastAPI:
254 """Create a FastAPI app instance for testing."""
255 return FastAPI()
258@pytest.fixture
259def restly_client(restly_app) -> RestlyTestClient:
260 """Create a RestlyTestClient instance for testing."""
261 try:
262 from .testing._client import RestlyTestClient
263 except ModuleNotFoundError as exc:
264 if exc.name == "httpx":
265 raise ModuleNotFoundError(_TESTING_EXTRA_MESSAGE, name="httpx") from exc
266 raise
268 return RestlyTestClient(restly_app)