diff --git a/openapi_first/templates/crud_app/main.py b/openapi_first/templates/crud_app/main.py index f07cc9c..497989c 100644 --- a/openapi_first/templates/crud_app/main.py +++ b/openapi_first/templates/crud_app/main.py @@ -26,6 +26,7 @@ Example: """ from openapi_first.app import OpenAPIFirstApp +from starlette_csrf import CSRFMiddleware import routes app = OpenAPIFirstApp( @@ -33,3 +34,13 @@ app = OpenAPIFirstApp( routes_module=routes, title="CRUD Example Service", ) + +app.add_middleware( + CSRFMiddleware, + secret="change-me-in-production", + cookie_name="csrftoken", + header_name="x-csrftoken", + cookie_secure=False, + cookie_httponly=False, + cookie_samesite="lax", +) diff --git a/openapi_first/templates/crud_app/test_crud_app.py b/openapi_first/templates/crud_app/test_crud_app.py index 32439ab..4c38a23 100644 --- a/openapi_first/templates/crud_app/test_crud_app.py +++ b/openapi_first/templates/crud_app/test_crud_app.py @@ -34,6 +34,14 @@ client = OpenAPIClient( client=client, ) +# Bootstrap CSRF token via a safe GET request +_ = client.list_items() +_CSRF_TOKEN = client.client.cookies.get("csrftoken") + + +def _csrf_headers() -> dict: + return {"X-CSRFToken": _CSRF_TOKEN} if _CSRF_TOKEN else {} + def test_list_items_initial(): """Initial items should be present.""" @@ -70,7 +78,8 @@ def test_create_item(): } response = client.create_item( - body=payload + body=payload, + headers=_csrf_headers(), ) assert response.status_code == 201 @@ -95,6 +104,7 @@ def test_update_item(): response = client.update_item( path_params={"item_id": 1}, body=payload, + headers=_csrf_headers(), ) assert response.status_code == 200 @@ -115,7 +125,8 @@ def test_update_item(): def test_delete_item(): """Deleting an item should remove it from the store.""" response = client.delete_item( - path_params={"item_id": 2} + path_params={"item_id": 2}, + headers=_csrf_headers(), ) assert response.status_code == 204 diff --git a/openapi_first/templates/health_app/main.py b/openapi_first/templates/health_app/main.py index af92485..66a8e83 100644 --- a/openapi_first/templates/health_app/main.py +++ b/openapi_first/templates/health_app/main.py @@ -27,6 +27,7 @@ Example: from openapi_first.app import OpenAPIFirstApp +from starlette_csrf import CSRFMiddleware import routes app = OpenAPIFirstApp( @@ -34,3 +35,13 @@ app = OpenAPIFirstApp( routes_module=routes, title="Health Check Service", ) + +app.add_middleware( + CSRFMiddleware, + secret="change-me-in-production", + cookie_name="csrftoken", + header_name="x-csrftoken", + cookie_secure=False, + cookie_httponly=False, + cookie_samesite="lax", +) diff --git a/openapi_first/templates/model_app/main.py b/openapi_first/templates/model_app/main.py index 2cf86ca..d0468da 100644 --- a/openapi_first/templates/model_app/main.py +++ b/openapi_first/templates/model_app/main.py @@ -26,6 +26,7 @@ Example: """ from openapi_first.app import OpenAPIFirstApp +from starlette_csrf import CSRFMiddleware import routes app = OpenAPIFirstApp( @@ -33,3 +34,13 @@ app = OpenAPIFirstApp( routes_module=routes, title="Model CRUD Example Service", ) + +app.add_middleware( + CSRFMiddleware, + secret="change-me-in-production", + cookie_name="csrftoken", + header_name="x-csrftoken", + cookie_secure=False, + cookie_httponly=False, + cookie_samesite="lax", +) diff --git a/openapi_first/templates/model_app/test_model_app.py b/openapi_first/templates/model_app/test_model_app.py index bc4fa6d..b5299b9 100644 --- a/openapi_first/templates/model_app/test_model_app.py +++ b/openapi_first/templates/model_app/test_model_app.py @@ -33,6 +33,14 @@ client = OpenAPIClient( client=client, ) +# Bootstrap CSRF token via a safe GET request +_ = client.list_items() +_CSRF_TOKEN = client.client.cookies.get("csrftoken") + + +def _csrf_headers() -> dict: + return {"X-CSRFToken": _CSRF_TOKEN} if _CSRF_TOKEN else {} + def test_list_items_initial(): """Initial items should be present.""" @@ -69,7 +77,8 @@ def test_create_item(): } response = client.create_item( - body=payload + body=payload, + headers=_csrf_headers(), ) assert response.status_code == 201 @@ -94,6 +103,7 @@ def test_update_item(): response = client.update_item( path_params={"item_id": 1}, body=payload, + headers=_csrf_headers(), ) assert response.status_code == 200 @@ -114,7 +124,8 @@ def test_update_item(): def test_delete_item(): """Deleting an item should remove it from the store.""" response = client.delete_item( - path_params={"item_id": 2} + path_params={"item_id": 2}, + headers=_csrf_headers(), ) assert response.status_code == 204 diff --git a/openapi_first/templates/vet_app/data.py b/openapi_first/templates/vet_app/data.py index 97ca9b3..5fc25b1 100644 --- a/openapi_first/templates/vet_app/data.py +++ b/openapi_first/templates/vet_app/data.py @@ -52,7 +52,7 @@ def create_parent(payload: ParentCreate) -> Parent: now = _now() parent = Parent( id=_parents_next_id, - **payload.model_dump(exclude={"id"}), + **payload.model_dump(exclude={"id", "metadata"}), metadata={"createdOn": now, "updatedOn": now} if payload.metadata else None, ) _parents[_parents_next_id] = parent @@ -101,7 +101,7 @@ def create_vet(payload: VetCreate) -> Vet: now = _now() vet = Vet( id=_vets_next_id, - **payload.model_dump(exclude={"id"}), + **payload.model_dump(exclude={"id", "metadata"}), metadata={"createdOn": now, "updatedOn": now} if payload.metadata else None, ) _vets[_vets_next_id] = vet @@ -150,7 +150,7 @@ def create_treatment(payload: TreatmentCreate) -> Treatment: now = _now() treatment = Treatment( id=_treatments_next_id, - **payload.model_dump(exclude={"id"}), + **payload.model_dump(exclude={"id", "metadata"}), metadata={"createdOn": now, "updatedOn": now} if payload.metadata else None, ) _treatments[_treatments_next_id] = treatment @@ -199,7 +199,7 @@ def create_pet(payload: PetCreate) -> Pet: now = _now() pet = Pet( id=_pets_next_id, - **payload.model_dump(exclude={"id"}), + **payload.model_dump(exclude={"id", "metadata"}), metadata={"createdOn": now, "updatedOn": now} if payload.metadata else None, ) _pets[_pets_next_id] = pet @@ -248,7 +248,7 @@ def create_appointment(payload: AppointmentCreate) -> Appointment: now = _now() appointment = Appointment( id=_appointments_next_id, - **payload.model_dump(exclude={"id"}), + **payload.model_dump(exclude={"id", "metadata"}), metadata={"createdOn": now, "updatedOn": now} if payload.metadata else None, ) _appointments[_appointments_next_id] = appointment diff --git a/openapi_first/templates/vet_app/main.py b/openapi_first/templates/vet_app/main.py index 9f1515c..4c6a7d7 100644 --- a/openapi_first/templates/vet_app/main.py +++ b/openapi_first/templates/vet_app/main.py @@ -25,7 +25,10 @@ Example: uvicorn main:app """ +from starlette.middleware.cors import CORSMiddleware + from openapi_first.app import OpenAPIFirstApp +from starlette_csrf import CSRFMiddleware import routes app = OpenAPIFirstApp( @@ -33,3 +36,21 @@ app = OpenAPIFirstApp( routes_module=routes, title="Veterinary Clinic Service", ) + +app.add_middleware( + CSRFMiddleware, + secret="change-me-in-production", + cookie_name="csrftoken", + header_name="x-csrftoken", + cookie_secure=False, + cookie_httponly=False, + cookie_samesite="lax", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) diff --git a/openapi_first/templates/vet_app/test_vet_app.py b/openapi_first/templates/vet_app/test_vet_app.py index e7392ea..77c4dba 100644 --- a/openapi_first/templates/vet_app/test_vet_app.py +++ b/openapi_first/templates/vet_app/test_vet_app.py @@ -20,10 +20,18 @@ client = OpenAPIClient( client=test_client, ) +# Bootstrap CSRF token via a safe GET request +_ = client.list_parents(query={"limit": 1}) +_CSRF_TOKEN = client.client.cookies.get("csrftoken") + + +def _csrf_headers() -> dict: + return {"X-CSRFToken": _CSRF_TOKEN} if _CSRF_TOKEN else {} + def test_list_parents(): """List parents returns paginated response.""" - response = client.list_parents(query_params={"limit": 10, "offset": 0}) + response = client.list_parents(query={"limit": 10, "offset": 0}) assert response.status_code == 200 data = response.json() assert "total" in data @@ -33,7 +41,7 @@ def test_list_parents(): def test_create_parent(): """Creating a parent returns 201 with the created entity.""" payload = {"name": "Alice", "email": "alice@example.com"} - response = client.create_parent(body=payload) + response = client.create_parent(body=payload, headers=_csrf_headers()) assert response.status_code == 201 parent = response.json() assert parent["name"] == "Alice" @@ -42,7 +50,7 @@ def test_create_parent(): def test_get_parent(): """Get parent by ID returns the entity.""" - parent = client.create_parent(body={"name": "Bob", "email": "bob@example.com"}).json() + parent = client.create_parent(body={"name": "Bob", "email": "bob@example.com"}, headers=_csrf_headers()).json() response = client.get_parent(path_params={"id": parent["id"]}) assert response.status_code == 200 assert response.json()["name"] == "Bob" @@ -50,23 +58,23 @@ def test_get_parent(): def test_update_parent(): """Update parent replaces its values.""" - parent = client.create_parent(body={"name": "Carol", "email": "carol@example.com"}).json() + parent = client.create_parent(body={"name": "Carol", "email": "carol@example.com"}, headers=_csrf_headers()).json() payload = {"name": "Carol Smith", "email": "carol.smith@example.com"} - response = client.update_parent(path_params={"id": parent["id"]}, body=payload) + response = client.update_parent(path_params={"id": parent["id"]}, body=payload, headers=_csrf_headers()) assert response.status_code == 200 assert response.json()["name"] == "Carol Smith" def test_delete_parent(): """Delete parent returns 204 and removes the entity.""" - parent = client.create_parent(body={"name": "Dave", "email": "dave@example.com"}).json() - response = client.delete_parent(path_params={"id": parent["id"]}) + parent = client.create_parent(body={"name": "Dave", "email": "dave@example.com"}, headers=_csrf_headers()).json() + response = client.delete_parent(path_params={"id": parent["id"]}, headers=_csrf_headers()) assert response.status_code == 204 def test_list_vets(): """List vets returns paginated response.""" - response = client.list_vets(query_params={"limit": 10, "offset": 0}) + response = client.list_vets(query={"limit": 10, "offset": 0}) assert response.status_code == 200 data = response.json() assert "total" in data @@ -76,7 +84,7 @@ def test_list_vets(): def test_create_vet(): """Creating a vet returns 201.""" payload = {"name": "Dr. Smith", "specialty": "Surgery", "email": "smith@clinic.com"} - response = client.create_vet(body=payload) + response = client.create_vet(body=payload, headers=_csrf_headers()) assert response.status_code == 201 assert response.json()["name"] == "Dr. Smith" @@ -91,33 +99,34 @@ def test_list_treatments(): def test_create_treatment(): """Creating a treatment returns 201.""" payload = {"label": "Vaccination", "description": "Annual vaccination"} - response = client.create_treatment(body=payload) + response = client.create_treatment(body=payload, headers=_csrf_headers()) assert response.status_code == 201 assert response.json()["label"] == "Vaccination" def test_create_pet(): """Creating a pet links FK references.""" - parent = client.create_parent(body={"name": "Owner", "email": "owner@example.com"}).json() + parent = client.create_parent(body={"name": "Owner", "email": "owner@example.com"}, headers=_csrf_headers()).json() payload = {"name": "Fido", "species": "dog", "parent_ids": [parent["id"]]} - response = client.create_pet(body=payload) + response = client.create_pet(body=payload, headers=_csrf_headers()) assert response.status_code == 201 assert response.json()["name"] == "Fido" def test_upload_pet_photo(): """Upload pet photo returns 200.""" - pet = client.create_pet(body={"name": "PhotoPet", "species": "cat", "parent_ids": []}).json() - response = client.upload_pet_photo( - path_params={"id": pet["id"]}, - body={}, + pet = client.create_pet(body={"name": "PhotoPet", "species": "cat", "parent_ids": []}, headers=_csrf_headers()).json() + response = client.client.post( + f"http://testserver/pets/{pet['id']}", + files={"file": ("test.jpg", b"fake-image-data", "image/jpeg")}, + headers=_csrf_headers(), ) assert response.status_code == 200 def test_list_appointments(): """List appointments returns paginated response with filter params.""" - response = client.list_appointments(query_params={"limit": 10, "offset": 0}) + response = client.list_appointments(query={"limit": 10, "offset": 0}) assert response.status_code == 200 data = response.json() assert "total" in data @@ -126,10 +135,10 @@ def test_list_appointments(): def test_full_appointment_lifecycle(): """Create a parent, vet, treatment, pet, then an appointment.""" - parent = client.create_parent(body={"name": "Eve", "email": "eve@example.com"}).json() - vet = client.create_vet(body={"name": "Dr. Jones", "specialty": "Dentistry", "email": "jones@clinic.com"}).json() - treatment = client.create_treatment(body={"label": "Cleaning", "description": "Teeth cleaning"}).json() - pet = client.create_pet(body={"name": "Max", "species": "dog", "parent_ids": [parent["id"]]}).json() + parent = client.create_parent(body={"name": "Eve", "email": "eve@example.com"}, headers=_csrf_headers()).json() + vet = client.create_vet(body={"name": "Dr. Jones", "specialty": "Dentistry", "email": "jones@clinic.com"}, headers=_csrf_headers()).json() + treatment = client.create_treatment(body={"label": "Cleaning", "description": "Teeth cleaning"}, headers=_csrf_headers()).json() + pet = client.create_pet(body={"name": "Max", "species": "dog", "parent_ids": [parent["id"]]}, headers=_csrf_headers()).json() payload = { "date": "2025-06-01T10:00:00", @@ -137,7 +146,7 @@ def test_full_appointment_lifecycle(): "vet_id": vet["id"], "treatment_id": treatment["id"], } - response = client.create_appointment(body=payload) + response = client.create_appointment(body=payload, headers=_csrf_headers()) assert response.status_code == 201 appointment = response.json() assert appointment["pet_id"] == pet["id"] @@ -147,5 +156,5 @@ def test_full_appointment_lifecycle(): assert get_resp.status_code == 200 # Delete it - del_resp = client.delete_appointment(path_params={"id": appointment["id"]}) + del_resp = client.delete_appointment(path_params={"id": appointment["id"]}, headers=_csrf_headers()) assert del_resp.status_code == 204 diff --git a/pyproject.toml b/pyproject.toml index 8c8b9cc..fbcb5d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ dependencies = [ # Code generation "datamodel-code-generator>=0.25.0", + + # CSRF protection (scaffolded apps) + "starlette-csrf>=3.0.0", ] [project.scripts]