Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,31 @@ See [Adaptive filtering](#adaptive-filtering) for details on how the execution p

### UDTF path

For runtime query vectors, complex joins, or explicit over-fetch control:
`vector_search_vector` provides an explicit table function for ANN search, returning all table columns plus `_distance`:

```sql
SELECT vs.key, vs._distance, d.title
FROM vector_usearch('my_table', ARRAY[0.1, 0.2, ...], 20) vs
JOIN my_table d ON d.id = vs.key
ORDER BY vs._distance ASC
LIMIT 10
SELECT id, title, _distance
FROM vector_search_vector('conn.schema.table', 'column', ARRAY[0.1, 0.2, ...], 10)
ORDER BY _distance ASC
```

The UDTF always calls `index.search()` directly — no filter absorption. Apply `WHERE` on the outer query to post-filter.
| Argument | Type | Description |
|---|---|---|
| table | string literal | Dot-separated table reference: `'conn.schema.table'` |
| column | string literal | Vector column with a registered index |
| query | `ARRAY[...]` literal | Query vector |
| k | integer | Number of nearest neighbors to return |

The UDTF calls `resolve()` (sync, cache-only) on the registry — the index must already be loaded before the query is planned. It always calls `index.search()` directly — no filter absorption. Apply `WHERE` on the outer query to post-filter.

```sql
-- With filtering, aggregation, etc.
SELECT category, COUNT(*) AS cnt, AVG(_distance) AS avg_dist
FROM vector_search_vector('conn.schema.table', 'embedding', ARRAY[...], 50)
WHERE category = 'nlp'
GROUP BY category
ORDER BY avg_dist
```

### Tuning

Expand Down Expand Up @@ -205,7 +219,7 @@ src/
rule.rs — USearchRule: optimizer rewrite rule
planner.rs — USearchExecPlanner, USearchExec: physical execution
udf.rs — l2_distance, cosine_distance, negative_dot_product scalar UDFs
udtf.rs — vector_usearch table function
udtf.rs — vector_search_vector table function
lookup.rs — PointLookupProvider trait + HashKeyProvider
keys.rs — DatasetLayout, pack_key/unpack_key key encoding

Expand Down Expand Up @@ -292,6 +306,6 @@ Tests cover optimizer rule matching/rejection, end-to-end execution through both
| Limitation | Notes |
|---|---|
| Stacked `Filter` nodes | Only one `Filter -> TableScan` layer is absorbed. `Filter -> Filter -> TableScan` falls back to exact execution. DataFusion typically combines multiple WHERE conditions into a single Filter, so this rarely occurs. |
| Runtime query vectors | The query vector must be a compile-time literal (`ARRAY[0.1, ...]`). Column references or subquery results are not rewritten. Use the UDTF path for runtime vectors. |
| Runtime query vectors | The query vector must be a compile-time literal (`ARRAY[0.1, ...]`). Column references or subquery results are not rewritten. Use `vector_search_vector` for explicit ANN queries. |
| `ef_search` per-query | `expansion_search` is global to the index instance. Per-query adjustment is not supported. |
| No DELETE / compaction | USearch soft-deletes entries but requires a full rebuild to reclaim space. |
11 changes: 5 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub use registry::{
};
pub use rule::USearchRule;
pub use udf::{cosine_distance_udf, l2_distance_udf, negative_dot_product_udf};
pub use udtf::USearchUDTF;
pub use udtf::VectorSearchVectorUDTF;

#[cfg(feature = "parquet-provider")]
pub use parquet_provider::ParquetLookupProvider;
Expand All @@ -99,7 +99,8 @@ use datafusion::prelude::SessionContext;
/// - `l2_distance(col, query)` — squared Euclidean distance (L2sq)
/// - `cosine_distance(col, query)` — cosine distance
/// - `negative_dot_product(col, query)` — negated inner product
/// - `vector_usearch(table, query, k)` — explicit ANN table function
/// - `vector_search_vector('conn.schema.table', 'column', ARRAY[...], k)`
/// — explicit ANN table function returning full rows + `_distance`
/// (cache-only for async-backed resolvers; does not trigger async loads)
/// - [`USearchRule`] — optimizer rewrite rule
///
Expand All @@ -110,10 +111,8 @@ pub fn register_all(ctx: &SessionContext, registry: Arc<dyn VectorIndexResolver>
ctx.register_udf(ScalarUDF::new_from_impl(cosine_distance_udf()));
ctx.register_udf(ScalarUDF::new_from_impl(negative_dot_product_udf()));
ctx.register_udtf(
"vector_usearch",
// `vector_usearch()` is synchronous and therefore cache-only for
// async-backed resolvers.
Arc::new(USearchUDTF::new(registry.clone())),
"vector_search_vector",
Arc::new(VectorSearchVectorUDTF::new(registry.clone())),
);
ctx.add_optimizer_rule(Arc::new(USearchRule::new(registry)));
Ok(())
Expand Down
6 changes: 3 additions & 3 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ async fn adaptive_filtered_execute(

/// Call `index.search` with the native scalar type appropriate for the column.
/// Converts the usearch error into a `DataFusionError::Execution`.
fn usearch_search(
pub(crate) fn usearch_search(
index: &usearch::Index,
query_f64: &[f64],
k: usize,
Expand Down Expand Up @@ -797,7 +797,7 @@ fn compute_raw_distance_f64(v: &[f64], q: &[f64], dist_type: &DistanceType) -> f
/// Extract the distance from a single row of a vector column.
///
/// Index of the key column in the lookup provider schema.
fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result<usize> {
pub(crate) fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result<usize> {
registered
.lookup_provider
.schema()
Expand All @@ -813,7 +813,7 @@ fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result
// ── Distance attachment ───────────────────────────────────────────────────────

/// Append a `_distance: Float32` column to each batch.
fn attach_distances(
pub(crate) fn attach_distances(
batches: Vec<RecordBatch>,
key_col_idx: usize,
key_to_dist: &HashMap<u64, f32>,
Expand Down
89 changes: 71 additions & 18 deletions src/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
//
// Replacement:
//
// Sort(fetch=k) ← kept (sort order)
// Projection([col(a), col(b), col("_distance").alias("dist")])
// USearchNode ← executes ANN
// Projection([final output cols])
// Sort(fetch=k)
// Projection([final output cols + optional hidden _distance])
// USearchNode

use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -151,25 +152,33 @@ impl USearchRule {
node: Arc::new(node) as Arc<dyn UserDefinedLogicalNode>,
}));

// Build Projection over USearchNode matching the original output schema.
// Build the final user-visible projection over USearchNode output.
let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance");
let new_proj_exprs = if proj_exprs_slice.is_empty() {
let final_proj_exprs = if proj_exprs_slice.is_empty() {
passthrough_projection(&vsn_df_schema, &table_ref)
} else {
remap_projections(proj_exprs_slice, dist_alias_str, &table_ref)
};
let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?;

// Keep the Sort node so DataFusion handles ordering by _distance / dist.
// USearch returns results in arbitrary (internal) order when the underlying
// data is fetched from the TableProvider.
Some(LogicalPlan::Sort(
datafusion::logical_expr::logical_plan::Sort {
expr: sort.expr.clone(),
input: Arc::new(LogicalPlan::Projection(new_proj)),
fetch: sort.fetch,
},
))
let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref());
let needs_hidden_distance = remapped_sort_exprs.iter().any(
|e| matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance"),
) && !projection_exposes_name(&final_proj_exprs, "_distance");

let mut sort_input_exprs = final_proj_exprs.clone();
if needs_hidden_distance {
sort_input_exprs.push(col("_distance"));
}

let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?;
let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort {
expr: remapped_sort_exprs,
input: Arc::new(LogicalPlan::Projection(sort_input)),
fetch: sort.fetch,
});

let outer_proj_exprs = build_outer_projection(&final_proj_exprs);
let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?;
Some(LogicalPlan::Projection(outer_proj))
}
}

Expand Down Expand Up @@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo
}

fn is_distance_expr(expr: &Expr) -> bool {
matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
let inner = match expr {
Expr::Alias(a) => a.expr.as_ref(),
other => other,
};
matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
}

fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec<f64>)> {
Expand Down Expand Up @@ -322,6 +335,46 @@ fn remap_projections(
.collect()
}

fn remap_sort_exprs(
sort_exprs: &[datafusion::logical_expr::SortExpr],
dist_alias_name: Option<&str>,
) -> Vec<datafusion::logical_expr::SortExpr> {
sort_exprs
.iter()
.map(|sort_expr| {
let remapped_expr = match &sort_expr.expr {
Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()),
expr if is_distance_expr(expr) => col("_distance"),
other => other.clone(),
};
datafusion::logical_expr::SortExpr {
expr: remapped_expr,
asc: sort_expr.asc,
nulls_first: sort_expr.nulls_first,
}
})
.collect()
}

fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool {
exprs.iter().any(|expr| match expr {
Expr::Alias(a) => a.name == name,
Expr::Column(c) => c.name == name,
_ => false,
})
}

fn build_outer_projection(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.map(|expr| match expr {
Expr::Alias(a) => col(a.name.as_str()),
Expr::Column(c) => Expr::Column(c.clone()),
other => col(other.schema_name().to_string()),
})
.collect()
}

/// Build a passthrough Projection for SELECT * queries (no original Projection node).
/// Projects only the original table columns (not `_distance`) so the output schema
/// matches the original Sort schema. The Sort re-evaluates the distance UDF expression
Expand Down
3 changes: 1 addition & 2 deletions src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
// Each takes (vector_col: FixedSizeList<Float32>, query: Array/Scalar) and
// returns a Float32 distance per row.
//
// These are identical to the vector_search UDFs but kept in this module so
// vector_usearch is fully self-contained (no dependency on vector_search).
// These are kept in this module alongside the UDTF and optimizer rule.

use std::any::Any;
use std::hash::{Hash, Hasher};
Expand Down
Loading
Loading