From 79fa6f9098be9a6e5b269cd3642694765b230ff1 Mon Sep 17 00:00:00 2001
From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com>
Date: Sat, 10 Aug 2024 10:03:02 +0300
Subject: [PATCH] Enforce sorting handle fetchable operators, add option to
 repartition based on row count estimates (#11875)

* Tmp

* Minor changes

* Minor changes

* Minor changes

* Implement top down recursion with delete check

* Minor changes

* Minor changes

* Address reviews

* Update comments

* Minor changes

* Make test deterministic

* Add fetch info to the statistics

* Enforce distribution use inexact count estimate also.

* Minor changes

* Minor changes

* Minor changes

* Do not add unnecessary hash partitioning

* Minor changes

* Add config option to use inexact row number estimates during planning

* Update config

* Minor changes

* Minor changes

* Final review

* Address reviews

* Add handling for sort removal with fetch

* Fix linter errors

* Minor changes

* Update config

* Cleanup stats under fetch

* Update SLT comment

---------

Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
---
 datafusion/common/src/config.rs               |   8 +
 datafusion/common/src/stats.rs                | 122 ++++++++++--
 datafusion/core/src/dataframe/mod.rs          |  12 +-
 datafusion/core/src/datasource/statistics.rs  |   2 +-
 .../enforce_distribution.rs                   | 148 ++++++++++++--
 .../src/physical_optimizer/enforce_sorting.rs | 180 ++++++++++++++++--
 .../src/physical_optimizer/sort_pushdown.rs   | 125 +++++++++---
 .../physical-plan/src/coalesce_batches.rs     |   6 +-
 .../physical-plan/src/execution_plan.rs       |   5 +
 datafusion/physical-plan/src/filter.rs        |   2 +-
 datafusion/physical-plan/src/limit.rs         | 153 +++------------
 datafusion/physical-plan/src/sorts/sort.rs    |   6 +-
 .../test_files/count_star_rule.slt            |   6 +-
 .../sqllogictest/test_files/group_by.slt      |  15 +-
 .../test_files/information_schema.slt         |   2 +
 datafusion/sqllogictest/test_files/limit.slt  |   9 +-
 datafusion/sqllogictest/test_files/order.slt  |  41 ++++
 .../test_files/sort_merge_join.slt            |  12 +-
 datafusion/sqllogictest/test_files/union.slt  |  30 ++-
 datafusion/sqllogictest/test_files/window.slt |  22 +--
 docs/source/user-guide/configs.md             |   1 +
 21 files changed, 643 insertions(+), 264 deletions(-)

diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs
index b5204b343f05..c48845c061e7 100644
--- a/datafusion/common/src/config.rs
+++ b/datafusion/common/src/config.rs
@@ -333,6 +333,14 @@ config_namespace! {
         /// Number of input rows partial aggregation partition should process, before
         /// aggregation ratio check and trying to switch to skipping aggregation mode
         pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000
+
+        /// Should DataFusion use row number estimates at the input to decide
+        /// whether increasing parallelism is beneficial or not. By default,
+        /// only exact row numbers (not estimates) are used for this decision.
+        /// Setting this flag to `true` will likely produce better plans.
+        /// if the source of statistics is accurate.
+        /// We plan to make this the default in the future.
+        pub use_row_number_estimates_to_optimize_partitioning: bool, default = false
     }
 }
 
diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs
index 6cefef8d0eb5..d6b5310581d7 100644
--- a/datafusion/common/src/stats.rs
+++ b/datafusion/common/src/stats.rs
@@ -19,9 +19,9 @@
 
 use std::fmt::{self, Debug, Display};
 
-use crate::ScalarValue;
+use crate::{Result, ScalarValue};
 
-use arrow_schema::Schema;
+use arrow_schema::{Schema, SchemaRef};
 
 /// Represents a value with a degree of certainty. `Precision` is used to
 /// propagate information the precision of statistical values.
