Coverage for fastapi_restly / _pytest_fixtures.py: 82%

147 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-24 11:13 +0000

1from __future__ import annotations 

2 

3import traceback 

4from contextlib import asynccontextmanager 

5from pathlib import Path 

6from typing import TYPE_CHECKING, AsyncIterator, Iterator 

7from unittest.mock import MagicMock, patch 

8 

9import alembic 

10import alembic.command 

11import alembic.config 

12import pytest 

13from fastapi import FastAPI 

14from sqlalchemy.ext.asyncio import AsyncConnection 

15from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession 

16from sqlalchemy.orm import Session as SA_Session 

17 

18from .db import activate_savepoint_only_mode 

19from .db._globals import _fr_globals, _get_restly_context 

20from .db._session import _clear_uncommitted 

21 

22if TYPE_CHECKING: 

23 from .testing._client import RestlyTestClient 

24 

25try: 

26 import pytest_asyncio 

27except ModuleNotFoundError as exc: 

28 if exc.name != "pytest_asyncio": 

29 raise 

30 pytest_asyncio = None 

31 

32_TESTING_EXTRA_MESSAGE = ( 

33 "fastapi_restly.pytest_fixtures requires optional testing dependencies. " 

34 'Install them with: pip install "fastapi-restly[testing]"' 

35) 

36 

37 

38@pytest.fixture(scope="session") 

39def restly_project_root() -> Path: 

40 """Return the project root directory.""" 

41 # Try to find the project root by looking for pyproject.toml 

42 current = Path.cwd() 

43 while current != current.parent: 

44 if (current / "pyproject.toml").exists(): 

45 return current 

46 current = current.parent 

47 raise Exception("Could not find a pyproject.toml to establish project root") 

48 

49 

50def _run_alembic_upgrade(project_root: Path) -> None: 

51 # Only run alembic migrations if the alembic directory exists 

52 alembic_dir = project_root / "alembic" 

53 if not alembic_dir.exists(): 

54 return # Skip if no alembic directory 

55 

56 # restly_project_root owns discovery; this helper only builds Alembic config. 

57 alembic_cfg = alembic.config.Config(project_root / "alembic.ini") 

58 alembic_cfg.set_main_option("script_location", str(alembic_dir)) 

59 try: 

60 alembic.command.upgrade(alembic_cfg, "head") 

61 except Exception as exc: 

62 tb = traceback.format_exc() 

63 pytest.exit( 

64 f"Alembic migrations failed: {exc}\n\nTraceback:\n{tb}", returncode=1 

65 ) 

66 

67 

68def _activate_savepoint_only_mode_sessions() -> None: 

69 # Only run if database connections are set up 

70 if not _fr_globals.async_make_session and not _fr_globals.make_session: 

71 return # Skip if no database connections 

72 

73 if _fr_globals.async_make_session: 73 ↛ 75line 73 didn't jump to line 75 because the condition on line 73 was always true

74 activate_savepoint_only_mode(_fr_globals.async_make_session) 

75 if _fr_globals.make_session: 75 ↛ exitline 75 didn't return from function '_activate_savepoint_only_mode_sessions' because the condition on line 75 was always true

76 activate_savepoint_only_mode(_fr_globals.make_session) 

77 

78 

79@pytest.fixture 

80def _shared_connection(): 

81 # Sync tests need a sync sessionmaker, but async-only projects should still 

82 # be able to use the restly_async_session fixture without one. 

83 if not _fr_globals.make_session: 

84 yield None 

85 return 

86 

87 engine = _fr_globals.make_session.kw["bind"] 

88 with engine.connect() as conn: 

89 yield conn 

90 

91 

92if pytest_asyncio is None: 92 ↛ 94line 92 didn't jump to line 94 because the condition on line 92 was never true

93 

94 @pytest.fixture 

95 def restly_async_session(_shared_connection) -> None: # pyright: ignore[reportRedeclaration] 

96 # The else-branch defines the real async fixture; this stub only 

97 # runs when the optional ``pytest_asyncio`` extra isn't installed. 

98 # Pyright cannot model mutually exclusive module-level branches. 

99 raise ModuleNotFoundError(_TESTING_EXTRA_MESSAGE, name="pytest_asyncio") 

100 

101else: 

102 

103 @pytest_asyncio.fixture 

104 async def restly_async_session( 

105 _shared_connection, 

106 ) -> AsyncIterator[SA_AsyncSession]: 

107 """ 

108 Pytest fixture providing a database session with savepoint-based isolation. 

109 

110 Each test runs inside a savepoint. At the end of the test, the savepoint is 

111 rolled back, leaving the database clean for the next test. 

112 

113 NOTE: Calling session.rollback() inside a test rolls back to the last savepoint 

114 (created by each patched commit()), NOT to the start of the test. This differs 

115 from production behavior. To undo all changes in a test, use session.rollback() 

116 after each commit(), but be aware that data added before the last commit() is 

117 still visible. 

118 """ 

119 # Only run if database connections are set up 

120 if not _fr_globals.async_make_session: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true

121 pytest.skip("Database connection not set up") 

122 

123 async_engine = _fr_globals.async_make_session.kw["bind"] 

124 

125 @asynccontextmanager 

126 async def get_bound_async_connection(): 

127 if _shared_connection is None: 127 ↛ 132line 127 didn't jump to line 132 because the condition on line 127 was always true

128 async with async_engine.connect() as async_conn: 

129 yield async_conn 

130 return 

131 

