Coverage for fastapi_restly / schemas / _generator.py: 92%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-24 11:13 +0000

1""" 

2Schema generation utilities for auto-generating Pydantic schemas from SQLAlchemy models. 

3""" 

4 

5import enum 

6import inspect 

7import types 

8from datetime import date, datetime, time 

9from decimal import Decimal 

10from typing import Any, Union, get_args 

11from uuid import UUID 

12 

13import pydantic 

14from pydantic import Field 

15from sqlalchemy import inspect as sa_inspect 

16from sqlalchemy.orm import DeclarativeBase, Mapped, RelationshipProperty 

17 

18from ._base import BaseSchema, IDSchema, ReadOnly, TimestampsSchemaMixin 

19 

20 

21def get_sqlalchemy_field_type(field: Any) -> Any: 

22 """ 

23 Extract the Python type from a SQLAlchemy Mapped field. 

24 

25 Args: 

26 field: A SQLAlchemy Mapped field 

27 

28 Returns: 

29 The Python type annotation 

30 """ 

31 # Get the type annotation from the Mapped field 

32 if hasattr(field, "type"): 

33 return field.type 

34 elif hasattr(field, "__origin__"): 

35 return field.__origin__ 

36 else: 

37 # Fallback to Any if we can't determine the type 

38 return Any 

39 

40 

41def is_relationship_field(field: Any) -> bool: 

42 """ 

43 Check if a field is a SQLAlchemy relationship. 

44 

45 Args: 

46 field: A SQLAlchemy Mapped field 

47 

48 Returns: 

49 True if the field is a relationship, False otherwise 

50 """ 

51 if isinstance(field, RelationshipProperty): 

52 return True 

53 return isinstance(getattr(field, "property", None), RelationshipProperty) 

54 

55 

56def get_relationship_target_model(field: Any) -> type[DeclarativeBase] | None: 

57 """ 

58 Get the target model class for a relationship field. 

59 

60 Args: 

61 field: A SQLAlchemy relationship field 

62 

63 Returns: 

64 The target model class or None if not found 

65 """ 

66 if not is_relationship_field(field): 

67 return None 

68 

69 # Try to get the target from the relationship property 

70 relationship = field 

71 if not isinstance(relationship, RelationshipProperty): 

72 relationship = getattr(field, "property", None) 

73 

74 if ( 

75 relationship is not None 

76 and hasattr(relationship, "mapper") 

77 and hasattr(relationship.mapper, "class_") 

78 ): 

79 return relationship.mapper.class_ 

80 

81 # Try to get from the type annotation 

82 if hasattr(field, "type"): 82 ↛ 92line 82 didn't jump to line 92 because the condition on line 82 was always true

83 target_type = field.type 

84 if hasattr(target_type, "__origin__") and target_type.__origin__ is list: 

85 # Handle list[Model] case 

86 args = get_args(target_type) 

87 if args: 87 ↛ 92line 87 didn't jump to line 92 because the condition on line 87 was always true

88 return args[0] 

89 elif inspect.isclass(target_type) and issubclass(target_type, DeclarativeBase): 89 ↛ 92line 89 didn't jump to line 92 because the condition on line 89 was always true

90 return target_type 

91 

92 return None 

93 

94 

95def get_model_fields(model_cls: type[DeclarativeBase]) -> dict[str, Any]: 

96 """ 

97 Extract field information from a SQLAlchemy model. 

98 

99 Args: 

100 model_cls: A SQLAlchemy model class 

101 

102 Returns: 

103 Dictionary mapping field names to their types and metadata 

104 """ 

105 fields: dict[str, Any] = {} 

106 

107 mapper = sa_inspect(model_cls) 

108 

109 # Get all annotations from the model class and its base classes 

110 all_annotations = {} 

111 for cls in model_cls.mro(): 

112 if hasattr(cls, "__annotations__"): 

113 all_annotations.update(cls.__annotations__) 

114 

115 for name, field_type in all_annotations.items(): 

116 if name.startswith("_"): 

117 continue 

118 

119 # Check if it's a Mapped field 

120 if not hasattr(field_type, "__origin__") or field_type.__origin__ is not Mapped: 

121 continue 

122 

123 # Extract the actual type from Mapped[Type] 

124 args = get_args(field_type) 

125 if not args: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 continue 

127 

128 actual_type = args[0] 

129 relationship = mapper.relationships.get(name) 

130 

131 rel_mapper = ( 

132 getattr(relationship, "mapper", None) if relationship is not None else None 

133 ) 

134 field_info: dict[str, Any] = { 

135 "type": actual_type, 

136 "is_relationship": relationship is not None, 

137 "target_model": (rel_mapper.class_ if rel_mapper is not None else None), 

138 "is_optional": False, 

139 "default": None, 

140 } 

