Skip to content

[SPARK-57688][SQL] Add spark.sql.execution.bypassPartialAggregation to skip partial agg#56777

Open
xumingming wants to merge 3 commits into
apache:masterfrom
xumingming:bypass-partial-agg
Open

[SPARK-57688][SQL] Add spark.sql.execution.bypassPartialAggregation to skip partial agg#56777
xumingming wants to merge 3 commits into
apache:masterfrom
xumingming:bypass-partial-agg

Conversation

@xumingming

Copy link
Copy Markdown

What changes were proposed in this pull request?

Adds a new SQL config spark.sql.execution.bypassPartialAggregation (default false). When set to true, planAggregateWithoutDistinct skips the pre-shuffle Partial-mode aggregation and runs a single Complete-mode aggregation after the shuffle instead. This can improve performance when group cardinality is high and the pre-shuffle reduction ratio is low.

The bypass is suppressed when a session_window grouping key is present, since MergingSessionsExec must be inserted in the Partial+Merge+Final path to correctly merge overlapping sessions.

The config has no effect on queries containing DISTINCT aggregate functions, where the partial aggregation phases are required for correctness and are always applied.

Why are the changes needed?

The standard two-phase aggregation plan (Partial → shuffle → Final) assumes that pre-shuffle partial aggregation meaningfully reduces data volume. This assumption breaks down in two scenarios.

Scenario 1: High group cardinality. When group cardinality is high relative to partition size, every input row maps to a distinct key, so the partial aggregation produces one output row per input row and adds CPU and memory overhead with zero shuffle benefit.

SELECT user_id, SUM(amount), COUNT(order_id), AVG(price)
FROM orders
GROUP BY user_id   – high-cardinality key: millions of distinct users

On a table with 500M rows and 200M distinct user_id values, the pre-shuffle HashAggregateExec in Partial mode churns through the full dataset, spills when the hash map overflows, and still emits ~200M rows into the shuffle. The partial phase wastes wall-clock time and memory without reducing shuffle write volume.

Scenario 2: Skewed input data. Even when partial aggregation can reduce data volume on average, skewed input partitions can make it harmful. If one partition contains a disproportionate share of rows for a small number of keys, the partial HashAggregateExec on that partition must hold a large hash map in memory, triggering spills. The skewed partition becomes the bottleneck and dominates wall-clock time — worse than if the data had been shuffled first and aggregated on already-partitioned, evenly distributed data.

SELECT country_code, SUM(revenue)
FROM orders
GROUP BY country_code   – a few dominant countries hold 80% of rows

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Added Unit Test.

Was this patch authored or co-authored using generative AI tooling?

No.

"When false (default), uses a two-phase Partial+Final aggregation across a shuffle. " +
"This setting has no effect on queries containing DISTINCT aggregate functions, where " +
"the partial aggregation phases are required for correctness and are always applied.")
.version("3.3.1")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.version("3.3.1")
.version("4.3.0")

.booleanConf
.createWithDefault(true)

val BYPASS_PARTIAL_AGGREGATION = buildConf("spark.sql.execution.bypassPartialAggregation")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparkConfigBindingPolicySuite requires every new config to declare a policy, please make sure to add withBindingPolicy.

"the partial aggregation phases are required for correctness and are always applied.")
.version("3.3.1")
.booleanConf
.createWithDefault(false)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be internal?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually want make it public, so users could utilize it to optimize performance, how do you think?

s"Expected:\n${expected.mkString("\n")}\nActual:\n${actual.mkString("\n")}")
}
}
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test gap: no test with AQE enabled.

Also, no TypedImperativeAggregate bypass test.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean no test with AQE disabled right?(AQE is enabled by default). I will make all tests run with and without AQE enabled.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan @viirya @ueshin for AggUtils/AQE interaction