@@ -247,21 +247,96 @@ impl Statistics {
 
     /// If the exactness of a [`Statistics`] instance is lost, this function relaxes
     /// the exactness of all information by converting them [`Precision::Inexact`].
-    pub fn into_inexact(self) -> Self {
-        Statistics {
-            num_rows: self.num_rows.to_inexact(),
-            total_byte_size: self.total_byte_size.to_inexact(),
-            column_statistics: self
-                .column_statistics
-                .into_iter()
-                .map(|cs| ColumnStatistics {
-                    null_count: cs.null_count.to_inexact(),
-                    max_value: cs.max_value.to_inexact(),
-                    min_value: cs.min_value.to_inexact(),
-                    distinct_count: cs.distinct_count.to_inexact(),
-                })
-                .collect::<Vec<_>>(),
+    pub fn to_inexact(mut self) -> Self {
+        self.num_rows = self.num_rows.to_inexact();
+        self.total_byte_size = self.total_byte_size.to_inexact();
+        self.column_statistics = self
+            .column_statistics
+            .into_iter()
+            .map(|s| s.to_inexact())
+            .collect();
+        self
+    }
+
+    /// Calculates the statistics after `fetch` and `skip` operations apply.
+    /// Here, `self` denotes per-partition statistics. Use the `n_partitions`
+    /// parameter to compute global statistics in a multi-partition setting.
+    pub fn with_fetch(
+        mut self,
+        schema: SchemaRef,
+        fetch: Option<usize>,
+        skip: usize,
+        n_partitions: usize,
+    ) -> Result<Self> {
+        let fetch_val = fetch.unwrap_or(usize::MAX);
+
+        self.num_rows = match self {
+            Statistics {
+                num_rows: Precision::Exact(nr),
+                ..
+            }
+            | Statistics {
+                num_rows: Precision::Inexact(nr),
+                ..
+            } => {
+                // Here, the inexact case gives us an upper bound on the number of rows.
+                if nr <= skip {
+                    // All input data will be skipped:
+                    Precision::Exact(0)
+                } else if nr <= fetch_val && skip == 0 {
+                    // If the input does not reach the `fetch` globally, and `skip`
+                    // is zero (meaning the input and output are identical), return
+                    // input stats as is.
+                    // TODO: Can input stats still be used, but adjusted, when `skip`
+                    //       is non-zero?
+                    return Ok(self);
+                } else if nr - skip <= fetch_val {
+                    // After `skip` input rows are skipped, the remaining rows are
+                    // less than or equal to the `fetch` values, so `num_rows` must
+                    // equal the remaining rows.
+                    check_num_rows(
+                        (nr - skip).checked_mul(n_partitions),
+                        // We know that we have an estimate for the number of rows:
+                        self.num_rows.is_exact().unwrap(),
+                    )
+                } else {
+                    // At this point we know that we were given a `fetch` value
+                    // as the `None` case would go into the branch above. Since
+                    // the input has more rows than `fetch + skip`, the number
+                    // of rows will be the `fetch`, but we won't be able to
+                    // predict the other statistics.
+                    check_num_rows(
+                        fetch_val.checked_mul(n_partitions),
+                        // We know that we have an estimate for the number of rows:
+                        self.num_rows.is_exact().unwrap(),
+                    )
+                }
+            }
+            Statistics {
+                num_rows: Precision::Absent,
+                ..
+            } => check_num_rows(fetch.and_then(|v| v.checked_mul(n_partitions)), false),
+        };
+        self.column_statistics = Statistics::unknown_column(&schema);
+        self.total_byte_size = Precision::Absent;
+        Ok(self)
+    }
+}
+
+/// Creates an estimate of the number of rows in the output using the given
+/// optional value and exactness flag.
+fn check_num_rows(value: Option<usize>, is_exact: bool) -> Precision<usize> {
+    if let Some(value) = value {
+        if is_exact {
+            Precision::Exact(value)
+        } else {
+            // If the input stats are inexact, so are the output stats.
+            Precision::Inexact(value)
         }
+    } else {
+        // If the estimate is not available (e.g. due to an overflow), we can
+        // not produce a reliable estimate.
+        Precision::Absent
     }
 }
 
@@ -336,14 +411,25 @@ impl ColumnStatistics {
     }
 
     /// Returns a [`ColumnStatistics`] instance having all [`Precision::Absent`] parameters.
-    pub fn new_unknown() -> ColumnStatistics {
-        ColumnStatistics {
+    pub fn new_unknown() -> Self {
+        Self {
             null_count: Precision::Absent,
             max_value: Precision::Absent,
             min_value: Precision::Absent,
             distinct_count: Precision::Absent,
         }
     }
+
+    /// If the exactness of a [`ColumnStatistics`] instance is lost, this
+    /// function relaxes the exactness of all information by converting them
+    /// [`Precision::Inexact`].
+    pub fn to_inexact(mut self) -> Self {
+        self.null_count = self.null_count.to_inexact();
+        self.max_value = self.max_value.to_inexact();
+        self.min_value = self.min_value.to_inexact();
+        self.distinct_count = self.distinct_count.to_inexact();
+        self
+    }
 }
 
 #[cfg(test)]
diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs
index 5fa65cb0da42..25a8d1c87f00 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -3000,13 +3000,13 @@ mod tests {
             .await?
             .select_columns(&["c1", "c2", "c3"])?
             .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
-            .limit(0, Some(1))?
             .sort(vec![
                 // make the test deterministic
                 col("c1").sort(true, true),
                 col("c2").sort(true, true),
                 col("c3").sort(true, true),
             ])?
+            .limit(0, Some(1))?
             .with_column("sum", col("c2") + col("c3"))?;
 
         let df_sum_renamed = df
@@ -3022,11 +3022,11 @@ mod tests {
 
         assert_batches_sorted_eq!(
             [
-                "+-----+-----+----+-------+",
-                "| one | two | c3 | total |",
-                "+-----+-----+----+-------+",
-                "| a   | 3   | 13 | 16    |",
-                "+-----+-----+----+-------+"
+                "+-----+-----+-----+-------+",
+                "| one | two | c3  | total |",
+                "+-----+-----+-----+-------+",
+                "| a   | 3   | -72 | -69   |",
+                "+-----+-----+-----+-------+",
             ],
             &df_sum_renamed
         );
diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs
index 9d031a6bbc85..669755877680 100644
--- a/datafusion/core/src/datasource/statistics.rs
+++ b/datafusion/core/src/datasource/statistics.rs
@@ -138,7 +138,7 @@ pub async fn get_statistics_with_limit(
         // If we still have files in the stream, it means that the limit kicked
         // in, and the statistic could have been different had we processed the
         // files in a different order.
-        statistics = statistics.into_inexact()
+        statistics = statistics.to_inexact()
     }
 
     Ok((result_files, statistics))
diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
index 1f076e448e60..2ee5624c83dd 100644
--- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
@@ -44,6 +44,7 @@ use crate::physical_plan::windows::WindowAggExec;
 use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning};
 
 use arrow::compute::SortOptions;
+use datafusion_common::stats::Precision;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_expr::logical_plan::JoinType;
 use datafusion_physical_expr::expressions::{Column, NoOp};
@@ -1031,6 +1032,105 @@ fn replace_order_preserving_variants(
     context.update_plan_from_children()
 }
 
+/// A struct to keep track of repartition requirements for each child node.
+struct RepartitionRequirementStatus {
+    /// The distribution requirement for the node.
+    requirement: Distribution,
+    /// Designates whether round robin partitioning is theoretically beneficial;
+    /// i.e. the operator can actually utilize parallelism.
+    roundrobin_beneficial: bool,
+    /// Designates whether round robin partitioning is beneficial according to
+    /// the statistical information we have on the number of rows.
+    roundrobin_beneficial_stats: bool,
+    /// Designates whether hash partitioning is necessary.
+    hash_necessary: bool,
+}
+
+/// Calculates the `RepartitionRequirementStatus` for each children to generate
+/// consistent and sensible (in terms of performance) distribution requirements.
+/// As an example, a hash join's left (build) child might produce
+///
+/// ```text
+/// RepartitionRequirementStatus {
+///     ..,
+///     hash_necessary: true
+/// }
+/// ```
+///
+/// while its right (probe) child might have very few rows and produce:
+///
+/// ```text
+/// RepartitionRequirementStatus {
+///     ..,
+///     hash_necessary: false
+/// }
+/// ```
+///
+/// These statuses are not consistent as all children should agree on hash
+/// partitioning. This function aligns the statuses to generate consistent
+/// hash partitions for each children. After alignment, the right child's
+/// status would turn into:
+///
+/// ```text
+/// RepartitionRequirementStatus {
+///     ..,
+///     hash_necessary: true
+/// }
+/// ```
+fn get_repartition_requirement_status(
+    plan: &Arc<dyn ExecutionPlan>,
+    batch_size: usize,
+    should_use_estimates: bool,
+) -> Result<Vec<RepartitionRequirementStatus>> {
+    let mut needs_alignment = false;
+    let children = plan.children();
+    let rr_beneficial = plan.benefits_from_input_partitioning();
+    let requirements = plan.required_input_distribution();
+    let mut repartition_status_flags = vec![];
+    for (child, requirement, roundrobin_beneficial) in
+        izip!(children.into_iter(), requirements, rr_beneficial)
+    {
+        // Decide whether adding a round robin is beneficial depending on
+        // the statistical information we have on the number of rows:
+        let roundrobin_beneficial_stats = match child.statistics()?.num_rows {
+            Precision::Exact(n_rows) => n_rows > batch_size,
+            Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size),
+            Precision::Absent => true,
+        };
+        let is_hash = matches!(requirement, Distribution::HashPartitioned(_));
+        // Hash re-partitioning is necessary when the input has more than one
+        // partitions:
+        let multi_partitions = child.output_partitioning().partition_count() > 1;
+        let roundrobin_sensible = roundrobin_beneficial && roundrobin_beneficial_stats;
+        needs_alignment |= is_hash && (multi_partitions || roundrobin_sensible);
+        repartition_status_flags.push((
+            is_hash,
+            RepartitionRequirementStatus {
+                requirement,
+                roundrobin_beneficial,
+                roundrobin_beneficial_stats,
+                hash_necessary: is_hash && multi_partitions,
+            },
+        ));
+    }
+    // Align hash necessary flags for hash partitions to generate consistent
+    // hash partitions at each children:
+    if needs_alignment {
+        // When there is at least one hash requirement that is necessary or
+        // beneficial according to statistics, make all children require hash
+        // repartitioning:
+        for (is_hash, status) in &mut repartition_status_flags {
+            if *is_hash {
+                status.hash_necessary = true;
+            }
+        }
+    }
+    Ok(repartition_status_flags
+        .into_iter()
+        .map(|(_, status)| status)
+        .collect())
+}
+
 /// This function checks whether we need to add additional data exchange
 /// operators to satisfy distribution requirements. Since this function
 /// takes care of such requirements, we should avoid manually adding data
@@ -1050,6 +1150,9 @@ fn ensure_distribution(
     let enable_round_robin = config.optimizer.enable_round_robin_repartition;
     let repartition_file_scans = config.optimizer.repartition_file_scans;
     let batch_size = config.execution.batch_size;
+    let should_use_estimates = config
+        .execution
+        .use_row_number_estimates_to_optimize_partitioning;
     let is_unbounded = dist_context.plan.execution_mode().is_unbounded();
     // Use order preserving variants either of the conditions true
     // - it is desired according to config
@@ -1082,6 +1185,8 @@ fn ensure_distribution(
         }
     };
 
+    let repartition_status_flags =
+        get_repartition_requirement_status(&plan, batch_size, should_use_estimates)?;
     // This loop iterates over all the children to:
     // - Increase parallelism for every child if it is beneficial.
     // - Satisfy the distribution requirements of every child, if it is not
@@ -1089,33 +1194,32 @@ fn ensure_distribution(
     // We store the updated children in `new_children`.
     let children = izip!(
         children.into_iter(),
-        plan.required_input_distribution().iter(),
         plan.required_input_ordering().iter(),
-        plan.benefits_from_input_partitioning(),
-        plan.maintains_input_order()
+        plan.maintains_input_order(),
+        repartition_status_flags.into_iter()
     )
     .map(
-        |(mut child, requirement, required_input_ordering, would_benefit, maintains)| {
-            // Don't need to apply when the returned row count is not greater than batch size
-            let num_rows = child.plan.statistics()?.num_rows;
-            let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) {
-                num_rows
-                    .get_value()
-                    .map(|value| value > &batch_size)
-                    .unwrap() // safe to unwrap since is_exact() is true
-            } else {
-                true
-            };
-
+        |(
+            mut child,
+            required_input_ordering,
+            maintains,
+            RepartitionRequirementStatus {
+                requirement,
+                roundrobin_beneficial,
+                roundrobin_beneficial_stats,
+                hash_necessary,
+            },
+        )| {
             let add_roundrobin = enable_round_robin
                 // Operator benefits from partitioning (e.g. filter):
-                && (would_benefit && repartition_beneficial_stats)
+                && roundrobin_beneficial
+                && roundrobin_beneficial_stats
                 // Unless partitioning increases the partition count, it is not beneficial:
                 && child.plan.output_partitioning().partition_count() < target_partitions;
 
             // When `repartition_file_scans` is set, attempt to increase
             // parallelism at the source.
-            if repartition_file_scans && repartition_beneficial_stats {
+            if repartition_file_scans && roundrobin_beneficial_stats {
                 if let Some(new_child) =
                     child.plan.repartitioned(target_partitions, config)?
                 {
@@ -1124,7 +1228,7 @@ fn ensure_distribution(
             }
 
             // Satisfy the distribution requirement if it is unmet.
-            match requirement {
+            match &requirement {
                 Distribution::SinglePartition => {
                     child = add_spm_on_top(child);
                 }
@@ -1134,7 +1238,11 @@ fn ensure_distribution(
                         // to increase parallelism.
                         child = add_roundrobin_on_top(child, target_partitions)?;
                     }
-                    child = add_hash_on_top(child, exprs.to_vec(), target_partitions)?;
+                    // When inserting hash is necessary to satisy hash requirement, insert hash repartition.
+                    if hash_necessary {
+                        child =
+                            add_hash_on_top(child, exprs.to_vec(), target_partitions)?;
+                    }
                 }
                 Distribution::UnspecifiedDistribution => {
                     if add_roundrobin {
@@ -1731,6 +1839,8 @@ pub(crate) mod tests {
             config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE;
             config.optimizer.prefer_existing_sort = $PREFER_EXISTING_SORT;
             config.optimizer.prefer_existing_union = $PREFER_EXISTING_UNION;
+            // Use a small batch size, to trigger RoundRobin in tests
+            config.execution.batch_size = 1;
 
             // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade
             //       because they were written prior to the separation of `BasicEnforcement` into
diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
index faf8d01a97fd..76df99b82c53 100644
--- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
@@ -61,7 +61,8 @@ use crate::physical_plan::{Distribution, ExecutionPlan, InputOrderMode};
 
 use datafusion_common::plan_err;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement};
+use datafusion_physical_expr::{Partitioning, PhysicalSortExpr, PhysicalSortRequirement};
+use datafusion_physical_plan::limit::LocalLimitExec;
 use datafusion_physical_plan::repartition::RepartitionExec;
 use datafusion_physical_plan::sorts::partial_sort::PartialSortExec;
 use datafusion_physical_plan::ExecutionPlanProperties;
@@ -189,7 +190,7 @@ impl PhysicalOptimizerRule for EnforceSorting {
         // missed by the bottom-up traversal:
         let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan);
         assign_initial_requirements(&mut sort_pushdown);
-        let adjusted = sort_pushdown.transform_down(pushdown_sorts)?.data;
+        let adjusted = pushdown_sorts(sort_pushdown)?;
 
         adjusted
             .plan
@@ -281,7 +282,7 @@ fn parallelize_sorts(
         // executors don't require single partition), then we can replace
         // the `CoalescePartitionsExec` + `SortExec` cascade with a `SortExec`
         // + `SortPreservingMergeExec` cascade to parallelize sorting.
-        requirements = remove_corresponding_coalesce_in_sub_plan(requirements)?;
+        requirements = remove_bottleneck_in_subplan(requirements)?;
         // We also need to remove the self node since `remove_corresponding_coalesce_in_sub_plan`
         // deals with the children and their children and so on.
         requirements = requirements.children.swap_remove(0);
@@ -299,7 +300,7 @@ fn parallelize_sorts(
     } else if is_coalesce_partitions(&requirements.plan) {
         // There is an unnecessary `CoalescePartitionsExec` in the plan.
         // This will handle the recursive `CoalescePartitionsExec` plans.
-        requirements = remove_corresponding_coalesce_in_sub_plan(requirements)?;
+        requirements = remove_bottleneck_in_subplan(requirements)?;
         // For the removal of self node which is also a `CoalescePartitionsExec`.
         requirements = requirements.children.swap_remove(0);
 
@@ -402,7 +403,12 @@ fn analyze_immediate_sort_removal(
             } else {
                 // Remove the sort:
                 node.children = node.children.swap_remove(0).children;
-                sort_input.clone()
+                if let Some(fetch) = sort_exec.fetch() {
+                    // If the sort has a fetch, we need to add a limit:
+                    Arc::new(LocalLimitExec::new(sort_input.clone(), fetch))
+                } else {
+                    sort_input.clone()
+                }
             };
             for child in node.children.iter_mut() {
                 child.data = false;
@@ -484,8 +490,11 @@ fn adjust_window_sort_removal(
     Ok(window_tree)
 }
 
-/// Removes the [`CoalescePartitionsExec`] from the plan in `node`.
-fn remove_corresponding_coalesce_in_sub_plan(
+/// Removes parallelization-reducing, avoidable [`CoalescePartitionsExec`]s from
+/// the plan in `node`. After the removal of such `CoalescePartitionsExec`s from
+/// the plan, some of the remaining `RepartitionExec`s might become unnecessary.
+/// Removes such `RepartitionExec`s from the plan as well.
+fn remove_bottleneck_in_subplan(
     mut requirements: PlanWithCorrespondingCoalescePartitions,
 ) -> Result<PlanWithCorrespondingCoalescePartitions> {
     let plan = &requirements.plan;
@@ -506,15 +515,27 @@ fn remove_corresponding_coalesce_in_sub_plan(
             .into_iter()
             .map(|node| {
                 if node.data {
-                    remove_corresponding_coalesce_in_sub_plan(node)
+                    remove_bottleneck_in_subplan(node)
                 } else {
                     Ok(node)
                 }
             })
             .collect::<Result<_>>()?;
     }
-
-    requirements.update_plan_from_children()
+    let mut new_reqs = requirements.update_plan_from_children()?;
+    if let Some(repartition) = new_reqs.plan.as_any().downcast_ref::<RepartitionExec>() {
+        let input_partitioning = repartition.input().output_partitioning();
+        // We can remove this repartitioning operator if it is now a no-op:
+        let mut can_remove = input_partitioning.eq(repartition.partitioning());
+        // We can also remove it if we ended up with an ineffective RR:
+        if let Partitioning::RoundRobinBatch(n_out) = repartition.partitioning() {
+            can_remove |= *n_out == input_partitioning.partition_count();
+        }
+        if can_remove {
+            new_reqs = new_reqs.children.swap_remove(0)
+        }
+    }
+    Ok(new_reqs)
 }
 
 /// Updates child to remove the unnecessary sort below it.
@@ -540,8 +561,11 @@ fn remove_corresponding_sort_from_sub_plan(
     requires_single_partition: bool,
 ) -> Result<PlanWithCorrespondingSort> {
     // A `SortExec` is always at the bottom of the tree.
-    if is_sort(&node.plan) {
-        node = node.children.swap_remove(0);
+    if let Some(sort_exec) = node.plan.as_any().downcast_ref::<SortExec>() {
+        // Do not remove sorts with fetch:
+        if sort_exec.fetch().is_none() {
+            node = node.children.swap_remove(0);
+        }
     } else {
         let mut any_connection = false;
         let required_dist = node.plan.required_input_distribution();
@@ -632,8 +656,9 @@ mod tests {
     use datafusion_common::Result;
     use datafusion_expr::JoinType;
     use datafusion_physical_expr::expressions::{col, Column, NotExpr};
-
     use datafusion_physical_optimizer::PhysicalOptimizerRule;
+    use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
+
     use rstest::rstest;
 
     fn create_test_schema() -> Result<SchemaRef> {
@@ -716,10 +741,7 @@ mod tests {
 
                 let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan);
                 assign_initial_requirements(&mut sort_pushdown);
-                sort_pushdown
-                    .transform_down(pushdown_sorts)
-                    .data()
-                    .and_then(check_integrity)?;
+                check_integrity(pushdown_sorts(sort_pushdown)?)?;
                 // TODO: End state payloads will be checked here.
             }
 
@@ -1049,6 +1071,130 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_remove_unnecessary_sort6() -> Result<()> {
+        let schema = create_test_schema()?;
+        let source = memory_exec(&schema);
+        let input = Arc::new(
+            SortExec::new(vec![sort_expr("non_nullable_col", &schema)], source)
+                .with_fetch(Some(2)),
+        );
+        let physical_plan = sort_exec(
+            vec![
+                sort_expr("non_nullable_col", &schema),
+                sort_expr("nullable_col", &schema),
+            ],
+            input,
+        );
+
+        let expected_input = [
+            "SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]",
+            "  SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]",
+            "    MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        let expected_optimized = [
+            "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]",
+            "  MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan, true);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_remove_unnecessary_sort7() -> Result<()> {
+        let schema = create_test_schema()?;
+        let source = memory_exec(&schema);
+        let input = Arc::new(SortExec::new(
+            vec![
+                sort_expr("non_nullable_col", &schema),
+                sort_expr("nullable_col", &schema),
+            ],
+            source,
+        ));
+
+        let physical_plan = Arc::new(
+            SortExec::new(vec![sort_expr("non_nullable_col", &schema)], input)
+                .with_fetch(Some(2)),
+        ) as Arc<dyn ExecutionPlan>;
+
+        let expected_input = [
+            "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]",
+            "  SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]",
+            "    MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        let expected_optimized = [
+            "LocalLimitExec: fetch=2",
+            "  SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]",
+            "    MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan, true);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_remove_unnecessary_sort8() -> Result<()> {
+        let schema = create_test_schema()?;
+        let source = memory_exec(&schema);
+        let input = Arc::new(SortExec::new(
+            vec![sort_expr("non_nullable_col", &schema)],
+            source,
+        ));
+        let limit = Arc::new(LocalLimitExec::new(input, 2));
+        let physical_plan = sort_exec(
+            vec![
+                sort_expr("non_nullable_col", &schema),
+                sort_expr("nullable_col", &schema),
+            ],
+            limit,
+        );
+
+        let expected_input = [
+            "SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]",
+            "  LocalLimitExec: fetch=2",
+            "    SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]",
+            "      MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        let expected_optimized = [
+            "LocalLimitExec: fetch=2",
+            "  SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]",
+            "    MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan, true);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_do_not_pushdown_through_limit() -> Result<()> {
+        let schema = create_test_schema()?;
+        let source = memory_exec(&schema);
+        // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source);
+        let input = Arc::new(SortExec::new(
+            vec![sort_expr("non_nullable_col", &schema)],
+            source,
+        ));
+        let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _;
+        let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], limit);
+
+        let expected_input = [
+            "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]",
+            "  GlobalLimitExec: skip=0, fetch=5",
+            "    SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]",
+            "      MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        let expected_optimized = [
+            "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]",
+            "  GlobalLimitExec: skip=0, fetch=5",
+            "    SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]",
+            "      MemoryExec: partitions=1, partition_sizes=[0]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan, true);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_remove_unnecessary_spm1() -> Result<()> {
         let schema = create_test_schema()?;
diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs
index 3577e109b069..17d63a06a6f8 100644
--- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs
@@ -15,12 +15,11 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::fmt::Debug;
 use std::sync::Arc;
 
-use super::utils::add_sort_above;
-use crate::physical_optimizer::utils::{
-    is_limit, is_sort_preserving_merge, is_union, is_window,
-};
+use super::utils::{add_sort_above, is_sort};
+use crate::physical_optimizer::utils::{is_sort_preserving_merge, is_union, is_window};
 use crate::physical_plan::filter::FilterExec;
 use crate::physical_plan::joins::utils::calculate_join_output_ordering;
 use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec};
@@ -30,7 +29,7 @@ use crate::physical_plan::sorts::sort::SortExec;
 use crate::physical_plan::tree_node::PlanContext;
 use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
 
-use datafusion_common::tree_node::Transformed;
+use datafusion_common::tree_node::{ConcreteTreeNode, Transformed, TreeNodeRecursion};
 use datafusion_common::{plan_err, JoinSide, Result};
 use datafusion_expr::JoinType;
 use datafusion_physical_expr::expressions::Column;
@@ -41,38 +40,63 @@ use datafusion_physical_expr::{
 /// This is a "data class" we use within the [`EnforceSorting`] rule to push
 /// down [`SortExec`] in the plan. In some cases, we can reduce the total
 /// computational cost by pushing down `SortExec`s through some executors. The
-/// object carries the parent required ordering as its data.
+/// object carries the parent required ordering and the (optional) `fetch` value
+/// of the parent node as its data.
 ///
 /// [`EnforceSorting`]: crate::physical_optimizer::enforce_sorting::EnforceSorting
-pub type SortPushDown = PlanContext<Option<Vec<PhysicalSortRequirement>>>;
+#[derive(Default, Clone)]
+pub struct ParentRequirements {
+    ordering_requirement: Option<Vec<PhysicalSortRequirement>>,
+    fetch: Option<usize>,
+}
+
+pub type SortPushDown = PlanContext<ParentRequirements>;
 
 /// Assigns the ordering requirement of the root node to the its children.
 pub fn assign_initial_requirements(node: &mut SortPushDown) {
     let reqs = node.plan.required_input_ordering();
     for (child, requirement) in node.children.iter_mut().zip(reqs) {
-        child.data = requirement;
+        child.data = ParentRequirements {
+            ordering_requirement: requirement,
+            fetch: None,
+        };
+    }
+}
+
+pub(crate) fn pushdown_sorts(sort_pushdown: SortPushDown) -> Result<SortPushDown> {
+    let mut new_node = pushdown_sorts_helper(sort_pushdown)?;
+    while new_node.tnr == TreeNodeRecursion::Stop {
+        new_node = pushdown_sorts_helper(new_node.data)?;
     }
+    let (new_node, children) = new_node.data.take_children();
+    let new_children = children
+        .into_iter()
+        .map(pushdown_sorts)
+        .collect::<Result<_>>()?;
+    new_node.with_new_children(new_children)
 }
 
-pub(crate) fn pushdown_sorts(
+fn pushdown_sorts_helper(
     mut requirements: SortPushDown,
 ) -> Result<Transformed<SortPushDown>> {
     let plan = &requirements.plan;
-    let parent_reqs = requirements.data.as_deref().unwrap_or(&[]);
+    let parent_reqs = requirements
+        .data
+        .ordering_requirement
+        .as_deref()
+        .unwrap_or(&[]);
     let satisfy_parent = plan
         .equivalence_properties()
         .ordering_satisfy_requirement(parent_reqs);
-
-    if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
+    if is_sort(plan) {
         let required_ordering = plan
             .output_ordering()
             .map(PhysicalSortRequirement::from_sort_exprs)
             .unwrap_or_default();
-
         if !satisfy_parent {
             // Make sure this `SortExec` satisfies parent requirements:
-            let fetch = sort_exec.fetch();
-            let sort_reqs = requirements.data.unwrap_or_default();
+            let sort_reqs = requirements.data.ordering_requirement.unwrap_or_default();
+            let fetch = requirements.data.fetch;
             requirements = requirements.children.swap_remove(0);
             requirements = add_sort_above(requirements, sort_reqs, fetch);
         };
@@ -82,12 +106,24 @@ pub(crate) fn pushdown_sorts(
         if let Some(adjusted) =
             pushdown_requirement_to_children(&child.plan, &required_ordering)?
         {
+            let fetch = child.plan.fetch();
             for (grand_child, order) in child.children.iter_mut().zip(adjusted) {
-                grand_child.data = order;
+                grand_child.data = ParentRequirements {
+                    ordering_requirement: order,
+                    fetch,
+                };
             }
             // Can push down requirements
-            child.data = None;
-            return Ok(Transformed::yes(child));
+            child.data = ParentRequirements {
+                ordering_requirement: Some(required_ordering),
+                fetch,
+            };
+
+            return Ok(Transformed {
+                data: child,
+                transformed: true,
+                tnr: TreeNodeRecursion::Stop,
+            });
         } else {
             // Can not push down requirements
             requirements.children = vec![child];
@@ -97,19 +133,24 @@ pub(crate) fn pushdown_sorts(
         // For non-sort operators, immediately return if parent requirements are met:
         let reqs = plan.required_input_ordering();
         for (child, order) in requirements.children.iter_mut().zip(reqs) {
-            child.data = order;
+            child.data.ordering_requirement = order;
         }
     } else if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_reqs)? {
         // Can not satisfy the parent requirements, check whether we can push
         // requirements down:
         for (child, order) in requirements.children.iter_mut().zip(adjusted) {
-            child.data = order;
+            child.data.ordering_requirement = order;
         }
-        requirements.data = None;
+        requirements.data.ordering_requirement = None;
     } else {
         // Can not push down requirements, add new `SortExec`:
-        let sort_reqs = requirements.data.clone().unwrap_or_default();
-        requirements = add_sort_above(requirements, sort_reqs, None);
+        let sort_reqs = requirements
+            .data
+            .ordering_requirement
+            .clone()
+            .unwrap_or_default();
+        let fetch = requirements.data.fetch;
+        requirements = add_sort_above(requirements, sort_reqs, fetch);
         assign_initial_requirements(&mut requirements);
     }
     Ok(Transformed::yes(requirements))
@@ -132,6 +173,43 @@ fn pushdown_requirement_to_children(
             RequirementsCompatibility::Compatible(adjusted) => Ok(Some(vec![adjusted])),
             RequirementsCompatibility::NonCompatible => Ok(None),
         }
+    } else if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
+        let sort_req = PhysicalSortRequirement::from_sort_exprs(
+            sort_exec.properties().output_ordering().unwrap_or(&[]),
+        );
+        if sort_exec
+            .properties()
+            .eq_properties
+            .requirements_compatible(parent_required, &sort_req)
+        {
+            debug_assert!(!parent_required.is_empty());
+            Ok(Some(vec![Some(parent_required.to_vec())]))
+        } else {
+            Ok(None)
+        }
+    } else if plan.fetch().is_some()
+        && plan.supports_limit_pushdown()
+        && plan
+            .maintains_input_order()
+            .iter()
+            .all(|maintain| *maintain)
+    {
+        let output_req = PhysicalSortRequirement::from_sort_exprs(
+            plan.properties().output_ordering().unwrap_or(&[]),
+        );
+        // Push down through operator with fetch when:
+        // - requirement is aligned with output ordering
+        // - it preserves ordering during execution
+        if plan
+            .properties()
+            .eq_properties
+            .requirements_compatible(parent_required, &output_req)
+        {
+            let req = (!parent_required.is_empty()).then(|| parent_required.to_vec());
+            Ok(Some(vec![req]))
+        } else {
+            Ok(None)
+        }
     } else if is_union(plan) {
         // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and
         // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec
@@ -174,7 +252,6 @@ fn pushdown_requirement_to_children(
         || plan.as_any().is::<FilterExec>()
         // TODO: Add support for Projection push down
         || plan.as_any().is::<ProjectionExec>()
-        || is_limit(plan)
         || plan.as_any().is::<HashJoinExec>()
         || pushdown_would_violate_requirements(parent_required, plan.as_ref())
     {
diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs
index de42a55ad350..13c10c535c08 100644
--- a/datafusion/physical-plan/src/coalesce_batches.rs
+++ b/datafusion/physical-plan/src/coalesce_batches.rs
@@ -212,7 +212,7 @@ impl ExecutionPlan for CoalesceBatchesExec {
     }
 
     fn statistics(&self) -> Result<Statistics> {
-        self.input.statistics()
+        Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1)
     }
 
     fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
@@ -224,6 +224,10 @@ impl ExecutionPlan for CoalesceBatchesExec {
             cache: self.cache.clone(),
         }))
     }
+
+    fn fetch(&self) -> Option<usize> {
+        self.fetch
+    }
 }
 
 /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details.
diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs
index 5a3fc086c1f8..a6a15e46860c 100644
--- a/datafusion/physical-plan/src/execution_plan.rs
+++ b/datafusion/physical-plan/src/execution_plan.rs
@@ -399,6 +399,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
     fn with_fetch(&self, _limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
         None
     }
+
+    /// Gets the fetch count for the operator, `None` means there is no fetch.
+    fn fetch(&self) -> Option<usize> {
+        None
+    }
 }
 
 /// Extension trait provides an easy API to fetch various properties of
diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs
index 69bcfefcd476..fa9108057cfe 100644
--- a/datafusion/physical-plan/src/filter.rs
+++ b/datafusion/physical-plan/src/filter.rs
@@ -126,7 +126,7 @@ impl FilterExec {
         let schema = input.schema();
         if !check_support(predicate, &schema) {
             let selectivity = default_selectivity as f64 / 100.0;
-            let mut stats = input_stats.into_inexact();
+            let mut stats = input_stats.to_inexact();
             stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity);
             stats.total_byte_size = stats
                 .total_byte_size
diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs
index f3dad6afabde..360e942226d2 100644
--- a/datafusion/physical-plan/src/limit.rs
+++ b/datafusion/physical-plan/src/limit.rs
@@ -31,7 +31,6 @@ use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
 
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
-use datafusion_common::stats::Precision;
 use datafusion_common::{internal_err, Result};
 use datafusion_execution::TaskContext;
 
@@ -185,80 +184,21 @@ impl ExecutionPlan for GlobalLimitExec {
     }
 
     fn statistics(&self) -> Result<Statistics> {
-        let input_stats = self.input.statistics()?;
-        let skip = self.skip;
-        let col_stats = Statistics::unknown_column(&self.schema());
-        let fetch = self.fetch.unwrap_or(usize::MAX);
-
-        let mut fetched_row_number_stats = Statistics {
-            num_rows: Precision::Exact(fetch),
-            column_statistics: col_stats.clone(),
-            total_byte_size: Precision::Absent,
-        };
+        Statistics::with_fetch(
+            self.input.statistics()?,
+            self.schema(),
+            self.fetch,
+            self.skip,
+            1,
+        )
+    }
 
-        let stats = match input_stats {
-            Statistics {
-                num_rows: Precision::Exact(nr),
-                ..
-            }
-            | Statistics {
-                num_rows: Precision::Inexact(nr),
-                ..
-            } => {
-                if nr <= skip {
-                    // if all input data will be skipped, return 0
-                    let mut skip_all_rows_stats = Statistics {
-                        num_rows: Precision::Exact(0),
-                        column_statistics: col_stats,
-                        total_byte_size: Precision::Absent,
-                    };
-                    if !input_stats.num_rows.is_exact().unwrap_or(false) {
-                        // The input stats are inexact, so the output stats must be too.
-                        skip_all_rows_stats = skip_all_rows_stats.into_inexact();
-                    }
-                    skip_all_rows_stats
-                } else if nr <= fetch && self.skip == 0 {
-                    // if the input does not reach the "fetch" globally, and "skip" is zero
-                    // (meaning the input and output are identical), return input stats.
-                    // Can input_stats still be used, but adjusted, in the "skip != 0" case?
-                    input_stats
-                } else if nr - skip <= fetch {
-                    // after "skip" input rows are skipped, the remaining rows are less than or equal to the
-                    // "fetch" values, so `num_rows` must equal the remaining rows
-                    let remaining_rows: usize = nr - skip;
-                    let mut skip_some_rows_stats = Statistics {
-                        num_rows: Precision::Exact(remaining_rows),
-                        column_statistics: col_stats,
-                        total_byte_size: Precision::Absent,
-                    };
-                    if !input_stats.num_rows.is_exact().unwrap_or(false) {
-                        // The input stats are inexact, so the output stats must be too.
-                        skip_some_rows_stats = skip_some_rows_stats.into_inexact();
-                    }
-                    skip_some_rows_stats
-                } else {
-                    // if the input is greater than "fetch+skip", the num_rows will be the "fetch",
-                    // but we won't be able to predict the other statistics
-                    if !input_stats.num_rows.is_exact().unwrap_or(false)
-                        || self.fetch.is_none()
-                    {
-                        // If the input stats are inexact, the output stats must be too.
-                        // If the fetch value is `usize::MAX` because no LIMIT was specified,
-                        // we also can't represent it as an exact value.
-                        fetched_row_number_stats =
-                            fetched_row_number_stats.into_inexact();
-                    }
-                    fetched_row_number_stats
-                }
-            }
-            _ => {
-                // The result output `num_rows` will always be no greater than the limit number.
-                // Should `num_rows` be marked as `Absent` here when the `fetch` value is large,
-                // as the actual `num_rows` may be far away from the `fetch` value?
-                fetched_row_number_stats.into_inexact()
-            }
-        };
-        Ok(stats)
+    fn fetch(&self) -> Option<usize> {
+        self.fetch
+    }
+
+    fn supports_limit_pushdown(&self) -> bool {
+        true
     }
 }
 
@@ -380,53 +320,21 @@ impl ExecutionPlan for LocalLimitExec {
     }
 
     fn statistics(&self) -> Result<Statistics> {
-        let input_stats = self.input.statistics()?;
-        let col_stats = Statistics::unknown_column(&self.schema());
-        let stats = match input_stats {
-            // if the input does not reach the limit globally, return input stats
-            Statistics {
-                num_rows: Precision::Exact(nr),
-                ..
-            }
-            | Statistics {
-                num_rows: Precision::Inexact(nr),
-                ..
-            } if nr <= self.fetch => input_stats,
-            // if the input is greater than the limit, the num_row will be greater
-            // than the limit because the partitions will be limited separately
-            // the statistic
-            Statistics {
-                num_rows: Precision::Exact(nr),
-                ..
-            } if nr > self.fetch => Statistics {
-                num_rows: Precision::Exact(self.fetch),
-                // this is not actually exact, but will be when GlobalLimit is applied
-                // TODO stats: find a more explicit way to vehiculate this information
-                column_statistics: col_stats,
-                total_byte_size: Precision::Absent,
-            },
-            Statistics {
-                num_rows: Precision::Inexact(nr),
-                ..
-            } if nr > self.fetch => Statistics {
-                num_rows: Precision::Inexact(self.fetch),
-                // this is not actually exact, but will be when GlobalLimit is applied
-                // TODO stats: find a more explicit way to vehiculate this information
-                column_statistics: col_stats,
-                total_byte_size: Precision::Absent,
-            },
-            _ => Statistics {
-                // the result output row number will always be no greater than the limit number
-                num_rows: Precision::Inexact(
-                    self.fetch
-                        * self.properties().output_partitioning().partition_count(),
-                ),
-
-                column_statistics: col_stats,
-                total_byte_size: Precision::Absent,
-            },
-        };
-        Ok(stats)
+        Statistics::with_fetch(
+            self.input.statistics()?,
+            self.schema(),
+            Some(self.fetch),
+            0,
+            1,
+        )
+    }
+
+    fn fetch(&self) -> Option<usize> {
+        Some(self.fetch)
+    }
+
+    fn supports_limit_pushdown(&self) -> bool {
+        true
     }
 }
 
@@ -565,6 +473,7 @@ mod tests {
     use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
     use arrow_array::RecordBatchOptions;
     use arrow_schema::Schema;
+    use datafusion_common::stats::Precision;
     use datafusion_physical_expr::expressions::col;
     use datafusion_physical_expr::PhysicalExpr;
 
@@ -794,7 +703,7 @@ mod tests {
 
         let row_count =
             row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
-        assert_eq!(row_count, Precision::Inexact(0));
+        assert_eq!(row_count, Precision::Exact(0));
 
         let row_count =
             row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs
index eb77d7716848..e7e1c5481f80 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -921,7 +921,7 @@ impl ExecutionPlan for SortExec {
     }
 
     fn statistics(&self) -> Result<Statistics> {
-        self.input.statistics()
+        Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1)
     }
 
     fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
@@ -934,6 +934,10 @@ impl ExecutionPlan for SortExec {
             cache: self.cache.clone(),
         }))
     }
+
+    fn fetch(&self) -> Option<usize> {
+        self.fetch
+    }
 }
 
 #[cfg(test)]
diff --git a/datafusion/sqllogictest/test_files/count_star_rule.slt b/datafusion/sqllogictest/test_files/count_star_rule.slt
index 99d358ad17f0..b552e6053769 100644
--- a/datafusion/sqllogictest/test_files/count_star_rule.slt
+++ b/datafusion/sqllogictest/test_files/count_star_rule.slt
@@ -86,10 +86,8 @@ logical_plan
 physical_plan
 01)ProjectionExec: expr=[a@0 as a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a]
 02)--WindowAggExec: wdw=[count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]
-03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]
-04)------CoalesceBatchesExec: target_batch_size=8192
-05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
-06)----------MemoryExec: partitions=1, partition_sizes=[1]
+03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]
+04)------MemoryExec: partitions=1, partition_sizes=[1]
 
 query II
 SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1 ORDER BY a;
diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt
index bd096f61fb5d..a4a886c75a77 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -2020,15 +2020,12 @@ physical_plan
 05)--------CoalesceBatchesExec: target_batch_size=8192
 06)----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4
 07)------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[last_value(r.col1) ORDER BY [r.col0 ASC NULLS LAST]]
-08)--------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1]
-09)----------------CoalesceBatchesExec: target_batch_size=8192
-10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)]
-11)--------------------CoalesceBatchesExec: target_batch_size=8192
-12)----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1
-13)------------------------MemoryExec: partitions=1, partition_sizes=[3]
-14)--------------------CoalesceBatchesExec: target_batch_size=8192
-15)----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1
-16)------------------------MemoryExec: partitions=1, partition_sizes=[3]
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+09)----------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1]
+10)------------------CoalesceBatchesExec: target_batch_size=8192
+11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)]
+12)----------------------MemoryExec: partitions=1, partition_sizes=[3]
+13)----------------------MemoryExec: partitions=1, partition_sizes=[3]
 
 # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by
 # a,b,c column. Column a has cardinality 2, column b has cardinality 4.
diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt
index 0cbbbf3c608c..ff793a72fd8a 100644
--- a/datafusion/sqllogictest/test_files/information_schema.slt
+++ b/datafusion/sqllogictest/test_files/information_schema.slt
@@ -215,6 +215,7 @@ datafusion.execution.sort_spill_reservation_bytes 10485760
 datafusion.execution.split_file_groups_by_statistics false
 datafusion.execution.target_partitions 7
 datafusion.execution.time_zone +00:00
+datafusion.execution.use_row_number_estimates_to_optimize_partitioning false
 datafusion.explain.logical_plan_only false
 datafusion.explain.physical_plan_only false
 datafusion.explain.show_schema false
@@ -304,6 +305,7 @@ datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserve
 datafusion.execution.split_file_groups_by_statistics false Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental
 datafusion.execution.target_partitions 7 Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system
 datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour
+datafusion.execution.use_row_number_estimates_to_optimize_partitioning false Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future.
 datafusion.explain.logical_plan_only false When set to true, the explain statement will only print logical plans
 datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans
 datafusion.explain.show_schema false When set to true, the explain statement will print schema information
diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt
index dc3d444854c4..4cdd40ac8c34 100644
--- a/datafusion/sqllogictest/test_files/limit.slt
+++ b/datafusion/sqllogictest/test_files/limit.slt
@@ -390,8 +390,8 @@ SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3;
 statement ok
 set datafusion.explain.show_sizes = false;
 