141 

142 # Check if the field is optional (Union with None or Optional) 

143 if isinstance(actual_type, types.UnionType): 

144 # Python 3.10+ `str | None` syntax 

145 union_args = get_args(actual_type) 

146 if type(None) in union_args: 146 ↛ 162line 146 didn't jump to line 162 because the condition on line 146 was always true

147 field_info["is_optional"] = True 

148 non_none_types = [arg for arg in union_args if arg is not type(None)] 

149 if non_none_types: 149 ↛ 162line 149 didn't jump to line 162 because the condition on line 149 was always true

150 field_info["type"] = non_none_types[0] 

151 elif hasattr(actual_type, "__origin__"): 

152 origin = actual_type.__origin__ 

153 if origin is Union: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true

154 args = get_args(actual_type) 

155 if type(None) in args: 

156 field_info["is_optional"] = True 

157 # Remove None from the type 

158 non_none_types = [arg for arg in args if arg is not type(None)] 

159 if non_none_types: 

160 field_info["type"] = non_none_types[0] 

161 

162 if relationship is not None: 

163 # Relationship fields are response-oriented in generated schemas. 

164 # Keep them optional so create/update inputs can rely on FK columns. 

165 field_info["is_optional"] = True 

166 elif name in mapper.columns: 166 ↛ 172line 166 didn't jump to line 172 because the condition on line 166 was always true

167 column = mapper.columns[name] 

168 if column.default is not None or column.server_default is not None: 

169 field_info["default"] = column.default or column.server_default 

170 field_info["is_optional"] = True 

171 

172 fields[name] = field_info 

173 

174 return fields 

175 

176 

177def create_schema_from_model( 

178 model_cls: type[DeclarativeBase], 

179 *, 

180 schema_name: str | None = None, 

181 include_relationships: bool = True, 

182 include_readonly_fields: bool = True, 

183) -> type[BaseSchema]: 

184 """ 

185 Auto-generate a Pydantic schema from a SQLAlchemy model. 

186 

187 Args: 

188 model_cls: The SQLAlchemy model class 

189 schema_name: Optional name for the generated schema class 

190 include_relationships: Whether to include relationship fields 

191 include_readonly_fields: Whether to include read-only fields like id, created_at, etc. 

192 

193 Returns: 

194 A Pydantic schema class 

195 """ 

196 if schema_name is None: 

197 schema_name = f"{model_cls.__name__}Read" 

198 

199 # Get field information from the model 

200 model_fields = get_model_fields(model_cls) 

201 

202 # Determine base classes - start with the most specific ones 

203 bases: list[type] = [] 

204 

205 # Check if model has timestamp fields (inherits from TimestampsMixin) 

206 has_timestamps = "created_at" in model_fields and "updated_at" in model_fields 

207 if has_timestamps: 

208 bases.append(TimestampsSchemaMixin) 

209 

210 # Check if model has an id field (inherits from IDBase) 

211 has_id = "id" in model_fields 

212 if has_id: 

213 bases.append(IDSchema) 

214 

215 # Always include BaseSchema as the base 

216 bases.append(BaseSchema) 

217 

218 # Create field definitions for the schema 

219 field_definitions: dict[str, Any] = {} 

220 read_only_fields: list[str] = [] 

221 

222 for field_name, field_info in model_fields.items(): 

223 # Skip relationships if not requested 

224 if field_info["is_relationship"] and not include_relationships: 

225 continue 

226 

227 # Determine if field should be read-only 

228 is_readonly = ( 

229 field_name in ["id", "created_at", "updated_at"] and include_readonly_fields 

230 ) 

231 

232 if is_readonly: 

233 read_only_fields.append(field_name) 

234 

235 # Convert SQLAlchemy type to Pydantic type 

236 pydantic_type = convert_sqlalchemy_type_to_pydantic( 

237 field_info["type"], field_info["is_optional"] 

238 ) 

239 

240 # Handle relationships 

241 if field_info["is_relationship"] and field_info["target_model"]: 

242 target_model = field_info["target_model"] 

243 

244 # Skip self-referential relationship to avoid infinite recursion 

245 if target_model is model_cls: 

246 continue 

247 

248 if ( 

249 hasattr(field_info["type"], "__origin__") 

250 and field_info["type"].__origin__ is list 

251 ): 

252 # Many relationship 

253 target_schema = create_schema_from_model( 

254 target_model, 

255 include_relationships=False, # Avoid circular references 

256 include_readonly_fields=False, 

257 ) 

258 pydantic_type = list[target_schema] 

