|
| 1 | +import uuid |
| 2 | +from typing import TypeVar, Generic, List, Any, Callable, get_args |
| 3 | +from pydantic import Field |
| 4 | +from pydantic_core import core_schema |
| 5 | +from ninja.filter_schema import FilterSchema as NinjaFilterSchema # Use alias to avoid name clash |
| 6 | + |
| 7 | +# Define a TypeVar for the list items |
| 8 | +T = TypeVar('T') |
| 9 | + |
| 10 | +class QueryStringList(Generic[T]): |
| 11 | + """ |
| 12 | + A Pydantic type that parses a query string parameter into a list of a specified type. |
| 13 | + It handles both comma-separated values in a single parameter |
| 14 | + (e.g., "?tags=movie,tv") and repeated parameters (e.g., "?id=1&id=2"). |
| 15 | + """ |
| 16 | + |
| 17 | + # Store the target item type (e.g., str, int, uuid.UUID) |
| 18 | + item_type: type |
| 19 | + |
| 20 | + @classmethod |
| 21 | + def _parse_value(cls, value: Any) -> List[T]: |
| 22 | + """ |
| 23 | + Parses the input value (string or list) into a list of the target item type. |
| 24 | + """ |
| 25 | + items_to_validate: List[str] = [] |
| 26 | + |
| 27 | + if isinstance(value, list): |
| 28 | + # Input is already a list (from repeated query parameters) |
| 29 | + items_to_validate = [str(item) for item in value] # Ensure items are strings before validation |
| 30 | + elif isinstance(value, str): |
| 31 | + # Input is a string (comma-separated or single value) |
| 32 | + items_to_validate = [item.strip() for item in value.split(',') if item.strip()] |
| 33 | + else: |
| 34 | + # Let Pydantic's validation handle other incorrect types later if necessary, |
| 35 | + # but ideally, web frameworks pass strings or lists for query params. |
| 36 | + # We could raise a ValueError here for stricter input checking. |
| 37 | + raise ValueError(f"Expected list or comma-separated str, got {type(value).__name__}") |
| 38 | + |
| 39 | + |
| 40 | + validated_items: List[T] = [] |
| 41 | + for item_str in items_to_validate: |
| 42 | + try: |
| 43 | + # Attempt to convert/validate the item using the target item_type's constructor |
| 44 | + # This works well for types like int, float, str, uuid.UUID |
| 45 | + validated_item = cls.item_type(item_str) |
| 46 | + validated_items.append(validated_item) |
| 47 | + except (ValueError, TypeError) as e: |
| 48 | + # If conversion fails, raise a ValueError that Pydantic can catch |
| 49 | + raise ValueError(f"Invalid item '{item_str}' for type {cls.item_type.__name__}: {e}") from e |
| 50 | + |
| 51 | + return validated_items |
| 52 | + |
| 53 | + @classmethod |
| 54 | + def __get_pydantic_core_schema__( |
| 55 | + cls, source: type[Any], handler: Callable[[Any], core_schema.CoreSchema] |
| 56 | + ) -> core_schema.CoreSchema: |
| 57 | + """ |
| 58 | + Hook into Pydantic's schema generation process. |
| 59 | + """ |
| 60 | + # 1. Extract the target item type T from QueryStringList[T] |
| 61 | + args = get_args(source) |
| 62 | + if not args or len(args) != 1: |
| 63 | + raise TypeError(f"{cls.__name__} requires exactly one type argument, e.g., {cls.__name__}[str]") |
| 64 | + |
| 65 | + # Store the item type on the class for the parser function to use. |
| 66 | + # Note: This approach assumes the class object is stable during validation. |
| 67 | + # Pydantic generally creates validators per schema field instance. |
| 68 | + cls.item_type = args[0] |
| 69 | + |
| 70 | + # 2. Define the core schema: |
| 71 | + # It should accept either a string or a list as input. |
| 72 | + # Then, it uses our custom parsing function `_parse_value`. |
| 73 | + |
| 74 | + # Schema that accepts raw input (list or string) |
| 75 | + # input_schema = core_schema.union_schema( |
| 76 | + # [ |
| 77 | + # core_schema.list_schema(core_schema.any_schema()), # Accepts lists (repeated params) |
| 78 | + # core_schema.str_schema() # Accepts strings (comma-separated or single) |
| 79 | + # ], |
| 80 | + # custom_error_type='value_error', # Generic error type if neither matches |
| 81 | + # custom_error_message='Input must be a query string or a list of query parameters' |
| 82 | + # ) |
| 83 | + |
| 84 | + # Chain the input schema with our custom parser/validator function |
| 85 | + # `no_info_plain_validator_function` takes a function that accepts only the value. |
| 86 | + # final_schema = core_schema.no_info_plain_validator_function( |
| 87 | + # cls._parse_value, # Our function that handles parsing and type conversion |
| 88 | + # # We could potentially chain input_schema first, but letting _parse_value |
| 89 | + # # handle both list and str input directly is simpler here. |
| 90 | + # ) |
| 91 | + |
| 92 | + # Pydantic V2 often uses function validators directly like this. |
| 93 | + # Let's simplify to just the plain validator assuming it gets the raw query param value. |
| 94 | + |
| 95 | + # Simplified schema: Directly use our parser function. |
| 96 | + # Pydantic will pass the raw query parameter value (str or list) to it. |
| 97 | + return core_schema.no_info_plain_validator_function(cls._parse_value) |
| 98 | + |
| 99 | +# --- Example Usage --- |
| 100 | + |
| 101 | +# Assume you have Django models like this (simplified): |
| 102 | +# class Tag(models.Model): |
| 103 | +# name = models.CharField(max_length=50, unique=True) |
| 104 | +# |
| 105 | +# class Like(models.Model): |
| 106 | +# user = models.ForeignKey(User, on_delete=models.CASCADE) |
| 107 | +# post = models.ForeignKey('Post', on_delete=models.CASCADE) |
| 108 | +# |
| 109 | +# class Post(models.Model): |
| 110 | +# title = models.CharField(max_length=100) |
| 111 | +# tags = models.ManyToManyField(Tag, related_name='posts') |
| 112 | +# likes = models.ManyToManyField(User, through=Like, related_name='liked_posts') |
| 113 | + |
| 114 | + |
| 115 | +# Define your FilterSchema using the custom QueryStringList type |
| 116 | +# class PostFilterSchema(NinjaFilterSchema): |
| 117 | +# # Handles ?tags=movie,tv OR ?tags=movie&tags=tv |
| 118 | +# # Maps to queryset filter: Post.objects.filter(tags__name__in=['movie', 'tv']) |
| 119 | +# tags: QueryStringList[str] | None = Field(None, q='tags__name__in') |
| 120 | +# |
| 121 | +# # Handles ?liked_by=uuid1,uuid2 OR ?liked_by=uuid1&liked_by=uuid2 |
| 122 | +# # Maps to queryset filter: Post.objects.filter(likes__user__in=[UUID('uuid1'), UUID('uuid2')]) |
| 123 | +# liked_by: QueryStringList[uuid.UUID] | None = Field(None, q='likes__user__in') # 'likes' is the ManyToManyField name from Post to User (through Like) |
| 124 | + |
| 125 | +# Example Ninja API endpoint |
| 126 | +# from ninja import NinjaAPI |
| 127 | +# from django.shortcuts import get_list_or_404 |
| 128 | +# from .models import Post # Assuming models.py is in the same directory |
| 129 | + |
| 130 | +# api = NinjaAPI() |
| 131 | + |
| 132 | +# @api.get("/posts") |
| 133 | +# def list_posts(request, filters: PostFilterSchema): |
| 134 | +# queryset = Post.objects.all() |
| 135 | +# filtered_queryset = filters.filter(queryset) # Apply the filters |
| 136 | +# |
| 137 | +# # You might want to apply distinct() if filtering on ManyToMany fields causes duplicates |
| 138 | +# posts = get_list_or_404(filtered_queryset.distinct()) |
| 139 | +# |
| 140 | +# # Replace with your actual Post schema/serialization |
| 141 | +# return [{"id": post.id, "title": post.title} for post in posts] |
0 commit comments