-# verify that there are multiple partitions in the input (i.e. MemoryExec says
-# there are 4 partitions) so that this tests multi-partition limit.
+# verify that there are multiple partitions in the input so that this tests
+# multi-partition limit.
 query TT
 EXPLAIN SELECT DISTINCT i FROM t1000;
 ----
@@ -402,8 +402,9 @@ physical_plan
 01)AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[]
 02)--CoalesceBatchesExec: target_batch_size=8192
 03)----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4
-04)------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[]
-05)--------MemoryExec: partitions=4
+04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+05)--------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[]
+06)----------MemoryExec: partitions=1
 
 statement ok
 set datafusion.explain.show_sizes = true;
diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt
index 3382d5ddabda..569602166b38 100644
--- a/datafusion/sqllogictest/test_files/order.slt
+++ b/datafusion/sqllogictest/test_files/order.slt
@@ -1148,3 +1148,44 @@ SELECT (SELECT c from ordered_table ORDER BY c LIMIT 1) UNION ALL (SELECT 23 as
 ----
 0
 23
+
+statement ok
+set datafusion.execution.use_row_number_estimates_to_optimize_partitioning = true;
+
+# Do not increase the number of partitions after fetch one, as this will be unnecessary.
+query TT
+EXPLAIN SELECT a + b as sum1 FROM (SELECT a, b
+  FROM ordered_table
+  ORDER BY a ASC LIMIT 1
+);
+----
+logical_plan
+01)Projection: ordered_table.a + ordered_table.b AS sum1
+02)--Limit: skip=0, fetch=1
+03)----Sort: ordered_table.a ASC NULLS LAST, fetch=1
+04)------TableScan: ordered_table projection=[a, b]
+physical_plan
+01)ProjectionExec: expr=[a@0 + b@1 as sum1]
+02)--SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]
+03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true
+
+statement ok
+set datafusion.execution.use_row_number_estimates_to_optimize_partitioning = false;
+
+# Here, we have multiple partitions after fetch one, since the row count estimate is not exact.
+query TT
+EXPLAIN SELECT a + b as sum1 FROM (SELECT a, b
+  FROM ordered_table
+  ORDER BY a ASC LIMIT 1
+);
+----
+logical_plan
+01)Projection: ordered_table.a + ordered_table.b AS sum1
+02)--Limit: skip=0, fetch=1
+03)----Sort: ordered_table.a ASC NULLS LAST, fetch=1
+04)------TableScan: ordered_table projection=[a, b]
+physical_plan
+01)ProjectionExec: expr=[a@0 + b@1 as sum1]
+02)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
+03)----SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]
+04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 6e7b50973cde..ea3088e69674 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -38,14 +38,10 @@ logical_plan
 03)--TableScan: t2 projection=[a, b]
 physical_plan
 01)SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 <= CAST(b@0 AS Int64)
