diff --git a/Cargo.lock b/Cargo.lock index 012573deb452d..331683ebbe15c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2090,6 +2090,7 @@ dependencies = [ "parquet", "rand 0.9.4", "tempfile", + "tokio", "url", ] diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index 06c84d8acb493..482c9fff17a1b 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -66,6 +66,7 @@ parking_lot = { workspace = true } parquet = { workspace = true, optional = true } rand = { workspace = true } tempfile = { workspace = true } +tokio = { workspace = true, features = ["time"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 829e313d2381e..b04fd95d410a1 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -20,10 +20,13 @@ use datafusion_common::{Result, internal_datafusion_err}; use std::fmt::Display; +use std::future::Future; use std::hash::{Hash, Hasher}; +use std::pin::Pin; use std::{cmp::Ordering, sync::Arc, sync::atomic}; mod pool; +mod reclaimer; #[cfg(feature = "arrow_buffer_pool")] pub mod arrow; @@ -36,6 +39,7 @@ pub use datafusion_common::{ human_readable_count, human_readable_duration, human_readable_size, units, }; pub use pool::*; +pub use reclaimer::{MemoryReclaimer, reclaimer_state}; /// Tracks and potentially limits memory use across operators during execution. /// @@ -209,6 +213,17 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug + Display { /// On error the `allocation` will not be increased in size fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>; + /// Async variant of [`Self::try_grow`]. Default delegates to the + /// sync version; reclaim-aware pools (e.g. [`TrackConsumersPool`]) + /// override to invoke registered [`MemoryReclaimer`]s on OOM. + fn try_grow_async<'a>( + &'a self, + reservation: &'a MemoryReservation, + additional: usize, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { self.try_grow(reservation, additional) }) + } + /// Return the total amount of memory reserved fn reserved(&self) -> usize; @@ -249,6 +264,9 @@ pub struct MemoryConsumer { name: String, can_spill: bool, id: usize, + /// Reclaimer collected by reclaim-aware pools at register time. Not + /// part of consumer identity (excluded from `Eq`/`Hash`). + reclaimer: Option>, } impl PartialEq for MemoryConsumer { @@ -287,20 +305,39 @@ impl MemoryConsumer { name: name.into(), can_spill: false, id: Self::new_unique_id(), + reclaimer: None, } } - /// Returns a clone of this [`MemoryConsumer`] with a new unique id, - /// which can be registered with a [`MemoryPool`], - /// This new consumer is separate from the original. + /// Clone this [`MemoryConsumer`] with a new unique id. + /// + /// Drops any attached reclaimer: it is bound to the original operator's + /// state and would target the wrong owner under a new id (and bypass + /// the id-keyed requestor-self-skip in `try_grow_async`). pub fn clone_with_new_id(&self) -> Self { Self { name: self.name.clone(), can_spill: self.can_spill, id: Self::new_unique_id(), + reclaimer: None, + } + } + + /// Attach a [`MemoryReclaimer`] and mark this consumer spill-capable. + /// Pools without reclaim support ignore the reclaimer. + pub fn with_reclaimer(self, reclaimer: Arc) -> Self { + Self { + can_spill: true, + reclaimer: Some(reclaimer), + ..self } } + /// Returns the attached [`MemoryReclaimer`], if any. + pub fn reclaimer(&self) -> Option<&Arc> { + self.reclaimer.as_ref() + } + /// Return the unique id of this [`MemoryConsumer`] pub fn id(&self) -> usize { self.id @@ -461,6 +498,17 @@ impl MemoryReservation { Ok(()) } + /// Async variant of [`Self::try_grow`]. On a reclaim-aware pool, + /// triggers registered [`MemoryReclaimer`]s before surfacing OOM. + pub async fn try_grow_async(&self, capacity: usize) -> Result<()> { + self.registration + .pool + .try_grow_async(self, capacity) + .await?; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); + Ok(()) + } + /// Splits off `capacity` bytes from this [`MemoryReservation`] /// into a new [`MemoryReservation`] with the same /// [`MemoryConsumer`]. diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index aac95b9d6a81f..4da8ec48988ad 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -15,19 +15,37 @@ // specific language governing permissions and limitations // under the License. +use crate::memory_pool::reclaimer::reclaimer_state; use crate::memory_pool::{ - MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, human_readable_size, + MemoryConsumer, MemoryLimit, MemoryPool, MemoryReclaimer, MemoryReservation, + human_readable_size, }; use datafusion_common::HashMap; use datafusion_common::{DataFusionError, Result, resources_datafusion_err}; use log::debug; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use std::fmt::{Display, Formatter}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; use std::{ num::NonZeroUsize, - sync::atomic::{AtomicUsize, Ordering}, + sync::atomic::{AtomicU8, AtomicUsize, Ordering}, }; +/// How long [`TrackConsumersPool::try_grow_async`] waits for an +/// in-flight sibling to finish reclaiming before retrying. Kept short +/// so we don't stall the requestor longer than the typical reclaim +/// (mpsc send + spill commit). +const RECLAIM_RETRY_SLEEP: Duration = Duration::from_millis(50); + +/// Maximum number of times [`TrackConsumersPool::try_grow_async`] +/// retries the candidate walk while siblings are still in-flight. +/// Bounds the total wait at `MAX_RECLAIM_RETRIES * RECLAIM_RETRY_SLEEP` +/// so a livelock surfaces as OOM rather than a hang. +const MAX_RECLAIM_RETRIES: usize = 3; + /// A [`MemoryPool`] that enforces no limit #[derive(Debug, Default)] pub struct UnboundedMemoryPool { @@ -324,6 +342,51 @@ struct TrackedConsumer { can_spill: bool, reserved: AtomicUsize, peak: AtomicUsize, + reclaimer: Option>, + /// Tri-state eligibility flag for [`reclaimer`], encoded per + /// [`reclaimer_state`]. The pool flips `AVAILABLE` ↔ `IN_FLIGHT` + /// for dedup; the reclaimer's owner may sticky-set `DISABLED` once + /// it can no longer free memory. Shared `Arc` so the reclaimer + /// side and the pool see the same cell. `None` reclaimer ⇒ flag + /// is unused but still allocated. + reclaimer_state: Arc, +} + +/// RAII guard for the [`IN_FLIGHT`] slot of a [`TrackedConsumer`]'s +/// `reclaimer_state` flag. `Drop` only restores `AVAILABLE` if the +/// state is still `IN_FLIGHT` — leaves a sticky `DISABLED` alone. +/// +/// [`IN_FLIGHT`]: reclaimer_state::IN_FLIGHT +struct ReclaimerStateGuard { + flag: Arc, +} + +impl Drop for ReclaimerStateGuard { + fn drop(&mut self) { + let _ = self.flag.compare_exchange( + reclaimer_state::IN_FLIGHT, + reclaimer_state::AVAILABLE, + Ordering::AcqRel, + Ordering::Relaxed, + ); + } +} + +impl ReclaimerStateGuard { + /// Try to transition the flag from `AVAILABLE` to `IN_FLIGHT`. + /// Fails on contention or on a sticky `DISABLED`. + fn try_acquire(flag: &Arc) -> Option { + flag.compare_exchange( + reclaimer_state::AVAILABLE, + reclaimer_state::IN_FLIGHT, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .ok() + .map(|_| Self { + flag: Arc::clone(flag), + }) + } } impl TrackedConsumer { @@ -339,9 +402,29 @@ impl TrackedConsumer { /// Grows the tracked consumer's reserved size, /// should be called after the pool has successfully performed the grow(). + /// + /// Uses the value `reserved` definitely held immediately after this + /// thread's `fetch_add` as the peak candidate, then bumps `peak` via a + /// monotone-max CAS loop. This avoids the race in the previous + /// `peak.fetch_max(self.reserved())` form, where a concurrent `shrink` + /// between the load of `reserved` and the max-write to `peak` could + /// record a peak below the true high-water mark. fn grow(&self, additional: usize) { - self.reserved.fetch_add(additional, Ordering::Relaxed); - self.peak.fetch_max(self.reserved(), Ordering::Relaxed); + let prev = self.reserved.fetch_add(additional, Ordering::Relaxed); + let new = prev + additional; + + let mut peak = self.peak.load(Ordering::Relaxed); + while peak < new { + match self.peak.compare_exchange_weak( + peak, + new, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => peak = actual, + } + } } /// Reduce the tracked consumer's reserved size, @@ -407,8 +490,19 @@ pub struct TrackConsumersPool { inner: I, /// The amount of consumers to report(ordered top to bottom by reservation size) top: NonZeroUsize, - /// Maps consumer_id --> TrackedConsumer - tracked_consumers: Mutex>, + /// Cap on the number of reclaim candidates considered per + /// [`try_grow_async`] call. Bounds reclaim work when many consumers + /// are registered. Defaults to 4; override with + /// [`Self::with_reclaim_candidate_limit`]. + reclaim_candidate_limit: NonZeroUsize, + /// Maps consumer_id --> TrackedConsumer. + /// + /// Protected by an [`RwLock`] rather than a [`Mutex`]: registration + /// (insert) and unregistration (remove) take the write lock; grow, + /// shrink, try_grow, metrics, and report_top take the read lock and run + /// concurrently. The per-consumer [`AtomicUsize`] fields are mutated + /// under the shared read lock — see [`TrackedConsumer::grow`]. + tracked_consumers: RwLock>, } impl Display for TrackConsumersPool { @@ -464,10 +558,18 @@ impl TrackConsumersPool { Self { inner, top, + reclaim_candidate_limit: NonZeroUsize::new(4).unwrap(), tracked_consumers: Default::default(), } } + /// Override the cap on reclaim candidates considered per + /// [`try_grow_async`] call (default `4`). + pub fn with_reclaim_candidate_limit(mut self, n: NonZeroUsize) -> Self { + self.reclaim_candidate_limit = n; + self + } + /// Returns a reference to the wrapped inner [`MemoryPool`]. pub fn inner(&self) -> &I { &self.inner @@ -476,7 +578,7 @@ impl TrackConsumersPool { /// Returns a snapshot of all currently tracked consumers. pub fn metrics(&self) -> Vec { self.tracked_consumers - .lock() + .read() .values() .map(Into::into) .collect() @@ -486,7 +588,7 @@ impl TrackConsumersPool { pub fn report_top(&self, top: usize) -> String { let mut consumers = self .tracked_consumers - .lock() + .read() .iter() .map(|(consumer_id, tracked_consumer)| { ( @@ -525,7 +627,17 @@ impl MemoryPool for TrackConsumersPool { fn register(&self, consumer: &MemoryConsumer) { self.inner.register(consumer); - let mut guard = self.tracked_consumers.lock(); + let reclaimer = consumer.reclaimer().cloned(); + // Reuse the reclaimer's own flag when it provides one — that + // way the reclaimer side can sticky-set `DISABLED` and the + // pool sees it on the next filter pass. Otherwise allocate a + // fresh `AVAILABLE` flag for in-flight dedup only. + let state = reclaimer + .as_ref() + .and_then(|r| r.reclaimer_state()) + .unwrap_or_else(|| Arc::new(AtomicU8::new(reclaimer_state::AVAILABLE))); + + let mut guard = self.tracked_consumers.write(); let existing = guard.insert( consumer.id(), TrackedConsumer { @@ -533,6 +645,8 @@ impl MemoryPool for TrackConsumersPool { can_spill: consumer.can_spill(), reserved: Default::default(), peak: Default::default(), + reclaimer, + reclaimer_state: state, }, ); @@ -544,27 +658,29 @@ impl MemoryPool for TrackConsumersPool { fn unregister(&self, consumer: &MemoryConsumer) { self.inner.unregister(consumer); - self.tracked_consumers.lock().remove(&consumer.id()); + self.tracked_consumers.write().remove(&consumer.id()); } fn grow(&self, reservation: &MemoryReservation, additional: usize) { self.inner.grow(reservation, additional); - self.tracked_consumers - .lock() - .entry(reservation.consumer().id()) - .and_modify(|tracked_consumer| { - tracked_consumer.grow(additional); - }); + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.grow(additional); + } } fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { self.inner.shrink(reservation, shrink); - self.tracked_consumers - .lock() - .entry(reservation.consumer().id()) - .and_modify(|tracked_consumer| { - tracked_consumer.shrink(shrink); - }); + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.shrink(shrink); + } } fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { @@ -584,15 +700,170 @@ impl MemoryPool for TrackConsumersPool { _ => e, })?; - self.tracked_consumers - .lock() - .entry(reservation.consumer().id()) - .and_modify(|tracked_consumer| { - tracked_consumer.grow(additional); - }); + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.grow(additional); + } Ok(()) } + fn try_grow_async<'a>( + &'a self, + reservation: &'a MemoryReservation, + additional: usize, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + // Fast path. + let initial_err = match self.try_grow(reservation, additional) { + Ok(()) => return Ok(()), + Err(e) => e, + }; + + // Mark the requestor as IN_FLIGHT for the duration of this + // walk. Without this, a victim's reclaim handler that + // recursively triggers `pool.reclaim` (e.g. a merge stream + // started inside an `ExternalSorter` spill) could pick the + // requestor as its own victim, send it a reclaim oneshot, + // and deadlock — the requestor is blocked here at + // `reclaimer.reclaim().await` and can't drain its own + // reclaim channel. Sticky-disabled or already-in-flight + // requestors aren't acquired; the walk proceeds without + // protection (the candidate filter still rejects the + // requestor by id). + let requestor_id = reservation.consumer().id(); + let _self_guard = self + .tracked_consumers + .read() + .get(&requestor_id) + .and_then(|tc| ReclaimerStateGuard::try_acquire(&tc.reclaimer_state)); + + let mut retries: usize = 0; + loop { + // Snapshot reclaimers. Only consumers strictly larger than + // the requestor are eligible: smaller-or-equal siblings would + // free less than the requestor itself can, so the requestor + // should self-spill instead. This rule also breaks the + // mutual-reclaim cycle (A targets B while B targets A) — at + // most one side of any pair can hold strictly more memory, + // so the other side has no candidates and surfaces an error + // for the caller's self-spill fallback. Filter out anyone + // whose `reclaimer_state` flag is not `AVAILABLE` (in-flight or + // sticky-disabled). Also count IN_FLIGHT siblings so we know + // whether to wait briefly for them to finish before giving up. + // Drop the read guard before awaiting any reclaim. + let requestor_reserved = { + let guard = self.tracked_consumers.read(); + guard + .get(&requestor_id) + .map(|tc| tc.reserved()) + .unwrap_or(0) + }; + let mut in_flight_seen: usize = 0; + let mut candidates: Vec<( + usize, + Arc, + Arc, + )> = { + let guard = self.tracked_consumers.read(); + guard + .iter() + .filter_map(|(cid, tc)| { + if *cid == requestor_id { + return None; + } + // Track in-flight siblings (any size) so we can + // decide whether a retry has any chance of helping. + let state = tc.reclaimer_state.load(Ordering::Acquire); + if state == reclaimer_state::IN_FLIGHT { + in_flight_seen += 1; + } + let reclaimer = tc.reclaimer.as_ref()?; + if tc.reserved() <= requestor_reserved { + return None; + } + if state != reclaimer_state::AVAILABLE { + return None; + } + Some(( + tc.reserved(), + Arc::clone(reclaimer), + Arc::clone(&tc.reclaimer_state), + )) + }) + .collect() + }; + // Order: priority desc, then reservation size desc. + candidates.sort_by(|(lr, l, _), (rr, r, _)| { + r.priority().cmp(&l.priority()).then_with(|| rr.cmp(lr)) + }); + // Cap reclaim work — only consider the top-ranked candidates. + candidates.truncate(self.reclaim_candidate_limit.get()); + + // For each candidate: try to claim its in-flight slot + // (skip on contention or sticky-disabled so we work on a + // different victim rather than serializing behind a + // sibling's reclaim); re-check `try_grow` before reclaiming + // in case a sibling already freed enough; reclaim; retry + // `try_grow`. The retry path goes through `self.try_grow`, + // which already updates the tracked consumer's atomic + // reservation — no manual accounting needed here. + for (_, reclaimer, flag) in candidates { + let _g = match ReclaimerStateGuard::try_acquire(&flag) { + Some(g) => g, + None => continue, + }; + if self.try_grow(reservation, additional).is_ok() { + return Ok(()); + } + if let Err(e) = reclaimer.reclaim(additional).await { + debug!("memory reclaimer returned error: {e}"); + continue; + } + if self.try_grow(reservation, additional).is_ok() { + return Ok(()); + } + } + + // Walk produced nothing usable. If other consumers are + // currently reclaiming for someone else, their freed bytes + // may land in the pool shortly — wait briefly and retry + // before falling through to OOM. Bounded so we don't stall + // forever on a livelock. + if in_flight_seen > 0 && retries < MAX_RECLAIM_RETRIES { + retries += 1; + tokio::time::sleep(RECLAIM_RETRY_SLEEP).await; + // Quick fast-path retry: an in-flight sibling may have + // freed bytes during the sleep. + if self.try_grow(reservation, additional).is_ok() { + return Ok(()); + } + continue; + } + break; + } + + // Fall through to the inner pool's own reclaim path, if any. + // The default impl just re-runs `inner.try_grow`, which + // bypasses `TrackConsumersPool::try_grow`, so the + // consumer-side update is still required. + self.inner + .try_grow_async(reservation, additional) + .await + .map_err(|_| initial_err)?; + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.grow(additional); + } + Ok(()) + }) + } + fn reserved(&self) -> usize { self.inner.reserved() } @@ -1046,4 +1317,126 @@ mod tests { "TrackConsumersPool Display" ); } + + /// N threads each call `grow(STEP)` then `shrink(STEP)` once on the same + /// consumer. Final `reserved == 0`. Peak hit at least once and at most + /// `THREADS * STEP` — validates that `fetch_add` on `reserved` is correct + /// under concurrent readers of the `RwLock`-protected map. + #[test] + fn test_tracked_consumer_concurrent_grow() { + const THREADS: usize = 16; + const STEP: usize = 7; + + let tracked = Arc::new(TrackConsumersPool::new( + UnboundedMemoryPool::default(), + NonZeroUsize::new(5).unwrap(), + )); + let tracked_clone = Arc::clone(&tracked); + let pool: Arc = tracked_clone; + let r = Arc::new(MemoryConsumer::new("c").register(&pool)); + + std::thread::scope(|s| { + for _ in 0..THREADS { + let r = Arc::clone(&r); + s.spawn(move || { + let local = r.new_empty(); + local.grow(STEP); + local.shrink(STEP); + }); + } + }); + + let metrics = tracked.metrics(); + let entry = metrics.iter().find(|m| m.name == "c").unwrap(); + assert_eq!(entry.reserved, 0); + assert!(entry.peak >= STEP); + assert!(entry.peak <= THREADS * STEP); + } + + /// N threads run interleaved `grow`/`shrink` pairs on the same consumer. + /// Final `reserved` must be 0; `peak` must be at least `STEP` (any grow + /// records its own bump) and at most `THREADS * STEP`. Validates the + /// monotone-max CAS on `peak`, fixing today's `fetch_max(self.reserved())` + /// race where an intervening shrink could drop `reserved` below the value + /// used to bump `peak`. + #[test] + fn test_tracked_consumer_concurrent_peak_monotone() { + const THREADS: usize = 16; + const ITERS: usize = 10_000; + const STEP: usize = 3; + + let tracked = Arc::new(TrackConsumersPool::new( + UnboundedMemoryPool::default(), + NonZeroUsize::new(5).unwrap(), + )); + let tracked_clone = Arc::clone(&tracked); + let pool: Arc = tracked_clone; + let r = Arc::new(MemoryConsumer::new("c").register(&pool)); + + std::thread::scope(|s| { + for _ in 0..THREADS { + let r = Arc::clone(&r); + s.spawn(move || { + let local = r.new_empty(); + for _ in 0..ITERS { + local.grow(STEP); + local.shrink(STEP); + } + }); + } + }); + + let entry = tracked + .metrics() + .into_iter() + .find(|m| m.name == "c") + .unwrap(); + assert_eq!(entry.reserved, 0, "all grows undone by shrinks"); + assert!(entry.peak >= STEP); + assert!(entry.peak <= THREADS * STEP); + } + + /// One thread loops register/unregister, another loops grow/shrink on a + /// stable consumer. Verifies no panics or deadlocks across the `RwLock` + /// boundary, and that the stable consumer's accounting is preserved + /// when a writer briefly takes the exclusive lock. + #[test] + fn test_tracked_consumers_pool_register_grow_concurrent() { + const ITERS: usize = 1_000; + + let tracked = Arc::new(TrackConsumersPool::new( + UnboundedMemoryPool::default(), + NonZeroUsize::new(5).unwrap(), + )); + let tracked_clone = Arc::clone(&tracked); + let pool: Arc = tracked_clone; + + let r = Arc::new(MemoryConsumer::new("stable").register(&pool)); + + std::thread::scope(|s| { + let pool_w = Arc::clone(&pool); + s.spawn(move || { + for i in 0..ITERS { + let _churn = + MemoryConsumer::new(format!("churn-{i}")).register(&pool_w); + } + }); + + let r_inner = Arc::clone(&r); + s.spawn(move || { + let local = r_inner.new_empty(); + for _ in 0..ITERS { + local.grow(5); + local.shrink(5); + } + }); + }); + + let metrics = tracked.metrics(); + let stable = metrics.iter().find(|m| m.name == "stable").unwrap(); + assert_eq!(stable.reserved, 0); + assert!(stable.peak >= 5); + assert!(metrics.iter().all(|m| !m.name.starts_with("churn-"))); + drop(r); + } } diff --git a/datafusion/execution/src/memory_pool/reclaimer.rs b/datafusion/execution/src/memory_pool/reclaimer.rs new file mode 100644 index 0000000000000..135853bff5237 --- /dev/null +++ b/datafusion/execution/src/memory_pool/reclaimer.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Operator hook used by a [`MemoryPool`] to free memory before an +//! allocation fails. +//! +//! [`MemoryPool`]: super::MemoryPool + +use datafusion_common::Result; +use std::fmt::Debug; +use std::sync::Arc; +use std::sync::atomic::AtomicU8; + +/// Encoded values stored in the [`reclaimer_state`] tri-state. +/// +/// [`reclaimer_state`]: MemoryReclaimer::reclaimer_state +pub mod reclaimer_state { + /// Reclaimer is idle and may be selected as a victim. + pub const AVAILABLE: u8 = 0; + /// A pool task is currently driving `reclaim` on this reclaimer. + pub const IN_FLIGHT: u8 = 1; + /// Reclaimer has been retired (e.g. operator entered a phase where + /// it can no longer free memory). Sticky — never returns to + /// `AVAILABLE`. + pub const DISABLED: u8 = 2; +} + +/// Hook attached to a [`MemoryConsumer`] via +/// [`MemoryConsumer::with_reclaimer`]. On +/// [`MemoryPool::try_grow_async`] failure the pool walks registered +/// reclaimers in descending [`Self::priority`] and asks each to free bytes. +/// +/// Implementations MUST: +/// +/// - Spill data **before** shrinking the reservation, so reported bytes +/// are recoverable. +/// - Release bytes via [`MemoryReservation::shrink`] / +/// [`MemoryReservation::free`] and return at most `target`. +/// - Not call `try_grow*` on the pool — risks reentrancy/deadlock. +/// - Not capture `Arc` / `Arc` +/// (creates a cycle that blocks `unregister`); use channels or `Weak`. +/// +/// [`MemoryConsumer`]: super::MemoryConsumer +/// [`MemoryConsumer::with_reclaimer`]: super::MemoryConsumer::with_reclaimer +/// [`MemoryPool::try_grow_async`]: super::MemoryPool::try_grow_async +/// [`MemoryReservation::shrink`]: super::MemoryReservation::shrink +/// [`MemoryReservation::free`]: super::MemoryReservation::free +#[async_trait::async_trait] +pub trait MemoryReclaimer: Send + Sync + Debug { + /// Upper bound on bytes this reclaimer can free. `None` = unknown. + fn reclaimable_bytes(&self) -> Option { + None + } + + /// Free up to `target` bytes; return the amount actually released. + /// See trait-level contract. + async fn reclaim(&self, target: usize) -> Result; + + /// Higher priorities reclaim first. Negative = last resort. + fn priority(&self) -> i32 { + 0 + } + + /// Optional shared tri-state flag controlling whether the pool + /// currently considers this reclaimer eligible. Values are defined + /// in [`reclaimer_state`]. Returning `Some(arc)` lets the + /// reclaimer's owner flip itself to `DISABLED` once it can no + /// longer free memory (e.g., on entering a merge phase), which + /// the pool observes immediately. Returning `None` lets the pool + /// allocate its own private flag — used only for in-flight dedup. + fn reclaimer_state(&self) -> Option> { + None + } +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 6c02af8dec6d3..677c3ae99be37 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -65,7 +65,9 @@ use datafusion_common::{ unwrap_or_internal_err, }; use datafusion_execution::TaskContext; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::memory_pool::{ + MemoryConsumer, MemoryReclaimer, MemoryReservation, reclaimer_state, +}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr::PhysicalExpr; @@ -74,6 +76,34 @@ use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; +/// Reclaimer for an [`ExternalSorter`] partition. Hands a oneshot off to +/// the partition's stream loop (the sorter's sole owner), which spills and +/// replies with the freed byte count. +#[derive(Debug)] +struct ExternalSorterReclaimer { + tx: tokio::sync::mpsc::Sender>, + /// Shared with the pool's `TrackedConsumer` entry. Stream loop + /// flips it to `DISABLED` on merge entry so the pool stops + /// targeting this consumer. + reclaimer_state: Arc, +} + +#[async_trait::async_trait] +impl MemoryReclaimer for ExternalSorterReclaimer { + async fn reclaim(&self, _target: usize) -> Result { + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + // Stream loop terminated, or response dropped: report 0. + if self.tx.send(resp_tx).await.is_err() { + return Ok(0); + } + Ok(resp_rx.await.unwrap_or(0)) + } + + fn reclaimer_state(&self) -> Option> { + Some(Arc::clone(&self.reclaimer_state)) + } +} + struct ExternalSorterMetrics { /// metrics baseline: BaselineMetrics, @@ -279,11 +309,16 @@ impl ExternalSorter { spill_compression: SpillCompression, metrics: &ExecutionPlanMetricsSet, runtime: Arc, + // Reclaimer attached to this partition's `MemoryConsumer`. + reclaimer: Option>, ) -> Result { let metrics = ExternalSorterMetrics::new(metrics, partition_id); - let reservation = MemoryConsumer::new(format!("ExternalSorter[{partition_id}]")) - .with_can_spill(true) - .register(&runtime.memory_pool); + let mut consumer = MemoryConsumer::new(format!("ExternalSorter[{partition_id}]")) + .with_can_spill(true); + if let Some(r) = reclaimer { + consumer = consumer.with_reclaimer(r); + } + let reservation = consumer.register(&runtime.memory_pool); let merge_reservation = MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]")) @@ -492,6 +527,10 @@ impl ExternalSorter { while let Some(batch) = sorted_stream.next().await { let batch = batch?; let sorted_size = get_reserved_bytes_for_record_batch(&batch)?; + // Sync `try_grow`, not `try_grow_async`: we are already in the + // spill path (freeing memory). A recursive reclaim here can + // close a cycle between two sorters that are each waiting on + // the other's spill to complete. if self.reservation.try_grow(sorted_size).is_err() { // Although the reservation is not enough, the batch is // already in memory, so it's okay to combine it with previously @@ -736,14 +775,16 @@ impl ExternalSorter { ) -> Result<()> { let size = get_reserved_bytes_for_record_batch(input)?; - match self.reservation.try_grow(size) { + match self.reservation.try_grow_async(size).await { Ok(_) => Ok(()), Err(e) => { if self.in_mem_batches.is_empty() { return Err(Self::err_with_oom_context(e)); } - // Spill and try again. + // Sibling reclaim was already attempted by `try_grow_async` + // (which skips this consumer). Spill our own buffer, retry + // sync — siblings won't free more on a second pass. self.sort_and_spill_in_mem_batches().await?; self.reservation .try_grow(size) @@ -1246,6 +1287,18 @@ impl ExecutionPlan for SortExec { ))) } (false, None) => { + // Spill-request channel; drained by the stream loop below. + let (reclaim_tx, mut reclaim_rx) = + tokio::sync::mpsc::channel::>(4); + let state = Arc::new(std::sync::atomic::AtomicU8::new( + reclaimer_state::AVAILABLE, + )); + let reclaimer: Arc = + Arc::new(ExternalSorterReclaimer { + tx: reclaim_tx, + reclaimer_state: Arc::clone(&state), + }); + let mut sorter = ExternalSorter::new( partition, input.schema(), @@ -1256,14 +1309,61 @@ impl ExecutionPlan for SortExec { context.session_config().spill_compression(), &self.metrics_set, context.runtime_env(), + Some(reclaimer), )?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; + // State machine: spill or insert, never both. The + // freed-byte reply is sent only after the spill + // completes, so the pool sees recoverable bytes. + // `biased` ensures spill wins over insert under + // pressure. + // + // Cancellation: selecting reclaim drops the in-flight + // `input.next()`. Safe for cancellation-safe inputs + // (channel receivers, e.g. RepartitionExec); other + // inputs could drop a batch here. + loop { + tokio::select! { + biased; + Some(resp_tx) = reclaim_rx.recv() => { + // A reclaim can be dequeued just after a + // prior spill drained `in_mem_batches` + // (sibling sent during the spill's awaits; + // pool's zero-byte filter can transiently + // miss us via split reservations). Nothing + // local to free — reply 0 and keep going. + if sorter.in_mem_batches.is_empty() { + let _ = resp_tx.send(0); + continue; + } + let before = sorter.used(); + sorter.sort_and_spill_in_mem_batches().await?; + let after = sorter.used(); + let _ = resp_tx + .send(before.saturating_sub(after)); + } + next = input.next() => match next { + Some(batch) => { + sorter.insert_batch(batch?).await?; + } + None => break, + } + } } + // Sticky-disable so concurrent `try_grow_async` + // callers stop targeting this consumer once we + // enter the merge phase. Set before dropping + // the receiver to close any window where the + // pool would observe `AVAILABLE` after the + // channel is gone (and hence get `Ok(0)` from a + // wasted `reclaim`). + state.store( + reclaimer_state::DISABLED, + std::sync::atomic::Ordering::Release, + ); + drop(reclaim_rx); sorter.sort().await }) .try_flatten(), @@ -2766,6 +2866,7 @@ mod tests { SpillCompression::Uncompressed, &metrics_set, Arc::clone(&runtime), + None, )?; // Insert enough data to force spilling.