diff --git a/client.go b/client.go index be9813b..14e38ee 100644 --- a/client.go +++ b/client.go @@ -486,6 +486,7 @@ func (c client) validateStruct(ctx context.Context, obj any) error { // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -503,6 +504,7 @@ func (c client) Insert(ctx context.Context, obj any) error { // // Deprecated: InsertRaw is now identical to Insert. Use Insert instead. func (c client) InsertRaw(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -518,6 +520,7 @@ func (c client) InsertRaw(ctx context.Context, obj any) error { // to be used for upserting. If none are specified, the first predicate with the `upsert` tag // will be used. func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error { + obj = UnwrapSchema(obj) // Validate struct before upsert if err := c.validateStruct(ctx, obj); err != nil { return err @@ -531,6 +534,7 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before update if err := c.validateStruct(ctx, obj); err != nil { return err @@ -557,6 +561,7 @@ func (c client) Delete(ctx context.Context, uids []string) error { // Get implements retrieving a single object by its UID. // Passed object must be a pointer to a struct. func (c client) Get(ctx context.Context, obj any, uid string) error { + obj = UnwrapSchema(obj) err := checkPointer(obj) if err != nil { return err @@ -575,6 +580,7 @@ func (c client) Get(ctx context.Context, obj any, uid string) error { // Returns a *dg.Query that can be further refined with filters, pagination, etc. // The returned query will be limited to the maximum number of edges specified in the options. func (c client) Query(ctx context.Context, model any) *dg.Query { + model = UnwrapSchema(model) client, err := c.pool.get() if err != nil { return nil @@ -590,6 +596,9 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { // If any object contains SimString fields tagged `dgraph:"embedding"`, the // corresponding shadow float32vector predicates (__vec) are also registered. func (c client) UpdateSchema(ctx context.Context, obj ...any) error { + for i := range obj { + obj[i] = UnwrapSchema(obj[i]) + } dgClient, err := c.pool.get() if err != nil { c.logger.Error(err, "Failed to get client from pool") diff --git a/record.go b/record.go new file mode 100644 index 0000000..bfcc762 --- /dev/null +++ b/record.go @@ -0,0 +1,71 @@ +package modusgraph + +import "reflect" + +// Schema identifies a value as a record of a generated schema-defining type. +// modusgraph-gen-emitted schema structs implement this via a generated +// SchemaTypeName() method that returns the canonical entity name +// (e.g. "Studio"). The interface is intentionally minimal — a single method +// returning a useful piece of metadata. +// +// Plain user structs (not emitted by modusgraph-gen) do not implement Schema +// and are unaffected by the modusgraph.Client routing it enables; they pass +// through to the existing reflection-based dgman pipeline exactly as before. +type Schema interface { + SchemaTypeName() string +} + +// UnwrapSchema returns the schema-defining record contained in obj. If obj +// is nil, it is returned as-is. If obj is already a Schema, it is returned +// as-is. If obj exposes an Unwrap() method whose return value satisfies +// Schema, that return is substituted. Otherwise obj is returned unchanged. +// +// This is the bridge between modusgraph-gen-emitted wrapper types and the +// rest of modusgraph.Client. It is purely additive: types that don't +// implement Schema and don't have an Unwrap() method (i.e. existing +// modusgraph users' plain structs) pass through untouched. +// +// Note on errors.Unwrap overlap: Go's errors package uses Unwrap() error +// as the standard "give me the wrapped thing" method. UnwrapSchema's +// secondary check (the returned value must itself implement Schema) means +// an error wrapper is not mistaken for a modusgraph wrapper — the +// reflection probe finds Unwrap(), calls it, gets an error, fails the +// Schema check, and returns the original obj. +func UnwrapSchema(obj any) any { + if obj == nil { + return obj + } + if _, ok := obj.(Schema); ok { + return obj + } + v := reflect.ValueOf(obj) + if !v.IsValid() { + return obj + } + // A typed nil pointer has a valid method set, but invoking Unwrap on a nil + // receiver would panic if the method dereferences it. Leave it untouched. + if v.Kind() == reflect.Pointer && v.IsNil() { + return obj + } + m := v.MethodByName("Unwrap") + if !m.IsValid() && v.Kind() != reflect.Pointer { + // Unwrap may be declared with a pointer receiver while obj was passed by + // value; a value's method set excludes pointer-receiver methods, so look + // it up on an addressable copy. + pv := reflect.New(v.Type()) + pv.Elem().Set(v) + m = pv.MethodByName("Unwrap") + } + if !m.IsValid() { + return obj + } + mt := m.Type() + if mt.NumIn() != 0 || mt.NumOut() != 1 { + return obj + } + inner := m.Call(nil)[0].Interface() + if _, ok := inner.(Schema); ok { + return inner + } + return obj +} diff --git a/record_example_test.go b/record_example_test.go new file mode 100644 index 0000000..534a844 --- /dev/null +++ b/record_example_test.go @@ -0,0 +1,49 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + + mg "github.com/matthewmcneely/modusgraph" +) + +// Actor is a schema-defining record. Implementing mg.Schema (a single +// SchemaTypeName method) marks it as a generated schema type; code generators +// such as modusgraph-gen emit this method. +type Actor struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +func (a *Actor) SchemaTypeName() string { return "Actor" } + +// ActorBuilder is a wrapper around Actor — the shape a generated fluent builder +// or domain wrapper takes. Exposing Unwrap lets the modusgraph client route the +// wrapper to its backing record, so the wrapper can be passed straight to +// Insert/Update/Get without the caller reaching for the inner value. +type ActorBuilder struct{ actor *Actor } + +func (b *ActorBuilder) Unwrap() *Actor { return b.actor } + +// ExampleSchema shows the wrapper pattern: the client unwraps an ActorBuilder +// to its Actor before persisting, so generated wrapper types work transparently +// while plain structs are unaffected. +func ExampleSchema() { + client, _ := mg.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + builder := &ActorBuilder{actor: &Actor{Name: "Sigourney Weaver"}} + + // Insert the wrapper; the client unwraps it to the Actor record. + if err := client.Insert(ctx, builder); err != nil { + panic(err) + } + fmt.Println(builder.actor.Name) +} diff --git a/record_integration_test.go b/record_integration_test.go new file mode 100644 index 0000000..41e9a13 --- /dev/null +++ b/record_integration_test.go @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "testing" + + mg "github.com/matthewmcneely/modusgraph" + "github.com/stretchr/testify/require" +) + +// studioRecord is a schema-defining record (implements mg.Schema). studioWrapper +// wraps it and exposes Unwrap, exactly as a modusgraph-gen wrapper would. +type studioRecord struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +func (s *studioRecord) SchemaTypeName() string { return "studioRecord" } + +type studioWrapper struct{ inner *studioRecord } + +func (w *studioWrapper) Unwrap() *studioRecord { return w.inner } + +// TestClientUnwrapsWrapperThroughRealMutation exercises the real client path, +// not UnwrapSchema in isolation: it inserts a wrapper and reads it back. If a +// mutation method stopped calling UnwrapSchema, the wrapper (which has no usable +// dgraph fields of its own) would not persist Name and the inner UID would stay +// empty — so this test fails on that regression. +func TestClientUnwrapsWrapperThroughRealMutation(t *testing.T) { + client, err := mg.NewClient("file://"+GetTempDir(t), mg.WithAutoSchema(true)) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + inner := &studioRecord{Name: "Acme"} + wrapper := &studioWrapper{inner: inner} + + require.NoError(t, client.Insert(ctx, wrapper)) + require.NotEmpty(t, inner.UID, + "Insert did not route the wrapper to its inner record") + + var got studioRecord + require.NoError(t, client.Get(ctx, &got, inner.UID)) + require.Equal(t, "Acme", got.Name) +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 0000000..caf5083 --- /dev/null +++ b/record_test.go @@ -0,0 +1,139 @@ +package modusgraph + +import ( + "errors" + "testing" +) + +type fakeRecord struct{ name string } + +func (f *fakeRecord) SchemaTypeName() string { return f.name } + +type fakeWrapper struct{ inner *fakeRecord } + +func (w *fakeWrapper) Unwrap() *fakeRecord { return w.inner } + +type fakeNonSchema struct{ X string } + +func TestUnwrapSchema_PassthroughForPlainStruct(t *testing.T) { + in := &fakeNonSchema{X: "hi"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough, got %T", out) + } +} + +func TestUnwrapSchema_PassthroughForSchemaStruct(t *testing.T) { + in := &fakeRecord{name: "Studio"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough for direct Schema, got %T", out) + } +} + +func TestUnwrapSchema_UnwrapsWrapper(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + out := UnwrapSchema(w) + if out != any(inner) { + t.Fatalf("expected unwrapped inner, got %T (%v)", out, out) + } +} + +func TestUnwrapSchema_IgnoresErrorsUnwrap(t *testing.T) { + // errors.New("x") has no Unwrap; wrap one to get something with Unwrap() error. + inner := errors.New("inner") + outer := &wrappedErr{err: inner} + out := UnwrapSchema(outer) + if out != any(outer) { + t.Fatalf("expected passthrough for error wrapper, got %T", out) + } +} + +type wrappedErr struct{ err error } + +func (w *wrappedErr) Error() string { return w.err.Error() } +func (w *wrappedErr) Unwrap() error { return w.err } + +func TestUnwrapSchema_NilInput(t *testing.T) { + if out := UnwrapSchema(nil); out != nil { + t.Fatalf("expected nil for nil input, got %v", out) + } +} + +func TestUnwrapSchema_TypedNilPointerDoesNotPanic(t *testing.T) { + // fakeWrapper.Unwrap dereferences its receiver, so invoking it on a typed + // nil pointer would panic. UnwrapSchema must return the value untouched. + var w *fakeWrapper + out := UnwrapSchema(w) + if out != any(w) { + t.Fatalf("expected typed nil pointer passthrough, got %T (%v)", out, out) + } +} + +func TestUnwrapSchema_PointerReceiverUnwrapOnValue(t *testing.T) { + // fakeWrapper.Unwrap has a pointer receiver. Passing the wrapper by value + // must still unwrap: a value's method set excludes pointer-receiver methods, + // so UnwrapSchema looks Unwrap up on an addressable copy. + inner := &fakeRecord{name: "Studio"} + w := fakeWrapper{inner: inner} + out := UnwrapSchema(w) + if out != any(inner) { + t.Fatalf("expected unwrapped inner from value wrapper, got %T (%v)", out, out) + } +} + +// recordingClient is the minimal surface needed to verify that wrappers +// passed to the Client interface get unwrapped before reaching internal +// reflection. It records whatever it received and returns nil. Each method +// applies obj = UnwrapSchema(obj) at the top, mirroring the patch landing +// in this task. +type recordingClient struct { + seen []any +} + +func (c *recordingClient) capture(obj any) any { + obj = UnwrapSchema(obj) + c.seen = append(c.seen, obj) + return obj +} + +func TestUnwrapSchema_CaptureForwardsInner(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + c := &recordingClient{} + got := c.capture(w) + if got != any(inner) { + t.Fatalf("expected inner record, got %T (%v)", got, got) + } + if len(c.seen) != 1 || c.seen[0] != any(inner) { + t.Fatalf("expected recording to hold inner record, got %v", c.seen) + } +} + +func TestUnwrapSchema_CapturePassthroughForPlain(t *testing.T) { + plain := &fakeNonSchema{X: "y"} + c := &recordingClient{} + got := c.capture(plain) + if got != any(plain) { + t.Fatalf("expected plain struct passthrough, got %T", got) + } +} + +func TestUnwrapSchema_VariadicUnwrapsEachElement(t *testing.T) { + innerA := &fakeRecord{name: "Studio"} + innerB := &fakeRecord{name: "Film"} + templates := []any{ + &fakeWrapper{inner: innerA}, + innerB, // already a Schema; passthrough + } + for i, obj := range templates { + templates[i] = UnwrapSchema(obj) + } + if templates[0] != any(innerA) { + t.Fatalf("template[0]: expected innerA, got %T", templates[0]) + } + if templates[1] != any(innerB) { + t.Fatalf("template[1]: expected innerB (passthrough), got %T", templates[1]) + } +}