-02)--SortExec: expr=[a@0 ASC], preserve_partitioning=[true]
-03)----CoalesceBatchesExec: target_batch_size=8192
-04)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
-05)--------MemoryExec: partitions=1, partition_sizes=[1]
-06)--SortExec: expr=[a@0 ASC], preserve_partitioning=[true]
-07)----CoalesceBatchesExec: target_batch_size=8192
-08)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
-09)--------MemoryExec: partitions=1, partition_sizes=[1]
+02)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false]
+03)----MemoryExec: partitions=1, partition_sizes=[1]
+04)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false]
+05)----MemoryExec: partitions=1, partition_sizes=[1]
 
 # inner join with join filter
 query TITI rowsort
diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt
index aedbee35400c..476ebe7ebebe 100644
--- a/datafusion/sqllogictest/test_files/union.slt
+++ b/datafusion/sqllogictest/test_files/union.slt
@@ -563,15 +563,12 @@ logical_plan
 physical_plan
 01)UnionExec
 02)--ProjectionExec: expr=[Int64(1)@0 as a]
-03)----AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1)], aggr=[], ordering_mode=Sorted
-04)------CoalesceBatchesExec: target_batch_size=2
-05)--------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1
-06)----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[], ordering_mode=Sorted
-07)------------PlaceholderRowExec
-08)--ProjectionExec: expr=[2 as a]
-09)----PlaceholderRowExec
-10)--ProjectionExec: expr=[3 as a]
-11)----PlaceholderRowExec
+03)----AggregateExec: mode=SinglePartitioned, gby=[1 as Int64(1)], aggr=[], ordering_mode=Sorted
+04)------PlaceholderRowExec
+05)--ProjectionExec: expr=[2 as a]
+06)----PlaceholderRowExec
+07)--ProjectionExec: expr=[3 as a]
+08)----PlaceholderRowExec
 
 # test UNION ALL aliases correctly with aliased subquery
 query TT
