Coverage for fastapi_restly / db / _session.py: 95%

164 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-24 11:13 +0000

1import warnings 

2from collections.abc import AsyncIterator, Callable, Iterator 

3from inspect import signature 

4from typing import Annotated, Any, cast 

5 

6from fastapi import Depends, FastAPI 

7from sqlalchemy import Engine, MetaData, create_engine, event 

8from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine 

9from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession 

10from sqlalchemy.orm import DeclarativeBase, sessionmaker 

11from sqlalchemy.orm import Session as SA_Session 

12 

13from .._exception_handlers import register_default_exception_handlers 

14from ..exc import RestlyConfigurationError, RestlyUncommittedChangesWarning 

15from ._globals import _fr_globals 

16 

17try: 

18 import orjson 

19except ImportError: 

20 json_deserializer = None 

21 json_serializer = None 

22else: 

23 

24 def orjson_serializer(obj): 

25 return orjson.dumps( 

26 obj, option=orjson.OPT_NAIVE_UTC | orjson.OPT_NON_STR_KEYS 

27 ).decode() 

28 

29 json_deserializer = orjson.loads 

30 json_serializer = orjson_serializer 

31 

32 

33def _setup_async_database_connection( 

34 async_database_url: str | None = None, 

35 *, 

36 async_engine: AsyncEngine | None = None, 

37 async_make_session: async_sessionmaker[Any] | None = None, 

38) -> async_sessionmaker[Any]: 

39 if not async_make_session: 

40 if not async_engine: 

41 async_engine = create_async_engine( 

42 async_database_url, # type: ignore[arg-type] 

43 json_serializer=json_serializer, 

44 json_deserializer=json_deserializer, 

45 ) 

46 async_make_session = async_sessionmaker( 

47 bind=async_engine, autoflush=False, expire_on_commit=False 

48 ) 

49 

50 factory_kw = getattr(async_make_session, "kw", None) 

51 if factory_kw is not None and factory_kw.get("expire_on_commit", True): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 warnings.warn( 

53 "The async session factory passed to fr.configure() has " 

54 "expire_on_commit=True. Restly's write handlers commit before " 

55 "building the response, so committed ORM attributes will expire " 

56 "and the async serializer will trigger a lazy reload outside the " 

57 "async context (MissingGreenlet). Pass expire_on_commit=False to " 

58 "your async_sessionmaker.", 

59 stacklevel=3, 

60 ) 

61 

62 _fr_globals.async_database_url = async_database_url 

63 _fr_globals.async_make_session = async_make_session 

64 return async_make_session 

65 

66 

67def _setup_database_connection( 

68 database_url: str | None = None, 

69 *, 

70 engine: Engine | None = None, 

71 make_session: sessionmaker[Any] | None = None, 

72) -> sessionmaker[Any]: 

73 if make_session is None: 

74 if engine is None: 

75 engine = create_engine( 

76 database_url, # type: ignore[arg-type] 

77 json_serializer=json_serializer, 

78 json_deserializer=json_deserializer, 

79 ) 

80 make_session = sessionmaker(bind=engine, expire_on_commit=False) 

81 

82 _fr_globals.database_url = database_url 

83 _fr_globals.make_session = make_session 

84 return make_session 

85 

86 

87def configure( 

88 app: FastAPI | None = None, 

89 *, 

90 async_database_url: str | None = None, 

91 async_engine: AsyncEngine | None = None, 

92 async_make_session: async_sessionmaker[Any] | None = None, 

93 database_url: str | None = None, 

94 engine: Engine | None = None, 

95 make_session: sessionmaker[Any] | None = None, 

96 session_generator: Callable[[], AsyncIterator[SA_AsyncSession]] | None = None, 

97 sync_session_generator: Callable[[], Iterator[SA_Session]] | None = None, 

98 warn_on_misuse: bool | None = None, 

99 warn_on_uncommitted: bool | None = None, 

100 install_default_exception_handlers: bool = True, 

101) -> None: 

