diff --git a/spec/GraphQLQueryComplexity.spec.js b/spec/GraphQLQueryComplexity.spec.js index 976cc761f4..2cf279d608 100644 --- a/spec/GraphQLQueryComplexity.spec.js +++ b/spec/GraphQLQueryComplexity.spec.js @@ -178,4 +178,32 @@ describe('graphql query complexity', () => { expect(result.errors).toBeUndefined(); }); }); + + describe('fragment fan-out', () => { + it('should reject query with exponential fragment fan-out efficiently', async () => { + await setupGraphQL({ + requestComplexity: { graphQLFields: 100 }, + }); + // Binary fan-out: each fragment spreads the next one twice. + // Without fix: 2^(levels-1) field visits = 2^25 ≈ 33M (hangs event loop). + // With fix (memoization): O(levels) traversal, same field count, instant rejection. + const levels = 26; + let query = 'query Q { ...F0 }\n'; + for (let i = 0; i < levels; i++) { + if (i === levels - 1) { + query += `fragment F${i} on Query { __typename }\n`; + } else { + query += `fragment F${i} on Query { ...F${i + 1} ...F${i + 1} }\n`; + } + } + const start = Date.now(); + const result = await graphqlRequest(query); + const elapsed = Date.now() - start; + // Must complete in under 5 seconds (without fix it would take seconds or hang) + expect(elapsed).toBeLessThan(5000); + // Field count is 2^(levels-1) = 16777216, which exceeds the limit of 100 + expect(result.errors).toBeDefined(); + expect(result.errors[0].message).toMatch(/Number of GraphQL fields .* exceeds maximum allowed/); + }); + }); }); diff --git a/src/GraphQL/helpers/queryComplexity.js b/src/GraphQL/helpers/queryComplexity.js index 0057e6438a..cd20424b8e 100644 --- a/src/GraphQL/helpers/queryComplexity.js +++ b/src/GraphQL/helpers/queryComplexity.js @@ -1,14 +1,22 @@ import { GraphQLError } from 'graphql'; import logger from '../../logger'; -function calculateQueryComplexity(operation, fragments) { +function calculateQueryComplexity(operation, fragments, limits = {}) { let maxDepth = 0; let totalFields = 0; + const fragmentCache = new Map(); + const { maxDepth: allowedMaxDepth, maxFields: allowedMaxFields } = limits; function visitSelectionSet(selectionSet, depth, visitedFragments) { if (!selectionSet) { return; } + if ( + (allowedMaxFields !== undefined && allowedMaxFields !== -1 && totalFields > allowedMaxFields) || + (allowedMaxDepth !== undefined && allowedMaxDepth !== -1 && maxDepth > allowedMaxDepth) + ) { + return; + } for (const selection of selectionSet.selections) { if (selection.kind === 'Field') { totalFields++; @@ -23,14 +31,36 @@ function calculateQueryComplexity(operation, fragments) { visitSelectionSet(selection.selectionSet, depth, visitedFragments); } else if (selection.kind === 'FragmentSpread') { const name = selection.name.value; + if (fragmentCache.has(name)) { + const cached = fragmentCache.get(name); + totalFields += cached.fields; + const adjustedDepth = depth + cached.maxDepthDelta; + if (adjustedDepth > maxDepth) { + maxDepth = adjustedDepth; + } + continue; + } if (visitedFragments.has(name)) { continue; } const fragment = fragments[name]; if (fragment) { - const branchVisited = new Set(visitedFragments); - branchVisited.add(name); - visitSelectionSet(fragment.selectionSet, depth, branchVisited); + if ( + (allowedMaxFields !== undefined && allowedMaxFields !== -1 && totalFields > allowedMaxFields) || + (allowedMaxDepth !== undefined && allowedMaxDepth !== -1 && maxDepth > allowedMaxDepth) + ) { + continue; + } + visitedFragments.add(name); + const savedFields = totalFields; + const savedMaxDepth = maxDepth; + maxDepth = depth; + visitSelectionSet(fragment.selectionSet, depth, visitedFragments); + const fieldsContribution = totalFields - savedFields; + const maxDepthDelta = maxDepth - depth; + fragmentCache.set(name, { fields: fieldsContribution, maxDepthDelta }); + maxDepth = Math.max(savedMaxDepth, maxDepth); + visitedFragments.delete(name); } } } @@ -69,7 +99,8 @@ function createComplexityValidationPlugin(getConfig) { const { depth, fields } = calculateQueryComplexity( requestContext.operation, - fragments + fragments, + { maxDepth: graphQLDepth, maxFields: graphQLFields } ); if (graphQLDepth !== -1 && depth > graphQLDepth) {