Skip to content

Commit d321b39

Browse files
authored
perf: Use count_rows on fragment to reduce lance scans with limit pushdowns only (#5120)
## Changes Made When reading lance with a limit pushdown, we can limit the number of fragments scanned by inspecting the num rows of each fragment. This reduces the # of scan tasks created and executed. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 6a8767c commit d321b39

File tree

2 files changed

+98
-6
lines changed

2 files changed

+98
-6
lines changed

daft/io/lance/lance_scan.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,48 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
164164
pushdowns=pushdowns,
165165
stats=None,
166166
)
167+
# Check if there is a limit pushdown and no filters
168+
elif pushdowns.limit is not None and self._pushed_filters is None:
169+
yield from self._create_scan_tasks_with_limit_and_no_filters(pushdowns, required_columns)
167170
else:
168-
# Regular scan without count pushdown
169171
yield from self._create_regular_scan_tasks(pushdowns, required_columns)
170172

173+
def _create_scan_tasks_with_limit_and_no_filters(
174+
self, pushdowns: PyPushdowns, required_columns: Optional[list[str]]
175+
) -> Iterator[ScanTask]:
176+
"""Create scan tasks optimized for limit pushdown with no filters."""
177+
assert self._pushed_filters is None, "Expected no filters when creating scan tasks with limit and no filters"
178+
assert pushdowns.limit is not None, "Expected a limit when creating scan tasks with limit and no filters"
179+
180+
fragments = self._ds.get_fragments()
181+
remaining_limit = pushdowns.limit
182+
183+
for fragment in fragments:
184+
if remaining_limit <= 0:
185+
# No more rows needed, stop creating scan tasks
186+
break
187+
188+
# Calculate effective rows using fragment.count_rows()
189+
# This is not expensive because count_rows simply checks physical_rows - num_deletions when there are no filters
190+
# https://github.com/lancedb/lance/blob/v0.34.0/rust/lance/src/dataset/fragment.rs#L1049-L1055
191+
effective_rows = fragment.count_rows()
192+
193+
if effective_rows > 0:
194+
# Determine how many rows this fragment should contribute
195+
rows_to_scan = min(remaining_limit, effective_rows)
196+
remaining_limit -= rows_to_scan
197+
198+
yield ScanTask.python_factory_func_scan_task(
199+
module=_lancedb_table_factory_function.__module__,
200+
func_name=_lancedb_table_factory_function.__name__,
201+
func_args=(self._ds, [fragment.fragment_id], required_columns, None, rows_to_scan),
202+
schema=self.schema()._schema,
203+
num_rows=rows_to_scan,
204+
size_bytes=None,
205+
pushdowns=pushdowns,
206+
stats=None,
207+
)
208+
171209
def _create_regular_scan_tasks(
172210
self, pushdowns: PyPushdowns, required_columns: Optional[list[str]]
173211
) -> Iterator[ScanTask]:

tests/io/lancedb/test_lancedb_reads.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,65 @@ def test_lancedb_read_filter(lance_dataset_path):
3939
assert df.to_pydict() == {"vector": data["vector"][:1]}
4040

4141

42-
def test_lancedb_read_limit(lance_dataset_path):
43-
df = daft.read_lance(lance_dataset_path)
44-
df = df.limit(1)
45-
df = df.select("vector")
46-
assert df.to_pydict() == {"vector": data["vector"][:1]}
42+
@pytest.fixture(scope="function")
43+
def large_lance_dataset_path(tmp_path_factory):
44+
"""Create a large Lance dataset with multiple fragments for testing limit operations."""
45+
tmp_dir = tmp_path_factory.mktemp("large_lance")
46+
47+
# Create 10 fragments of 1000 rows each (10,000 total rows)
48+
for frag_idx in range(10):
49+
# Generate data for this fragment
50+
vectors = [[float(i * 0.1 + frag_idx * 1000), float(i * 0.2 + frag_idx * 1000)] for i in range(1000)]
51+
big_ints = [i + frag_idx * 1000 for i in range(1000)]
52+
53+
fragment_data = {"vector": vectors, "big_int": big_ints}
54+
55+
# Write fragment (first write creates dataset, subsequent writes append)
56+
mode = "append" if frag_idx > 0 else None
57+
lance.write_dataset(pa.Table.from_pydict(fragment_data), tmp_dir, mode=mode)
58+
59+
yield str(tmp_dir)
60+
61+
62+
@pytest.mark.parametrize(
63+
"limit_size,expected_scan_tasks",
64+
[
65+
# Small limits
66+
(1000, 1),
67+
(1001, 2),
68+
# Big limits
69+
(9000, 9),
70+
(9001, 10),
71+
(10000, 10),
72+
],
73+
)
74+
def test_lancedb_read_limit_large_dataset(large_lance_dataset_path, limit_size, expected_scan_tasks):
75+
"""Test limit operation on a large Lance dataset with multiple fragments."""
76+
import io
77+
78+
df = daft.read_lance(large_lance_dataset_path)
79+
80+
# Test with different limit sizes
81+
df = df.limit(limit_size)
82+
df = df.select("vector", "big_int")
83+
84+
# Capture the explain output
85+
string_io = io.StringIO()
86+
df.explain(True, file=string_io)
87+
explain_output = string_io.getvalue()
88+
89+
# Assert that we have the expected number of scan tasks
90+
assert f"Num Scan Tasks = {expected_scan_tasks}" in explain_output
91+
92+
result = df.to_pydict()
93+
94+
# Verify we got the expected number of rows
95+
assert len(result["vector"]) == limit_size
96+
assert len(result["big_int"]) == limit_size
97+
98+
# Verify the data is ordered correctly (should get first N rows)
99+
expected_big_ints = list(range(limit_size))
100+
assert result["big_int"] == expected_big_ints
47101

48102

49103
def test_lancedb_with_version(lance_dataset_path):

0 commit comments

Comments
 (0)