102 """Configure FastAPI-Restly. Call once at startup. 

103 

104 Pass async parameters (``async_database_url``, ``async_engine``, or 

105 ``async_make_session``) to enable async support, sync parameters 

106 (``database_url``, ``engine``, or ``make_session``) for sync support, 

107 or both if your application uses both. 

108 

109 Use ``session_generator`` / ``sync_session_generator`` (or ``engine`` / 

110 ``make_session``) to construct sessions your way -- a custom engine, 

111 isolation level, ``search_path``, logging, an existing ``sessionmaker``. A 

112 custom generator's job is to **construct, yield, and clean up** (close / 

113 roll back on the way out); it must **not** commit. Customizing how a session 

114 is built never takes the commit away from Restly. 

115 

116 Restly owns the commit. Every write -- the CRUD handlers (``handle_create`` 

117 / ``handle_update`` / ``handle_delete``) and ``write_action`` -- runs 

118 ``before_commit`` -> commit -> ``after_commit`` around your domain logic; 

119 the commit is the framework's single responsibility. A custom (non-CRUD) 

120 write route either brackets its mutation with ``write_action(...)`` 

121 (recommended) or commits the session itself with ``await 

122 self.session.commit()``. 

123 

124 By default Restly warns (:class:`RestlyUncommittedChangesWarning`) when a 

125 request finishes with uncommitted changes still in the session -- the tell 

126 of a custom write route that forgot to commit. This applies to every session 

127 source, built-in or custom. A route that intentionally leaves a flush 

128 uncommitted (a validate-then-rollback dry run) should suppress the warning 

129 for just that request with ``session.info["_fr_suppress_uncommitted"] = 

130 True``. ``warn_on_uncommitted=False`` turns the check off globally; that is 

131 rarely the right response to the warning -- prefer fixing the missing 

132 commit or the per-route suppression. 

133 

134 Pass ``warn_on_misuse=True`` to enable opt-in registration-time misuse 

135 warnings (:class:`RestlyMisuseWarning`): when a view class is registered 

136 via ``include_view``, the framework flags route-shell overrides, direct 

137 ``session.commit()`` calls in view methods, and CRUD route sets hand-rolled 

138 on a bare ``View``. Off by default; intended for development, templates, 

139 and CI. Enable it before registering views. 

140 

141 Pass your :class:`FastAPI` ``app`` to install fastapi-restly's default 

142 exception handlers (currently: a translator that turns SQLAlchemy 

143 :class:`~sqlalchemy.exc.IntegrityError` into HTTP 409 Conflict). Set 

144 ``install_default_exception_handlers=False`` to opt out. If you do not 

145 pass ``app`` here, the handlers are registered the first time a view is 

146 mounted via :func:`fastapi_restly.include_view` instead. 

147 """ 

148 if not any( 

149 ( 

150 async_database_url is not None, 

151 async_engine is not None, 

152 async_make_session is not None, 

153 database_url is not None, 

154 engine is not None, 

155 make_session is not None, 

156 session_generator is not None, 

157 sync_session_generator is not None, 

158 warn_on_misuse is not None, 

159 warn_on_uncommitted is not None, 

160 app is not None and install_default_exception_handlers, 

161 ) 

162 ): 

163 raise TypeError("fr.configure() requires at least one setup argument.") 

164 

165 if warn_on_misuse is not None: 

166 _fr_globals.warn_on_misuse = warn_on_misuse 

167 if warn_on_uncommitted is not None: 

168 _fr_globals.warn_on_uncommitted = warn_on_uncommitted 

169 if ( 

170 async_database_url is not None 

171 or async_engine is not None 

172 or async_make_session is not None 

173 ): 

174 _setup_async_database_connection( 

175 async_database_url=async_database_url, 

176 async_engine=async_engine, 

177 async_make_session=async_make_session, 

178 ) 

179 if database_url is not None or engine is not None or make_session is not None: 

180 _setup_database_connection( 

181 database_url=database_url, engine=engine, make_session=make_session 

182 ) 

183 if session_generator is not None: 

184 _fr_globals.session_generator = session_generator 

185 if sync_session_generator is not None: 

186 _fr_globals.sync_session_generator = sync_session_generator 

187 if app is not None and install_default_exception_handlers: 

188 register_default_exception_handlers(app) 

189 

190 

191def activate_savepoint_only_mode( 

192 make_session: async_sessionmaker[Any] | sessionmaker[Any], 

193) -> None: 

194 """ 

195 Intended for use in tests. Puts the session factory into savepoint-only mode so 

196 that no test data is ever committed to the database. Each test can roll back 

197 instantly by closing the session, leaving the database clean for the next test. 

198 

199 This is done with "create_savepoint" mode and a wrapper on engine.connect() that 

200 begins the outer transaction before the Session can use it. 

201 https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#session-external-transaction 

202 """ 

203 engine = _get_sync_engine(make_session) 

204 

205 # Check if already activated (look for the marker attribute we set) 

206 if hasattr(engine.connect, "_original_connect"): 

207 return # Already activated, skip 

208 

209 original_connect = engine.connect 

210 

211 def _begin_on_connect(): 

212 connection = original_connect() 