132 async_conn = AsyncConnection( 

133 async_engine, sync_connection=_shared_connection 

134 ) 

135 async with async_conn: 

136 yield async_conn 

137 

138 async with get_bound_async_connection() as async_conn: 

139 async with _fr_globals.async_make_session(bind=async_conn) as sess: 

140 

141 class AsyncSessionContext: 

142 def __init__(self, *, flush_on_success: bool) -> None: 

143 self.flush_on_success = flush_on_success 

144 

145 async def __aenter__(self): 

146 await sess.begin_nested() 

147 return sess 

148 

149 async def __aexit__(self, exc_type, exc_value, tb): 

150 if self.flush_on_success and exc_type is None: 

151 await sess.flush() 

152 return False # re-raise any exception 

153 

154 mock_sessionmaker = MagicMock() 

155 mock_sessionmaker.side_effect = lambda *args, **kwargs: ( 

156 AsyncSessionContext(flush_on_success=False) 

157 ) 

158 # session.begin() is used as a context manager (async with 

159 # session.begin():). Return the same isolated session and flush 

160 # pending changes after successful explicit transaction blocks. 

161 mock_sessionmaker.begin.side_effect = lambda *args, **kwargs: ( 

162 AsyncSessionContext(flush_on_success=True) 

163 ) 

164 

165 async def passthrough_exit(self, exc_type, exc_value, traceback): 

166 await sess.flush() 

167 return False # re-raise any exception 

168 

169 async def patched_commit(self): 

170 await sess.flush() 

171 await sess.begin_nested() 

172 # Treat the savepoint as this fixture's commit boundary. 

173 # Clear the pending-change flag set by flush; a write that 

174 # never calls commit() still leaves the flag set and warns. 

175 _clear_uncommitted(getattr(sess, "sync_session", sess)) 

176 

177 globals_obj = _get_restly_context() 

178 original_async_make_session = globals_obj.async_make_session 

179 globals_obj.async_make_session = mock_sessionmaker 

180 try: 

181 with ( 

182 patch.object(SA_AsyncSession, "__aexit__", passthrough_exit), 

183 patch.object(SA_AsyncSession, "commit", patched_commit), 

184 ): 

185 yield sess 

186 finally: 

187 globals_obj.async_make_session = original_async_make_session 

188 

189 

190@pytest.fixture 

191def restly_session(_shared_connection) -> Iterator[SA_Session]: 

192 """ 

193 Pytest fixture providing a database session with savepoint-based isolation. 

194 

195 Each test runs inside a savepoint. At the end of the test, the savepoint is 

196 rolled back, leaving the database clean for the next test. 

197 

198 NOTE: Calling session.rollback() inside a test rolls back to the last savepoint 

199 (created by each patched commit()), NOT to the start of the test. This differs 

200 from production behavior. To undo all changes in a test, use session.rollback() 

201 after each commit(), but be aware that data added before the last commit() is 

202 still visible. 

203 """ 

204 # Only run if database connections are set up 

205 if not _fr_globals.make_session: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 pytest.skip("Database connection not set up") 

207 

208 with _fr_globals.make_session(bind=_shared_connection) as sess: 

209 

210 def begin_nested(): 

211 sess.begin_nested() 

212 return sess 

213 

214 mock_sessionmaker = MagicMock() 

215 mock_sessionmaker.side_effect = begin_nested 

216 # session.begin() is used as a context manager (with session.begin():) 

217 # We need it to also return our savepoint session so explicit transaction 

218 # blocks work correctly with our isolation mechanism 

219 mock_sessionmaker.begin.return_value.__enter__.side_effect = begin_nested 

220 

221 def exit_nested(exc_type, exc_value, tb): 

222 if exc_type is None: 222 ↛ 224line 222 didn't jump to line 224 because the condition on line 222 was always true

223 sess.flush() 

224 return False # re-raise any exception 

225 

226 def passthrough_exit(self, exc_type, exc_value, traceback): 

227 sess.flush() 

228 return False # re-raise any exception 

229 

230 def patched_commit(self): 

231 sess.flush() 

232 sess.begin_nested() 

233 # Mimic after_commit (see the async fixture for the full rationale): 

234 # clear the uncommitted-changes flag so the request-end check does not 

235 # false-warn under savepoint mode. 

236 _clear_uncommitted(getattr(sess, "sync_session", sess)) 

237 

238 globals_obj = _get_restly_context() 

239 original_make_session = globals_obj.make_session 

240 globals_obj.make_session = mock_sessionmaker 

241 try: 

242 with ( 

243 patch.object(SA_Session, "__exit__", passthrough_exit), 

244 patch.object(SA_Session, "commit", patched_commit), 

245 ): 

246 mock_sessionmaker.begin.return_value.__exit__.side_effect = exit_nested 

247 yield sess 

248 finally: 

249 globals_obj.make_session = original_make_session 

250 

251 

252@pytest.fixture 

253def restly_app() -> FastAPI: 

254 """Create a FastAPI app instance for testing.""" 

255 return FastAPI() 

256 

257 

258@pytest.fixture 

259def restly_client(restly_app) -> RestlyTestClient: 

260 """Create a RestlyTestClient instance for testing.""" 

261 try: 

262 from .testing._client import RestlyTestClient 

263 except ModuleNotFoundError as exc: 

264 if exc.name == "httpx": 

265 raise ModuleNotFoundError(_TESTING_EXTRA_MESSAGE, name="httpx") from exc 

266 raise 

267 

268 return RestlyTestClient(restly_app)