From b40666d9a96bee13c1834c9efe0450a19c2553e8 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 8 Apr 2026 17:57:50 +0530 Subject: [PATCH 1/7] fix(rule): preserve hidden distance for sort Keep _distance in an inner projection when ORDER BY uses a vector\ndistance expression that is not part of the final select list.\n\nThis fixes split-provider execution for queries like SELECT id ORDER\nBY l2_distance(vector, ARRAY[...]) LIMIT k while preserving the final\noutput schema. Add an execution test for the direct ORDER BY shape to\ncover the production case. --- src/rule.rs | 88 ++++++++++++++++++++++++++++++++++++---------- tests/execution.rs | 11 ++++++ 2 files changed, 81 insertions(+), 18 deletions(-) diff --git a/src/rule.rs b/src/rule.rs index e0b8a7c..e8eca08 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -18,9 +18,10 @@ // // Replacement: // -// Sort(fetch=k) ← kept (sort order) -// Projection([col(a), col(b), col("_distance").alias("dist")]) -// USearchNode ← executes ANN +// Projection([final output cols]) +// Sort(fetch=k) +// Projection([final output cols + optional hidden _distance]) +// USearchNode use std::collections::HashMap; use std::sync::Arc; @@ -151,25 +152,33 @@ impl USearchRule { node: Arc::new(node) as Arc, })); - // Build Projection over USearchNode matching the original output schema. + // Build the final user-visible projection over USearchNode output. let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance"); - let new_proj_exprs = if proj_exprs_slice.is_empty() { + let final_proj_exprs = if proj_exprs_slice.is_empty() { passthrough_projection(&vsn_df_schema, &table_ref) } else { remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) }; - let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?; - - // Keep the Sort node so DataFusion handles ordering by _distance / dist. - // USearch returns results in arbitrary (internal) order when the underlying - // data is fetched from the TableProvider. - Some(LogicalPlan::Sort( - datafusion::logical_expr::logical_plan::Sort { - expr: sort.expr.clone(), - input: Arc::new(LogicalPlan::Projection(new_proj)), - fetch: sort.fetch, - }, - )) + let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref()); + let needs_hidden_distance = remapped_sort_exprs.iter().any(|e| { + matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance") + }) && !projection_exposes_name(&final_proj_exprs, "_distance"); + + let mut sort_input_exprs = final_proj_exprs.clone(); + if needs_hidden_distance { + sort_input_exprs.push(col("_distance")); + } + + let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?; + let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort { + expr: remapped_sort_exprs, + input: Arc::new(LogicalPlan::Projection(sort_input)), + fetch: sort.fetch, + }); + + let outer_proj_exprs = build_outer_projection(&final_proj_exprs); + let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?; + Some(LogicalPlan::Projection(outer_proj)) } } @@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo } fn is_distance_expr(expr: &Expr) -> bool { - matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name())) + let inner = match expr { + Expr::Alias(a) => a.expr.as_ref(), + other => other, + }; + matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name())) } fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec)> { @@ -322,6 +335,45 @@ fn remap_projections( .collect() } +fn remap_sort_exprs( + sort_exprs: &[datafusion::logical_expr::SortExpr], + dist_alias_name: Option<&str>, +) -> Vec { + sort_exprs + .iter() + .map(|sort_expr| { + let remapped_expr = match &sort_expr.expr { + Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()), + expr if is_distance_expr(expr) => col("_distance"), + other => other.clone(), + }; + datafusion::logical_expr::SortExpr { + expr: remapped_expr, + asc: sort_expr.asc, + nulls_first: sort_expr.nulls_first, + } + }) + .collect() +} + +fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { + exprs.iter().any(|expr| match expr { + Expr::Alias(a) => a.name == name, + Expr::Column(c) => c.name == name, + _ => false, + }) +} + +fn build_outer_projection(exprs: &[Expr]) -> Vec { + exprs.iter() + .filter_map(|expr| match expr { + Expr::Alias(a) => Some(col(a.name.as_str())), + Expr::Column(c) => Some(Expr::Column(c.clone())), + _ => None, + }) + .collect() +} + /// Build a passthrough Projection for SELECT * queries (no original Projection node). /// Projects only the original table columns (not `_distance`) so the output schema /// matches the original Sort schema. The Sort re-evaluates the distance UDF expression diff --git a/tests/execution.rs b/tests/execution.rs index 5fa7c33..38724b9 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -530,6 +530,17 @@ async fn exec_split_provider_select_specific_columns() { assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); } +/// SELECT specific columns without projecting the distance expression. +/// This is the production shape behind `vector_distance(...)`. +#[tokio::test] +async fn exec_split_provider_order_by_udf_direct() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = format!("SELECT id FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"); + let ids = collect_ids(&ctx, &sql).await; + assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}"); + assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); +} + /// SELECT * with distance UDF — should fall back to UDF brute-force /// (since vector column is not in lookup provider schema). #[tokio::test] From da9fb6db5e0f46bd5c59491043bf62c71e643e12 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 8 Apr 2026 18:01:43 +0530 Subject: [PATCH 2/7] style(rule): format hidden distance rewrite --- src/rule.rs | 9 +++++---- tests/execution.rs | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/rule.rs b/src/rule.rs index e8eca08..eaa54dd 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -160,9 +160,9 @@ impl USearchRule { remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) }; let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref()); - let needs_hidden_distance = remapped_sort_exprs.iter().any(|e| { - matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance") - }) && !projection_exposes_name(&final_proj_exprs, "_distance"); + let needs_hidden_distance = remapped_sort_exprs.iter().any( + |e| matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance"), + ) && !projection_exposes_name(&final_proj_exprs, "_distance"); let mut sort_input_exprs = final_proj_exprs.clone(); if needs_hidden_distance { @@ -365,7 +365,8 @@ fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { } fn build_outer_projection(exprs: &[Expr]) -> Vec { - exprs.iter() + exprs + .iter() .filter_map(|expr| match expr { Expr::Alias(a) => Some(col(a.name.as_str())), Expr::Column(c) => Some(Expr::Column(c.clone())), diff --git a/tests/execution.rs b/tests/execution.rs index 38724b9..92bd933 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -531,7 +531,8 @@ async fn exec_split_provider_select_specific_columns() { } /// SELECT specific columns without projecting the distance expression. -/// This is the production shape behind `vector_distance(...)`. +/// This matches the split-provider direct ORDER BY shape used by callers that +/// rewrite higher-level search helpers into the low-level distance UDF. #[tokio::test] async fn exec_split_provider_order_by_udf_direct() { let ctx = make_split_provider_ctx("items::vector").await; From 8e53743e73c54fc4f7606463c27e1de18d00e4c0 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Wed, 8 Apr 2026 18:16:18 +0530 Subject: [PATCH 3/7] test(rule): cover computed sort projections --- src/rule.rs | 8 ++--- tests/execution.rs | 82 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/src/rule.rs b/src/rule.rs index eaa54dd..7896516 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -367,10 +367,10 @@ fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool { fn build_outer_projection(exprs: &[Expr]) -> Vec { exprs .iter() - .filter_map(|expr| match expr { - Expr::Alias(a) => Some(col(a.name.as_str())), - Expr::Column(c) => Some(Expr::Column(c.clone())), - _ => None, + .map(|expr| match expr { + Expr::Alias(a) => col(a.name.as_str()), + Expr::Column(c) => Expr::Column(c.clone()), + other => col(other.schema_name().to_string()), }) .collect() } diff --git a/tests/execution.rs b/tests/execution.rs index 92bd933..8fc47fb 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -19,7 +19,9 @@ use std::sync::Arc; use arrow_array::builder::{FixedSizeListBuilder, Float32Builder}; -use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array}; +use arrow_array::{ + FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, UInt64Array, +}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; @@ -152,6 +154,60 @@ async fn collect_ids(ctx: &SessionContext, sql: &str) -> Vec { ids } +/// Collect a named integer column from a query result. +async fn collect_i64_column(ctx: &SessionContext, sql: &str, column_name: &str) -> Vec { + let df = ctx + .sql(sql) + .await + .unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}")); + let batches = df + .collect() + .await + .unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}")); + + let mut values: Vec = vec![]; + for batch in &batches { + let col_idx = batch + .schema() + .index_of(column_name) + .unwrap_or_else(|e| panic!("no '{column_name}' column in result: {e}\nSQL: {sql}")); + let column = batch.column(col_idx); + if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values().iter().map(|v| *v as i64)); + } else if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values()); + } else { + panic!("column '{column_name}' not Int64/UInt64\nSQL: {sql}"); + } + } + values +} + +/// Collect the first integer column from a query result. +async fn collect_first_i64_column(ctx: &SessionContext, sql: &str) -> Vec { + let df = ctx + .sql(sql) + .await + .unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}")); + let batches = df + .collect() + .await + .unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}")); + + let mut values: Vec = vec![]; + for batch in &batches { + let column = batch.column(0); + if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values().iter().map(|v| *v as i64)); + } else if let Some(arr) = column.as_any().downcast_ref::() { + values.extend(arr.values()); + } else { + panic!("first result column not Int64/UInt64\nSQL: {sql}"); + } + } + values +} + const Q: &str = "ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]"; // ═══════════════════════════════════════════════════════════════════════════════ @@ -542,6 +598,30 @@ async fn exec_split_provider_order_by_udf_direct() { assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); } +/// Direct ORDER BY UDF with an aliased computed projection must preserve the +/// computed output through the rewrite. +#[tokio::test] +async fn exec_split_provider_order_by_udf_with_computed_alias() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = format!( + "SELECT CAST(id + 1 AS BIGINT) AS id_plus FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2" + ); + let values = collect_i64_column(&ctx, &sql, "id_plus").await; + assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}"); +} + +/// Direct ORDER BY UDF with an unaliased computed projection relies on the +/// outer projection rebuilding by schema name rather than by raw expression. +#[tokio::test] +async fn exec_split_provider_order_by_udf_with_computed_expr() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = format!( + "SELECT CAST(id + 1 AS BIGINT) FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2" + ); + let values = collect_first_i64_column(&ctx, &sql).await; + assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}"); +} + /// SELECT * with distance UDF — should fall back to UDF brute-force /// (since vector column is not in lookup provider schema). #[tokio::test] From b75d92c27aac2424935aee196d8fdbec00c37901 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Thu, 16 Apr 2026 16:06:50 +0530 Subject: [PATCH 4/7] refactor: make usearch_search, attach_distances, provider_key_col_idx pub(crate) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These helpers are needed by the new vector_search_vector UDTF to reuse the same HNSW search → fetch → attach pattern as the ORDER BY path. --- src/planner.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/planner.rs b/src/planner.rs index b863e53..05b4dfa 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -616,7 +616,7 @@ async fn adaptive_filtered_execute( /// Call `index.search` with the native scalar type appropriate for the column. /// Converts the usearch error into a `DataFusionError::Execution`. -fn usearch_search( +pub(crate) fn usearch_search( index: &usearch::Index, query_f64: &[f64], k: usize, @@ -797,7 +797,7 @@ fn compute_raw_distance_f64(v: &[f64], q: &[f64], dist_type: &DistanceType) -> f /// Extract the distance from a single row of a vector column. /// /// Index of the key column in the lookup provider schema. -fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result { +pub(crate) fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result { registered .lookup_provider .schema() @@ -813,7 +813,7 @@ fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result // ── Distance attachment ─────────────────────────────────────────────────────── /// Append a `_distance: Float32` column to each batch. -fn attach_distances( +pub(crate) fn attach_distances( batches: Vec, key_col_idx: usize, key_to_dist: &HashMap, From e7344eb2752119cb18e6ce89240188c954021ae7 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Thu, 16 Apr 2026 16:14:38 +0530 Subject: [PATCH 5/7] feat: rename vector_usearch to vector_search_vector, return full rows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the old vector_usearch UDTF that returned only (key, _distance) with vector_search_vector that returns all table columns plus _distance. New signature: vector_search_vector('conn.schema.table', 'column', ARRAY[...], k) The UDTF reuses usearch_search, attach_distances, and provider_key_col_idx from the planner module to follow the same HNSW search → fetch_by_keys → attach_distances pattern as the ORDER BY execution path. --- src/lib.rs | 11 ++- src/udf.rs | 3 +- src/udtf.rs | 212 ++++++++++++++++++++++++++++++---------------------- 3 files changed, 130 insertions(+), 96 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d287a77..702e10a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,7 +80,7 @@ pub use registry::{ }; pub use rule::USearchRule; pub use udf::{cosine_distance_udf, l2_distance_udf, negative_dot_product_udf}; -pub use udtf::USearchUDTF; +pub use udtf::VectorSearchVectorUDTF; #[cfg(feature = "parquet-provider")] pub use parquet_provider::ParquetLookupProvider; @@ -99,7 +99,8 @@ use datafusion::prelude::SessionContext; /// - `l2_distance(col, query)` — squared Euclidean distance (L2sq) /// - `cosine_distance(col, query)` — cosine distance /// - `negative_dot_product(col, query)` — negated inner product -/// - `vector_usearch(table, query, k)` — explicit ANN table function +/// - `vector_search_vector('conn.schema.table', 'column', ARRAY[...], k)` +/// — explicit ANN table function returning full rows + `_distance` /// (cache-only for async-backed resolvers; does not trigger async loads) /// - [`USearchRule`] — optimizer rewrite rule /// @@ -110,10 +111,8 @@ pub fn register_all(ctx: &SessionContext, registry: Arc ctx.register_udf(ScalarUDF::new_from_impl(cosine_distance_udf())); ctx.register_udf(ScalarUDF::new_from_impl(negative_dot_product_udf())); ctx.register_udtf( - "vector_usearch", - // `vector_usearch()` is synchronous and therefore cache-only for - // async-backed resolvers. - Arc::new(USearchUDTF::new(registry.clone())), + "vector_search_vector", + Arc::new(VectorSearchVectorUDTF::new(registry.clone())), ); ctx.add_optimizer_rule(Arc::new(USearchRule::new(registry))); Ok(()) diff --git a/src/udf.rs b/src/udf.rs index 32652fe..b714498 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -3,8 +3,7 @@ // Each takes (vector_col: FixedSizeList, query: Array/Scalar) and // returns a Float32 distance per row. // -// These are identical to the vector_search UDFs but kept in this module so -// vector_usearch is fully self-contained (no dependency on vector_search). +// These are kept in this module alongside the UDTF and optimizer rule. use std::any::Any; use std::hash::{Hash, Hasher}; diff --git a/src/udtf.rs b/src/udtf.rs index fd4da2b..c97c523 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -1,106 +1,90 @@ -// udtf.rs — USearchUDTF: explicit SQL table function interface. +// udtf.rs — vector_search_vector: explicit SQL table function for ANN search. // // Usage: -// SELECT key, _distance FROM vector_usearch('table', ARRAY[...], k) -// SELECT key, _distance FROM vector_usearch('table', ARRAY[...], k, ef_search) +// SELECT * FROM vector_search_vector('conn.schema.table', 'column', ARRAY[...], k) // -// Returns two columns only: `key: UInt64` and `_distance: Float32`. -// To get full row data, JOIN the result against the data table: -// -// SELECT d.id, d.name, vs._distance -// FROM vector_usearch('items', ARRAY[...], 10) vs -// JOIN items d ON d.id = vs.key -// ORDER BY vs._distance +// Returns all table columns plus `_distance: Float32`. +// Requires a vector index on the specified column. use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::sync::Arc; -use arrow_array::{Array, Float32Array, UInt64Array}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_array::Array; +use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::{Session, TableFunctionImpl, TableProvider}; use datafusion::common::Result; use datafusion::error::DataFusionError; -use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::execution::TaskContext; use datafusion::logical_expr::{Expr, TableType}; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, }; use datafusion::scalar::ScalarValue; -use crate::registry::VectorIndexResolver; +use crate::planner::{attach_distances, provider_key_col_idx, usearch_search}; +use crate::registry::{RegisteredTable, VectorIndexResolver}; // ── UDTF ───────────────────────────────────────────────────────────────────── -/// Table function: vector_usearch(table_name, query_vec, k [, ef_search]) +/// Table function: vector_search_vector('conn.schema.table', 'column', ARRAY[...], k) /// -/// Returns `(key: UInt64, _distance: Float32)`. Join with your data table on -/// the key column to retrieve full rows. +/// Returns all table columns plus `_distance: Float32`. /// -/// This entry point is synchronous. For async-backed [`VectorIndexResolver`] -/// implementations, it only works when the target index is already loaded in -/// the local cache. `vector_usearch()` does not call `prepare()` and cannot -/// trigger async index loads. -pub struct USearchUDTF { +/// This entry point is synchronous. It calls `resolve()` (cache-only) on the +/// registry, so the index must already be loaded (e.g. via `refresh_for_tables` +/// before planning). If the index is not cached, the function returns an error. +pub struct VectorSearchVectorUDTF { registry: Arc, } -impl fmt::Debug for USearchProvider { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "USearchProvider({})", self.table_name) - } -} - -impl fmt::Debug for USearchUDTF { +impl fmt::Debug for VectorSearchVectorUDTF { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "USearchUDTF") + write!(f, "VectorSearchVectorUDTF") } } -impl USearchUDTF { +impl VectorSearchVectorUDTF { pub fn new(registry: Arc) -> Self { Self { registry } } } -impl TableFunctionImpl for USearchUDTF { +impl TableFunctionImpl for VectorSearchVectorUDTF { fn call(&self, exprs: &[Expr]) -> Result> { - if exprs.len() < 3 { - return Err(DataFusionError::Execution( - "vector_usearch requires at least 3 args: (table, query_vec, k)".into(), + if exprs.len() != 4 { + return Err(DataFusionError::Plan( + "vector_search_vector requires 4 arguments: \ + vector_search_vector('conn.schema.table', 'column', ARRAY[...], k)" + .into(), )); } - let table_name = extract_string_literal(&exprs[0])?; - let query_vec = extract_f32_vec(&exprs[1])?; - let k = extract_usize_literal(&exprs[2])?; + let table_ref = extract_string_literal(&exprs[0])?; + let column = extract_string_literal(&exprs[1])?; + let query_vec = extract_f32_vec(&exprs[2])?; + let k = extract_usize_literal(&exprs[3])?; - // Optional ef_search — used as a hint for the search expansion width. - // NOTE: changing ef_search on a shared Arc affects all concurrent - // queries. For production use, maintain separate index instances per - // query, or set ef_search at load time. - let _ef_search: Option = if exprs.len() > 3 { - Some(extract_usize_literal(&exprs[3])?) - } else { - None - }; + // Build the registry key: "conn::schema::table::column" + let (conn, schema, table) = parse_dot_table_ref(&table_ref)?; + let reg_key = format!("{conn}::{schema}::{table}::{column}"); - let registered = self.registry.resolve(&table_name).ok_or_else(|| { + let registered = self.registry.resolve(®_key).ok_or_else(|| { DataFusionError::Execution(format!( - "vector_usearch: table '{table_name}' is not loaded locally. \ -This synchronous path only checks the local cache and cannot trigger async \ -loads. Use the optimizer/planner vector query path or pre-load the index first." + "vector_search_vector: no loaded vector index for '{reg_key}'. \ + Ensure the table is synced and has a vector index on column '{column}'." )) })?; - Ok(Arc::new(USearchProvider { - index: registered.index.clone(), - table_name, + Ok(Arc::new(VectorSearchVectorProvider { + registered, query_vec, k, })) @@ -110,33 +94,32 @@ loads. Use the optimizer/planner vector query path or pre-load the index first." // ── TableProvider ───────────────────────────────────────────────────────────── /// TableProvider returned by the UDTF. Executes a USearch ANN query in scan(), -/// returning only `(key: UInt64, _distance: Float32)`. -struct USearchProvider { - index: Arc, - table_name: String, +/// fetches full rows via the lookup provider, and appends `_distance`. +struct VectorSearchVectorProvider { + registered: Arc, query_vec: Vec, k: usize, } -fn udtf_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("key", DataType::UInt64, false), - Field::new("_distance", DataType::Float32, true), - ])) +impl fmt::Debug for VectorSearchVectorProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "VectorSearchVectorProvider(k={})", self.k) + } } #[async_trait] -impl TableProvider for USearchProvider { +impl TableProvider for VectorSearchVectorProvider { fn as_any(&self) -> &dyn Any { self } fn schema(&self) -> SchemaRef { - udtf_schema() + // RegisteredTable.schema already includes all data columns + _distance + self.registered.schema.clone() } fn table_type(&self) -> TableType { - TableType::Base + TableType::Temporary } async fn scan( @@ -145,33 +128,71 @@ impl TableProvider for USearchProvider { projection: Option<&Vec>, _filters: &[Expr], _limit: Option, - ) -> Result> { - let matches = self - .index - .search(&self.query_vec, self.k) - .map_err(|e| DataFusionError::Execution(format!("USearch search error: {e}")))?; - - let schema = udtf_schema(); - - let keys = UInt64Array::from(matches.keys.clone()); - let dists = Float32Array::from(matches.distances.clone()); - - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(keys), Arc::new(dists)]) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + ) -> Result> { + // 1. HNSW search + let query_f64: Vec = self.query_vec.iter().map(|&v| v as f64).collect(); + let matches = usearch_search( + &self.registered.index, + &query_f64, + self.k, + self.registered.scalar_kind, + )?; + + if matches.keys.is_empty() { + let schema = match projection { + Some(indices) => Arc::new( + self.registered + .schema + .project(indices) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?, + ), + None => self.registered.schema.clone(), + }; + return Ok(Arc::new(BatchExec::new(schema, vec![]))); + } - // Apply column projection so DataFusion's JOIN column indices are correct. + // 2. Build key → distance map + let key_to_dist: HashMap = matches + .keys + .iter() + .zip(matches.distances.iter()) + .map(|(&k, &d)| (k, d)) + .collect(); + + // 3. Fetch full rows from lookup provider + let data_batches = self + .registered + .lookup_provider + .fetch_by_keys(&matches.keys, &self.registered.key_col, None) + .await?; + + // 4. Attach _distance column + let key_col_idx = provider_key_col_idx(&self.registered)?; + let result_batches = attach_distances( + data_batches, + key_col_idx, + &key_to_dist, + &self.registered.schema, + )?; + + // 5. Apply projection if needed let (proj_schema, proj_batches) = if let Some(indices) = projection { let ps = Arc::new( - schema + self.registered + .schema .project(indices) .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?, ); - let pb = batch - .project(indices) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; - (ps, vec![pb]) + let pb: Vec = result_batches + .into_iter() + .map(|b| { + b.project(indices) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + }) + .collect::>()?; + (ps, pb) } else { - (schema, vec![batch]) + (self.registered.schema.clone(), result_batches) }; Ok(Arc::new(BatchExec::new(proj_schema, proj_batches))) @@ -249,9 +270,24 @@ impl ExecutionPlan for BatchExec { } } -// ── Literal extraction helpers ──────────────────────────────────────────────── +// ── Helpers ────────────────────────────────────────────────────────────────── + +/// Parse a dot-separated table reference: "conn.schema.table" → ("conn", "schema", "table") +fn parse_dot_table_ref(s: &str) -> Result<(String, String, String)> { + let parts: Vec<&str> = s.splitn(3, '.').collect(); + if parts.len() != 3 { + return Err(DataFusionError::Plan(format!( + "Expected 'connection.schema.table', got '{s}'" + ))); + } + Ok(( + parts[0].to_string(), + parts[1].to_string(), + parts[2].to_string(), + )) +} -// DataFusion 51: Expr::Literal is a 2-tuple (ScalarValue, Option). +// DataFusion 51+: Expr::Literal is a 2-tuple (ScalarValue, Option). fn extract_string_literal(expr: &Expr) -> Result { match expr { From d51d8ba6427a71b4ec9c4adef1cde73e3820cd02 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Thu, 16 Apr 2026 18:18:30 +0530 Subject: [PATCH 6/7] docs: update README for vector_search_vector UDTF Update UDTF section to reflect the new vector_search_vector signature and full-row return schema. Update module structure reference. --- README.md | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index bc929fb..f3b2da1 100644 --- a/README.md +++ b/README.md @@ -144,17 +144,31 @@ See [Adaptive filtering](#adaptive-filtering) for details on how the execution p ### UDTF path -For runtime query vectors, complex joins, or explicit over-fetch control: +`vector_search_vector` provides an explicit table function for ANN search, returning all table columns plus `_distance`: ```sql -SELECT vs.key, vs._distance, d.title -FROM vector_usearch('my_table', ARRAY[0.1, 0.2, ...], 20) vs -JOIN my_table d ON d.id = vs.key -ORDER BY vs._distance ASC -LIMIT 10 +SELECT id, title, _distance +FROM vector_search_vector('conn.schema.table', 'column', ARRAY[0.1, 0.2, ...], 10) +ORDER BY _distance ASC ``` -The UDTF always calls `index.search()` directly — no filter absorption. Apply `WHERE` on the outer query to post-filter. +| Argument | Type | Description | +|---|---|---| +| table | string literal | Dot-separated table reference: `'conn.schema.table'` | +| column | string literal | Vector column with a registered index | +| query | `ARRAY[...]` literal | Query vector | +| k | integer | Number of nearest neighbors to return | + +The UDTF calls `resolve()` (sync, cache-only) on the registry — the index must already be loaded before the query is planned. It always calls `index.search()` directly — no filter absorption. Apply `WHERE` on the outer query to post-filter. + +```sql +-- With filtering, aggregation, etc. +SELECT category, COUNT(*) AS cnt, AVG(_distance) AS avg_dist +FROM vector_search_vector('conn.schema.table', 'embedding', ARRAY[...], 50) +WHERE category = 'nlp' +GROUP BY category +ORDER BY avg_dist +``` ### Tuning @@ -205,7 +219,7 @@ src/ rule.rs — USearchRule: optimizer rewrite rule planner.rs — USearchExecPlanner, USearchExec: physical execution udf.rs — l2_distance, cosine_distance, negative_dot_product scalar UDFs - udtf.rs — vector_usearch table function + udtf.rs — vector_search_vector table function lookup.rs — PointLookupProvider trait + HashKeyProvider keys.rs — DatasetLayout, pack_key/unpack_key key encoding @@ -292,6 +306,6 @@ Tests cover optimizer rule matching/rejection, end-to-end execution through both | Limitation | Notes | |---|---| | Stacked `Filter` nodes | Only one `Filter -> TableScan` layer is absorbed. `Filter -> Filter -> TableScan` falls back to exact execution. DataFusion typically combines multiple WHERE conditions into a single Filter, so this rarely occurs. | -| Runtime query vectors | The query vector must be a compile-time literal (`ARRAY[0.1, ...]`). Column references or subquery results are not rewritten. Use the UDTF path for runtime vectors. | +| Runtime query vectors | The query vector must be a compile-time literal (`ARRAY[0.1, ...]`). Column references or subquery results are not rewritten. Use `vector_search_vector` for explicit ANN queries. | | `ef_search` per-query | `expansion_search` is global to the index instance. Per-query adjustment is not supported. | | No DELETE / compaction | USearch soft-deletes entries but requires a full rebuild to reclaim space. | From 7822f7de0e4d68edc56247fd74c98d87e9962bbe Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Thu, 16 Apr 2026 18:33:08 +0530 Subject: [PATCH 7/7] fix: use f64 precision for UDTF query vectors, add UDTF tests Parse query vectors as f64 to match the optimizer path's precision, avoiding silent accuracy loss for F64-quantized indexes. Add 5 tests for vector_search_vector: basic happy path, projection pushdown, bad table ref error, registry miss error, k > dataset size. --- src/udtf.rs | 35 +++++++++--------- tests/execution.rs | 90 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 18 deletions(-) diff --git a/src/udtf.rs b/src/udtf.rs index c97c523..7e2b1a2 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -69,7 +69,7 @@ impl TableFunctionImpl for VectorSearchVectorUDTF { let table_ref = extract_string_literal(&exprs[0])?; let column = extract_string_literal(&exprs[1])?; - let query_vec = extract_f32_vec(&exprs[2])?; + let query_vec = extract_f64_vec(&exprs[2])?; let k = extract_usize_literal(&exprs[3])?; // Build the registry key: "conn::schema::table::column" @@ -97,7 +97,7 @@ impl TableFunctionImpl for VectorSearchVectorUDTF { /// fetches full rows via the lookup provider, and appends `_distance`. struct VectorSearchVectorProvider { registered: Arc, - query_vec: Vec, + query_vec: Vec, k: usize, } @@ -130,10 +130,9 @@ impl TableProvider for VectorSearchVectorProvider { _limit: Option, ) -> Result> { // 1. HNSW search - let query_f64: Vec = self.query_vec.iter().map(|&v| v as f64).collect(); let matches = usearch_search( &self.registered.index, - &query_f64, + &self.query_vec, self.k, self.registered.scalar_kind, )?; @@ -311,7 +310,7 @@ fn extract_usize_literal(expr: &Expr) -> Result { } } -fn extract_f32_vec(expr: &Expr) -> Result> { +fn extract_f64_vec(expr: &Expr) -> Result> { use arrow_array::{Float32Array, Float64Array}; match expr { @@ -320,11 +319,11 @@ fn extract_f32_vec(expr: &Expr) -> Result> { return Err(DataFusionError::Execution("Empty query vector".into())); } let inner = arr.value(0); - if let Some(f32a) = inner.as_any().downcast_ref::() { - return Ok(f32a.values().to_vec()); - } if let Some(f64a) = inner.as_any().downcast_ref::() { - return Ok(f64a.values().iter().map(|&v| v as f32).collect()); + return Ok(f64a.values().to_vec()); + } + if let Some(f32a) = inner.as_any().downcast_ref::() { + return Ok(f32a.values().iter().map(|&v| v as f64).collect()); } Err(DataFusionError::Execution( "FixedSizeList inner is not Float32/Float64".into(), @@ -335,11 +334,11 @@ fn extract_f32_vec(expr: &Expr) -> Result> { return Err(DataFusionError::Execution("Empty query vector".into())); } let inner = arr.value(0); - if let Some(f32a) = inner.as_any().downcast_ref::() { - return Ok(f32a.values().to_vec()); - } if let Some(f64a) = inner.as_any().downcast_ref::() { - return Ok(f64a.values().iter().map(|&v| v as f32).collect()); + return Ok(f64a.values().to_vec()); + } + if let Some(f32a) = inner.as_any().downcast_ref::() { + return Ok(f32a.values().iter().map(|&v| v as f64).collect()); } Err(DataFusionError::Execution( "List scalar inner is not Float32/Float64".into(), @@ -349,10 +348,10 @@ fn extract_f32_vec(expr: &Expr) -> Result> { let mut result = Vec::with_capacity(sf.args.len()); for arg in &sf.args { match arg { - Expr::Literal(ScalarValue::Float64(Some(v)), _) => result.push(*v as f32), - Expr::Literal(ScalarValue::Float32(Some(v)), _) => result.push(*v), - Expr::Literal(ScalarValue::Int64(Some(v)), _) => result.push(*v as f32), - Expr::Literal(ScalarValue::Int32(Some(v)), _) => result.push(*v as f32), + Expr::Literal(ScalarValue::Float64(Some(v)), _) => result.push(*v), + Expr::Literal(ScalarValue::Float32(Some(v)), _) => result.push(*v as f64), + Expr::Literal(ScalarValue::Int64(Some(v)), _) => result.push(*v as f64), + Expr::Literal(ScalarValue::Int32(Some(v)), _) => result.push(*v as f64), other => { return Err(DataFusionError::Execution(format!( "Non-literal in ARRAY[...]: {other:?}" @@ -363,7 +362,7 @@ fn extract_f32_vec(expr: &Expr) -> Result> { Ok(result) } other => Err(DataFusionError::Execution(format!( - "Cannot extract f32 vector from: {other:?}" + "Cannot extract f64 vector from: {other:?}" ))), } } diff --git a/tests/execution.rs b/tests/execution.rs index 8fc47fb..41b175d 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -937,3 +937,93 @@ async fn udf_dimension_mismatch_select_star() { "expected dimension mismatch error, got: {msg}" ); } + +// ═══════════════════════════════════════════════════════════════════════════════ +// vector_search_vector UDTF tests +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Basic happy path: returns correct rows with _distance column. +#[tokio::test] +async fn udtf_vector_search_vector_basic() { + let ctx = make_exec_ctx("conn::schema::items::vector").await; + let sql = "SELECT id, label, _distance FROM vector_search_vector('conn.schema.items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 3) ORDER BY _distance ASC"; + let df = ctx.sql(sql).await.expect("sql"); + let batches = df.collect().await.expect("collect"); + + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 3, "expected 3 results"); + + // First result should be row 1 (exact match, distance 0) + let ids = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("id col"); + assert_eq!(ids.value(0), 1, "closest must be row 1"); + + let dists = batches[0] + .column(2) + .as_any() + .downcast_ref::() + .expect("_distance col"); + assert!( + (dists.value(0) - 0.0).abs() < 1e-6, + "row 1 distance must be 0.0, got {}", + dists.value(0) + ); +} + +/// Projection pushdown: only requested columns are returned. +#[tokio::test] +async fn udtf_vector_search_vector_projection() { + let ctx = make_exec_ctx("conn::schema::items::vector").await; + let sql = "SELECT id, _distance FROM vector_search_vector('conn.schema.items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 2)"; + let df = ctx.sql(sql).await.expect("sql"); + let batches = df.collect().await.expect("collect"); + assert_eq!( + batches[0].num_columns(), + 2, + "expected 2 columns (id, _distance), got {}", + batches[0].num_columns() + ); + let schema = batches[0].schema(); + let col_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect(); + assert_eq!(col_names, vec!["id", "_distance"]); +} + +/// parse_dot_table_ref error: fewer than 3 parts. +#[tokio::test] +async fn udtf_vector_search_vector_bad_table_ref() { + let ctx = make_exec_ctx("conn::schema::items::vector").await; + let sql = "SELECT * FROM vector_search_vector('items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 3)"; + let err = ctx.sql(sql).await.unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("connection.schema.table"), + "expected table ref format error, got: {msg}" + ); +} + +/// Registry miss: column not in registry returns clear error. +#[tokio::test] +async fn udtf_vector_search_vector_registry_miss() { + let ctx = make_exec_ctx("conn::schema::items::vector").await; + let sql = "SELECT * FROM vector_search_vector('conn.schema.items', 'nonexistent', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 3)"; + let err = ctx.sql(sql).await.unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("no loaded vector index"), + "expected registry miss error, got: {msg}" + ); +} + +/// Empty result: search with k larger than dataset returns all rows. +#[tokio::test] +async fn udtf_vector_search_vector_k_larger_than_dataset() { + let ctx = make_exec_ctx("conn::schema::items::vector").await; + let sql = "SELECT id, _distance FROM vector_search_vector('conn.schema.items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 100)"; + let df = ctx.sql(sql).await.expect("sql"); + let batches = df.collect().await.expect("collect"); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 4, "expected all 4 rows when k > dataset size"); +}