|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import base64 |
3 | 4 | import dataclasses |
4 | 5 | import json |
5 | 6 | import logging |
| 7 | +import warnings |
6 | 8 | from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, Union, cast |
7 | 9 | from urllib.parse import parse_qs |
8 | 10 |
|
|
25 | 27 | RequestValidationError, |
26 | 28 | ResponseValidationError, |
27 | 29 | ) |
28 | | -from aws_lambda_powertools.event_handler.openapi.params import Param |
| 30 | +from aws_lambda_powertools.event_handler.openapi.params import Param, UploadFile |
29 | 31 | from aws_lambda_powertools.event_handler.openapi.types import UnionType |
30 | 32 |
|
31 | 33 | if TYPE_CHECKING: |
|
44 | 46 | CONTENT_DISPOSITION_NAME_PARAM = "name=" |
45 | 47 | APPLICATION_JSON_CONTENT_TYPE = "application/json" |
46 | 48 | APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded" |
| 49 | +MULTIPART_FORM_DATA_CONTENT_TYPE = "multipart/form-data" |
47 | 50 |
|
48 | 51 |
|
49 | 52 | class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler): |
@@ -141,14 +144,18 @@ def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: |
141 | 144 | elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE): |
142 | 145 | return self._parse_form_data(app) |
143 | 146 |
|
| 147 | + # Handle multipart/form-data (file uploads) |
| 148 | + elif content_type.startswith(MULTIPART_FORM_DATA_CONTENT_TYPE): |
| 149 | + return self._parse_multipart_data(app, content_type) |
| 150 | + |
144 | 151 | else: |
145 | 152 | raise RequestUnsupportedContentType( |
146 | | - "Only JSON body or Form() are supported", |
| 153 | + "Unsupported content type", |
147 | 154 | errors=[ |
148 | 155 | { |
149 | 156 | "type": "unsupported_content_type", |
150 | 157 | "loc": ("body",), |
151 | | - "msg": "Only JSON body or Form() are supported", |
| 158 | + "msg": f"Unsupported content type: {content_type}", |
152 | 159 | "input": {}, |
153 | 160 | "ctx": {}, |
154 | 161 | }, |
@@ -195,6 +202,49 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: |
195 | 202 | ], |
196 | 203 | ) from e |
197 | 204 |
|
| 205 | + def _parse_multipart_data(self, app: EventHandlerInstance, content_type: str) -> dict[str, Any]: |
| 206 | + """Parse multipart/form-data from the request body (file uploads).""" |
| 207 | + try: |
| 208 | + # Extract the boundary from the content-type header |
| 209 | + boundary = _extract_multipart_boundary(content_type) |
| 210 | + if not boundary: |
| 211 | + raise ValueError("Missing boundary in multipart/form-data content-type header") |
| 212 | + |
| 213 | + # Get raw body bytes |
| 214 | + raw_body = app.current_event.body or "" |
| 215 | + if app.current_event.is_base64_encoded: |
| 216 | + body_bytes = base64.b64decode(raw_body) |
| 217 | + else: |
| 218 | + warnings.warn( |
| 219 | + "Received multipart/form-data without base64 encoding. " |
| 220 | + "Binary file uploads may be corrupted. " |
| 221 | + "If using API Gateway REST API (v1), configure Binary Media Types " |
| 222 | + "to include 'multipart/form-data'. " |
| 223 | + "See: https://docs.aws.amazon.com/apigateway/latest/developerguide/" |
| 224 | + "api-gateway-payload-encodings.html", |
| 225 | + stacklevel=2, |
| 226 | + ) |
| 227 | + # Use latin-1 to preserve all byte values (0-255) since the body |
| 228 | + # may contain raw binary data that isn't valid UTF-8 |
| 229 | + body_bytes = raw_body.encode("latin-1") |
| 230 | + |
| 231 | + return _parse_multipart_body(body_bytes, boundary) |
| 232 | + |
| 233 | + except ValueError: |
| 234 | + raise |
| 235 | + except Exception as e: |
| 236 | + raise RequestValidationError( |
| 237 | + [ |
| 238 | + { |
| 239 | + "type": "multipart_invalid", |
| 240 | + "loc": ("body",), |
| 241 | + "msg": "Multipart form data parsing error", |
| 242 | + "input": {}, |
| 243 | + "ctx": {"error": str(e)}, |
| 244 | + }, |
| 245 | + ], |
| 246 | + ) from e |
| 247 | + |
198 | 248 |
|
199 | 249 | class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler): |
200 | 250 | """ |
@@ -398,7 +448,12 @@ def _request_body_to_args( |
398 | 448 | continue |
399 | 449 |
|
400 | 450 | value = _normalize_field_value(value=value, field_info=field.field_info) |
401 | | - values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) |
| 451 | + |
| 452 | + # UploadFile objects bypass Pydantic validation — they're already constructed |
| 453 | + if isinstance(value, UploadFile): |
| 454 | + values[field.name] = value |
| 455 | + else: |
| 456 | + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) |
402 | 457 |
|
403 | 458 | return values, errors |
404 | 459 |
|
@@ -474,6 +529,10 @@ def _is_or_contains_sequence(annotation: Any) -> bool: |
474 | 529 |
|
475 | 530 | def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any: |
476 | 531 | """Normalize field value, converting lists to single values for non-sequence fields.""" |
| 532 | + # When annotation is bytes but value is UploadFile, extract raw content |
| 533 | + if isinstance(value, UploadFile) and field_info.annotation is bytes: |
| 534 | + return value.content |
| 535 | + |
477 | 536 | if _is_or_contains_sequence(field_info.annotation): |
478 | 537 | return value |
479 | 538 | elif isinstance(value, list) and value: |
@@ -587,3 +646,106 @@ def _get_param_value( |
587 | 646 | value = input_dict.get(field_name) |
588 | 647 |
|
589 | 648 | return value |
| 649 | + |
| 650 | + |
| 651 | +def _extract_multipart_boundary(content_type: str) -> str | None: |
| 652 | + """Extract the boundary string from a multipart/form-data content-type header.""" |
| 653 | + for segment in content_type.split(";"): |
| 654 | + stripped = segment.strip() |
| 655 | + if stripped.startswith("boundary="): |
| 656 | + boundary = stripped[len("boundary=") :] |
| 657 | + # Remove optional quotes around boundary |
| 658 | + if boundary.startswith('"') and boundary.endswith('"'): |
| 659 | + boundary = boundary[1:-1] |
| 660 | + return boundary |
| 661 | + return None |
| 662 | + |
| 663 | + |
| 664 | +def _parse_multipart_body(body: bytes, boundary: str) -> dict[str, Any]: |
| 665 | + """ |
| 666 | + Parse a multipart/form-data body into a dict of field names to values. |
| 667 | +
|
| 668 | + File fields get bytes values; regular form fields get string values. |
| 669 | + Multiple values for the same field name are collected into lists. |
| 670 | + """ |
| 671 | + delimiter = f"--{boundary}".encode() |
| 672 | + end_delimiter = f"--{boundary}--".encode() |
| 673 | + |
| 674 | + result: dict[str, Any] = {} |
| 675 | + |
| 676 | + # Split body by the boundary delimiter |
| 677 | + raw_parts = body.split(delimiter) |
| 678 | + |
| 679 | + for raw_part in raw_parts: |
| 680 | + # Skip the preamble (before first boundary) and epilogue (after closing boundary) |
| 681 | + if not raw_part or raw_part.strip() == b"" or raw_part.strip() == b"--": |
| 682 | + continue |
| 683 | + |
| 684 | + # Remove the end delimiter marker if present |
| 685 | + chunk = raw_part |
| 686 | + if chunk.endswith(end_delimiter): |
| 687 | + chunk = chunk[: -len(end_delimiter)] |
| 688 | + |
| 689 | + # Strip leading \r\n |
| 690 | + if chunk.startswith(b"\r\n"): |
| 691 | + chunk = chunk[2:] |
| 692 | + |
| 693 | + # Strip trailing \r\n |
| 694 | + if chunk.endswith(b"\r\n"): |
| 695 | + chunk = chunk[:-2] |
| 696 | + |
| 697 | + # Split headers from body at the double CRLF |
| 698 | + header_end = chunk.find(b"\r\n\r\n") |
| 699 | + if header_end == -1: |
| 700 | + continue |
| 701 | + |
| 702 | + header_section = chunk[:header_end].decode("utf-8") |
| 703 | + body_section = chunk[header_end + 4 :] |
| 704 | + |
| 705 | + # Parse Content-Disposition to get the field name and optional filename |
| 706 | + field_name = None |
| 707 | + filename = None |
| 708 | + content_type_header = None |
| 709 | + |
| 710 | + for header_line in header_section.split("\r\n"): |
| 711 | + header_lower = header_line.lower() |
| 712 | + if header_lower.startswith("content-disposition:"): |
| 713 | + field_name = _extract_header_param(header_line, "name") |
| 714 | + filename = _extract_header_param(header_line, "filename") |
| 715 | + elif header_lower.startswith("content-type:"): |
| 716 | + content_type_header = header_line.split(":", 1)[1].strip() |
| 717 | + |
| 718 | + if field_name is None: |
| 719 | + continue |
| 720 | + |
| 721 | + # If it has a filename, it's a file upload — wrap as UploadFile |
| 722 | + # Otherwise it's a regular form field — decode to string |
| 723 | + if filename is not None: |
| 724 | + value: Any = UploadFile(content=body_section, filename=filename, content_type=content_type_header) |
| 725 | + else: |
| 726 | + value = body_section.decode("utf-8") |
| 727 | + |
| 728 | + # Collect multiple values for same field name into a list |
| 729 | + if field_name in result: |
| 730 | + existing = result[field_name] |
| 731 | + if isinstance(existing, list): |
| 732 | + existing.append(value) |
| 733 | + else: |
| 734 | + result[field_name] = [existing, value] |
| 735 | + else: |
| 736 | + result[field_name] = value |
| 737 | + |
| 738 | + return result |
| 739 | + |
| 740 | + |
| 741 | +def _extract_header_param(header_line: str, param_name: str) -> str | None: |
| 742 | + """Extract a parameter value from a header line (e.g., name="file" from Content-Disposition).""" |
| 743 | + search = f'{param_name}="' |
| 744 | + idx = header_line.find(search) |
| 745 | + if idx == -1: |
| 746 | + return None |
| 747 | + start = idx + len(search) |
| 748 | + end = header_line.find('"', start) |
| 749 | + if end == -1: |
| 750 | + return None |
| 751 | + return header_line[start:end] |
0 commit comments