Skip to content

Commit c0f2859

Browse files
author
Peng Ren
committed
Fix pagination issue
1 parent 2b56215 commit c0f2859

2 files changed

Lines changed: 193 additions & 7 deletions

File tree

pymongosql/result_set.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _ensure_results_available(self, count: int = 1) -> None:
9292
while len(self._cached_results) < count and self._cursor_id != 0:
9393
try:
9494
# Use getMore to fetch next batch
95-
if self._database and self._execution_plan.collection:
95+
if self._database is not None and self._execution_plan.collection:
9696
getmore_cmd = {
9797
"getMore": self._cursor_id,
9898
"collection": self._execution_plan.collection,
@@ -256,12 +256,16 @@ def fetchall(self) -> List[Sequence[Any]]:
256256
all_results = []
257257

258258
try:
259-
# Handle command result (db.command)
260-
if not self._cache_exhausted:
261-
# Results are already processed in constructor, just extend
262-
all_results.extend(self._cached_results)
263-
self._total_fetched += len(self._cached_results)
264-
self._cache_exhausted = True
259+
# Ensure all results are available in cache by requesting a very large number
260+
# This will trigger getMore calls until all data is exhausted
261+
if not self._cache_exhausted and self._cursor_id != 0:
262+
self._ensure_results_available(float("inf"))
263+
264+
# Now get everything from cache
265+
all_results.extend(self._cached_results)
266+
self._total_fetched += len(self._cached_results)
267+
self._cached_results.clear()
268+
self._cache_exhausted = True
265269

266270
except PyMongoError as e:
267271
self._errors.append({"error": str(e), "type": type(e).__name__})
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# -*- coding: utf-8 -*-
2+
from pymongosql.result_set import ResultSet
3+
from pymongosql.sql.builder import ExecutionPlan
4+
5+
6+
class TestResultSetPagination:
7+
"""Test suite for ResultSet pagination with getMore"""
8+
9+
# Shared projections used by tests
10+
PROJECTION_WITH_FIELDS = {"name": 1, "email": 1}
11+
PROJECTION_EMPTY = {}
12+
13+
def test_pagination_cursor_id_zero(self, conn):
14+
"""Test pagination when cursor_id is 0 (all results in firstBatch)"""
15+
db = conn.database
16+
# Query with small limit - all results fit in firstBatch
17+
command_result = db.command({"find": "users", "limit": 5})
18+
19+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
20+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
21+
22+
# Check cursor_id - should be 0 when all results fit in firstBatch
23+
assert result_set._cursor_id == 0
24+
assert result_set._cache_exhausted is False # Not exhausted yet, but no getMore needed
25+
26+
# Fetch all results
27+
rows = result_set.fetchall()
28+
assert len(rows) == 5
29+
30+
# After fetching all, cache should be exhausted
31+
assert result_set._cache_exhausted is True
32+
33+
def test_pagination_multiple_batches(self, conn):
34+
"""Test pagination across multiple batches with getMore"""
35+
db = conn.database
36+
# Use a small batch size (batchSize) to force pagination
37+
command_result = db.command({"find": "users", "batchSize": 5}) # Only 5 results per batch
38+
39+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
40+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
41+
42+
# Initial results should have cursor_id > 0 since we have 22 total users and batchSize is 5
43+
initial_cached = len(result_set._cached_results)
44+
assert initial_cached <= 5 # Should have at most 5 in cache from firstBatch
45+
46+
# Fetch multiple results (should trigger getMore)
47+
rows = result_set.fetchmany(10)
48+
assert len(rows) == 10
49+
50+
# After fetching, we should have processed multiple batches
51+
assert result_set._total_fetched >= 10
52+
53+
def test_pagination_ensure_results_available(self, conn):
54+
"""Test _ensure_results_available with pagination"""
55+
db = conn.database
56+
# Request results with small batch size
57+
command_result = db.command({"find": "users", "batchSize": 3}) # Small batch to test pagination
58+
59+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
60+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
61+
62+
# Initially, cache might have 3 results
63+
initial_cache_size = len(result_set._cached_results)
64+
assert initial_cache_size <= 3
65+
66+
# Ensure we have 8 results available - should trigger getMore
67+
result_set._ensure_results_available(8)
68+
assert len(result_set._cached_results) >= 8
69+
70+
# Check that cursor_id was updated
71+
assert result_set._cursor_id >= 0
72+
73+
def test_pagination_fetchone_triggers_getmore(self, conn):
74+
"""Test that fetchone triggers getMore when needed"""
75+
db = conn.database
76+
# Create result set with small batch size
77+
command_result = db.command({"find": "users", "batchSize": 2}) # Very small batch
78+
79+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
80+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
81+
82+
_ = result_set._cursor_id
83+
rows_fetched = []
84+
85+
# Fetch many single rows - should trigger getMore multiple times
86+
for _ in range(10):
87+
row = result_set.fetchone()
88+
if row:
89+
rows_fetched.append(row)
90+
91+
assert len(rows_fetched) == 10
92+
# rowcount should reflect total fetched
93+
assert result_set.rowcount >= 10
94+
95+
def test_pagination_cache_exhausted_flag(self, conn):
96+
"""Test cache exhausted flag is set correctly"""
97+
db = conn.database
98+
command_result = db.command({"find": "users", "limit": 3})
99+
100+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
101+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
102+
103+
assert result_set._cache_exhausted is False
104+
105+
# Fetch all results
106+
rows = result_set.fetchall()
107+
assert len(rows) == 3
108+
109+
# After exhausting results, flag should be set
110+
assert result_set._cache_exhausted is True
111+
112+
# Subsequent fetches should return empty
113+
more_rows = result_set.fetchall()
114+
assert more_rows == []
115+
116+
def test_pagination_rowcount_tracking(self, conn):
117+
"""Test rowcount is accurately tracked during pagination"""
118+
db = conn.database
119+
command_result = db.command({"find": "users", "batchSize": 4})
120+
121+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
122+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
123+
124+
initial_rowcount = result_set.rowcount
125+
assert initial_rowcount <= 4 # Initial batch size
126+
127+
# Fetch multiple batches
128+
batch1 = result_set.fetchmany(8)
129+
assert result_set.rowcount >= 8
130+
131+
batch2 = result_set.fetchmany(5)
132+
assert result_set.rowcount >= 13
133+
134+
# Fetch all remaining
135+
all_remaining = result_set.fetchall()
136+
_ = result_set.rowcount
137+
138+
# All 22 users should be fetched eventually
139+
total_fetched = len(batch1) + len(batch2) + len(all_remaining)
140+
assert total_fetched == 22
141+
142+
def test_pagination_with_projection(self, conn):
143+
"""Test pagination with field projection applied"""
144+
db = conn.database
145+
command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "batchSize": 3})
146+
147+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
148+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
149+
150+
# Fetch across multiple batches
151+
rows = result_set.fetchall()
152+
153+
# Should have all 22 users
154+
assert len(rows) == 22
155+
156+
# Each row should have exactly 2 projected fields
157+
col_names = [desc[0] for desc in result_set.description]
158+
for row in rows:
159+
assert len(row) == 2
160+
assert isinstance(row[col_names.index("name")], (str, type(None)))
161+
assert isinstance(row[col_names.index("email")], (str, type(None)))
162+
163+
def test_pagination_fetchmany_across_batches(self, conn):
164+
"""Test fetchmany that spans multiple getMore calls"""
165+
db = conn.database
166+
command_result = db.command({"find": "users", "batchSize": 3})
167+
168+
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
169+
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db)
170+
171+
# Fetch 10 rows - should span multiple batches
172+
batch1 = result_set.fetchmany(10)
173+
assert len(batch1) == 10
174+
175+
# Fetch next 10 - should get more users
176+
batch2 = result_set.fetchmany(10)
177+
assert len(batch2) == 10
178+
179+
# Fetch remaining results
180+
batch3 = result_set.fetchmany(5)
181+
# Should get remaining users (total > 20, depends on actual data size)
182+
assert len(batch3) > 0

0 commit comments

Comments
 (0)