Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions api_v2/views/mixins/eager_loading_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ class EagerLoadingMixin:
"""
Mixin to apply eager loading optimisations to a ViewSet.

Dynamically applies `selected_related()` for ForeignKey fields and
`prefetch_related()` from ManyToMany/reverse relationships. This improves
query efficiency and prevents N+1 problems
Handles the running of `select_related()` (for ForeignKey fields) and
`prefetch_related()` (from ManyToMany/reverse relationships) queryset methods
to allow developers to solve N+1 problems on Open5e endpoints.

## Usage
1. Make sure your ViewSet inherits from `EagerLoadingMixin` before its base
class (ie. ReadOnlyModelViewSet).
2. Re-define `select_related_fields` and `prefetch_related_fields` lists on
the child ViewSet to specify relationships to optimise.
the child ViewSet to specify relationships to select related / pre-fetch.

## Usage Example
```
Expand All @@ -28,24 +28,69 @@ class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
prefetch_related_fields = []

def get_queryset(self):
"""
Builds the queryset with optimised eager loading based on the requested and excluded fields.
"""
queryset = super().get_queryset()
requested_fields = self.request.query_params.get('fields', '').split(',')
filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields)
filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields)

# Check fields included or excluded via query parameter. We use this data
# so that we only eagerly load fields actually returned by the API.
requested_fields = self.parse_requested_fields()
excluded_fields = self.parse_excluded_fields()

filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields, excluded_fields)
filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields, excluded_fields)

return queryset \
.select_related(*filtered_select_fields) \
.prefetch_related(*filtered_prefetch_fields)

def parse_requested_fields(self):
"""
Parses the 'fields' query param into a list of requested field paths.
"""
requested_fields = self.request.query_params.get('fields', '')
requested_fields = requested_fields.split(',')
requested_fields = [field for field in requested_fields if field]
return requested_fields

def parse_excluded_fields(self):
"""
Parses 'exclude' query params into a flat list of field paths for use in eager loading
"""
excluded_fields = []
for key, value in self.request.query_params.items():
if key == 'exclude':
excluded_fields += value.split(',')
elif key.endswith('__exclude'):
prefix = key.removesuffix('__exclude')
excluded_fields += [f'{prefix}__{field}' for field in value.split(',')]
return excluded_fields

def filter_fields(self, related_fields, requested_fields):
def filter_fields(self, related_fields, requested_fields=None, excluded_fields=None):
"""
Filters'related_fields' according to whether they are included in
'requested_fields'. Used to remove fields from eager loading if they are
not requested (and thus not returned by API), avoiding unnecessary DB calls
Filters 'related_fields' according to whether they are included in
'requested_fields' or 'excluded_fields'. Used to remove fields from eager
loading if they are not returned by API call to avoid unnecessary DB calls
"""
if not any(requested_fields):
return related_fields
return [
related_field for related_field in related_fields
if any(related_field == req or related_field.startswith(req + '__') for req in requested_fields)
]
# avoids mutable default argument issues: set to empty list if param missing
requested_fields = requested_fields or []
excluded_fields = excluded_fields or []

def field_matches(field, targets):
# Returns True if 'field' equals any 'target', or is a child path of one
return any(field == target or field.startswith(target + '__') for target in targets)

if requested_fields:
related_fields = [
related_field for related_field in related_fields
if field_matches(related_field, requested_fields)
]

if excluded_fields:
related_fields = [
related_field for related_field in related_fields
if not field_matches(related_field, excluded_fields)
]

return related_fields
61 changes: 48 additions & 13 deletions api_v2/views/mixins/exclude_fields_mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
class ExcludeFieldsMixin:
"""
This Mixin supports dynamically excluding returned fields of serializers that
inherit from it via the `?exclude` query parameter.

Syntactically similar to the default `?field` DRF query parameter. Nested
fields are similarly excluded via the '__' operator (see Examples).

## Usage
1. Make sure your ViewSet inherits from `ExcludeFieldsMixin` before its base
class (ie. ReadOnlyModelViewSet).
2. Pass exclude params in the request query string to remove fields from the response.

# Exclude top-level fields
GET /v2/creatures/?exclude=traits,actions

# Exclude nested fields
GET /v2/creatures/?actions__exclude=attacks
"""

def get_serializer_class(self):

# Handle other mixins that might also override get_serializer_class
Expand All @@ -7,25 +26,41 @@ def get_serializer_class(self):
else:
serializer_class = getattr(self, 'serializer_class')

# just return the regular serializer if there is no request
# Return base serializer if there is no request. This stops calculation of
# excluded fields for nested serializers, avoiding unnecessary computing.
if not hasattr(self, 'request') or not hasattr(self.request, 'query_params'):
return serializer_class

exclude_fields = self.request.query_params.get('exclude', '').split(',')
# Iterates over params, scans for any 'exclude' or '<field>_exclude' keys
# and builds a dict mapping API field paths to lists of field to remove from each
# e.g. '?exclude=id&document__exclude=permalink' becomes:
# { '': ['id'], 'document': ['permalink'] }
fields_to_exclude = {}
for key, value in self.request.query_params.items():
if key == 'exclude':
fields_to_exclude[''] = value.split(',')
elif key.endswith('__exclude'):
fields_to_exclude[key.removesuffix('__exclude')] = value.split(',')

if not exclude_fields:
if not fields_to_exclude:
return serializer_class

# create a new serializer with 'exclude_fields' removed and return it

# Walks the serializer tree removing fields at each level flagged for removal.
# 'path' tracks where we are in the tree, which we use as a key to index into
# 'fields_to_exclude' to check which fields to remove. We then recurse into
# nested serializers and apply the same logic.
def strip_excluded_fields(fields, path=''):
for excluded_field in fields_to_exclude.get(path, []):
fields.pop(excluded_field, None)
for field_name, field in fields.items():
nested_serializer = getattr(field, 'child', field)
if hasattr(nested_serializer, 'fields'):
nested_path = f'{path}__{field_name}' if path else field_name
strip_excluded_fields(nested_serializer.fields, nested_path)

class DynamicSerializer(serializer_class):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
strip_excluded_fields(self.fields)

excluded_fields = []
for field in exclude_fields:
if field in self.fields:
self.fields.pop(field)
excluded_fields.append(field)

return DynamicSerializer

return DynamicSerializer