259 else: 

260 # One relationship 

261 target_schema = create_schema_from_model( 

262 target_model, 

263 include_relationships=False, # Avoid circular references 

264 include_readonly_fields=False, 

265 ) 

266 pydantic_type = target_schema 

267 

268 if field_info["is_optional"]: 268 ↛ 273line 268 didn't jump to line 273 because the condition on line 268 was always true

269 pydantic_type = pydantic_type | None 

270 

271 # Add field to definitions - use proper Pydantic field format 

272 # Don't include SQLAlchemy defaults as they're not JSON-serializable 

273 if field_info["is_optional"]: 

274 field_definitions[field_name] = (pydantic_type, Field(default=None)) 

275 else: 

276 field_definitions[field_name] = (pydantic_type, ...) 

277 

278 # Apply ReadOnly annotation to read-only fields 

279 for field_name in read_only_fields: 

280 if field_name in field_definitions: 280 ↛ 279line 280 didn't jump to line 279 because the condition on line 280 was always true

281 original_type, field_info = field_definitions[field_name] 

282 # Apply ReadOnly annotation to the type 

283 field_definitions[field_name] = (ReadOnly[original_type], field_info) 

284 

285 # Create the schema class using pydantic.create_model 

286 schema_cls = pydantic.create_model( # type: ignore[call-overload] 

287 schema_name, 

288 __doc__=f"Auto-generated schema for {model_cls.__name__}", 

289 __base__=tuple(bases), 

290 **field_definitions, 

291 ) 

292 

293 return schema_cls 

294 

295 

296def convert_sqlalchemy_type_to_pydantic( 

297 sqlalchemy_type: Any, is_optional: bool = False 

298) -> Any: 

299 """ 

300 Convert a SQLAlchemy type to a Pydantic-compatible type. 

301 

302 Args: 

303 sqlalchemy_type: The SQLAlchemy type 

304 is_optional: Whether the field is optional 

305 

306 Returns: 

307 A Pydantic-compatible type 

308 """ 

309 type_name = getattr(sqlalchemy_type, "__name__", str(sqlalchemy_type)) 

310 

311 if sqlalchemy_type is Any: 

312 pydantic_type = Any 

313 elif sqlalchemy_type in ( 

314 str, 

315 int, 

316 float, 

317 bool, 

318 dict, 

319 list, 

320 datetime, 

321 date, 

322 time, 

323 UUID, 

324 Decimal, 

325 ): 

326 pydantic_type = sqlalchemy_type 

327 elif isinstance(sqlalchemy_type, type) and issubclass(sqlalchemy_type, enum.Enum): 

328 pydantic_type = sqlalchemy_type 

329 elif isinstance(sqlalchemy_type, type) and issubclass( 

330 sqlalchemy_type, DeclarativeBase 

331 ): 

332 # Relationship targets are replaced with nested schemas later. 

333 pydantic_type = sqlalchemy_type 

334 elif getattr(sqlalchemy_type, "__origin__", None) is not None: 

335 # Preserve parameterized container types like dict[str, Any] or list[int]. 

336 pydantic_type = sqlalchemy_type 

337 elif type_name in {"Text", "String"}: 

338 pydantic_type = str 

339 elif type_name in {"Integer"}: 

340 pydantic_type = int 

341 elif type_name in {"Float"}: 

342 pydantic_type = float 

343 elif type_name in {"Boolean"}: 

344 pydantic_type = bool 

345 elif type_name in {"DateTime"}: 

346 pydantic_type = datetime 

347 elif type_name in {"Date"}: 

348 pydantic_type = date 

349 elif type_name in {"Time"}: 

350 pydantic_type = time 

351 else: 

352 raise TypeError( 

353 f"Unsupported field type for auto-generated schema: {sqlalchemy_type!r}" 

354 ) 

355 

356 # Handle optional types 

357 if is_optional: 

358 pydantic_type = pydantic_type | None 

359 

360 return pydantic_type 

361 

362 

363def auto_generate_schema_for_view( 

364 view_cls: type, model_cls: type[DeclarativeBase], schema_name: str | None = None 

365) -> type[BaseSchema]: 

366 """ 

367 Auto-generate a schema for a view class if none is specified. 

368 

369 Args: 

370 view_cls: The view class 

371 model_cls: The SQLAlchemy model class 

372 schema_name: Optional name for the generated schema 

373 

374 Returns: 

375 A Pydantic schema class 

376 """ 

377 if schema_name is None: 

378 schema_name = f"{model_cls.__name__}Read" 

379 

380 return create_schema_from_model( 

381 model_cls, schema_name=schema_name, include_relationships=False 

382 )