Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ func (c client) validateStruct(ctx context.Context, obj any) error {
// Insert implements inserting an object or slice of objects in the database.
// Passed object must be a pointer to a struct with appropriate dgraph tags.
func (c client) Insert(ctx context.Context, obj any) error {
obj = UnwrapSchema(obj)
// Validate struct before insertion
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -503,6 +504,7 @@ func (c client) Insert(ctx context.Context, obj any) error {
//
// Deprecated: InsertRaw is now identical to Insert. Use Insert instead.
func (c client) InsertRaw(ctx context.Context, obj any) error {
obj = UnwrapSchema(obj)
// Validate struct before insertion
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -518,6 +520,7 @@ func (c client) InsertRaw(ctx context.Context, obj any) error {
// to be used for upserting. If none are specified, the first predicate with the `upsert` tag
// will be used.
func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error {
obj = UnwrapSchema(obj)
// Validate struct before upsert
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -531,6 +534,7 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error
// Update implements updating an existing object in the database.
// Passed object must be a pointer to a struct.
func (c client) Update(ctx context.Context, obj any) error {
obj = UnwrapSchema(obj)
// Validate struct before update
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -557,6 +561,7 @@ func (c client) Delete(ctx context.Context, uids []string) error {
// Get implements retrieving a single object by its UID.
// Passed object must be a pointer to a struct.
func (c client) Get(ctx context.Context, obj any, uid string) error {
obj = UnwrapSchema(obj)
err := checkPointer(obj)
if err != nil {
return err
Expand All @@ -575,6 +580,7 @@ func (c client) Get(ctx context.Context, obj any, uid string) error {
// Returns a *dg.Query that can be further refined with filters, pagination, etc.
// The returned query will be limited to the maximum number of edges specified in the options.
func (c client) Query(ctx context.Context, model any) *dg.Query {
model = UnwrapSchema(model)
client, err := c.pool.get()
if err != nil {
return nil
Expand All @@ -590,6 +596,9 @@ func (c client) Query(ctx context.Context, model any) *dg.Query {
// If any object contains SimString fields tagged `dgraph:"embedding"`, the
// corresponding shadow float32vector predicates (<field>__vec) are also registered.
func (c client) UpdateSchema(ctx context.Context, obj ...any) error {
for i := range obj {
obj[i] = UnwrapSchema(obj[i])
}
dgClient, err := c.pool.get()
if err != nil {
c.logger.Error(err, "Failed to get client from pool")
Expand Down
71 changes: 71 additions & 0 deletions record.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package modusgraph

import "reflect"

// Schema identifies a value as a record of a generated schema-defining type.
// modusgraph-gen-emitted schema structs implement this via a generated
// SchemaTypeName() method that returns the canonical entity name
// (e.g. "Studio"). The interface is intentionally minimal — a single method
// returning a useful piece of metadata.
//
// Plain user structs (not emitted by modusgraph-gen) do not implement Schema
// and are unaffected by the modusgraph.Client routing it enables; they pass
// through to the existing reflection-based dgman pipeline exactly as before.
type Schema interface {
SchemaTypeName() string
}

// UnwrapSchema returns the schema-defining record contained in obj. If obj
// is nil, it is returned as-is. If obj is already a Schema, it is returned
// as-is. If obj exposes an Unwrap() method whose return value satisfies
// Schema, that return is substituted. Otherwise obj is returned unchanged.
//
// This is the bridge between modusgraph-gen-emitted wrapper types and the
// rest of modusgraph.Client. It is purely additive: types that don't
// implement Schema and don't have an Unwrap() method (i.e. existing
// modusgraph users' plain structs) pass through untouched.
//
// Note on errors.Unwrap overlap: Go's errors package uses Unwrap() error
// as the standard "give me the wrapped thing" method. UnwrapSchema's
// secondary check (the returned value must itself implement Schema) means
// an error wrapper is not mistaken for a modusgraph wrapper — the
// reflection probe finds Unwrap(), calls it, gets an error, fails the
// Schema check, and returns the original obj.
func UnwrapSchema(obj any) any {
if obj == nil {
return obj
}
if _, ok := obj.(Schema); ok {
return obj
}
v := reflect.ValueOf(obj)
if !v.IsValid() {
return obj
}
// A typed nil pointer has a valid method set, but invoking Unwrap on a nil
// receiver would panic if the method dereferences it. Leave it untouched.
if v.Kind() == reflect.Pointer && v.IsNil() {
return obj
}
m := v.MethodByName("Unwrap")
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
if !m.IsValid() && v.Kind() != reflect.Pointer {
// Unwrap may be declared with a pointer receiver while obj was passed by
// value; a value's method set excludes pointer-receiver methods, so look
// it up on an addressable copy.
pv := reflect.New(v.Type())
pv.Elem().Set(v)
m = pv.MethodByName("Unwrap")
}
if !m.IsValid() {
return obj
}
mt := m.Type()
if mt.NumIn() != 0 || mt.NumOut() != 1 {
return obj
}
inner := m.Call(nil)[0].Interface()
if _, ok := inner.(Schema); ok {
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
return inner
}
return obj
}
49 changes: 49 additions & 0 deletions record_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc.
* SPDX-License-Identifier: Apache-2.0
*/

package modusgraph_test

import (
"context"
"fmt"

mg "github.com/matthewmcneely/modusgraph"
)

// Actor is a schema-defining record. Implementing mg.Schema (a single
// SchemaTypeName method) marks it as a generated schema type; code generators
// such as modusgraph-gen emit this method.
type Actor struct {
UID string `json:"uid,omitempty"`
DType []string `json:"dgraph.type,omitempty"`
Name string `json:"name,omitempty" dgraph:"index=exact"`
}

func (a *Actor) SchemaTypeName() string { return "Actor" }

// ActorBuilder is a wrapper around Actor — the shape a generated fluent builder
// or domain wrapper takes. Exposing Unwrap lets the modusgraph client route the
// wrapper to its backing record, so the wrapper can be passed straight to
// Insert/Update/Get without the caller reaching for the inner value.
type ActorBuilder struct{ actor *Actor }

func (b *ActorBuilder) Unwrap() *Actor { return b.actor }

// ExampleSchema shows the wrapper pattern: the client unwraps an ActorBuilder
// to its Actor before persisting, so generated wrapper types work transparently
// while plain structs are unaffected.
func ExampleSchema() {
client, _ := mg.NewClient("dgraph://localhost:9080")
defer client.Close()

ctx := context.Background()
builder := &ActorBuilder{actor: &Actor{Name: "Sigourney Weaver"}}

// Insert the wrapper; the client unwraps it to the Actor record.
if err := client.Insert(ctx, builder); err != nil {
panic(err)
}
fmt.Println(builder.actor.Name)
}
51 changes: 51 additions & 0 deletions record_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc.
* SPDX-License-Identifier: Apache-2.0
*/

package modusgraph_test

import (
"context"
"testing"

mg "github.com/matthewmcneely/modusgraph"
"github.com/stretchr/testify/require"
)

// studioRecord is a schema-defining record (implements mg.Schema). studioWrapper
// wraps it and exposes Unwrap, exactly as a modusgraph-gen wrapper would.
type studioRecord struct {
UID string `json:"uid,omitempty"`
DType []string `json:"dgraph.type,omitempty"`
Name string `json:"name,omitempty" dgraph:"index=exact"`
}

func (s *studioRecord) SchemaTypeName() string { return "studioRecord" }

type studioWrapper struct{ inner *studioRecord }

func (w *studioWrapper) Unwrap() *studioRecord { return w.inner }

// TestClientUnwrapsWrapperThroughRealMutation exercises the real client path,
// not UnwrapSchema in isolation: it inserts a wrapper and reads it back. If a
// mutation method stopped calling UnwrapSchema, the wrapper (which has no usable
// dgraph fields of its own) would not persist Name and the inner UID would stay
// empty — so this test fails on that regression.
func TestClientUnwrapsWrapperThroughRealMutation(t *testing.T) {
client, err := mg.NewClient("file://"+GetTempDir(t), mg.WithAutoSchema(true))
require.NoError(t, err)
defer client.Close()

ctx := context.Background()
inner := &studioRecord{Name: "Acme"}
wrapper := &studioWrapper{inner: inner}

require.NoError(t, client.Insert(ctx, wrapper))
require.NotEmpty(t, inner.UID,
"Insert did not route the wrapper to its inner record")

var got studioRecord
require.NoError(t, client.Get(ctx, &got, inner.UID))
require.Equal(t, "Acme", got.Name)
}
139 changes: 139 additions & 0 deletions record_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package modusgraph

import (
"errors"
"testing"
)

type fakeRecord struct{ name string }

func (f *fakeRecord) SchemaTypeName() string { return f.name }

type fakeWrapper struct{ inner *fakeRecord }

func (w *fakeWrapper) Unwrap() *fakeRecord { return w.inner }

type fakeNonSchema struct{ X string }

func TestUnwrapSchema_PassthroughForPlainStruct(t *testing.T) {
in := &fakeNonSchema{X: "hi"}
out := UnwrapSchema(in)
if out != any(in) {
t.Fatalf("expected passthrough, got %T", out)
}
}

func TestUnwrapSchema_PassthroughForSchemaStruct(t *testing.T) {
in := &fakeRecord{name: "Studio"}
out := UnwrapSchema(in)
if out != any(in) {
t.Fatalf("expected passthrough for direct Schema, got %T", out)
}
}

func TestUnwrapSchema_UnwrapsWrapper(t *testing.T) {
inner := &fakeRecord{name: "Studio"}
w := &fakeWrapper{inner: inner}
out := UnwrapSchema(w)
if out != any(inner) {
t.Fatalf("expected unwrapped inner, got %T (%v)", out, out)
}
}

func TestUnwrapSchema_IgnoresErrorsUnwrap(t *testing.T) {
// errors.New("x") has no Unwrap; wrap one to get something with Unwrap() error.
inner := errors.New("inner")
outer := &wrappedErr{err: inner}
out := UnwrapSchema(outer)
if out != any(outer) {
t.Fatalf("expected passthrough for error wrapper, got %T", out)
}
}

type wrappedErr struct{ err error }

func (w *wrappedErr) Error() string { return w.err.Error() }
func (w *wrappedErr) Unwrap() error { return w.err }

func TestUnwrapSchema_NilInput(t *testing.T) {
if out := UnwrapSchema(nil); out != nil {
t.Fatalf("expected nil for nil input, got %v", out)
}
}

func TestUnwrapSchema_TypedNilPointerDoesNotPanic(t *testing.T) {
// fakeWrapper.Unwrap dereferences its receiver, so invoking it on a typed
// nil pointer would panic. UnwrapSchema must return the value untouched.
var w *fakeWrapper
out := UnwrapSchema(w)
if out != any(w) {
t.Fatalf("expected typed nil pointer passthrough, got %T (%v)", out, out)
}
}

func TestUnwrapSchema_PointerReceiverUnwrapOnValue(t *testing.T) {
// fakeWrapper.Unwrap has a pointer receiver. Passing the wrapper by value
// must still unwrap: a value's method set excludes pointer-receiver methods,
// so UnwrapSchema looks Unwrap up on an addressable copy.
inner := &fakeRecord{name: "Studio"}
w := fakeWrapper{inner: inner}
out := UnwrapSchema(w)
if out != any(inner) {
t.Fatalf("expected unwrapped inner from value wrapper, got %T (%v)", out, out)
}
}

// recordingClient is the minimal surface needed to verify that wrappers
// passed to the Client interface get unwrapped before reaching internal
// reflection. It records whatever it received and returns nil. Each method
// applies obj = UnwrapSchema(obj) at the top, mirroring the patch landing
// in this task.
type recordingClient struct {
seen []any
}

func (c *recordingClient) capture(obj any) any {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Tests use a local recordingClient mock to simulate unwrapping instead of exercising real client methods, so regressions in the actual integration points (7 client mutation/query methods) may go undetected.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At record_test.go, line 73:

<comment>Tests use a local `recordingClient` mock to simulate unwrapping instead of exercising real client methods, so regressions in the actual integration points (7 client mutation/query methods) may go undetected.</comment>

<file context>
@@ -0,0 +1,117 @@
+	seen []any
+}
+
+func (c *recordingClient) capture(obj any) any {
+	obj = UnwrapSchema(obj)
+	c.seen = append(c.seen, obj)
</file context>

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9bc4289. Added TestClientUnwrapsWrapperThroughRealMutation, which inserts a wrapper through the real client and reads it back via Get, so a mutation method dropping its UnwrapSchema call is caught — the prior tests exercised UnwrapSchema only through a local mock. (Issue identified by cubic.)

obj = UnwrapSchema(obj)
c.seen = append(c.seen, obj)
return obj
}

func TestUnwrapSchema_CaptureForwardsInner(t *testing.T) {
inner := &fakeRecord{name: "Studio"}
w := &fakeWrapper{inner: inner}
c := &recordingClient{}
got := c.capture(w)
if got != any(inner) {
t.Fatalf("expected inner record, got %T (%v)", got, got)
}
if len(c.seen) != 1 || c.seen[0] != any(inner) {
t.Fatalf("expected recording to hold inner record, got %v", c.seen)
}
}

func TestUnwrapSchema_CapturePassthroughForPlain(t *testing.T) {
plain := &fakeNonSchema{X: "y"}
c := &recordingClient{}
got := c.capture(plain)
if got != any(plain) {
t.Fatalf("expected plain struct passthrough, got %T", got)
}
}

func TestUnwrapSchema_VariadicUnwrapsEachElement(t *testing.T) {
innerA := &fakeRecord{name: "Studio"}
innerB := &fakeRecord{name: "Film"}
templates := []any{
&fakeWrapper{inner: innerA},
innerB, // already a Schema; passthrough
}
for i, obj := range templates {
templates[i] = UnwrapSchema(obj)
}
if templates[0] != any(innerA) {
t.Fatalf("template[0]: expected innerA, got %T", templates[0])
}
if templates[1] != any(innerB) {
t.Fatalf("template[1]: expected innerB (passthrough), got %T", templates[1])
}
}
Loading