@@ -594,15 +591,12 @@ logical_plan
 physical_plan
 01)UnionExec
 02)--ProjectionExec: expr=[count(*)@1 as count, n@0 as n]
-03)----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[count(*)], ordering_mode=Sorted
-04)------CoalesceBatchesExec: target_batch_size=2
-05)--------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1
-06)----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[count(*)], ordering_mode=Sorted
-07)------------ProjectionExec: expr=[5 as n]
-08)--------------PlaceholderRowExec
-09)--ProjectionExec: expr=[1 as count, max(Int64(10))@0 as n]
-10)----AggregateExec: mode=Single, gby=[], aggr=[max(Int64(10))]
-11)------PlaceholderRowExec
+03)----AggregateExec: mode=SinglePartitioned, gby=[n@0 as n], aggr=[count(*)], ordering_mode=Sorted
+04)------ProjectionExec: expr=[5 as n]
+05)--------PlaceholderRowExec
+06)--ProjectionExec: expr=[1 as count, max(Int64(10))@0 as n]
+07)----AggregateExec: mode=Single, gby=[], aggr=[max(Int64(10))]
+08)------PlaceholderRowExec
 
 
 # Test issue: https://github.com/apache/datafusion/issues/11409
diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt
index 4f4b9749c561..dfc882667617 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -1777,17 +1777,17 @@ physical_plan
 02)--AggregateExec: mode=Final, gby=[], aggr=[count(*)]
 03)----CoalescePartitionsExec
 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(*)]