@HyukjinKwon HyukjinKwon left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 blocking, 0 non-blocking, 0 nits.
Well-scoped feature (session_window and DISTINCT correctly excluded) and correct for plain-column grouping, but the bypass plan is invalid for expression grouping keys. (uros-b's existing config + test-gap comments are valid and not repeated here.)

Correctness (1)

  • AggUtils.scala:146: Some(groupingAttributes) over the raw child references synthetic grouping attributes the child doesn't output (e.g. GROUP BY v % 10); should be Some(groupingExpressions) — see inline

Verification

Traced the rewrite Partial+shuffle+Final → shuffle+Complete: equivalent for plain-attribute grouping (empty/single/many rows), and session_window (gated by !hasSessionWindow) and DISTINCT (different planner) are excluded. The one non-equivalent input is a non-attribute grouping key: groupingAttributes = groupingExpressions.map(_.toAttribute), and SparkStrategies wraps v % 10 as Alias(v % 10, "k"), so .toAttribute is not in the raw child's output → HashPartitioning over a missing attribute → bind failure at execution. Output schema is preserved (resultExpressions unchanged).

val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
val completeAggregate = createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requiredChildDistributionExpressions = Some(groupingAttributes) is applied over the raw child here, but groupingAttributes = groupingExpressions.map(_.toAttribute). For any grouping key that isn't a plain child attribute — e.g. GROUP BY v % 10SparkStrategies wraps it as Alias(v % 10, "k") (SparkStrategies.scala ~L708), so .toAttribute is a synthetic AttributeReference that is not in the raw child's output. The resulting ClusteredDistributionHashPartitioning(thatAttr) then fails to bind against the child at execution.

The normal Final path (L186) can use Some(groupingAttributes) only because its child is the Partial agg that produces those attributes; the bypass's child is the raw input, so it should distribute on the expressions themselves:

Suggested change
requiredChildDistributionExpressions = Some(groupingAttributes),
requiredChildDistributionExpressions = Some(groupingExpressions),

This is currently untested: the first test groups by (v % 10).as("k") but only checks executedPlan structure (no execution), and the SUM/COUNT/AVG tests group by a materialized plain column k. An executing GROUP BY <expression> test (collect + checkAnswer under the config) would catch it. (Side note: with the fix, the bypass evaluates the grouping expression twice — in the shuffle partitioning and in the Complete agg — vs once in the two-phase path; fine for deterministic keys, worth a thought for nondeterministic ones.)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. "requiredChildDistributionExpressions = Some(groupingAttributes)" is a good catch, I will make the change. Under the hood, even if we don't make the change, current code would produce the right result because the PullOutGroupingExpressions rule extracts the non-attribute grouping expression before the plan enters planAggregateWithoutDistinct. But the change you suggest is indeed great to make the code more readable, reasonable, will make the change.
  2. For the nondeterministic concerns, PullOutNondeterministic pulls the nondeterministic grouping expressions into a upstream Project, so it will not be evaluated multiple times.

…tion to skip pre-shuffle partial agg

Adds a new SQL config spark.sql.execution.bypassPartialAggregation
(default false). When set to true, planAggregateWithoutDistinct skips
the pre-shuffle Partial-mode aggregation and runs a single Complete-mode
aggregation after the shuffle instead. This can improve performance when
group cardinality is high and the pre-shuffle reduction ratio is low.

The bypass is suppressed when a session_window grouping key is present,
since MergingSessionsExec must be inserted in the Partial+Merge+Final
path to correctly merge overlapping sessions.

The config has no effect on queries containing DISTINCT aggregate
functions, where the partial aggregation phases are required for
correctness and are always applied.

@HyukjinKwon HyukjinKwon left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 addressed, 1 remaining, 3 new. (3 new = 0 newly introduced, 3 late catches — my own misses; same commit, no code changed between rounds.)

0 blocking, 2 non-blocking, 2 nits. Well-scoped feature; my prior blocking finding was overstated — corrected below.

Correctness (1)

  • AggUtils.scala:146 (prior thread, remaining): correction — my earlier "GROUP BY v % 10 bind-fails" claim was wrong; PullOutGroupingExpressions pulls that key into a child Project. Some(groupingExpressions) (which you agreed to) still helps for a narrow case — see Verification.

Design / architecture (1)

  • AggUtils.scala:142: bypass also fires for global aggregation (no grouping keys) — all rows shuffle to one partition with no pre-agg, zero benefit — see inline

Nits: 2 minor items (see inline comments).

Verification

Re-traced the Partial+shuffle+Final → shuffle+Complete rewrite: row-equivalent for empty/single/many rows, NULL keys, and duplicates; session_window (!hasSessionWindow) and DISTINCT (separate planner) are excluded. On the distribution key: PullOutGroupingExpressions (Optimizer.scala:341; comment L343-344 "the grouping keys can only be attribute and literal") pulls complex keys like v % 10 into a child Project, so Some(groupingAttributes) binds fine there — my prior round was wrong about that. It differs from Some(groupingExpressions) only for foldable / childless keys (a constant literal, spark_partition_id()), which aren't pulled out and aren't in the raw child's output → bind failure under bypass=true; hence the agreed change is still worth keeping.

// when a session_window grouping key is present so that the normal Partial+Merge+Final path
// runs and MergingSessionsExec is correctly inserted.
val hasSessionWindow = groupingExpressions.exists(_.metadata.contains(SessionWindow.marker))
if (child.conf.bypassPartialAggregation && !hasSessionWindow) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking (perf): this gate also fires for global aggregation (groupingExpressions.isEmpty). There requiredChildDistributionExpressions = Some(groupingAttributes) is Some(Nil)AllTuples, so all raw rows shuffle to a single partition with no pre-aggregation. For a cardinality-1 global agg that's a pure regression with zero upside, and a user who enables this session-wide for high-cardinality grouped queries silently pessimizes any global aggs in the same session. Consider also requiring grouping keys:

Suggested change
if (child.conf.bypassPartialAggregation && !hasSessionWindow) {
if (child.conf.bypassPartialAggregation && groupingExpressions.nonEmpty && !hasSessionWindow) {

// One event for key "b" stands alone.
val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" — same session

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: em-dash (non-ASCII) in a // comment — CLAUDE.md/scalastyle flag non-ASCII in comments.

Suggested change
("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" same session
("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" - same session

val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:39", 2, "a"), // within 10s of the first "a" — same session
("2016-03-27 19:39:56", 3, "a"), // > 10s gap — separate session

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: em-dash (non-ASCII) in a // comment.

Suggested change
("2016-03-27 19:39:56", 3, "a"), // > 10s gap separate session
("2016-03-27 19:39:56", 3, "a"), // > 10s gap - separate session

…or better diagnostics

Switch scalar-aggregate tests (SUM, COUNT, AVG, session_window) to use
checkAnswer instead of raw actual.toSeq == expected.toSeq, providing
better error messages when comparisons fail by pinpointing the
mismatched row and column.

Keep manual zip-and-sort for the collect_list test since checkAnswer
does not sort nested arrays — collect_list output order within groups
is non-deterministic between Partial+Final and Complete aggregation
paths.

Also replace non-ASCII em-dashes with ASCII equivalents (--, -, :) in
test names and comments to satisfy scalastyle.
…tions

Global aggregations (no GROUP BY) always produce a single output row, so
the pre-shuffle partial aggregation achieves the maximum possible
reduction ratio. Bypassing it would shuffle all raw rows to a single
partition with no benefit — strictly worse than Partial+Final.

Extract hasGroupingKeys = groupingExpressions.nonEmpty and add it to the
bypass gate alongside hasSessionWindow, so the bypass only fires when
there are grouping keys to hash-partition on.

Add a test verifying that global aggregations continue to produce
Partial+Final plans even with bypassPartialAggregation=true.
@xumingming xumingming force-pushed the bypass-partial-agg branch from fbcf3c9 to c8a214a Compare June 26, 2026 06:29
@xumingming

Copy link
Copy Markdown
Author

@uros-b @HyukjinKwon Thanks for the review, made the following changes:

  • Fixed the version, SparkConfigBindingPolicySuite of the new config
  • Added test for TypedImperativeAggregate
  • Make all the test go through both AQE enabled/disabled.
  • Only bypass partial agg when it is not global agg.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants