diff --git a/docs/specs/2026-05-21-query-edge-filter-design.md b/docs/specs/2026-05-21-query-edge-filter-design.md new file mode 100644 index 0000000..f6ddcb5 --- /dev/null +++ b/docs/specs/2026-05-21-query-edge-filter-design.md @@ -0,0 +1,198 @@ +--- +date: 2026-05-21 +topic: query-edge-filter +status: implemented +--- + +# Edge-Predicate Filtering for Generated Query Builders + +## Goal + +Let a generated `Query` filter root records by a scalar predicate of a _neighbouring_ node +reached over an edge — "people who have a dog named Fido" — as a first-class, generated method: + +```go +client.Person.Query(ctx).WhereDogs(`eq(name, "Fido")`).Nodes() +``` + +Today `Query.Filter` (and the `typed.Query[T].Filter` it delegates to) only constrains the +root node's _own_ predicates: the filter string lands in dgraph's root `@filter`, which has no +syntax for an edge target's scalar value. There is no way, short of hand-written DQL through +`Client.QueryRaw`, to express "root has an edge whose target matches X." + +## Non-Goals + +- **A typed predicate DSL.** `WhereDogs(filter string, params ...any)` takes a dgraph `@filter` + string, exactly like the existing `Filter`. A type-safe + `WhereDogs(func(c *DogCriteria){ c.NameEq("Fido") })` face is future work; it would layer over the + same `WhereEdge` substrate this spec introduces. +- **Multi-hop filters** (root → edge → edge). The filter string constrains the _immediate_ edge + target's own predicates. +- **Changing `Filter`, `Nodes`, `First`, `IterNodes`, or CRUD.** + +## Why This Approach + +**dgman emits one query block.** A `typed.Query[T]` wraps a single `*dg.Query`, which dgman renders +as one root `@filter` over an `expand` body (`query.go:generateQuery`). dgman exposes no way to +attach a `@filter` to an edge sub-block. So edge filtering cannot be a new dgman builder call — it +needs a genuinely separate execution path. + +**Server-side semi-join in one request.** A query carrying edge constraints runs as a single +multi-block DQL request: + +1. **Var block** — + ` as var(func: ) @cascade { … one filtered block per edge constraint }`. `@cascade` + drops any node with an empty block, so the `as` variable binds exactly the roots that satisfy + every constraint. The variable lives on the server. +2. **Data block** — the existing `*dg.Query`, intersected with `uid()`, carrying the caller's + `Filter`, ordering, pagination, and dgman's normal projection. +3. **Count block** (only for `NodesAndCount`) — `count(uid)` over `uid()` with the caller's + `Filter` re-applied, so the total matches the rows the data block would return without + pagination. + +The matched UIDs are never returned to the client or inlined into a `uid(0x1, 0x2, …)` literal. +Memory and DQL size stay bounded by the query, not by how many roots match. This is the same shape +dgman's own `NodesAndCount` uses internally (`filtered as var(...)` feeding `func: uid(filtered)`). + +**On the rejected alternative.** An earlier draft of this design ran the semi-join client-side — a +pre-pass returned the matching UIDs, which the main query then inlined as `uid()` — and +rejected a `QueryRaw` two-block query on the grounds that it would force re-implementing the result +projection (`expand` drops managed reverse edges, `reverse_test.go`). That reasoning assumed the +data block's body would be _hand-written_. It need not be: rendering the request with +`dg.NewQueryBlock(varBlock, dataBlock).String()` lets dgman generate the data block's projection +exactly as it does for a normal query, so reverse-edge-aware expansion is preserved. The request is +then executed with `Client.QueryRaw`, and the data block is decoded through the typed +predicate-remap path (see `multi_query.go`). This is strictly better than the client-side pre-pass: +the UIDs stay on the server (bounded memory and DQL size), and for the single-shot terminals +(`Nodes`, `First`, `NodesAndCount`) the whole semi-join runs in one read-only transaction, closing +the second-read consistency window the client-side form had. + +**`IterNodes`.** Each page is its own request that re-resolves the var block server-side and pages +the data block with `first`/`offset`. Memory stays bounded regardless of result size — the property +the streaming terminal advertises — at the cost of re-running the `@cascade` match per page, and of +reading each page from a fresh snapshot rather than one transaction. For the unbounded-result case +this terminal exists to serve, bounded memory is the property that matters. + +## Design + +### `typed.Query[T]` — the `WhereEdge` substrate + +`Query[T]` carries `conn`/`ctx` (to run the request) and an `edges` slice; `customRootExpr` records +a caller's `UID`/`RootFunc` narrowing so the var block can root at it: + +```go +type Query[T any] struct { + q *dg.Query + conn modusgraph.Client + ctx context.Context + limit int + offset int + edges []edgeFilter + filters []filterFrag + customRootExpr string +} + +type edgeFilter struct { + predicate string + filter string + params []any +} +``` + +New builder, accumulating (each call ANDs another constraint): + +```go +func (qb *Query[T]) WhereEdge(predicate, filter string, params ...any) *Query[T] +``` + +The terminals (`Nodes`, `First`, `IterNodes`, `NodesAndCount`) check for edge constraints. With +none, they run the plain dgman query unchanged. With constraints, they call `runEdge`, which +assembles the var/data/(count) blocks, renders them with `dg.NewQueryBlock`, runs the request via +`QueryRaw`, and decodes. `runEdge` pushes the data-block filter onto `qb.q` last-write-wins and +never mutates the accumulated filters, so `IterNodes` can call it once per page. + +### Server-side var DQL + +For `WhereEdge("pets", "eq(name, $1)", "Fido")` over `Owner`, with a root filter +`Filter("eq(name, \"Alice\")")` and `NodesAndCount`: + +```dql +{ + mgMatched as var(func: type(Owner)) @filter(has(dgraph.type)) @cascade { + uid + mg_e0 : pets @filter(eq(name, "Fido")) { uid } + } + mgData(func: type(Owner)) @filter(has(dgraph.type) AND (eq(name, "Alice")) AND uid(mgMatched)) { + uid + expand(_all_) { … } + } + mgCount(func: uid(mgMatched)) @filter(has(dgraph.type) AND (eq(name, "Alice"))) { + count(uid) + } +} +``` + +The var block is built by reconfiguring a fresh `conn.Query(ctx, &T{})` with +`As(mgMatched).Var().Cascade().Query(body, params...)`; when the caller narrowed the root, the var +block roots at `customRootExpr` so the match is the intersection of the caller's root and the edge +constraints, not an overwrite. Every edge block is aliased `mg_e0`, `mg_e1`, … so two constraints on +the same predicate do not collide as duplicate fields. Each edge filter is written numbering its +params from `$1`; `shiftPlaceholders` renumbers them against the concatenated params slice before +they are joined into one body. + +### Generated face — `Query.Where` + +`wrapper_query.go.tmpl` emits one thin method per edge field, delegating to the substrate — the same +pattern `Filter`/`Cascade` already use: + +```go +func (q *OwnerQuery) WherePets(filter string, params ...any) *OwnerQuery { + q.typed.WhereEdge("pets", filter, params...) + return q +} +``` + +The method name is `Where` + the field's accessor name; the predicate string is the field's resolved +dgraph predicate. Generated for every edge field (multi, singular, and reverse). No parser changes — +`model.Field` already carries `IsEdge`/`Predicate`. + +## Error handling + +`WhereEdge` never executes — it only appends. The request error (malformed filter, transport +failure) surfaces from the terminal: `Nodes`/`First`/ `NodesAndCount` return it; `IterNodes` yields +one `(nil, err)` and stops. A var block matching zero roots is not an error — `uid()` of an empty +var yields no rows, so the terminal returns an empty result. + +## Testing + +- **`typed/query_test.go`** — `owner`/`pet` test types (an edge pair). Behavioral tests against the + file engine: `WhereEdge` filters by edge target; no match yields empty; `$N` params bind; + `WhereEdge` composes with a root `Filter`; a `UID` root is preserved (intersection, not + overwrite); two `WhereEdge` calls AND; `First`, `IterNodes`, and `NodesAndCount` honor edge + constraints (the count reflects the full match, independent of `Limit`). +- **`typed/query_internal_test.go`** — white-box assertion that the rendered DQL is a server-side + var (`mgMatched as var(`, `uid(mgMatched)`) carrying no inlined `uid(0x…)` literal list. +- **`generator_test.go`** — a two-type edge schema asserts `Where` is generated and delegates + to `typed.WhereEdge`, and that an edgeless type gets no `Where*` method. +- **`wrapper_query_e2e_test.go`** — `client.Director.Query(ctx).WhereFilms(...)` end-to-end against + the file-backed client. + +## Migration / blast radius + +- **Modified:** `typed/query.go` (struct fields, `WhereEdge`, the + `runEdge`/`edgeBlocks`/`edgeVarBlock`/`edgeCountBlock`/`edgeMatchBody`/ `shiftPlaceholders` + helpers, edge-aware terminals, doc comments); `typed/client.go` (`Query` passes `conn`/`ctx`); + `wrapper_query.go.tmpl` (generated `Where`). +- **Regenerated:** the `movies` fixture — every `*_query_gen.go` for an entity with edges gains + `Where` methods. +- **New tests** in `typed/query_test.go`, `typed/query_internal_test.go`, `generator_test.go`, + `wrapper_query_e2e_test.go`. +- No change to `Filter`, `Nodes`, `First`, `IterNodes`, CRUD, or any other generated artifact. The + var-block path is inert unless `WhereEdge` is called. + +## Open decisions + +None. The string-filter API (over a typed DSL) and one-hop depth were settled before implementation; +the typed predicate DSL is recorded above as future work. The semi-join runs server-side via a var +block rather than the client-side pre-pass an earlier draft proposed, so matched UIDs stay off the +client — see _Why This Approach_. diff --git a/typed/client.go b/typed/client.go new file mode 100644 index 0000000..939986a --- /dev/null +++ b/typed/client.go @@ -0,0 +1,83 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "iter" + + "github.com/matthewmcneely/modusgraph" +) + +// Client provides type-safe CRUD and query operations over records of type T. +// T is the schema struct (for example schema.Actor); modusgraph reflects over +// the struct's dgraph/json tags, so T needs no constraint. +type Client[T any] struct { + conn modusgraph.Client +} + +// NewClient binds a Client[T] to conn. +func NewClient[T any](conn modusgraph.Client) *Client[T] { + return &Client[T]{conn: conn} +} + +// Get loads the T with the given UID. +func (c *Client[T]) Get(ctx context.Context, uid string) (rec *T, err error) { + ctx, span := currentTracer().StartSpan(ctx, "get", entityName[T]()) + defer func() { span.End(err) }() + var out T + if err = c.conn.Get(ctx, &out, uid); err != nil { + return nil, err + } + return &out, nil +} + +// Add inserts a new T. modusgraph writes the assigned UID back into rec. +func (c *Client[T]) Add(ctx context.Context, rec *T) (err error) { + ctx, span := currentTracer().StartSpan(ctx, "add", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Insert(ctx, rec) +} + +// Update modifies an existing T (must have its UID set). +func (c *Client[T]) Update(ctx context.Context, rec *T) (err error) { + ctx, span := currentTracer().StartSpan(ctx, "update", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Update(ctx, rec) +} + +// Upsert inserts or updates rec, matching against predicates. With no +// predicates, the first field tagged dgraph:"upsert" is used. +func (c *Client[T]) Upsert(ctx context.Context, rec *T, predicates ...string) (err error) { + ctx, span := currentTracer().StartSpan(ctx, "upsert", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Upsert(ctx, rec, predicates...) +} + +// Delete removes the T with the given UID. +func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { + ctx, span := currentTracer().StartSpan(ctx, "delete", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Delete(ctx, []string{uid}) +} + +// Query returns a typed query builder for T. conn and ctx are carried so the +// builder can run a WhereEdge pre-pass (see Query.WhereEdge) if one is needed. +func (c *Client[T]) Query(ctx context.Context) *Query[T] { + var z T + return &Query[T]{q: c.conn.Query(ctx, &z), conn: c.conn, ctx: ctx} +} + +// defaultPageSize is the page size IterNodes uses to page through results. +const defaultPageSize = 50 + +// Iter returns an iterator over every T, paging transparently so large result +// sets are not materialized at once. It yields each record in turn; on error +// it yields a final (nil, err) and stops. All pages execute against one +// read-only transaction, so the iteration reads a single consistent snapshot. +func (c *Client[T]) Iter(ctx context.Context) iter.Seq2[*T, error] { + return c.Query(ctx).IterNodes() +} diff --git a/typed/client_test.go b/typed/client_test.go new file mode 100644 index 0000000..6fa2b1d --- /dev/null +++ b/typed/client_test.go @@ -0,0 +1,209 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// widget is a minimal schema struct used to exercise the typed package. +type widget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +// owner and pet exercise Query.WhereEdge: owner has an outbound "pets" edge to +// pet, and pet's Name carries an index so eq(name, ...) resolves inside an edge +// filter. The pair is the typed-package analogue of the Person/Dog example in +// docs/specs/2026-05-21-query-edge-filter-design.md. +type owner struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Pets []*pet `json:"pets,omitempty"` +} + +type pet struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +// newConn builds a local file-backed modusgraph client for a test. +func newConn(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestClient_AddPopulatesUIDAndGetReadsBack(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if w.UID == "" { + t.Fatal("Add did not populate UID on the passed struct") + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Name != "sprocket" || got.Qty != 3 { + t.Fatalf("Get returned %+v, want Name=sprocket Qty=3", got) + } +} + +func TestClient_Update(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "gear", Qty: 1} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + w.Qty = 99 + if err := c.Update(ctx, w); err != nil { + t.Fatalf("Update: %v", err) + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Qty != 99 { + t.Fatalf("Update did not persist; Qty = %d, want 99", got.Qty) + } +} + +func TestClient_Delete(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "bolt"} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if err := c.Delete(ctx, w.UID); err != nil { + t.Fatalf("Delete: %v", err) + } + if _, err := c.Get(ctx, w.UID); err == nil { + t.Fatal("Get after Delete returned no error; expected not-found") + } +} + +func TestClient_IterPagesThroughAllRecords(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // 125 is deliberately larger than the package's 50-record page size, so + // a correct Iter must fetch more than one page. + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("Iter yielded %d records, want %d", seen, n) + } +} + +// gadget is a dedicated upsert struct. It must not be the shared widget, because +// widget is used in tests that insert many records with duplicate Name values; +// adding a "upsert" directive to widget.Name would cause those inserts to +// collide and break unrelated tests. +type gadget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Label string `json:"label,omitempty" dgraph:"index=exact upsert"` + Stock int `json:"stock,omitempty" dgraph:"index=int"` +} + +func TestClient_Upsert(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[gadget](newConn(t)) + + // First call — creates the record. + g := &gadget{Label: "sprocket", Stock: 10} + if err := c.Upsert(ctx, g, "label"); err != nil { + t.Fatalf("Upsert (create): %v", err) + } + if g.UID == "" { + t.Fatal("Upsert (create) did not populate UID") + } + + // Second call — same Label value, different Stock. Must UPDATE, not insert. + g2 := &gadget{Label: "sprocket", Stock: 99} + if err := c.Upsert(ctx, g2, "label"); err != nil { + t.Fatalf("Upsert (update): %v", err) + } + + // Exactly one record must exist and it must carry the updated Stock. + nodes, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Query after Upsert: %v", err) + } + if len(nodes) != 1 { + t.Fatalf("got %d gadgets after two upserts on the same label, want 1", len(nodes)) + } + if nodes[0].Stock != 99 { + t.Fatalf("upserted gadget Stock = %d, want 99", nodes[0].Stock) + } +} + +func TestClient_IterStopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("Iter yielded %d records after break at 10, want 10", seen) + } +} diff --git a/typed/doc.go b/typed/doc.go new file mode 100644 index 0000000..1596c86 --- /dev/null +++ b/typed/doc.go @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package typed binds a Go type to the otherwise any-typed modusgraph.Client, +// giving you generic, type-safe CRUD and a fluent query builder with no +// per-entity code generation. It is the handwritten substrate that +// modusgraph-gen's generated clients compose over, and it is useful on its own +// wherever you want compile-time types over modusgraph. +// +// # Why +// +// modusgraph.Client is value-oriented: its methods take and return any, and a +// query is assembled by hand from dgman primitives and decoded into a slice you +// declare at the call site. That works, but every call site repeats the same +// shape — declare the destination, build the query, decode, check the type. +// Package typed lifts that shape into the type system once. +// +// Without the typed layer, a "first matching record" lookup carries the type on +// every line and decodes by hand: +// +// var out []Person +// q := client.Query(ctx, &Person{}). +// Filter("eq(name, $1)", "Alice"). +// First(1) +// if err := q.Nodes(&out); err != nil { +// return nil, err +// } +// var person *Person +// if len(out) > 0 { +// person = &out[0] +// } +// +// With it, the type is declared once — when the client is constructed — and the +// terminal returns exactly what you asked for: +// +// people := typed.NewClient[Person](client) +// person, err := people.Query(ctx). +// Filter("eq(name, $1)", "Alice"). +// First() +// // person is *Person; nil when nothing matched. +// +// # The query builder +// +// Query[T] is a fluent builder. Builder methods (Filter, OrderAsc, Limit, +// WhereEdge, and the rest) return *Query[T] for chaining; terminals (Nodes, +// First, NodesAndCount, IterNodes) execute and decode typed results. The +// builder delegates the actual querying, parameter binding, and injection-safe +// $N substitution to dgman — it adds the type binding and the fragment +// composition dgman does not provide: +// +// - Accumulated Filter fragments AND together, each fragment parenthesized so +// a fragment containing OR keeps its precedence. +// - OrGroup ORs several sub-scopes into one parenthesized group. +// - WhereEdge constrains T by a predicate of a neighbouring node reached over +// an edge, resolved by a pre-pass and intersected with any root you set. +// - IterNodes streams arbitrarily large result sets one page at a time over a +// single read-only snapshot. +// +// # Composing larger requests +// +// MultiQuery batches N same-type blocks into one Dgraph round-trip, keyed by +// block name. The filter subpackage builds parameterised @filter expressions +// (the substrate behind generated By and Or combinators), and the search +// subpackage merges ordered result sets by ID. +// +// # Tracing +// +// Every terminal opens a span through a process-wide tracer that is a no-op by +// default. Install one with SetTracer to emit spans without the typed package +// depending on any telemetry library. +package typed diff --git a/typed/example_test.go b/typed/example_test.go new file mode 100644 index 0000000..c17d3d1 --- /dev/null +++ b/typed/example_test.go @@ -0,0 +1,162 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "fmt" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// Person is the schema struct the examples bind a typed client to. modusgraph +// reflects over the dgraph/json tags, so the type needs no special interface. +type Person struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Age int `json:"age,omitempty" dgraph:"index=int"` + Friends []*Person `json:"friends,omitempty"` +} + +// ExampleClient shows the core lift package typed provides: declare the type +// once at construction, then Add, Get, and Query in terms of *Person rather +// than any. +func ExampleClient() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + + people := typed.NewClient[Person](conn) + ctx := context.Background() + + alice := &Person{Name: "Alice", Age: 30} + if err := people.Add(ctx, alice); err != nil { // Add writes the new UID back into alice. + panic(err) + } + + got, err := people.Get(ctx, alice.UID) // got is *Person, not any. + if err != nil { + panic(err) + } + fmt.Println(got.Name) +} + +// ExampleClient_query builds a filtered, ordered, paged query. The terminal +// returns []Person directly — no destination slice to declare, no decode step. +func ExampleClient_query() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + adults, err := people.Query(ctx). + Filter("ge(age, $1)", 18). + OrderAsc("name"). + Limit(50). + Nodes() + if err != nil { + panic(err) + } + fmt.Println(len(adults)) +} + +// ExampleQuery_First returns a single record or nil, replacing the +// declare-slice-then-index-element-zero idiom of the untyped client. +func ExampleQuery_First() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + person, err := people.Query(ctx). + Filter("eq(name, $1)", "Alice"). + First() + if err != nil { + panic(err) + } + if person == nil { + fmt.Println("not found") + return + } + fmt.Println(person.Name) +} + +// ExampleQuery_OrGroup ANDs a scalar filter with an OR of two sub-scopes: +// age >= 18 AND (name == "Alice" OR name == "Bob"). Each sub-scope is a +// detached Query whose filter is captured, not executed. +func ExampleQuery_OrGroup() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + got, err := people.Query(ctx). + Filter("ge(age, $1)", 18). + OrGroup( + typed.NewDetachedQuery[Person]().Filter(`eq(name, "Alice")`), + typed.NewDetachedQuery[Person]().Filter(`eq(name, "Bob")`), + ). + Nodes() + if err != nil { + panic(err) + } + fmt.Println(len(got)) +} + +// ExampleQuery_WhereEdge constrains people by a scalar of a neighbour reached +// over the "friends" edge — something a root filter cannot express. The builder +// resolves it with a pre-pass and intersects with any root you set. +func ExampleQuery_WhereEdge() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + // Everyone who has a friend named "Alice". + got, err := people.Query(ctx). + WhereEdge("friends", `eq(name, $1)`, "Alice"). + Nodes() + if err != nil { + panic(err) + } + fmt.Println(len(got)) +} + +// ExampleQuery_IterNodes streams a large result set one page at a time over a +// single consistent snapshot, so the whole set is never held in memory at once. +func ExampleClient_iter() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + for person, err := range people.Query(ctx).OrderAsc("name").IterNodes() { + if err != nil { + panic(err) + } + fmt.Println(person.Name) + } +} + +// ExampleMultiQuery batches several same-type queries into one Dgraph +// round-trip, keyed by block name. +func ExampleMultiQuery() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + mq := typed.NewMultiQuery[Person](conn). + Add("adults", people.Query(ctx).Filter("ge(age, $1)", 18)). + Add("named_alice", people.Query(ctx).Filter(`eq(name, "Alice")`)) + + results, err := mq.Execute(ctx) // one round-trip + if err != nil { + panic(err) + } + fmt.Println(len(results["adults"]), len(results["named_alice"])) +} diff --git a/typed/filter/example_test.go b/typed/filter/example_test.go new file mode 100644 index 0000000..6c602f9 --- /dev/null +++ b/typed/filter/example_test.go @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package filter_test + +import ( + "fmt" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +// ExampleBuilder composes a parameterised @filter expression. Terms within a +// group join with OR; groups join with AND; required terms form their own +// group. Build returns the expression and the positional params that +// typed.Query[T].Filter consumes — the values never get interpolated into the +// string, so the expression is safe against injection. +func ExampleBuilder() { + var b filter.Builder + b.RequiredEq("archiveStatus", "Active") + b.EqGroupString("name", []filter.String{{Value: "Alice"}, {Value: "Bob"}}) + + expr, params := b.Build() + fmt.Println(expr) + fmt.Println(params...) + // Output: + // eq(archiveStatus, $1) AND (eq(name, $2) OR eq(name, $3)) + // Active Alice Bob +} diff --git a/typed/filter/filter.go b/typed/filter/filter.go new file mode 100644 index 0000000..d67f118 --- /dev/null +++ b/typed/filter/filter.go @@ -0,0 +1,118 @@ +// Package filter provides typed values and a parameterised expression builder +// for composing dgraph @filter clauses on generated Query types. +// +// Generated By methods accept []UUID or []String and feed them into +// Builder.EqGroupUUID / Builder.EqGroupString. Consumers can also build +// custom expressions directly with Builder for cases the generator does not +// cover (multi-predicate joins, non-equality operators, domain defaults). +package filter + +import ( + "fmt" + "strings" +) + +// UUID is one UUID-valued filter term, optionally negated. A leading "!" in +// the parsed source negates the term ("!abc" becomes {Negated: true, Value: "abc"}). +type UUID struct { + Negated bool + Value string +} + +// String is one string-valued filter term, optionally negated. +type String struct { + Negated bool + Value string +} + +// ParseUUID parses "value" or "!value" into a UUID. +func ParseUUID(s string) UUID { + neg, v := parseNegation(s) + return UUID{Negated: neg, Value: v} +} + +// ParseString parses "value" or "!value" into a String. +func ParseString(s string) String { + neg, v := parseNegation(s) + return String{Negated: neg, Value: v} +} + +func parseNegation(s string) (bool, string) { + if strings.HasPrefix(s, "!") { + return true, s[1:] + } + return false, s +} + +// term is one predicate-agnostic value used by Builder. +type term struct { + value string + negated bool +} + +// Builder composes parameterised DQL @filter expressions. Terms within an +// EqGroup join with OR; groups join with AND. Required terms become their own +// single-term group. The output is the (expression, positional params) pair +// that typed.Query[T].Filter consumes. +type Builder struct { + groups []string + params []any +} + +func (b *Builder) param(v any) string { + b.params = append(b.params, v) + return fmt.Sprintf("$%d", len(b.params)) +} + +// EqGroupUUID adds an OR-group of eq(predicate, value) terms for one +// UUID-typed predicate. An empty terms slice is a no-op. +func (b *Builder) EqGroupUUID(predicate string, terms []UUID) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +// EqGroupString adds an OR-group of eq(predicate, value) terms for one +// string-typed predicate. +func (b *Builder) EqGroupString(predicate string, terms []String) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +func (b *Builder) addEqGroup(predicate string, terms []term) { + parts := make([]string, 0, len(terms)) + for _, t := range terms { + eq := fmt.Sprintf("eq(%s, %s)", predicate, b.param(t.value)) + if t.negated { + eq = "NOT " + eq + } + parts = append(parts, eq) + } + b.groups = append(b.groups, "("+strings.Join(parts, " OR ")+")") +} + +// RequiredEq adds a single mandatory eq(predicate, value) term (its own group). +func (b *Builder) RequiredEq(predicate, value string) { + b.groups = append(b.groups, fmt.Sprintf("eq(%s, %s)", predicate, b.param(value))) +} + +// Build returns the combined DQL filter expression and its parameters. When +// no groups were added it returns ("", nil) — callers should skip the +// .Filter() call entirely in that case. +func (b *Builder) Build() (string, []any) { + if len(b.groups) == 0 { + return "", nil + } + return strings.Join(b.groups, " AND "), b.params +} diff --git a/typed/filter/filter_test.go b/typed/filter/filter_test.go new file mode 100644 index 0000000..10fdf2b --- /dev/null +++ b/typed/filter/filter_test.go @@ -0,0 +1,130 @@ +package filter_test + +import ( + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestParseUUID(t *testing.T) { + tests := []struct { + name string + in string + want filter.UUID + }{ + {"plain", "abc", filter.UUID{Value: "abc"}}, + {"negated", "!abc", filter.UUID{Negated: true, Value: "abc"}}, + {"empty", "", filter.UUID{}}, + {"just bang", "!", filter.UUID{Negated: true}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := filter.ParseUUID(tt.in) + if got != tt.want { + t.Errorf("ParseUUID(%q) = %+v, want %+v", tt.in, got, tt.want) + } + }) + } +} + +func TestParseString(t *testing.T) { + got := filter.ParseString("!hello") + want := filter.String{Negated: true, Value: "hello"} + if got != want { + t.Errorf("ParseString = %+v, want %+v", got, want) + } +} + +func TestBuilder_Empty(t *testing.T) { + var b filter.Builder + expr, params := b.Build() + if expr != "" || params != nil { + t.Errorf("empty Build = (%q, %v), want (\"\", nil)", expr, params) + } +} + +func TestBuilder_EqGroupUUID_SingleTerm(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "(eq(id, $1))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 1 || params[0] != "u1" { + t.Errorf("params = %v, want [u1]", params) + } +} + +func TestBuilder_EqGroupUUID_MultipleTermsJoinWithOR(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}, {Value: "u2"}, {Negated: true, Value: "u3"}}) + expr, params := b.Build() + want := "(eq(id, $1) OR eq(id, $2) OR NOT eq(id, $3))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 3 { + t.Errorf("len(params) = %d, want 3", len(params)) + } +} + +func TestBuilder_EqGroupString_NoTermsIsNoop(t *testing.T) { + var b filter.Builder + b.EqGroupString("name", nil) + expr, _ := b.Build() + if expr != "" { + t.Errorf("empty EqGroupString should be no-op, got expr=%q", expr) + } +} + +func TestBuilder_MultipleGroupsJoinWithAND(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + b.EqGroupString("name", []filter.String{{Value: "Alice"}}) + expr, params := b.Build() + want := "(eq(id, $1)) AND (eq(name, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "u1" || params[1] != "Alice" { + t.Errorf("params = %v, want [u1 Alice]", params) + } +} + +func TestBuilder_RequiredEqIsOwnGroup(t *testing.T) { + var b filter.Builder + b.RequiredEq("archiveStatus", "Active") + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "eq(archiveStatus, $1) AND (eq(id, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 { + t.Errorf("len(params) = %d, want 2", len(params)) + } +} + +func TestBuilder_PositionalParamsAreSequential(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "a"}, {Value: "b"}}) + b.EqGroupString("name", []filter.String{{Value: "c"}}) + expr, params := b.Build() + // Assert the exact expression: placeholders must be numbered $1..$N in + // emission order and bound to the matching params. A substring check would + // pass even if the numbering were scrambled (e.g. "$3 ... $1 ... $2"). + const want = "(eq(id, $1) OR eq(id, $2)) AND (eq(name, $3))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + wantParams := []any{"a", "b", "c"} + if len(params) != len(wantParams) { + t.Fatalf("params = %v, want %v", params, wantParams) + } + for i, p := range wantParams { + if params[i] != p { + t.Errorf("param[%d] = %v, want %v", i, params[i], p) + } + } +} diff --git a/typed/filter/fulltext.go b/typed/filter/fulltext.go new file mode 100644 index 0000000..a025ef0 --- /dev/null +++ b/typed/filter/fulltext.go @@ -0,0 +1,21 @@ +package filter + +import "fmt" + +// AnyOfText adds a fulltext OR-match group: anyoftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AnyOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("anyoftext(%s, %s)", predicate, b.param(term))) +} + +// AllOfText adds a fulltext AND-match group: alloftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AllOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("alloftext(%s, %s)", predicate, b.param(term))) +} diff --git a/typed/filter/fulltext_test.go b/typed/filter/fulltext_test.go new file mode 100644 index 0000000..1d71e0b --- /dev/null +++ b/typed/filter/fulltext_test.go @@ -0,0 +1,41 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestAnyOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "honda civic") + expr, params := b.Build() + if !strings.Contains(expr, "anyoftext(resourceName, $1)") { + t.Fatalf("expected anyoftext(resourceName, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "honda civic" { + t.Fatalf("expected params [\"honda civic\"], got %v", params) + } +} + +func TestAllOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AllOfText("description", "engine block") + expr, params := b.Build() + if !strings.Contains(expr, "alloftext(description, $1)") { + t.Fatalf("expected alloftext(description, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "engine block" { + t.Fatalf("expected params [\"engine block\"], got %v", params) + } +} + +func TestAnyOfTextEmptyTermIsNoop(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "") + expr, params := b.Build() + if expr != "" || params != nil { + t.Fatalf("expected empty expr/params for empty term, got %q / %v", expr, params) + } +} diff --git a/typed/multi_query.go b/typed/multi_query.go new file mode 100644 index 0000000..a07d279 --- /dev/null +++ b/typed/multi_query.go @@ -0,0 +1,284 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// MultiQuery batches N homogeneous-type Query[T] blocks into a single +// Dgraph multi-block request. All blocks return rows of the same T; the +// per-block result is keyed by the block name supplied at Add. +// +// Dgraph executes the blocks concurrently on the server side; the entire +// batch costs one gRPC round-trip. +type MultiQuery[T any] struct { + conn modusgraph.Client + names []string + blocks map[string]*Query[T] +} + +// NewMultiQuery constructs a MultiQuery bound to conn. +func NewMultiQuery[T any](conn modusgraph.Client) *MultiQuery[T] { + return &MultiQuery[T]{ + conn: conn, + blocks: make(map[string]*Query[T]), + } +} + +// Add registers a named block. Names must be unique within one MultiQuery, and +// each *Query[T] may be added only once: Execute names the block's underlying +// dgman query in place, so registering the same Query pointer under two names +// would make both blocks render with whichever name was applied last. Both +// conditions are programming errors and panic rather than fail at runtime. +func (mq *MultiQuery[T]) Add(name string, q *Query[T]) *MultiQuery[T] { + if _, exists := mq.blocks[name]; exists { + panic(fmt.Sprintf("multi_query: duplicate block name %q", name)) + } + for existingName, existing := range mq.blocks { + if existing == q { + panic(fmt.Sprintf("multi_query: Query already added as %q; build a separate Query per block", existingName)) + } + } + mq.names = append(mq.names, name) + mq.blocks[name] = q + return mq +} + +// BlockNames returns the registered block names in insertion order. +func (mq *MultiQuery[T]) BlockNames() []string { + out := make([]string, len(mq.names)) + copy(out, mq.names) + return out +} + +// Execute runs every registered block in a single Dgraph round-trip and +// returns the per-block results, keyed by the block name supplied at Add. +// A block that matched no rows appears as an empty (non-nil) slice in the +// result map; the key is always present. +// +// Execute rejects blocks that carry WhereEdge constraints — those require a +// runtime pre-pass that cannot be folded into the multi-block batch. Run such +// queries individually with Query.Nodes. +// +// Dgraph keys response JSON by predicate name (e.g. resourceName), but Go +// structs typically use their json tag (e.g. name). Execute remaps the keys +// per T's tags before decoding — recursing into nested edge structs — so a +// schema that uses `dgraph:"predicate=..."` with a divergent `json:"..."` +// decodes correctly at every depth, matching dgman's QueryBlock.Scan path. +func (mq *MultiQuery[T]) Execute(ctx context.Context) (map[string][]T, error) { + if len(mq.names) == 0 { + return map[string][]T{}, nil + } + + rawBlocks := make([]*dg.Query, 0, len(mq.names)) + for _, name := range mq.names { + block := mq.blocks[name] + if len(block.edges) != 0 { + return nil, fmt.Errorf( + "multi_query: block %q carries WhereEdge constraints; "+ + "MultiQuery cannot batch edge-filtered blocks", name) + } + // Name the underlying dgman query so blocks do not collide on the + // default "data" name and so the response JSON keys are predictable. + block.q.Name(name) + rawBlocks = append(rawBlocks, block.q) + } + + dql := dg.NewQueryBlock(rawBlocks...).String() + raw, err := mq.conn.QueryRaw(ctx, dql, nil) + if err != nil { + return nil, fmt.Errorf("multi_query: dgraph: %w", err) + } + + var perBlockRaw map[string]json.RawMessage + if err := json.Unmarshal(raw, &perBlockRaw); err != nil { + return nil, fmt.Errorf("multi_query: decoding response: %w", err) + } + + rowType := reflect.TypeFor[T]() + + out := make(map[string][]T, len(mq.names)) + for _, name := range mq.names { + body, ok := perBlockRaw[name] + if !ok { + out[name] = []T{} + continue + } + remapped, err := remapPredicateKeys(body, rowType) + if err != nil { + return nil, fmt.Errorf("multi_query: remapping block %q: %w", name, err) + } + body = remapped + var rows []T + if err := json.Unmarshal(body, &rows); err != nil { + return nil, fmt.Errorf("multi_query: decoding block %q: %w", name, err) + } + if rows == nil { + rows = []T{} + } + out[name] = rows + } + return out, nil +} + +// buildPredicateToJSONMap returns a map from dgraph predicate name → JSON tag +// name for fields on T where the two differ. Mirrors dgman's unexported helper +// of the same name; we need our own because the multi-block response from +// QueryRaw bypasses dgman's scan path. +func buildPredicateToJSONMap(t reflect.Type) map[string]string { + t = getElemType(t) + if t == nil || t.Kind() != reflect.Struct { + return nil + } + result := make(map[string]string) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + jsonName := strings.Split(jsonTag, ",")[0] + if jsonName == "" { + continue + } + dgraphTag := field.Tag.Get("dgraph") + if dgraphTag == "" { + continue + } + var predName string + for part := range strings.FieldsSeq(dgraphTag) { + if p, ok := strings.CutPrefix(part, "predicate="); ok { + predName = p + break + } + } + if predName == "" || predName == jsonName { + continue + } + if predName == "uid" || predName == "dgraph.type" { + continue + } + result[predName] = jsonName + } + return result +} + +// remapPredicateKeys rewrites dgraph predicate names to JSON-tag names +// throughout a block body, descending into nested edge structs so a renamed +// predicate (dgraph `predicate=` diverging from `json=`) on a nested edge still +// decodes. It reproduces dgman's QueryBlock.Scan remap, which QueryRaw bypasses. +// The walk is type-driven by dstType (the row type T, or a nested field type), +// so only declared edge fields are recursed into; scalars and unrecognized keys +// pass through untouched. The top-level structural error is returned (so a +// malformed block surfaces its root cause); nested remaps are best-effort, since +// a type mismatch there surfaces with full context at the caller's typed decode. +func remapPredicateKeys(data json.RawMessage, dstType reflect.Type) (json.RawMessage, error) { + dstType = getElemType(dstType) + if dstType == nil || dstType.Kind() != reflect.Struct { + return data, nil + } + switch firstNonSpace(data) { + case '[': + return remapArray(data, dstType) + case '{': + return remapObject(data, dstType) + default: + return data, nil + } +} + +// remapArray applies the per-object remap to every element of a JSON array whose +// elements decode into dstType. +func remapArray(data json.RawMessage, dstType reflect.Type) (json.RawMessage, error) { + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + return nil, err + } + for i, item := range arr { + if remapped, err := remapPredicateKeys(item, dstType); err == nil { + arr[i] = remapped + } + } + return json.Marshal(arr) +} + +// remapObject renames dstType's renamed-predicate keys in one JSON object and +// recurses into its edge fields (struct- or slice-of-struct-typed). +func remapObject(data json.RawMessage, dstType reflect.Type) (json.RawMessage, error) { + var obj map[string]json.RawMessage + if err := json.Unmarshal(data, &obj); err != nil { + return nil, err + } + predMap := buildPredicateToJSONMap(dstType) + fieldTypes := buildFieldTypeMap(dstType) + remapped := make(map[string]json.RawMessage, len(obj)) + for key, val := range obj { + newKey := key + if mapped, ok := predMap[key]; ok { + newKey = mapped + } + if ft, ok := fieldTypes[newKey]; ok && getElemType(ft).Kind() == reflect.Struct { + if rv, err := remapPredicateKeys(val, ft); err == nil { + val = rv + } + } + remapped[newKey] = val + } + return json.Marshal(remapped) +} + +// buildFieldTypeMap maps each of t's JSON-tag names to its field type, so the +// remap can decide which keys are edges worth recursing into. +func buildFieldTypeMap(t reflect.Type) map[string]reflect.Type { + t = getElemType(t) + if t == nil || t.Kind() != reflect.Struct { + return nil + } + result := make(map[string]reflect.Type, t.NumField()) + for i := range t.NumField() { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + jsonName := strings.Split(jsonTag, ",")[0] + if jsonName == "" { + continue + } + result[jsonName] = field.Type + } + return result +} + +// getElemType unwraps pointer, slice, and array types to their base element +// type, so an edge field declared as *T, []T, or []*T resolves to T. +func getElemType(t reflect.Type) reflect.Type { + for t != nil && (t.Kind() == reflect.Pointer || t.Kind() == reflect.Slice || t.Kind() == reflect.Array) { + t = t.Elem() + } + return t +} + +// firstNonSpace returns the first non-whitespace byte of b, or 0 if none. +func firstNonSpace(b []byte) byte { + for _, c := range b { + switch c { + case ' ', '\t', '\n', '\r': + continue + default: + return c + } + } + return 0 +} diff --git a/typed/multi_query_test.go b/typed/multi_query_test.go new file mode 100644 index 0000000..0d4b53f --- /dev/null +++ b/typed/multi_query_test.go @@ -0,0 +1,190 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestMultiQueryAddAccumulatesBlocks(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q1 := typed.NewClient[widget](conn).Query(context.Background()) + q2 := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q1) + mq.Add("byQty", q2) + got := mq.BlockNames() + if len(got) != 2 || got[0] != "byName" || got[1] != "byQty" { + t.Fatalf("BlockNames = %v, want [byName, byQty]", got) + } +} + +func TestMultiQueryAddRejectsDuplicateName(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on duplicate block name") + } + }() + mq.Add("byName", q) +} + +func TestMultiQueryAddRejectsSameQueryTwice(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("first", q) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic when the same Query is added under two names") + } + }() + // Execute names the block's underlying query in place, so reusing one Query + // pointer would corrupt block composition; Add must reject it up front. + mq.Add("second", q) +} + +func TestMultiQueryExecuteReturnsPerBlockResults(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[widget](conn) + + for _, w := range []*widget{ + {Name: "sprocket", Qty: 1}, + {Name: "gear", Qty: 5}, + {Name: "bolt", Qty: 10}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[widget](conn) + mq.Add("all", c.Query(ctx)) + mq.Add("filtered", c.Query(ctx).Filter("eq(name, $1)", "gear")) + + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if got := len(results["all"]); got != 3 { + t.Fatalf("results[all] has %d rows, want 3", got) + } + if got := len(results["filtered"]); got != 1 { + t.Fatalf("results[filtered] has %d rows, want 1", got) + } + if results["filtered"][0].Name != "gear" { + t.Fatalf("results[filtered][0].Name = %q, want gear", results["filtered"][0].Name) + } +} + +func TestMultiQueryExecuteEmptyReturnsEmptyMap(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + results, err := mq.Execute(context.Background()) + if err != nil { + t.Fatalf("Execute on empty MultiQuery: %v", err) + } + if len(results) != 0 { + t.Fatalf("expected empty map, got %v", results) + } +} + +// renamed exercises the predicate-vs-json-tag remap. Dgraph returns the +// "thingName" key (the predicate name) but the struct's JSON tag is +// "name"; MultiQuery.Execute must remap before unmarshaling so Name +// populates. +type renamed struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"predicate=thingName index=hash,fulltext"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +func TestMultiQueryExecuteRemapsPredicateKeys(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[renamed](conn) + + for _, w := range []*renamed{ + {Name: "alpha", Qty: 1}, + {Name: "beta", Qty: 2}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[renamed](conn) + mq.Add("all", c.Query(ctx)) + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rows := results["all"] + if len(rows) != 2 { + t.Fatalf("rows = %d, want 2", len(rows)) + } + for _, r := range rows { + if r.Name == "" { + t.Fatalf("Name not populated; multi-block response was not remapped from predicate key: %+v", r) + } + } +} + +// nestedChild sits on a nested edge and carries a renamed predicate: dgraph +// names it "childLabel" but the struct tags it json:"label". +type nestedChild struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Label string `json:"label,omitempty" dgraph:"predicate=childLabel index=exact"` +} + +// nestedParent has no renamed top-level predicate, only a renamed one on its +// nested child. A top-level-only remap therefore skips the block entirely and +// leaves Children[i].Label empty; the recursive remap descends the edge. +type nestedParent struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Children []*nestedChild `json:"children,omitempty"` +} + +func TestMultiQueryExecuteRemapsNestedPredicateKeys(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[nestedParent](conn) + + p := &nestedParent{Name: "root", Children: []*nestedChild{{Label: "leaf"}}} + if err := c.Add(ctx, p); err != nil { + t.Fatalf("Add: %v", err) + } + + mq := typed.NewMultiQuery[nestedParent](conn) + mq.Add("all", c.Query(ctx)) + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rows := results["all"] + if len(rows) != 1 { + t.Fatalf("rows = %d, want 1", len(rows)) + } + if len(rows[0].Children) != 1 { + t.Fatalf("Children = %d, want 1", len(rows[0].Children)) + } + // With the old top-level-only remap this is empty: the nested "childLabel" + // predicate key is never rewritten to the struct's json:"label". + if got := rows[0].Children[0].Label; got != "leaf" { + t.Fatalf("nested Children[0].Label = %q, want leaf; nested predicate key not remapped", got) + } +} diff --git a/typed/option.go b/typed/option.go new file mode 100644 index 0000000..d944483 --- /dev/null +++ b/typed/option.go @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +// Option configures a *T. Generated With constructors return an Option; +// generated New/Wrap constructors apply them via Apply. +type Option[T any] func(*T) + +// Apply applies opts to target in declaration order. +func Apply[T any](target *T, opts ...Option[T]) { + for _, opt := range opts { + opt(target) + } +} diff --git a/typed/option_test.go b/typed/option_test.go new file mode 100644 index 0000000..7c1f378 --- /dev/null +++ b/typed/option_test.go @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestApply_RunsOptionsInOrder(t *testing.T) { + type rec struct{ trail []string } + r := &rec{} + + typed.Apply(r, + func(x *rec) { x.trail = append(x.trail, "a") }, + func(x *rec) { x.trail = append(x.trail, "b") }, + func(x *rec) { x.trail = append(x.trail, "c") }, + ) + + if got := strings.Join(r.trail, ""); got != "abc" { + t.Fatalf("Apply ran options as %q, want %q", got, "abc") + } +} + +func TestApply_NoOptionsIsNoop(t *testing.T) { + type rec struct{ n int } + r := &rec{n: 7} + typed.Apply(r) + if r.n != 7 { + t.Fatalf("Apply with no options mutated target: n = %d, want 7", r.n) + } +} diff --git a/typed/query.go b/typed/query.go new file mode 100644 index 0000000..7be8dab --- /dev/null +++ b/typed/query.go @@ -0,0 +1,664 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "encoding/json" + "fmt" + "iter" + "reflect" + "strconv" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// Block names and the query-variable name used by the WhereEdge server-side var +// query. The var block binds matched root UIDs; the data and count blocks +// consume uid(edgeVarName), so the UIDs never leave the server. +const ( + edgeVarName = "mgMatched" + edgeDataBlock = "mgData" + edgeCountBlock = "mgCount" +) + +// Query is a fluent, type-safe query builder over records of type T. Builder +// methods return *Query[T] for chaining, except As, Var, and GroupBy, which +// change the result shape and transition to *RawQuery; terminal methods +// (Nodes, First, IterNodes) execute the query and decode typed results. +// +// A Query is single-use. Builder methods mutate the underlying query in place +// and return the same *Query, so a Query value should be built as one chain +// and handed to a single terminal. It is not safe to save a Query to a +// variable and branch it into independent queries: every branch shares — and +// keeps mutating — the same underlying query. +// +// Repeated builder calls do not all behave the same way. Limit, Offset, After, +// Cascade, Name, RootFunc, and Vars overwrite: the last call wins. Filter, +// OrderAsc, OrderDesc, and WhereEdge accumulate: each call adds to the query. +// Accumulated Filter fragments AND together (see CombinedFilter, OrGroup). +// +// Limit and Offset additionally record the bounds that IterNodes pages +// within — a Limit caps the rows it streams, an Offset is its start. +type Query[T any] struct { + q *dg.Query + conn modusgraph.Client // runs the WhereEdge pre-pass; set by Client.Query + ctx context.Context // carried for the WhereEdge pre-pass query + limit int // caller-set row cap; 0 = unbounded + offset int // caller-set starting offset; 0 = none + edges []edgeFilter // accumulated WhereEdge constraints; empty = none + filters []filterFrag // accumulated @filter fragments, ANDed; empty = none + + // customRootExpr is the caller's root narrowing (set by UID or RootFunc), or + // "" if none. The WhereEdge var block roots at it, so the matched UIDs are the + // intersection of the caller's root and the edge constraints rather than + // overwriting the caller's root (see edgeVarBlock). + customRootExpr string + + // varsFuncDef and varsMap hold GraphQL named variables set via Vars. The + // WhereEdge path renders its own multi-block request, so runEdge forwards + // these to the QueryBlock and QueryRaw; without that they would ride only on + // qb.q and be dropped when the request is composed and run as raw DQL. + varsFuncDef string + varsMap map[string]string +} + +// edgeFilter is one accumulated WhereEdge constraint: a dgraph @filter +// expression scoped to an outbound edge predicate of T. +type edgeFilter struct { + predicate string + filter string + params []any +} + +// filterFrag is one accumulated @filter fragment. Fragments join with AND. +type filterFrag struct { + expr string + params []any +} + +// NewDetachedQuery returns a Query[T] with no connection, used only to +// accumulate a filter expression: its By/Filter calls record fragments +// that CombinedFilter reads back. It must not be executed (it has no terminal +// path) and exists as the capture target behind the generated Or and +// WhereBy combinators. +func NewDetachedQuery[T any]() *Query[T] { + return &Query[T]{} +} + +// Filter adds a dgraph @filter expression. params bind to placeholders. +// Repeated calls accumulate: every fragment ANDs together. +func (qb *Query[T]) Filter(filter string, params ...any) *Query[T] { + qb.addFilter(filter, params) + return qb +} + +// addFilter accumulates one @filter fragment. Fragments AND together: the +// effective filter is every fragment joined with AND, each fragment's $N +// placeholders shifted to stay bound to its own params. dgman's own Filter is +// last-write-wins, so the full combined expression is re-pushed on every call. +// A detached query (nil q — used to capture a sub-scope's filter for OrGroup or +// WhereBy) accumulates with no dgman query to push to; CombinedFilter +// reads the fragments back. +func (qb *Query[T]) addFilter(expr string, params []any) { + if expr == "" { + return + } + qb.filters = append(qb.filters, filterFrag{expr: expr, params: params}) + if qb.q != nil { + combined, cp := combineAnd(qb.filters) + qb.q.Filter(combined, cp...) + } +} + +// combineAnd joins fragments with AND, renumbering each fragment's ordinal +// placeholders against the concatenated params slice. Each fragment is wrapped +// in its own parentheses so a fragment that itself contains OR keeps its +// intended precedence: without the parens, "a OR b" ANDed with "c" would parse +// as "a OR (b AND c)" because dgraph binds AND tighter than OR. +func combineAnd(frags []filterFrag) (string, []any) { + parts := make([]string, 0, len(frags)) + var params []any + for _, f := range frags { + if f.expr == "" { + continue + } + parts = append(parts, "("+shiftPlaceholders(f.expr, len(params))+")") + params = append(params, f.params...) + } + if len(parts) == 0 { + return "", nil + } + return strings.Join(parts, " AND "), params +} + +// CombinedFilter returns the AND-combined accumulated @filter expression and +// its params, or ("", nil) when no filter was set. It is the substrate behind +// the generated Or and WhereBy combinators: they run a sub-scope's +// By/Filter calls against a detached query, then fold the captured +// expression into a parent OR group or edge constraint. +func (qb *Query[T]) CombinedFilter() (string, []any) { + return combineAnd(qb.filters) +} + +// OrGroup adds one @filter group that ORs the combined filter of each sub. +// Each sub is a detached Query[T] whose By/Filter calls have been +// accumulated; their combined (AND) expressions are parenthesized, joined with +// OR, and the whole OR group ANDs with the receiver's other filters. Subs with +// an empty filter are skipped; an all-empty OrGroup is a no-op. It is the +// substrate behind the generated Query.Or combinator. +func (qb *Query[T]) OrGroup(subs ...*Query[T]) *Query[T] { + parts := make([]string, 0, len(subs)) + var params []any + for _, s := range subs { + e, p := s.CombinedFilter() + if e == "" { + continue + } + parts = append(parts, "("+shiftPlaceholders(e, len(params))+")") + params = append(params, p...) + } + if len(parts) == 0 { + return qb + } + qb.addFilter("("+strings.Join(parts, " OR ")+")", params) + return qb +} + +// OrderAsc orders results ascending by clause. +func (qb *Query[T]) OrderAsc(clause string) *Query[T] { + qb.q.OrderAsc(clause) + return qb +} + +// OrderDesc orders results descending by clause. +func (qb *Query[T]) OrderDesc(clause string) *Query[T] { + qb.q.OrderDesc(clause) + return qb +} + +// Limit caps the number of results. dgman names this First; it is renamed +// here so it does not collide with the First terminal. +func (qb *Query[T]) Limit(n int) *Query[T] { + qb.limit = n + qb.q.First(n) + return qb +} + +// Offset skips the first n results. +func (qb *Query[T]) Offset(n int) *Query[T] { + qb.offset = n + qb.q.Offset(n) + return qb +} + +// After returns results with UID greater than uid (cursor pagination). +func (qb *Query[T]) After(uid string) *Query[T] { + qb.q.After(uid) + return qb +} + +// Cascade drops nodes missing any of the given predicates (all, if none given). +func (qb *Query[T]) Cascade(predicates ...string) *Query[T] { + qb.q.Cascade(predicates...) + return qb +} + +// RootFunc overrides the query root function. dgman's default root function +// is type(); RootFunc replaces it with an expression such as +// eq(name, "Alice") or has(email). Repeated calls overwrite. +func (qb *Query[T]) RootFunc(rootFunc string) *Query[T] { + qb.customRootExpr = rootFunc + qb.q.RootFunc(rootFunc) + return qb +} + +// Name sets the query block name. It defaults to "data"; dgman uses the name +// to both generate and decode the query, so a renamed block still decodes +// into []T. Repeated calls overwrite. +func (qb *Query[T]) Name(queryName string) *Query[T] { + qb.q.Name(queryName) + return qb +} + +// Vars supplies GraphQL variables for a parameterized query: funcDef is the +// query function definition (for example "getByName($n: string)") and vars +// binds each variable. The query then executes via dgraph's QueryWithVars +// path. Repeated calls overwrite. +func (qb *Query[T]) Vars(funcDef string, vars map[string]string) *Query[T] { + qb.varsFuncDef = funcDef + qb.varsMap = vars + qb.q.Vars(funcDef, vars) + return qb +} + +// WhereEdge constrains results to records that have at least one `predicate` +// edge whose target node satisfies the dgraph @filter expression. params bind +// to $N placeholders within filter, exactly as Filter binds them. +// +// Where Filter constrains T's own scalar predicates, WhereEdge constrains a +// neighbouring node reached over an edge. dgraph's root @filter cannot express +// that, so a query carrying WhereEdge constraints executes as one request built +// around a server-side var block: an @cascade block binds the UIDs of roots +// that satisfy every constraint, and the data block roots at uid(...) of that +// var — keeping ordering, pagination, and result projection on the normal path +// while the matched UIDs never leave the server. See +// docs/specs/2026-05-21-query-edge-filter-design.md. +// +// WhereEdge accumulates: multiple calls AND together (a record must satisfy +// every edge constraint). It is the substrate behind the generated +// Query.Where methods. +func (qb *Query[T]) WhereEdge(predicate, filter string, params ...any) *Query[T] { + qb.edges = append(qb.edges, edgeFilter{predicate: predicate, filter: filter, params: params}) + return qb +} + +// WhereAnyOfText adds an @filter(anyoftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAnyOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("anyoftext(%s, $1)", predicate), []any{term}) + return qb +} + +// WhereAllOfText adds an @filter(alloftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAllOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("alloftext(%s, $1)", predicate), []any{term}) + return qb +} + +// As names the query block as a dgraph query variable. dgraph requires such a +// variable be consumed by another block, which a single-block typed query +// cannot do, so As transitions out of the typed query: it returns a *RawQuery, +// which exposes no node terminal. +func (qb *Query[T]) As(varName string) *RawQuery { + qb.q.As(varName) + return &RawQuery{q: qb.q} +} + +// Var marks the query block as a dgraph var block. A var block computes query +// variables and returns no data of its own, so Var transitions out of the +// typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) Var() *RawQuery { + qb.q.Var() + return &RawQuery{q: qb.q} +} + +// GroupBy adds an @groupby(predicate) aggregation. A grouped query returns +// aggregation groups rather than a slice of T, so GroupBy transitions out of +// the typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) GroupBy(predicate string) *RawQuery { + qb.q.GroupBy(predicate) + return &RawQuery{q: qb.q} +} + +// Nodes executes the query and returns all matching records. +func (qb *Query[T]) Nodes() (out []T, err error) { + _, span := currentTracer().StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + if len(qb.edges) > 0 { + out, _, err = qb.runEdge(false) + return out, err + } + if err = qb.q.Nodes(&out); err != nil { + return nil, err + } + return out, nil +} + +// First executes the query with an implicit Limit(1) and returns the first +// record, or (nil, nil) if the query matched no rows. +func (qb *Query[T]) First() (rec *T, err error) { + _, span := currentTracer().StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + var out []T + if len(qb.edges) > 0 { + qb.q.First(1) + out, _, err = qb.runEdge(false) + } else { + err = qb.q.First(1).Nodes(&out) + } + if err != nil { + return nil, err + } + if len(out) == 0 { + return nil, nil + } + return &out[0], nil +} + +// IterNodes executes the query and returns an iterator over matching records, +// paging transparently so a large result set is never materialized at once. +// +// IterNodes is a terminal operation: it drives Offset/Limit internally as it +// pages and leaves the builder spent — do not call another terminal on the +// same Query afterward. A Limit set on the query caps the total number of +// rows streamed; an Offset is the starting point. +// +// With no WhereEdge constraints, every page executes against one read-only +// transaction, so the iteration reads a single consistent snapshot: a +// concurrent writer cannot make it skip or repeat rows. With WhereEdge +// constraints, each page is its own request that re-resolves the server-side +// match var — keeping memory bounded, at the cost of reading each page from a +// fresh snapshot. On error it yields a final (nil, err) and stops. +func (qb *Query[T]) IterNodes() iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + _, span := currentTracer().StartSpan(qb.ctx, "query", entityName[T]()) + var ferr error + defer func() { span.End(ferr) }() + remaining := qb.limit // 0 = unbounded + for off := qb.offset; ; off += defaultPageSize { + size := defaultPageSize + if remaining > 0 && remaining < size { + size = remaining // shrink the last page so it can't overshoot the cap + } + var page []T + var err error + if len(qb.edges) > 0 { + // Each page re-resolves the WhereEdge var server-side, so no page + // materializes the full matched-UID set. + qb.q.Offset(off).First(size) + page, _, err = qb.runEdge(false) + } else { + err = qb.q.Offset(off).First(size).Nodes(&page) + } + if err != nil { + ferr = err + yield(nil, err) + return + } + for i := range page { + if !yield(&page[i], nil) { + return // consumer broke out + } + } + if remaining > 0 { + if remaining -= len(page); remaining <= 0 { + return // hit the caller's Limit + } + } + if len(page) < size { + return // result set exhausted + } + } + } +} + +// Raw returns the underlying dgman query for operations Query does not wrap +// (for example the raw-selection Query method). Raw does not carry WhereEdge +// constraints — those are resolved only when a terminal runs. +func (qb *Query[T]) Raw() *dg.Query { + return qb.q +} + +// UID roots the query at a specific node UID. Results still decode into []T. +func (qb *Query[T]) UID(uid string) *Query[T] { + qb.customRootExpr = "uid(" + uid + ")" + qb.q.UID(uid) + return qb +} + +// All sets the edge-traversal depth for this query, overriding the client's +// default maxEdgeTraversal. Use a small depth to stay under Dgraph's 4MB gRPC +// limit on highly-connected entities. +func (qb *Query[T]) All(depth int) *Query[T] { + qb.q.All(depth) + return qb +} + +// NodesAndCount executes the query and returns the matching records together +// with the total count (useful for pagination totals). Like Nodes, it runs the +// WhereEdge pre-pass first when edge constraints are present. +func (qb *Query[T]) NodesAndCount() (out []T, count int, err error) { + _, span := currentTracer().StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + if len(qb.edges) > 0 { + return qb.runEdge(true) + } + count, err = qb.q.NodesAndCount(&out) + if err != nil { + return nil, 0, err + } + return out, count, nil +} + +// String renders the generated DQL without executing it. WhereEdge constraints +// are not reflected — they are resolved only when a terminal runs. +func (qb *Query[T]) String() string { + return qb.q.String() +} + +// FormatBlock renders the query as a single DQL block named name, without +// executing it. The returned text is suitable for inclusion inside a wrapping +// "{ ... }" multi-block request — it does not include outer braces. +// +// FormatBlock is the substrate behind MultiQuery; external callers can use it +// to compose typed queries into larger hand-written DQL requests. +// +// Filter parameters are inlined at Filter-call time (dgman renders $N +// placeholders into the filter string immediately), so the returned block +// carries no unresolved variables. WhereEdge constraints are not formatted — +// they require a runtime pre-pass and would produce no useful output here. +func (qb *Query[T]) FormatBlock(name string) (string, error) { + if len(qb.edges) != 0 { + return "", fmt.Errorf("typed: FormatBlock cannot render a Query carrying WhereEdge constraints") + } + qb.q.Name(name) + wrapped := dg.NewQueryBlock(qb.q).String() + // QueryBlock.String() wraps the block in "{\n ... }" — strip the wrapper so + // the caller can compose blocks inside their own braces. + inner := strings.TrimPrefix(wrapped, "{\n") + inner = strings.TrimSuffix(inner, "}") + return inner, nil +} + +// RawQuery is a query whose result is not a slice of T — produced by the +// shape-changing builders Query.As, Query.Var, and Query.GroupBy. A RawQuery +// deliberately exposes no typed node terminal: its result must be decoded by +// the caller through the underlying dgman query, obtained via Raw. +type RawQuery struct { + q *dg.Query +} + +// Raw returns the underlying dgman query, for the caller to execute and decode. +func (r *RawQuery) Raw() *dg.Query { + return r.q +} + +// String returns the generated DQL. +func (r *RawQuery) String() string { + return r.q.String() +} + +// As names the block as a dgraph query variable. See Query.As. +func (r *RawQuery) As(varName string) *RawQuery { + r.q.As(varName) + return r +} + +// Var marks the block as a dgraph var block. See Query.Var. +func (r *RawQuery) Var() *RawQuery { + r.q.Var() + return r +} + +// GroupBy adds an @groupby(predicate) aggregation. See Query.GroupBy. +func (r *RawQuery) GroupBy(predicate string) *RawQuery { + r.q.GroupBy(predicate) + return r +} + +// runEdge executes a WhereEdge query as a single server-side request: a var +// block binds the matched root UIDs, and the data block (plus a count block when +// withCount) consumes uid(mgMatched). The matched UIDs stay on the server — they +// are never materialized into the client or inlined into a uid(...) literal — so +// memory and DQL size stay bounded regardless of how many roots match. +// +// runEdge is idempotent in qb: edgeBlocks pushes the data-block filter +// last-write-wins onto qb.q and never mutates the accumulated filters, so +// IterNodes can call runEdge once per page (each page re-resolves the var +// server-side). +func (qb *Query[T]) runEdge(withCount bool) (rows []T, count int, err error) { + block := dg.NewQueryBlock(qb.edgeBlocks(withCount)...) + // Forward any GraphQL named variables set via Vars: dgman renders the + // "query " declaration only when the QueryBlock carries them, and + // QueryRaw binds them at execution. + if qb.varsMap != nil { + block.Vars(qb.varsFuncDef, qb.varsMap) + } + raw, err := qb.conn.QueryRaw(qb.ctx, block.String(), qb.varsMap) + if err != nil { + return nil, 0, fmt.Errorf("typed: WhereEdge query: %w", err) + } + var perBlock map[string]json.RawMessage + if err := json.Unmarshal(raw, &perBlock); err != nil { + return nil, 0, fmt.Errorf("typed: decoding WhereEdge response: %w", err) + } + if body, ok := perBlock[edgeDataBlock]; ok { + remapped, rerr := remapPredicateKeys(body, reflect.TypeFor[T]()) + if rerr != nil { + return nil, 0, fmt.Errorf("typed: remapping WhereEdge rows: %w", rerr) + } + if err := json.Unmarshal(remapped, &rows); err != nil { + return nil, 0, fmt.Errorf("typed: decoding WhereEdge rows: %w", err) + } + } + if withCount { + count, err = decodeCount(perBlock[edgeCountBlock]) + if err != nil { + return nil, 0, err + } + } + return rows, count, nil +} + +// edgeBlocks assembles the var block, the data block, and (when withCount) the +// count block for a WhereEdge query. The matched UIDs are captured in the var +// block and consumed by uid(mgMatched), never inlined as a literal list. +// +// It pushes the data-block filter last-write-wins onto qb.q and leaves the +// accumulated qb.filters untouched, so it is safe to call once per IterNodes +// page. The caller's @filter is captured before the uid() term is appended, so +// the count block re-applies the same user filter without it. +func (qb *Query[T]) edgeBlocks(withCount bool) []*dg.Query { + userExpr, userParams := combineAnd(qb.filters) + + dataExpr := "uid(" + edgeVarName + ")" + if userExpr != "" { + dataExpr = "(" + userExpr + ") AND uid(" + edgeVarName + ")" + } + qb.q.Filter(dataExpr, userParams...).Name(edgeDataBlock) + + blocks := []*dg.Query{qb.edgeVarBlock(), qb.q} + if withCount { + blocks = append(blocks, qb.edgeCountBlock(userExpr, userParams)) + } + return blocks +} + +// edgeVarBlock builds the var block that binds mgMatched to the roots surviving +// @cascade over every WhereEdge constraint. It roots at the caller's narrowing +// (UID/RootFunc) when present, so mgMatched is the intersection of the caller's +// root and the edge constraints rather than discarding the caller's root. +func (qb *Query[T]) edgeVarBlock() *dg.Query { + var z T + body, params := qb.edgeMatchBody() + v := qb.conn.Query(qb.ctx, &z) + if qb.customRootExpr != "" { + v.RootFunc(qb.customRootExpr) + } + v.As(edgeVarName).Var().Cascade().Query(body, params...) + return v +} + +// edgeCountBlock builds the count block: count(uid) over uid(mgMatched) with the +// caller's @filter re-applied, so the total matches the rows the data block +// would return without pagination. +func (qb *Query[T]) edgeCountBlock(userExpr string, userParams []any) *dg.Query { + var z T + c := qb.conn.Query(qb.ctx, &z) + c.RootFunc("uid(" + edgeVarName + ")") + if userExpr != "" { + c.Filter(userExpr, userParams...) + } + c.Query("{ count(uid) }").Name(edgeCountBlock) + return c +} + +// decodeCount reads the count(uid) aggregation from a count block body of the +// form [{"count": N}]. +func decodeCount(body json.RawMessage) (int, error) { + if len(body) == 0 { + return 0, nil + } + var rows []struct { + Count int `json:"count"` + } + if err := json.Unmarshal(body, &rows); err != nil { + return 0, fmt.Errorf("typed: decoding WhereEdge count: %w", err) + } + if len(rows) == 0 { + return 0, nil + } + return rows[0].Count, nil +} + +// edgeMatchBody renders the selection set for the pre-pass: uid plus one +// aliased, filtered block per WhereEdge constraint. The caller adds a bare +// @cascade, which then drops any node with an empty block — so a survivor +// satisfies every constraint. Blocks are aliased mg_e0, mg_e1, ... so two +// constraints on the same predicate do not collide as duplicate fields. Each +// fragment's $N placeholders are shifted to stay bound to its own params once +// every fragment's params are concatenated into one slice. +func (qb *Query[T]) edgeMatchBody() (body string, params []any) { + var b strings.Builder + b.WriteString("{\n\tuid\n") + for i, e := range qb.edges { + b.WriteString("\tmg_e") + b.WriteString(strconv.Itoa(i)) + b.WriteString(" : ") + b.WriteString(e.predicate) + b.WriteString(" @filter(") + b.WriteString(shiftPlaceholders(e.filter, len(params))) + b.WriteString(") { uid }\n") + params = append(params, e.params...) + } + b.WriteString("}") + return b.String(), params +} + +// shiftPlaceholders rewrites dgman ordinal placeholders ($1, $2, ...) in expr, +// adding delta to each index. WhereEdge filters are written independently, each +// numbering its params from $1; concatenating them into one pre-pass body +// needs every fragment renumbered against the combined params slice. A '$' not +// followed by a digit is left as-is, matching dgman's parseQueryWithParams. +func shiftPlaceholders(expr string, delta int) string { + if delta == 0 || !strings.ContainsRune(expr, '$') { + return expr + } + var b strings.Builder + for i := 0; i < len(expr); i++ { + if expr[i] != '$' { + b.WriteByte(expr[i]) + continue + } + j := i + 1 + for j < len(expr) && expr[j] >= '0' && expr[j] <= '9' { + j++ + } + if j == i+1 { // '$' not followed by digits — leave verbatim + b.WriteByte('$') + continue + } + n, _ := strconv.Atoi(expr[i+1 : j]) + b.WriteByte('$') + b.WriteString(strconv.Itoa(n + delta)) + i = j - 1 + } + return b.String() +} diff --git a/typed/query_internal_test.go b/typed/query_internal_test.go new file mode 100644 index 0000000..026a668 --- /dev/null +++ b/typed/query_internal_test.go @@ -0,0 +1,80 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "reflect" + "strings" + "testing" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// TestRemapPredicateKeys_SurfacesMalformedBlock guards the swallowed-error fix: +// MultiQuery.Execute used to discard a remap error and fall through to a generic +// downstream decode failure. remapPredicateKeys must now return the structural +// error so Execute can wrap it as "remapping block %q". +func TestRemapPredicateKeys_SurfacesMalformedBlock(t *testing.T) { + // A block body that opens as an array but is not valid JSON cannot be + // remapped; the error must propagate rather than be swallowed. + if _, err := remapPredicateKeys([]byte("[not valid json"), reflect.TypeFor[ivOwner]()); err == nil { + t.Fatal("remapPredicateKeys swallowed a malformed-array error; it must surface it") + } + // A well-formed array remaps without error. + if _, err := remapPredicateKeys([]byte(`[{"name":"x"}]`), reflect.TypeFor[ivOwner]()); err != nil { + t.Fatalf("remapPredicateKeys errored on valid input: %v", err) + } +} + +type ivPet struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +type ivOwner struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Pets []*ivPet `json:"pets,omitempty"` +} + +// TestEdgeBlocksRenderServerSideVar asserts a WhereEdge query renders as a +// server-side var block consumed via uid(mgMatched). The matched roots are +// bound on the server and never inlined into a uid(, ...) list, which +// is what keeps WhereEdge bounded regardless of how many roots match — the +// concern that motivated replacing the eager client-side pre-pass. +func TestEdgeBlocksRenderServerSideVar(t *testing.T) { + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + t.Cleanup(conn.Close) + + qb := NewClient[ivOwner](conn).Query(context.Background()). + Filter(`eq(name, "Alice")`). + WhereEdge("pets", `eq(name, "Fido")`) + + dql := dg.NewQueryBlock(qb.edgeBlocks(true)...).String() + + for _, want := range []string{ + "mgMatched as var(", // matched roots bound server-side + "@cascade", // edge constraint enforced by cascade + "uid(mgMatched)", // data and count blocks consume the var + "count(uid)", // count block present + } { + if !strings.Contains(dql, want) { + t.Errorf("rendered DQL missing %q:\n%s", want, dql) + } + } + // The owner UIDs must never be inlined into the query — that was the + // unbounded behavior. A server-side var carries no uid(0x...) literal list. + if strings.Contains(dql, "uid(0x") { + t.Errorf("rendered DQL inlines UID literals (unbounded):\n%s", dql) + } +} diff --git a/typed/query_test.go b/typed/query_test.go new file mode 100644 index 0000000..fed4fc5 --- /dev/null +++ b/typed/query_test.go @@ -0,0 +1,1400 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "strings" + "testing" + + dg "github.com/dolan-in/dgman/v2" + "github.com/go-logr/logr/funcr" + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// newCountingConn builds a file-backed modusgraph client exactly like newConn, +// but wires in a logr.Logger that counts dgman query executions. dgman logs +// every executed query at verbosity 3 with the message "execute query"; the +// returned *int is incremented once per such log line. +// +// dgman's logger is process-global, and modusgraph allows only one live +// file-backed engine per process (see modusgraph.ErrSingletonOnly). Each call +// uses a fresh t.TempDir() URI for data isolation. Tests that use +// newCountingConn must NOT call t.Parallel(): a second live client would hit +// the engine singleton, and parallel tests would also corrupt the shared +// query count. +func newCountingConn(t *testing.T, count *int) modusgraph.Client { + t.Helper() + logger := funcr.New(func(_, args string) { + // funcr renders the message into args as `"msg"="execute query"`. + // Match that exact pair so unrelated dgman/pool log lines (which log + // other messages, e.g. "executeQuery" for query blocks) are ignored. + if strings.Contains(args, `"msg"="execute query"`) { + *count++ + } + }, funcr.Options{Verbosity: 3}) + conn, err := modusgraph.NewClient("file://"+t.TempDir(), + modusgraph.WithAutoSchema(true), modusgraph.WithLogger(logger)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestQuery_NodesReturnsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Nodes returned %d records, want 3", len(got)) + } +} + +func TestQuery_LimitCapsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + got, err := c.Query(ctx).Limit(2).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("Limit(2) returned %d records, want 2", len(got)) + } +} + +func TestQuery_FirstReturnsAMatch(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "only", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil || got.Name != "only" { + t.Fatalf("First returned %+v, want Name=only", got) + } +} + +func TestQuery_FirstNoMatchReturnsNilNil(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First on empty: unexpected error %v", err) + } + if got != nil { + t.Fatalf("First on empty returned %+v, want nil", got) + } +} + +func TestQuery_BuilderChainCompilesAndRuns(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "x", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Every builder method must return *Query[widget] so the chain stays typed. + _, err := c.Query(ctx). + OrderAsc("qty"). + Offset(0). + Limit(10). + Cascade(). + Nodes() + if err != nil { + t.Fatalf("builder chain Nodes: %v", err) + } +} + +func TestQuery_RawExposesUnderlyingBuilder(t *testing.T) { + c := typed.NewClient[widget](newConn(t)) + if c.Query(context.Background()).Raw() == nil { + t.Fatal("Raw() returned nil; expected the underlying *dg.Query") + } +} + +func TestQuery_Filter(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert three widgets with distinct names. + for _, name := range []string{"alpha", "beta", "gamma"} { + if err := c.Add(ctx, &widget{Name: name}); err != nil { + t.Fatalf("Add %s: %v", name, err) + } + } + + // Filter to exactly those whose name equals "beta" (index=exact allows eq()). + got, err := c.Query(ctx).Filter(`eq(name, "beta")`).Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("Filter returned %d records, want 1", len(got)) + } + if got[0].Name != "beta" { + t.Fatalf("Filter returned Name=%q, want beta", got[0].Name) + } +} + +func TestQuery_FilterAccumulatesWithAnd(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Three widgets; only "beta"/9 satisfies BOTH name=="beta" and qty>=5. + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "beta", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // Two Filter calls must AND together, not overwrite. With last-write-wins + // only ge(qty, 5) survives and this returns the two qty>=5 rows instead of + // the single AND match. + got, err := c.Query(ctx). + Filter(`eq(name, "beta")`). + Filter(`ge(qty, "5")`). + Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("two ANDed Filters returned %d records, want 1 (name==beta AND qty>=5)", len(got)) + } + if got[0].Name != "beta" || got[0].Qty != 9 { + t.Fatalf("got %+v, want Name=beta Qty=9", got[0]) + } +} + +func TestQuery_CombinedFilterShiftsPlaceholders(t *testing.T) { + q := typed.NewDetachedQuery[widget]() + if expr, params := q.CombinedFilter(); expr != "" || params != nil { + t.Fatalf("empty CombinedFilter = (%q, %v), want (\"\", nil)", expr, params) + } + q.Filter("eq(name, $1)", "a") + q.Filter("eq(qty, $1)", 7) + expr, params := q.CombinedFilter() + const want = "(eq(name, $1)) AND (eq(qty, $2))" + if expr != want { + t.Fatalf("CombinedFilter expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "a" || params[1] != 7 { + t.Fatalf("CombinedFilter params = %v, want [a 7]", params) + } +} + +// TestQuery_CombinedFilterParenthesizesFragments pins the precedence guarantee: +// a fragment that contains OR must stay grouped when it is ANDed with another +// fragment. Without per-fragment parentheses the expression would render as +// "a OR b AND c", which dgraph parses as "a OR (b AND c)" — silently widening +// the result set. +func TestQuery_CombinedFilterParenthesizesFragments(t *testing.T) { + q := typed.NewDetachedQuery[widget]() + q.Filter(`eq(name, "alpha") OR eq(name, "beta")`) + q.Filter(`ge(qty, "5")`) + expr, _ := q.CombinedFilter() + const want = `(eq(name, "alpha") OR eq(name, "beta")) AND (ge(qty, "5"))` + if expr != want { + t.Fatalf("CombinedFilter precedence: expr = %q, want %q", expr, want) + } +} + +func TestQuery_OrGroup(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "gamma", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // name == "alpha" OR name == "gamma": two of three rows. + got, err := c.Query(ctx).OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("OrGroup Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("OrGroup(alpha, gamma) returned %d rows, want 2", len(got)) + } + + // AND-of-OR: qty>=5 AND (name==alpha OR name==gamma) → only alpha/9. + got, err = c.Query(ctx). + Filter(`ge(qty, "5")`). + OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("AND-of-OR Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "alpha" { + t.Fatalf("qty>=5 AND (alpha OR gamma) returned %+v, want [alpha/9]", got) + } +} + +func TestQuery_OrderAscDesc(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert widgets with distinct Qty values in non-sorted order so a + // stable natural ordering cannot hide a missing sort. + qtys := []int{30, 10, 50, 20, 40} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ascending. + asc, err := c.Query(ctx).OrderAsc("qty").Nodes() + if err != nil { + t.Fatalf("OrderAsc Nodes: %v", err) + } + if len(asc) != len(qtys) { + t.Fatalf("OrderAsc returned %d records, want %d", len(asc), len(qtys)) + } + for i := range len(asc) - 1 { + if asc[i].Qty > asc[i+1].Qty { + t.Fatalf("OrderAsc: asc[%d].Qty=%d > asc[%d].Qty=%d; not ascending", + i, asc[i].Qty, i+1, asc[i+1].Qty) + } + } + + // Descending. + desc, err := c.Query(ctx).OrderDesc("qty").Nodes() + if err != nil { + t.Fatalf("OrderDesc Nodes: %v", err) + } + if len(desc) != len(qtys) { + t.Fatalf("OrderDesc returned %d records, want %d", len(desc), len(qtys)) + } + for i := range len(desc) - 1 { + if desc[i].Qty < desc[i+1].Qty { + t.Fatalf("OrderDesc: desc[%d].Qty=%d < desc[%d].Qty=%d; not descending", + i, desc[i].Qty, i+1, desc[i+1].Qty) + } + } +} + +func TestQuery_OffsetSkipsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Five widgets with distinct, deliberately unsorted Qty values. + qtys := []int{40, 10, 50, 20, 30} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ordering ascending by qty gives 10,20,30,40,50; Offset(2) drops the + // first two, so 3 rows remain and the first is the 3rd-smallest (30). + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Nodes() + if err != nil { + t.Fatalf("Offset Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("OrderAsc.Offset(2) returned %d records, want 3", len(got)) + } + if got[0].Qty != 30 { + t.Fatalf("first row after Offset(2) has Qty=%d, want 30 (3rd-smallest)", got[0].Qty) + } +} + +func TestQuery_AfterCursor(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // First pass: grab all rows so we can pick a non-last cursor UID. + all, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + if len(all) < 3 { + t.Fatalf("expected at least 3 widgets, got %d", len(all)) + } + cursor := all[1].UID // a non-last row + + // After(cursor) uses default UID ordering to skip past the cursor node. + got, err := c.Query(ctx).After(cursor).Nodes() + if err != nil { + t.Fatalf("After Nodes: %v", err) + } + if len(got) == 0 { + t.Fatal("After(cursor) returned no rows; expected the rows past the cursor") + } + for _, w := range got { + if w.UID <= cursor { + t.Fatalf("After(%s) returned UID %s, which is not strictly greater than the cursor", + cursor, w.UID) + } + } +} + +func TestQuery_CascadeDropsIncompleteNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Widgets with Qty > 0 carry a qty predicate. Widgets with Qty left 0 + // have it omitted entirely (json tag is omitempty), so they have no qty + // predicate at all. + withQty := []int{5, 9, 13} + for _, q := range withQty { + if err := c.Add(ctx, &widget{Name: "has-qty", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + for i := range 4 { + if err := c.Add(ctx, &widget{Name: "no-qty"}); err != nil { + t.Fatalf("Add no-qty[%d]: %v", i, err) + } + } + + // @cascade(qty) drops any node that lacks the qty predicate. + got, err := c.Query(ctx).Cascade("qty").Nodes() + if err != nil { + t.Fatalf("Cascade Nodes: %v", err) + } + if len(got) != len(withQty) { + t.Fatalf("Cascade(qty) returned %d records, want %d (only the qty-bearing widgets)", + len(got), len(withQty)) + } + for _, w := range got { + if w.Qty == 0 { + t.Fatalf("Cascade(qty) returned a widget with Qty=0 (no qty predicate): %+v", w) + } + } +} + +func TestQuery_FilterOrderLimitOffsetCombined(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // A known set: five "keep" widgets plus a "drop" widget the filter excludes. + for _, q := range []int{50, 20, 40, 10, 30} { + if err := c.Add(ctx, &widget{Name: "keep", Qty: q}); err != nil { + t.Fatalf("Add keep qty=%d: %v", q, err) + } + } + if err := c.Add(ctx, &widget{Name: "drop", Qty: 99}); err != nil { + t.Fatalf("Add drop: %v", err) + } + + // Filter to name=keep -> qtys {10,20,30,40,50}; OrderAsc -> sorted; + // Offset(1) drops 10; Limit(2) keeps {20,30}. + got, err := c.Query(ctx). + Filter(`eq(name, "keep")`). + OrderAsc("qty"). + Offset(1). + Limit(2). + Nodes() + if err != nil { + t.Fatalf("combined chain Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("combined chain returned %d records, want 2", len(got)) + } + if got[0].Qty != 20 || got[1].Qty != 30 { + t.Fatalf("combined chain window = [%d, %d], want [20, 30]", got[0].Qty, got[1].Qty) + } +} + +func TestQuery_FirstOnMultipleRows(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, q := range []int{30, 10, 20} { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + + // First on an ascending-by-qty query yields exactly the smallest row. + got, err := c.Query(ctx).OrderAsc("qty").First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil { + t.Fatal("First returned nil on a non-empty result set") + } + if got.Qty != 10 { + t.Fatalf("First on OrderAsc(qty) returned Qty=%d, want 10 (smallest)", got.Qty) + } +} + +func TestQuery_NodesEmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) // fresh client, no inserts + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes on empty client: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("Nodes on empty client returned %d records, want 0", len(got)) + } +} + +func TestQuery_OrderAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // OrderAsc and OrderDesc accumulate: both clauses must survive on the + // same query. dgman renders them as "orderasc:"/"orderdesc:" in the + // generated query string. + q := c.Query(ctx).OrderAsc("name").OrderDesc("qty") + s := q.Raw().String() + if !strings.Contains(s, "orderasc: name") { + t.Fatalf("query string missing ascending name order; got:\n%s", s) + } + if !strings.Contains(s, "orderdesc: qty") { + t.Fatalf("query string missing descending qty order; got:\n%s", s) + } +} + +func TestQuery_CascadeOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Cascade overwrites: the second call wins, the first predicate is gone. + // dgman renders predicates as @cascade(pred1,pred2,...) with no spaces. + q := c.Query(ctx).Cascade("name").Cascade("qty") + s := q.Raw().String() + if !strings.Contains(s, "@cascade(qty)") { + t.Fatalf("second Cascade(qty) not rendered in query string; got:\n%s", s) + } + if strings.Contains(s, "@cascade(name)") { + t.Fatalf("first Cascade(name) still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_TerminalRunsTwice(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + // A terminal is re-runnable: calling Nodes twice on the same builder + // succeeds both times and yields equal-length results. + q := c.Query(ctx) + first, err := q.Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + second, err := q.Nodes() + if err != nil { + t.Fatalf("second Nodes: %v", err) + } + if len(first) != len(second) { + t.Fatalf("Nodes run twice returned %d then %d records; want equal lengths", + len(first), len(second)) + } +} + +func TestQuery_BuilderAliasesAndAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // (i) Filter accumulates: after two Filter calls both survive, ANDed. + q := c.Query(ctx) + q.Filter(`eq(name, "alpha")`) + q.Filter(`eq(name, "beta")`) + s := q.Raw().String() + if !strings.Contains(s, `eq(name, "alpha")`) { + t.Fatalf("Filter A dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, `eq(name, "beta")`) { + t.Fatalf("Filter B dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, " AND ") { + t.Fatalf("accumulated filters not ANDed; got:\n%s", s) + } + + // (ii) The builder aliases: a saved reference and further mutation observe + // the same underlying query. ref and q point at the same *Query, so a + // mutation through one is visible through the other. This documents the + // single-use footgun: you cannot branch a saved builder. + ref := q + if ref != q { + t.Fatal("builder reference is not identical to the original *Query") + } + q.OrderAsc("name") + if ref.Raw().String() != q.Raw().String() { + t.Fatal("mutating q did not affect ref; builder is expected to alias a shared query") + } + if !strings.Contains(ref.Raw().String(), "orderasc: name") { + t.Fatalf("order applied via q not visible through ref; got:\n%s", ref.Raw().String()) + } +} + +func TestQuery_RawRoundTrips(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "raw-target", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Take the raw *dg.Query, apply a dgman-only builder method directly, + // then execute via the raw query's own Nodes(&dst). + var raw *dg.Query = c.Query(ctx).Raw() + raw.OrderAsc("qty") + + var dst []widget + if err := raw.Nodes(&dst); err != nil { + t.Fatalf("raw query Nodes: %v", err) + } + if len(dst) != 1 { + t.Fatalf("raw query returned %d records, want 1", len(dst)) + } + if dst[0].Name != "raw-target" || dst[0].Qty != 7 { + t.Fatalf("raw query returned %+v, want Name=raw-target Qty=7", dst[0]) + } +} + +func TestQuery_SingleQueryPerTerminal(t *testing.T) { + // Uses the global dgman logger; must not run in parallel. + ctx := context.Background() + // queriesExecuted is incremented by newCountingConn's logger each time + // dgman runs a query, so it reflects real database round-trips. + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + + for i := range 2 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // Building the chain runs no queries: builder methods only mutate the AST. + before := queriesExecuted + q := c.Query(ctx).Filter(`eq(name, "w")`).OrderAsc("qty").Limit(10) + if queriesExecuted != before { + t.Fatalf("builder methods executed %d queries, want 0", queriesExecuted-before) + } + + // The Nodes terminal runs exactly one query. + if _, err := q.Nodes(); err != nil { + t.Fatalf("Nodes: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("Nodes executed %d queries, want exactly 1", got) + } + + // A fresh builder's First terminal also runs exactly one query. + before = queriesExecuted + if _, err := c.Query(ctx).First(); err != nil { + t.Fatalf("First: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("First executed %d queries, want exactly 1", got) + } +} + +func TestIterNodes_StreamsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 // > defaultPageSize (50): forces multiple pages + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for w, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + if w == nil { + t.Fatal("IterNodes yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } +} + +func TestIterNodes_StopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("IterNodes yielded %d records after break at 10, want 10", seen) + } +} + +func TestIterNodes_EmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes over empty set yielded error: %v", err) + } + seen++ + } + if seen != 0 { + t.Fatalf("IterNodes over empty set yielded %d records, want 0", seen) + } +} + +func TestIterNodes_RespectsLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 100 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(30).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != 30 { + t.Fatalf("Limit(30).IterNodes() streamed %d records, want 30", seen) + } +} + +func TestIterNodes_LimitExceedsResultSet(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 30 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(500).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("Limit(500).IterNodes() over %d records streamed %d, want %d", n, seen, n) + } +} + +func TestIterNodes_RespectsOffset(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 (not 0) so omitempty never suppresses the field, + // keeping OrderAsc("qty") a true total order over all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + got := make([]int, 0, n-3) // offset 3 of n records + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(3).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 7 { + t.Fatalf("Offset(3).IterNodes() streamed %d records, want 7", len(got)) + } + for i, q := range got { + if q != i+4 { // Qty=1..10; offset 3 skips 1,2,3 → starts at 4 + t.Fatalf("Offset(3).IterNodes()[%d] Qty = %d, want %d", i, q, i+4) + } + } +} + +func TestIterNodes_RespectsOffsetAndLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all 200 records. + const n = 200 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + got := make([]int, 0, 120) // Limit(120) + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(60).Limit(120).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 120 { + t.Fatalf("Offset(60).Limit(120).IterNodes() streamed %d records, want 120", len(got)) + } + for i, q := range got { + if q != i+61 { // Qty=1..200; offset 60 skips 1..60 → starts at 61 + t.Fatalf("result[%d] Qty = %d, want %d", i, q, i+61) + } + } +} + +func TestIterNodes_OneQueryPerPage(t *testing.T) { + ctx := context.Background() + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + const n = 125 // ceil(125/50) = 3 page queries + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Obtaining the iterator runs no query — IterNodes is lazy. Measure the + // delta around the build, not the absolute count, so the assertion holds + // regardless of how many queries the seeding above happened to run. + before := queriesExecuted + seq := c.Query(ctx).IterNodes() + if delta := queriesExecuted - before; delta != 0 { + t.Fatalf("building the IterNodes iterator executed %d queries, want 0", delta) + } + seen := 0 + for _, err := range seq { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } + if delta := queriesExecuted - before; delta != 3 { // ceil(125/50) = 3 pages + t.Fatalf("IterNodes over %d records ran %d queries, want 3", n, delta) + } +} + +func TestIterNodes_YieldsErrorAndStops(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "w", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + // A syntactically invalid @filter (unbalanced parenthesis) makes the page + // query fail at execution; IterNodes must yield one (nil, err) and stop. + gotErr := false + for w, err := range c.Query(ctx).Filter(`eq(name, "w"`).IterNodes() { + if err != nil { + gotErr = true + if w != nil { + t.Fatalf("error yield carried a non-nil widget: %+v", w) + } + break + } + t.Fatal("IterNodes over a malformed query yielded a record before erroring") + } + if !gotErr { + t.Fatal("IterNodes over a malformed query did not yield an error") + } +} + +func TestQuery_LimitOffsetStillDriveNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Regression: Limit/Offset now also set Query struct fields; confirm they + // still drive the Nodes terminal. + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Limit(3).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Offset(2).Limit(3).Nodes() returned %d records, want 3", len(got)) + } + for i, w := range got { + if w.Qty != i+3 { // Qty=1..10; offset 2 skips 1,2 → starts at 3 + t.Fatalf("Nodes()[%d] Qty = %d, want %d", i, w.Qty, i+3) + } + } +} + +func TestQuery_RootFuncOverridesRoot(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // RootFunc replaces the default type(widget) root with an eq() lookup; + // the query still decodes into []widget. + got, err := c.Query(ctx).RootFunc(`eq(name, "b")`).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf(`RootFunc(eq(name,"b")).Nodes() returned %d records, want 1`, len(got)) + } + if got[0].Name != "b" { + t.Fatalf("RootFunc lookup returned %q, want \"b\"", got[0].Name) + } +} + +func TestQuery_RootFuncRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RootFunc renders into the (func: ...) position and overwrites: the + // second call wins. + q := c.Query(ctx).RootFunc(`eq(name, "x")`).RootFunc(`eq(name, "y")`) + s := q.Raw().String() + if !strings.Contains(s, `func: eq(name, "y")`) { + t.Fatalf("second RootFunc not rendered; got:\n%s", s) + } + if strings.Contains(s, `eq(name, "x")`) { + t.Fatalf("first RootFunc still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_NameDecodesAfterRename(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Name renames the query block. dgman uses the name symmetrically to + // generate and decode, so a renamed block still decodes into []widget. + got, err := c.Query(ctx).Name("widgets").Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf(`Name("widgets").Nodes() returned %d records, want 3`, len(got)) + } +} + +func TestQuery_NameRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Name renders as the block name and overwrites: the second call wins. + q := c.Query(ctx).Name("first").Name("second") + s := q.Raw().String() + if !strings.Contains(s, "second(func:") { + t.Fatalf("second Name not rendered as block name; got:\n%s", s) + } + if strings.Contains(s, "first(func:") { + t.Fatalf("first Name still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_AsRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // As transitions to *RawQuery, prefixes the block with " as ", + // and overwrites: the second call wins. + q := c.Query(ctx).As("first").As("second") + if q == nil { + t.Fatal("As() returned nil *RawQuery") + } + s := q.String() + if !strings.Contains(s, "second as ") { + t.Fatalf("second As not rendered; got:\n%s", s) + } + if strings.Contains(s, "first as ") { + t.Fatalf("first As still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_VarsRendersQueryPrefix(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Vars renders a "query " prefix on the generated DQL. + q := c.Query(ctx).Vars("getByName($n: string)", map[string]string{"$n": "b"}) + s := q.Raw().String() + if !strings.Contains(s, "query getByName($n: string)") { + t.Fatalf("Vars did not render the query-definition prefix; got:\n%s", s) + } +} + +func TestQuery_VarsParameterizedQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Vars supplies a GraphQL variable bound into the root function; the + // query executes via dgraph's QueryWithVars path. + got, err := c.Query(ctx). + Vars("getByName($n: string)", map[string]string{"$n": "b"}). + RootFunc("eq(name, $n)"). + Nodes() + if err != nil { + t.Fatalf("Vars query Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "b" { + t.Fatalf(`Vars parameterized query returned %+v, want one widget named "b"`, got) + } +} + +func TestQuery_VarReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Var transitions to *RawQuery and emits a var block: dgman renders the + // block name as "var". + rq := c.Query(ctx).Var() + if rq == nil { + t.Fatal("Var() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "var(func:") { + t.Fatalf("Var() did not render a var block; got:\n%s", s) + } +} + +func TestQuery_GroupByReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // GroupBy transitions to *RawQuery and emits an @groupby clause. + rq := c.Query(ctx).GroupBy("name") + if rq == nil { + t.Fatal("GroupBy() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf(`GroupBy("name") did not render an @groupby clause; got:\n%s`, s) + } +} + +func TestRawQuery_RawExposesUnderlyingQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + rq := c.Query(ctx).Var() + // Raw returns the underlying *dg.Query; String mirrors Raw().String(). + var raw *dg.Query = rq.Raw() + if raw == nil { + t.Fatal("RawQuery.Raw() returned nil") + } + if rq.String() != raw.String() { + t.Fatalf("RawQuery.String() and Raw().String() differ:\n%s\n---\n%s", + rq.String(), raw.String()) + } +} + +func TestRawQuery_GroupByThenVarChains(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RawQuery re-exposes Var and GroupBy so the canonical .GroupBy(...).Var() + // composition still chains; both clauses survive. + s := c.Query(ctx).GroupBy("name").Var().String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing after GroupBy().Var(); got:\n%s", s) + } + if !strings.Contains(s, "var(func:") { + t.Fatalf("var block missing after GroupBy().Var(); got:\n%s", s) + } +} + +func TestRawQuery_CarriesEarlierBuilders(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Builders applied on *Query[T] before the GroupBy transition survive + // into the *RawQuery — the two share one underlying *dg.Query. + s := c.Query(ctx).Filter(`eq(name, "z")`).GroupBy("name").String() + if !strings.Contains(s, `eq(name, "z")`) { + t.Fatalf("Filter set before GroupBy did not survive the transition; got:\n%s", s) + } + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing; got:\n%s", s) + } +} + +// seedOwners inserts owner/pet pairs over conn for the WhereEdge tests. Each +// map entry is one owner owning one pet of the given name; the pet is inserted +// first so the owner's edge links an already-persisted node. It returns an +// owner client bound to conn. +func seedOwners( + ctx context.Context, t *testing.T, conn modusgraph.Client, ownerToPet map[string]string, +) *typed.Client[owner] { + t.Helper() + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + for ownerName, petName := range ownerToPet { + p := &pet{Name: petName} + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", petName, err) + } + if err := owners.Add(ctx, &owner{Name: ownerName, Pets: []*pet{p}}); err != nil { + t.Fatalf("Add owner %q: %v", ownerName, err) + } + } + return owners +} + +func TestQuery_WhereEdgeFiltersByEdgeTarget(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + // WhereEdge constrains owners by a scalar of the pet reached over the + // "pets" edge — something a root Filter cannot express. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("WhereEdge(pets, name=Fido) returned %d owners, want 2 (Alice, Carol)", len(got)) + } + for _, o := range got { + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge returned %q, want only Fido owners (Alice, Carol)", o.Name) + } + } +} + +func TestQuery_WhereEdgeNoMatchReturnsEmpty(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // No pet is named Nemo: the pre-pass matches zero roots, so Nodes returns + // an empty result — not an error — and never runs the main query. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("WhereEdge for an unowned pet name returned %d owners, want 0", len(got)) + } +} + +func TestQuery_WhereEdgeBindsParams(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // The $1 placeholder in a WhereEdge filter binds exactly as it does for Filter. + got, err := owners.Query(ctx).WhereEdge("pets", "eq(name, $1)", "Rex").Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Bob" { + t.Fatalf("WhereEdge(pets, name=$1, Rex) returned %+v, want [Bob]", got) + } +} + +func TestQuery_WhereEdgeCombinesWithFilter(t *testing.T) { + ctx := context.Background() + // Alice and Carol both own a Fido; a root Filter on name narrows to Alice. + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + got, err := owners.Query(ctx). + Filter(`eq(name, "Alice")`). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("Filter+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("Filter(name=Alice)+WhereEdge(pets,name=Fido) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgePreservesUIDRoot(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + + fido := &pet{Name: "Fido"} + if err := pets.Add(ctx, fido); err != nil { + t.Fatalf("Add pet: %v", err) + } + alice := &owner{Name: "Alice", Pets: []*pet{fido}} + carol := &owner{Name: "Carol", Pets: []*pet{fido}} + for _, o := range []*owner{alice, carol} { + if err := owners.Add(ctx, o); err != nil { + t.Fatalf("Add owner %q: %v", o.Name, err) + } + } + + // Both Alice and Carol own Fido, so the WhereEdge var block matches both. + // Rooting the query at Alice's UID must survive that match: the var block + // roots at the caller's UID, so mgMatched is the intersection (just Alice), + // not every Fido owner. + got, err := owners.Query(ctx). + UID(alice.UID). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("UID+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("UID(Alice)+WhereEdge(pets,name=Fido) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeMultipleConstraintsAnd(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + + // Alice owns both Fido and Rex; Bob owns only Fido. + fido, rex := &pet{Name: "Fido"}, &pet{Name: "Rex"} + for _, p := range []*pet{fido, rex} { + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", p.Name, err) + } + } + if err := owners.Add(ctx, &owner{Name: "Alice", Pets: []*pet{fido, rex}}); err != nil { + t.Fatalf("Add Alice: %v", err) + } + if err := owners.Add(ctx, &owner{Name: "Bob", Pets: []*pet{fido}}); err != nil { + t.Fatalf("Add Bob: %v", err) + } + + // Two WhereEdge calls AND together: only an owner of BOTH pets survives. + got, err := owners.Query(ctx). + WhereEdge("pets", `eq(name, "Fido")`). + WhereEdge("pets", `eq(name, "Rex")`). + Nodes() + if err != nil { + t.Fatalf("two-WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("WhereEdge(Fido) AND WhereEdge(Rex) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeFirst(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // First runs the pre-pass too: it returns the Rex owner, never a Fido one. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Rex")`).First() + if err != nil { + t.Fatalf("WhereEdge First: %v", err) + } + if got == nil || got.Name != "Bob" { + t.Fatalf("WhereEdge(pets,name=Rex).First() = %+v, want Bob", got) + } + + // First with an edge constraint nothing satisfies is (nil, nil). + none, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).First() + if err != nil { + t.Fatalf("WhereEdge First no-match: unexpected error %v", err) + } + if none != nil { + t.Fatalf("WhereEdge First with no match = %+v, want nil", none) + } +} + +func TestQuery_WhereEdgeNodesAndCount(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + "Dave": "Fido", + }) + + // Three owners have a Fido. Limit caps the returned rows to 2, but the count + // reflects the full matched set: the count block runs count(uid) over + // uid(mgMatched), independent of the data block's pagination. + rows, count, err := owners.Query(ctx). + WhereEdge("pets", `eq(name, "Fido")`). + Limit(2). + NodesAndCount() + if err != nil { + t.Fatalf("WhereEdge NodesAndCount: %v", err) + } + if count != 3 { + t.Fatalf("WhereEdge NodesAndCount count = %d, want 3 (Alice, Carol, Dave)", count) + } + if len(rows) != 2 { + t.Fatalf("WhereEdge NodesAndCount returned %d rows, want 2 (Limit)", len(rows)) + } +} + +func TestQuery_WhereEdgeIterNodes(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + seen := 0 + for o, err := range owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).IterNodes() { + if err != nil { + t.Fatalf("WhereEdge IterNodes yielded error: %v", err) + } + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge IterNodes yielded %q, want a Fido owner", o.Name) + } + seen++ + } + if seen != 2 { + t.Fatalf("WhereEdge IterNodes streamed %d owners, want 2", seen) + } +} + +func TestQuery_WhereEdgeForwardsVars(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Fido", + }) + + // Vars binds a GraphQL variable into the root function, combined with a + // WhereEdge constraint. The WhereEdge path renders its own multi-block + // request via QueryRaw, so it must forward the variable; otherwise $n is + // unbound and the query errors. Both own a Fido; $n narrows to Alice. + got, err := owners.Query(ctx). + Vars("byName($n: string)", map[string]string{"$n": "Alice"}). + RootFunc("eq(name, $n)"). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("Vars+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("Vars($n=Alice)+WhereEdge(pets,Fido) returned %+v, want [Alice]", got) + } +} + +func TestQuery_UIDRootsAtNode(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).UID(w.UID).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "sprocket" { + t.Fatalf("UID query returned %+v, want one widget named sprocket", got) + } +} + +func TestQuery_NodesAndCountReturnsTotal(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := 0; i < 3; i++ { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add: %v", err) + } + } + + nodes, count, err := c.Query(ctx).NodesAndCount() + if err != nil { + t.Fatalf("NodesAndCount: %v", err) + } + if count != 3 || len(nodes) != 3 { + t.Fatalf("got count=%d len=%d, want 3 and 3", count, len(nodes)) + } +} + +func TestQuery_AllSetsTraversalDepth(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "deep", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // All(1) overrides the default traversal depth for this query; the call + // must chain and the query must still execute and decode. + got, err := c.Query(ctx).All(1).Nodes() + if err != nil { + t.Fatalf("Nodes with All(1): %v", err) + } + if len(got) != 1 { + t.Fatalf("got %d widgets, want 1", len(got)) + } +} + +func TestQuery_StringRendersDQL(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + dql := c.Query(ctx).Filter("eq(name, $1)", "sprocket").String() + if !strings.Contains(dql, "widget") { + t.Fatalf("String() = %q, want it to mention the widget type", dql) + } +} diff --git a/typed/search/merge.go b/typed/search/merge.go new file mode 100644 index 0000000..2546274 --- /dev/null +++ b/typed/search/merge.go @@ -0,0 +1,27 @@ +// Package search provides helpers for assembling fulltext / ranked search +// results across multiple typed query blocks. +package search + +// MergeByID concatenates inputs into a single slice while preserving +// first-seen order and dropping any subsequent occurrence of an ID already +// emitted. The id function extracts a comparable identifier from each row. +// +// MergeByID is intended for use after typed.MultiQuery.Execute, when +// consumers want a single ranked slice from N per-field result sets: +// inputs[0] takes priority, inputs[1] fills in next, etc. A nil result +// indicates no rows survived (the inputs were all empty). +func MergeByID[T any](id func(T) string, inputs ...[]T) []T { + seen := make(map[string]struct{}) + var out []T + for _, in := range inputs { + for _, row := range in { + k := id(row) + if _, dup := seen[k]; dup { + continue + } + seen[k] = struct{}{} + out = append(out, row) + } + } + return out +} diff --git a/typed/search/merge_test.go b/typed/search/merge_test.go new file mode 100644 index 0000000..e4e8583 --- /dev/null +++ b/typed/search/merge_test.go @@ -0,0 +1,86 @@ +package search_test + +import ( + "reflect" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/search" +) + +type rec struct { + ID string + Tag string +} + +func id(r rec) string { return r.ID } + +func TestMergeByID(t *testing.T) { + cases := []struct { + name string + inputs [][]rec + want []rec + }{ + { + name: "empty inputs returns nil", + inputs: nil, + want: nil, + }, + { + name: "single empty slice returns nil", + inputs: [][]rec{{}}, + want: nil, + }, + { + name: "single slice returns it as-is", + inputs: [][]rec{{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }}, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + { + name: "two slices merge in priority order", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "duplicate ID keeps first-seen entry", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "a", Tag: "desc"}, {ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "intra-slice duplicates dedup too", + inputs: [][]rec{ + {{ID: "a", Tag: "1"}, {ID: "a", Tag: "2"}, {ID: "b", Tag: "1"}}, + }, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := search.MergeByID(id, c.inputs...) + if !reflect.DeepEqual(got, c.want) { + t.Fatalf("got %v, want %v", got, c.want) + } + }) + } +} diff --git a/typed/tracing.go b/typed/tracing.go new file mode 100644 index 0000000..0c043d5 --- /dev/null +++ b/typed/tracing.go @@ -0,0 +1,71 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "reflect" + "sync/atomic" +) + +// Span is a tracing span for a single database operation. End is called once, +// with the operation's final error (nil on success). +type Span interface { + End(err error) +} + +// Tracer starts a Span around a typed-layer database operation. The typed +// client calls the installed Tracer for every DB call; the default is a no-op, +// so the typed package itself carries no tracing dependency. Install a real +// tracer — for example github.com/mlwelles/modusgraph-telemetry's OpenTelemetry +// tracer — with SetTracer. +type Tracer interface { + // StartSpan begins a span for operation op (for example "get") on the named + // collection, returning a context carrying the span and the Span itself. + StartSpan(ctx context.Context, op, collection string) (context.Context, Span) +} + +type noopSpan struct{} + +func (noopSpan) End(error) {} + +type noopTracer struct{} + +func (noopTracer) StartSpan(ctx context.Context, _, _ string) (context.Context, Span) { + return ctx, noopSpan{} +} + +// tracerHolder is the process-wide tracer the typed package uses, held in an +// atomic.Pointer so SetTracer and the per-operation reads in currentTracer are +// data-race free. It is a no-op until a host installs one via SetTracer. +var tracerHolder atomic.Pointer[Tracer] + +// currentTracer returns the installed tracer, or the no-op tracer if SetTracer +// has not run. Every terminal reads through here, so the load stays on the hot +// path; atomic.Pointer makes it lock-free. +func currentTracer() Tracer { + if p := tracerHolder.Load(); p != nil { + return *p + } + return noopTracer{} +} + +// SetTracer installs the process-wide tracer for typed-layer DB spans. Passing +// nil restores the no-op tracer. It is safe to call concurrently with active +// queries: in-flight terminals keep the tracer they already loaded, and later +// terminals observe the new one. +func SetTracer(t Tracer) { + if t == nil { + t = noopTracer{} + } + tracerHolder.Store(&t) +} + +// entityName returns the unqualified Go type name of T (for example "Resource"), +// used as the db.collection.name span attribute. +func entityName[T any]() string { + return reflect.TypeFor[T]().Name() +} diff --git a/typed/tracing_test.go b/typed/tracing_test.go new file mode 100644 index 0000000..5e24875 --- /dev/null +++ b/typed/tracing_test.go @@ -0,0 +1,71 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "testing" +) + +func TestSetTracer_InstallsAndResets(t *testing.T) { + t.Cleanup(func() { SetTracer(nil) }) + + rec := &recordingTracer{} + SetTracer(rec) + + _, span := currentTracer().StartSpan(context.Background(), "get", "Widget") + span.End(nil) + + if rec.op != "get" || rec.collection != "Widget" { + t.Fatalf("installed tracer not invoked: %+v", rec) + } + if !rec.ended { + t.Fatal("span.End was not called") + } + + // nil restores the no-op tracer, which must not panic. + SetTracer(nil) + _, span = currentTracer().StartSpan(context.Background(), "x", "Y") + span.End(nil) +} + +// TestSetTracer_ConcurrentWithReads asserts SetTracer is safe to call while +// other goroutines read the tracer. Run under -race, a plain package var would +// flag a data race here; the atomic.Pointer holder does not. +func TestSetTracer_ConcurrentWithReads(t *testing.T) { + t.Cleanup(func() { SetTracer(nil) }) + + done := make(chan struct{}) + go func() { + defer close(done) + for range 1000 { + _, span := currentTracer().StartSpan(context.Background(), "get", "Widget") + span.End(nil) + } + }() + for i := range 1000 { + if i%2 == 0 { + SetTracer(&recordingTracer{}) + } else { + SetTracer(nil) + } + } + <-done +} + +type recordingTracer struct { + op, collection string + ended bool +} + +func (r *recordingTracer) StartSpan(ctx context.Context, op, collection string) (context.Context, Span) { + r.op, r.collection = op, collection + return ctx, &recordingSpan{r} +} + +type recordingSpan struct{ r *recordingTracer } + +func (s *recordingSpan) End(error) { s.r.ended = true }