-05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2
-06)----------ProjectionExec: expr=[]
-07)------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[]
-08)--------------CoalesceBatchesExec: target_batch_size=4096
-09)----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2
-10)------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[]
-11)--------------------ProjectionExec: expr=[c1@0 as c1]
-12)----------------------CoalesceBatchesExec: target_batch_size=4096
-13)------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434
-14)--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
-15)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true
+05)--------ProjectionExec: expr=[]
+06)----------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[]
+07)------------CoalesceBatchesExec: target_batch_size=4096
+08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2
+09)----------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[]
+10)------------------ProjectionExec: expr=[c1@0 as c1]
+11)--------------------CoalesceBatchesExec: target_batch_size=4096
+12)----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434
+13)------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
+14)--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true
+
 
 query I
 SELECT count(*) as global_count FROM
diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md
index e0c8391a259a..6f315f539b11 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -91,6 +91,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus
 | datafusion.execution.keep_partition_by_columns                          | false                     | Should DataFusion keep the columns used for partition_by in the output RecordBatches                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    |
 | datafusion.execution.skip_partial_aggregation_probe_ratio_threshold     | 0.8                       | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input                                                                                                                                                                                                                                                                                                                                                                                               |
 | datafusion.execution.skip_partial_aggregation_probe_rows_threshold      | 100000                    | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode                                                                                                                                                                                                                                                                                                                                                                                                                                                     |
+| datafusion.execution.use_row_number_estimates_to_optimize_partitioning  | false                     | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future.                                                                                                                                                                                                                                            |
 | datafusion.optimizer.enable_distinct_aggregation_soft_limit             | true                      | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read.                                                                                                                                                                                                                                                                                                                                                                           |
 | datafusion.optimizer.enable_round_robin_repartition                     | true                      | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores                                                                                                                                                                                                                                                                                                                                                                                                                                                             |
 | datafusion.optimizer.enable_topk_aggregation                            | true                      | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               |