diff --git a/src/expr/src/relation/mod.rs b/src/expr/src/relation/mod.rs index 0c202ddaf711d..be8505830e3a3 100644 --- a/src/expr/src/relation/mod.rs +++ b/src/expr/src/relation/mod.rs @@ -3035,31 +3035,40 @@ impl RowSetFinishing { let limit = self.limit.unwrap_or(usize::MAX); - // Count how many rows we'd expand into, returning early from the whole function - // if we don't have enough memory to expand the result, or break early from the - // iteration once we pass our limit. - let mut num_rows = 0; - let mut num_bytes: usize = 0; - for (row, count) in &rows[offset_nth_row..] { - num_rows += count.get(); - num_bytes = num_bytes.saturating_add(count.get().saturating_mul(row.byte_len())); - - // Check that result fits into max_result_size. - if num_bytes > max_result_size { - return Err(format!( - "result exceeds max size of {}", - ByteSize::b(u64::cast_from(max_result_size)) - )); - } + // The code below is logically equivalent to: + // + // let mut total = 0; + // for (_, count) in &rows[offset_nth_row..] { + // total += count.get(); + // } + // let return_size = std::cmp::min(total, limit); + // + // but it breaks early if the limit is reached, instead of scanning the entire code. + let return_row_count = rows[offset_nth_row..] + .iter() + .try_fold(0, |sum: usize, (_, count)| { + let new_sum = sum.saturating_add(count.get()); + if new_sum > limit { + None + } else { + Some(new_sum) + } + }) + .unwrap_or(limit); - // Stop iterating if we've passed limit. - if num_rows > limit { - break; - } + // Check that the bytes allocated in the Vec below will be less than the minimum possible + // byte limit (i.e., if zero rows spill to heap). We still have to check each row below + // because they could spill to heap and end up using more memory. + const MINIMUM_ROW_BYTES: usize = std::mem::size_of::<Row>(); + let bytes_to_be_allocated = MINIMUM_ROW_BYTES.saturating_mul(return_row_count); + if bytes_to_be_allocated > max_result_size { + return Err(format!( + "result exceeds max size of {}", + ByteSize::b(u64::cast_from(max_result_size)) + )); } - let return_size = std::cmp::min(num_rows, limit); - let mut ret = Vec::with_capacity(return_size); + let mut ret = Vec::with_capacity(return_row_count); let mut remaining = limit; let mut row_buf = Row::default(); let mut datum_vec = mz_repr::DatumVec::new(); diff --git a/test/sqllogictest/vars.slt b/test/sqllogictest/vars.slt index 9b645c22db7e9..f39e9b04090ee 100644 --- a/test/sqllogictest/vars.slt +++ b/test/sqllogictest/vars.slt @@ -434,6 +434,19 @@ SELECT generate_series(1, 2) query error db error: ERROR: result exceeds max size of 100 B SELECT generate_series(1, 10) +# Regression for #22724 +# Ensure duplicate rows don't overcount bytes in the presence of LIMIT. +query T +SELECT x FROM (VALUES ('{"row": 1}')) AS a (x), generate_series(1, 50000) LIMIT 1 +---- +{"row": 1} + +# Ensure that a large ordering key but small projection does not count against the result size limit. +query I +select 1 from (select array_agg(generate_series) x from generate_series(1, 1000000)) order by x limit 1 +---- +1 + statement ok RESET max_query_result_size