213 connection.begin() 

214 return connection 

215 

216 # Using setattr to silence pyright 

217 setattr(_begin_on_connect, "_original_connect", original_connect) 

218 

219 engine.connect = _begin_on_connect 

220 make_session.configure(join_transaction_mode="create_savepoint") 

221 

222 

223def deactivate_savepoint_only_mode( 

224 make_session: async_sessionmaker[Any] | sessionmaker[Any], 

225) -> None: 

226 """ 

227 Reverts the effect of `activate_savepoint_only_mode`. 

228 Restores the original engine.connect and disables savepoint-only mode. 

229 """ 

230 engine = _get_sync_engine(make_session) 

231 _begin_on_connect = cast(Any, engine.connect) 

232 if hasattr(_begin_on_connect, "_original_connect"): 232 ↛ 237line 232 didn't jump to line 237 because the condition on line 232 was always true

233 # Restore the original connect that was saved by activate_savepoint_only_mode 

234 engine.connect = _begin_on_connect._original_connect 

235 # If engine was never activated, there is nothing to restore; this is safe to call 

236 

237 make_session.configure(join_transaction_mode=None) 

238 

239 

240def get_async_engine() -> AsyncEngine: 

241 """Return the async engine registered via configure().""" 

242 if _fr_globals.async_make_session is None: 

243 raise RestlyConfigurationError( 

244 "Call fr.configure() before using get_async_engine()." 

245 ) 

246 return _fr_globals.async_make_session.kw["bind"] 

247 

248 

249def get_engine() -> Engine: 

250 """Return the sync engine registered via configure().""" 

251 if _fr_globals.make_session is None: 

252 raise RestlyConfigurationError("Call fr.configure() before using get_engine().") 

253 return _fr_globals.make_session.kw["bind"] 

254 

255 

256def _resolve_metadata(base_or_metadata: type[DeclarativeBase] | MetaData) -> MetaData: 

257 if isinstance(base_or_metadata, MetaData): 

258 return base_or_metadata 

259 metadata = getattr(base_or_metadata, "metadata", None) 

260 if isinstance(metadata, MetaData): 

261 return metadata 

262 raise TypeError( 

263 "create_all() expects a DeclarativeBase subclass or a MetaData; got " 

264 f"{base_or_metadata!r}" 

265 ) 

266 

267 

268def create_all(base_or_metadata: type[DeclarativeBase] | MetaData) -> None: 

269 """Create all tables for ``base_or_metadata`` on the configured sync engine. 

270 

271 A dev/demo convenience over ``metadata.create_all(engine)`` so a quickstart 

272 can create its schema without reaching for the raw engine:: 

273 

274 fr.db.create_all(Base) # or fr.db.create_all(Base.metadata) 

275 

276 Accepts a ``DeclarativeBase`` subclass (its ``.metadata`` is used) or a 

277 ``MetaData``. Requires :func:`configure` first. Use Alembic migrations in 

278 production. 

279 """ 

280 _resolve_metadata(base_or_metadata).create_all(get_engine()) 

281 

282 

283async def async_create_all(base_or_metadata: type[DeclarativeBase] | MetaData) -> None: 

284 """Async equivalent of :func:`create_all`, on the configured async engine. 

285 

286 Usage:: 

287 

288 await fr.db.async_create_all(Base) 

289 """ 

290 metadata = _resolve_metadata(base_or_metadata) 

291 engine = get_async_engine() 

292 async with engine.begin() as conn: 

293 await conn.run_sync(metadata.create_all) 

294 

295 

296def _get_sync_engine( 

297 make_session: async_sessionmaker[Any] | sessionmaker[Any], 

298) -> Engine: 

299 engine = make_session.kw["bind"] 

300 if isinstance(engine, AsyncEngine): 

301 return engine.sync_engine 

302 return engine 

303 

304 

305def _should_warn_uncommitted() -> bool: 

306 """The uncommitted-changes check applies whenever ``warn_on_uncommitted`` is 

307 on. Restly owns the commit, so changes still pending when a request ends are 

308 the tell of a custom write route that never committed. 

309 """ 

310 return _fr_globals.warn_on_uncommitted 

311 

312 

313def _mark_uncommitted(session: SA_Session, flush_context: Any = None) -> None: 

314 session.info["_fr_uncommitted"] = True 

315 

316 

317def _clear_uncommitted(session: SA_Session, *args: Any) -> None: 

318 session.info.pop("_fr_uncommitted", None) 

319 

320 

321def _arm_uncommitted_warning(session: SA_AsyncSession | SA_Session) -> None: 

322 """Register flush/commit/rollback listeners so an uncommitted flush at the 

323 end of a request can be detected. Async sessions delegate to a sync 

324 ``Session``; that is where ORM events fire (and whose ``info`` is shared). 

325 """ 

326 if not _should_warn_uncommitted(): 

327 return 

328 target = getattr(session, "sync_session", session) 

329 try: 

330 event.listen(target, "after_flush", _mark_uncommitted) 

331 event.listen(target, "after_commit", _clear_uncommitted) 

332 event.listen(target, "after_rollback", _clear_uncommitted) 

333 except Exception: 

334 # Best-effort dev aid: unusual sessions (test stubs, or session types 

335 # without ORM flush events) opt out. Never break a request. 

336 pass 

337 

338 

339def _warn_if_uncommitted(session: SA_AsyncSession | SA_Session) -> None: 

340 """Warn if the request is ending with changes that were flushed but never 

341 committed (the ``_fr_uncommitted`` flag), or added but never flushed 

342 (``new``/``dirty``/``deleted``) -- all about to be rolled back. Called only 

343 on the success path; an endpoint that raised never reaches this point. 

344 """ 

345 if not _should_warn_uncommitted(): 

346 return 

347 target = getattr(session, "sync_session", session) 

348 try: 

349 if target.info.get("_fr_suppress_uncommitted"): 

350 return 

351 uncommitted = bool( 

352 target.info.get("_fr_uncommitted") 

353 or target.new 

354 or target.dirty 

355 or target.deleted 

356 ) 

357 except Exception: 

358 return # unusual session -> opt out silently 

359 if uncommitted: 

360 warnings.warn( 

361 "Request finished with uncommitted changes in the database session; " 

362 "they will be rolled back when the session closes. A custom write " 

363 "route must commit its changes -- bracket the mutation with " 

364 "write_action(...) (the framework then commits), or reuse " 

365 "handle_<verb>(). Only if the rollback is intentional (e.g. a " 

366 "validate-then-rollback dry run), suppress the warning for that " 

367 'route with session.info["_fr_suppress_uncommitted"] = True.', 

368 RestlyUncommittedChangesWarning, 

369 stacklevel=2, 

370 ) 

371 

372 

373async def _async_generate_session() -> AsyncIterator[SA_AsyncSession]: 

374 """FastAPI dependency for async database session.""" 

375 if _fr_globals.session_generator is not None: 

376 async for session in _fr_globals.session_generator(): 

377 _arm_uncommitted_warning(session) 

378 yield session 

379 _warn_if_uncommitted(session) 

380 return 

381 if _fr_globals.async_make_session is None: 

382 raise RestlyConfigurationError( 

383 "Call fr.configure() before using AsyncSessionDep." 

384 ) 

385 

386 # FastAPI does not support contextmanagers as dependency directly, 

387 # but it does support generators. Restly owns the commit (the handle 

388 # design runs it inside ``handle_<verb>`` / ``write_action``), so this 

389 # dependency only manages the session lifecycle: the context manager rolls 

390 # back and closes on the way out, and any change a custom route flushed but 

391 # never committed is discarded (and warned about). 

392 async with _fr_globals.async_make_session() as session: 

393 _arm_uncommitted_warning(session) 

394 yield session 

395 _warn_if_uncommitted(session) 

396 

397 

398def _session_dependency(dependency: Callable[..., Any]) -> Any: 

399 depends = cast(Callable[..., Any], Depends) 

400 if "scope" in signature(Depends).parameters: 400 ↛ 402line 400 didn't jump to line 402 because the condition on line 400 was always true

401 return depends(dependency, scope="function") 

402 return depends(dependency) 

403 

404 

405AsyncSessionDep = Annotated[ 

406 SA_AsyncSession, _session_dependency(_async_generate_session) 

407] 

408 

409 

410def _generate_session() -> Iterator[SA_Session]: 

411 """FastAPI dependency for sync database session.""" 

412 if _fr_globals.sync_session_generator is not None: 

413 for session in _fr_globals.sync_session_generator(): 

414 _arm_uncommitted_warning(session) 

415 yield session 

416 _warn_if_uncommitted(session) 

417 return 

418 if _fr_globals.make_session is None: 

419 raise RestlyConfigurationError("Call fr.configure() before using SessionDep.") 

420 

421 with _fr_globals.make_session() as session: 

422 _arm_uncommitted_warning(session) 

423 yield session 

424 _warn_if_uncommitted(session) 

425 

426 

427SessionDep = Annotated[SA_Session, _session_dependency(_generate_session)]