Files
openapi-first/openapi_first/codegen_routes.py
2026-06-16 04:14:14 +05:30

313 lines
10 KiB
Python

"""
Route handler code generation from OpenAPI specifications.
This module generates Python route handler stubs from an OpenAPI 3.x
specification. Each resource (derived from the first path segment)
gets its own file under the output directory. Every OpenAPI operation
must define an ``operationId``, which becomes the handler function name.
Notes:
**Design constraints:**
- ``operationId`` is required on every operation (matching
``binder.bind_routes``).
- Handlers are stubs raising ``NotImplementedError``.
- Sub-resources (e.g. ``/pets/{id}/photo``) are grouped with their
parent resource (``pets``).
- Parameter types and defaults are inferred from the spec.
- ``response: Response`` is injected for non-200 success codes.
"""
from pathlib import Path
from typing import Any
from .loader import load_openapi
def generate_routes(
spec_path: Path,
output_dir: Path,
*,
use_models: bool = False,
models_module: str = "models",
) -> list[Path]:
"""
Generate route handler stubs from an OpenAPI specification.
Creates one ``<resource>.py`` file per resource in *output_dir*.
Resources are derived from the first path segment (e.g. ``/pets``
and ``/pets/{id}`` both group under ``pets``).
Args:
spec_path:
Path to the OpenAPI specification file (YAML or JSON).
output_dir:
Directory where the generated route files are written.
Created automatically if it does not exist.
use_models:
If ``True``, import Pydantic models from *models_module*
for request-body schemas referenced via ``$ref``.
models_module:
Dotted Python module path from which to import models
(e.g. ``"models"``, ``"app.models"``).
Returns:
list[Path]:
Absolute paths of every generated route file.
Raises:
OpenAPISpecLoadError:
If the spec cannot be loaded or validated.
ValueError:
If any operation is missing ``operationId``.
"""
spec = load_openapi(spec_path)
output_dir = Path(output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
# Group paths by resource (first non-param path segment)
resources: dict[str, list[tuple[str, str, dict]]] = {}
paths = spec.get("paths", {})
for path, methods in paths.items():
segments = [s for s in path.split("/") if s and not s.startswith("{")]
if not segments:
continue
resource = segments[0]
if resource not in resources:
resources[resource] = []
for http_method, operation in methods.items():
if http_method.startswith("x-"):
continue
resources[resource].append((path, http_method, operation))
generated_files: list[Path] = []
for resource in sorted(resources):
operations = resources[resource]
_validate_operations(resource, operations)
file_path = output_dir / f"{resource}.py"
content = _generate_resource_file(
resource=resource,
operations=operations,
spec_path=str(spec_path),
use_models=use_models,
models_module=models_module,
)
file_path.write_text(content, encoding="utf-8")
generated_files.append(file_path)
return generated_files
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
_TYPE_MAP: dict[str, str] = {
"integer": "int",
"number": "float",
"boolean": "bool",
"string": "str",
"array": "list",
"object": "dict",
}
def _validate_operations(resource: str, operations: list[tuple[str, str, dict]]) -> None:
"""Ensure every operation has an operationId."""
for path, http_method, operation in operations:
if not operation.get("operationId"):
raise ValueError(
f"Missing operationId for {http_method.upper()} {path} "
f"in resource '{resource}'. "
"All operations must have an explicit operationId."
)
def _resolve_type(schema: dict[str, Any]) -> str:
"""Map an OpenAPI schema type to a Python type annotation."""
openapi_type = schema.get("type", "")
return _TYPE_MAP.get(openapi_type, "str")
def _get_request_body_schema(operation: dict[str, Any]) -> str | None:
"""Extract the ``$ref`` schema name from a request body, if any."""
request_body = operation.get("requestBody")
if not request_body:
return None
content = request_body.get("content", {})
for media_info in content.values():
schema = media_info.get("schema", {})
ref = schema.get("$ref", "")
if ref:
return ref.rsplit("/", 1)[-1]
return None
def _get_success_status(operation: dict[str, Any]) -> str:
"""Return the primary success status code (200, 201, 204, or default)."""
responses = operation.get("responses", {})
for code in ("201", "200", "204"):
if code in responses:
return code
if "default" in responses:
return "default"
return "200"
def _needs_any(operations: list[tuple[str, str, dict]]) -> bool:
"""Check if any operation uses a type that requires `from typing import Any`."""
for _, _, op in operations:
for param in op.get("parameters", []):
schema = param.get("schema", {})
if schema.get("type", "") not in _TYPE_MAP:
return True
return False
# ---------------------------------------------------------------------------
# File / function generation
# ---------------------------------------------------------------------------
def _generate_resource_file(
resource: str,
operations: list[tuple[str, str, dict]],
spec_path: str,
use_models: bool,
models_module: str,
) -> str:
"""Assemble the full Python source for a single resource file."""
op_ids = [op.get("operationId", "") for _, _, op in operations]
lines: list[str] = [
f'"""Route handlers for the {resource} resource.',
"",
f"Generated from OpenAPI spec: {spec_path}",
f'Bind via operationIds: {", ".join(op_ids)}',
'"""',
"",
"from fastapi import Response, HTTPException",
]
# Conditional typing import
if _needs_any(operations):
lines.append("from typing import Any")
# Conditional model imports
if use_models:
schemas_needed: set[str] = set()
for _, _, op in operations:
schema = _get_request_body_schema(op)
if schema:
schemas_needed.add(schema)
if schemas_needed:
lines.append(f"from {models_module} import {', '.join(sorted(schemas_needed))}")
lines.append("")
for path, http_method, operation in operations:
lines.append("")
lines.extend(_generate_handler(path, http_method, operation, use_models))
lines.append("")
return "\n".join(lines)
def _generate_handler(
path: str,
http_method: str,
operation: dict[str, Any],
use_models: bool,
) -> list[str]:
"""Build the source lines for a single handler function."""
operation_id = operation.get("operationId", "")
summary = operation.get("summary", f"{http_method.upper()} {path}")
params: list[str] = []
path_params: list[str] = []
query_params: list[str] = []
for param in operation.get("parameters", []):
name: str = param.get("name", "")
param_in: str = param.get("in", "")
schema: dict[str, Any] = param.get("schema", {})
param_type: str = _resolve_type(schema)
required: bool = param.get("required", False)
description: str = schema.get("description", schema.get("x-description", ""))
default_raw = schema.get("default")
if param_in == "path":
path_params.append(f"{name}: {param_type}")
elif param_in == "query":
if default_raw is not None:
default_repr = repr(default_raw)
query_params.append(f"{name}: {param_type} = {default_repr}")
elif required:
query_params.append(f"{name}: {param_type}")
else:
query_params.append(f"{name}: {param_type} = None")
# header / cookie params could be extended here
params.extend(path_params)
params.extend(query_params)
# Request body
schema_name = _get_request_body_schema(operation)
if operation.get("requestBody"):
if use_models and schema_name:
params.append(f"payload: {schema_name}")
else:
params.append("payload: dict")
# Inject Response for non-200 success codes
success_status = _get_success_status(operation)
if success_status in ("201", "204"):
params.append("response: Response")
# Build function body
lines: list[str] = []
param_str = ", ".join(params)
lines.append(f"def {operation_id}({param_str}):")
lines.append(f' """{summary}')
# Document parameters
doc_params: list[tuple[str, str, str]] = []
for param in operation.get("parameters", []):
name: str = param.get("name", "")
param_in: str = param.get("in", "")
schema: dict[str, Any] = param.get("schema", {})
param_type: str = _resolve_type(schema)
description: str = param.get("description", schema.get("x-description", ""))
if param_in in ("path", "query"):
doc_params.append((name, param_type, description))
if doc_params:
lines.append("")
lines.append(" Parameters")
lines.append(" ----------")
for pname, ptype, desc in doc_params:
if desc:
lines.append(f" {pname} : {ptype}")
lines.append(f" {desc}")
else:
lines.append(f" {pname} : {ptype}")
if operation.get("requestBody"):
if use_models and schema_name:
lines.append("")
lines.append(f" payload : {schema_name}")
lines.append(" Request body.")
else:
lines.append("")
lines.append(" payload : dict")
lines.append(" Request body.")
lines.append(' """')
lines.append(" raise NotImplementedError")
return lines