diff --git a/api_v2/views/mixins/eager_loading_mixin.py b/api_v2/views/mixins/eager_loading_mixin.py index 84ffb1c7..25dabdd8 100644 --- a/api_v2/views/mixins/eager_loading_mixin.py +++ b/api_v2/views/mixins/eager_loading_mixin.py @@ -2,15 +2,15 @@ class EagerLoadingMixin: """ Mixin to apply eager loading optimisations to a ViewSet. - Dynamically applies `selected_related()` for ForeignKey fields and - `prefetch_related()` from ManyToMany/reverse relationships. This improves - query efficiency and prevents N+1 problems + Handles the running of `select_related()` (for ForeignKey fields) and + `prefetch_related()` (from ManyToMany/reverse relationships) queryset methods + to allow developers to solve N+1 problems on Open5e endpoints. ## Usage 1. Make sure your ViewSet inherits from `EagerLoadingMixin` before its base class (ie. ReadOnlyModelViewSet). 2. Re-define `select_related_fields` and `prefetch_related_fields` lists on - the child ViewSet to specify relationships to optimise. + the child ViewSet to specify relationships to select related / pre-fetch. ## Usage Example ``` @@ -28,24 +28,69 @@ class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet): prefetch_related_fields = [] def get_queryset(self): + """ + Builds the queryset with optimised eager loading based on the requested and excluded fields. + """ queryset = super().get_queryset() - requested_fields = self.request.query_params.get('fields', '').split(',') - filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields) - filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields) + + # Check fields included or excluded via query parameter. We use this data + # so that we only eagerly load fields actually returned by the API. + requested_fields = self.parse_requested_fields() + excluded_fields = self.parse_excluded_fields() + + filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields, excluded_fields) + filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields, excluded_fields) return queryset \ .select_related(*filtered_select_fields) \ .prefetch_related(*filtered_prefetch_fields) + + def parse_requested_fields(self): + """ + Parses the 'fields' query param into a list of requested field paths. + """ + requested_fields = self.request.query_params.get('fields', '') + requested_fields = requested_fields.split(',') + requested_fields = [field for field in requested_fields if field] + return requested_fields + + def parse_excluded_fields(self): + """ + Parses 'exclude' query params into a flat list of field paths for use in eager loading + """ + excluded_fields = [] + for key, value in self.request.query_params.items(): + if key == 'exclude': + excluded_fields += value.split(',') + elif key.endswith('__exclude'): + prefix = key.removesuffix('__exclude') + excluded_fields += [f'{prefix}__{field}' for field in value.split(',')] + return excluded_fields - def filter_fields(self, related_fields, requested_fields): + def filter_fields(self, related_fields, requested_fields=None, excluded_fields=None): """ - Filters'related_fields' according to whether they are included in - 'requested_fields'. Used to remove fields from eager loading if they are - not requested (and thus not returned by API), avoiding unnecessary DB calls + Filters 'related_fields' according to whether they are included in + 'requested_fields' or 'excluded_fields'. Used to remove fields from eager + loading if they are not returned by API call to avoid unnecessary DB calls """ - if not any(requested_fields): - return related_fields - return [ - related_field for related_field in related_fields - if any(related_field == req or related_field.startswith(req + '__') for req in requested_fields) - ] \ No newline at end of file + # avoids mutable default argument issues: set to empty list if param missing + requested_fields = requested_fields or [] + excluded_fields = excluded_fields or [] + + def field_matches(field, targets): + # Returns True if 'field' equals any 'target', or is a child path of one + return any(field == target or field.startswith(target + '__') for target in targets) + + if requested_fields: + related_fields = [ + related_field for related_field in related_fields + if field_matches(related_field, requested_fields) + ] + + if excluded_fields: + related_fields = [ + related_field for related_field in related_fields + if not field_matches(related_field, excluded_fields) + ] + + return related_fields \ No newline at end of file diff --git a/api_v2/views/mixins/exclude_fields_mixin.py b/api_v2/views/mixins/exclude_fields_mixin.py index 3d0805f6..57508fe3 100644 --- a/api_v2/views/mixins/exclude_fields_mixin.py +++ b/api_v2/views/mixins/exclude_fields_mixin.py @@ -1,4 +1,23 @@ class ExcludeFieldsMixin: + """ + This Mixin supports dynamically excluding returned fields of serializers that + inherit from it via the `?exclude` query parameter. + + Syntactically similar to the default `?field` DRF query parameter. Nested + fields are similarly excluded via the '__' operator (see Examples). + + ## Usage + 1. Make sure your ViewSet inherits from `ExcludeFieldsMixin` before its base + class (ie. ReadOnlyModelViewSet). + 2. Pass exclude params in the request query string to remove fields from the response. + + # Exclude top-level fields + GET /v2/creatures/?exclude=traits,actions + + # Exclude nested fields + GET /v2/creatures/?actions__exclude=attacks + """ + def get_serializer_class(self): # Handle other mixins that might also override get_serializer_class @@ -7,25 +26,41 @@ def get_serializer_class(self): else: serializer_class = getattr(self, 'serializer_class') - # just return the regular serializer if there is no request + # Return base serializer if there is no request. This stops calculation of + # excluded fields for nested serializers, avoiding unnecessary computing. if not hasattr(self, 'request') or not hasattr(self.request, 'query_params'): return serializer_class - exclude_fields = self.request.query_params.get('exclude', '').split(',') + # Iterates over params, scans for any 'exclude' or '_exclude' keys + # and builds a dict mapping API field paths to lists of field to remove from each + # e.g. '?exclude=id&document__exclude=permalink' becomes: + # { '': ['id'], 'document': ['permalink'] } + fields_to_exclude = {} + for key, value in self.request.query_params.items(): + if key == 'exclude': + fields_to_exclude[''] = value.split(',') + elif key.endswith('__exclude'): + fields_to_exclude[key.removesuffix('__exclude')] = value.split(',') - if not exclude_fields: + if not fields_to_exclude: return serializer_class - - # create a new serializer with 'exclude_fields' removed and return it + + # Walks the serializer tree removing fields at each level flagged for removal. + # 'path' tracks where we are in the tree, which we use as a key to index into + # 'fields_to_exclude' to check which fields to remove. We then recurse into + # nested serializers and apply the same logic. + def strip_excluded_fields(fields, path=''): + for excluded_field in fields_to_exclude.get(path, []): + fields.pop(excluded_field, None) + for field_name, field in fields.items(): + nested_serializer = getattr(field, 'child', field) + if hasattr(nested_serializer, 'fields'): + nested_path = f'{path}__{field_name}' if path else field_name + strip_excluded_fields(nested_serializer.fields, nested_path) + class DynamicSerializer(serializer_class): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + strip_excluded_fields(self.fields) - excluded_fields = [] - for field in exclude_fields: - if field in self.fields: - self.fields.pop(field) - excluded_fields.append(field) - - return DynamicSerializer - + return DynamicSerializer \ No newline at end of file