Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ func isComplexEnum(envel idltype.IdlType) bool {
return false
}

func isOptionalComplexEnum(ty idltype.IdlType) bool {
switch v := ty.(type) {
case *idltype.Option:
return isComplexEnum(v.Option)
case *idltype.COption:
return isComplexEnum(v.COption)
}
return false
}

func register_TypeName_as_ComplexEnum(name string) {
typeRegistryComplexEnum[name] = struct{}{}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//nolint:all // Forked from anchor-go generator, maintaining original code structure
package generator

import (
"strings"
"testing"

"github.com/dave/jennifer/jen"
"github.com/gagliardetto/anchor-go/idl"
"github.com/gagliardetto/anchor-go/idl/idltype"
"github.com/stretchr/testify/assert"
)

// complexEnumGuard mirrors the condition used in gen_marshal_DefinedFieldsNamed
// and gen_unmarshal_DefinedFieldsNamed to decide whether a field is routed to
// the specialized enum encoder/parser or falls through to the generic
// Encode/Decode path.
func complexEnumGuard(ty idltype.IdlType) bool {
return isComplexEnum(ty) ||
(IsArray(ty) && isComplexEnum(ty.(*idltype.Array).Type)) ||
(IsVec(ty) && isComplexEnum(ty.(*idltype.Vec).Vec)) ||
isOptionalComplexEnum(ty)
}

func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) {
const name = "Outcome"
register_TypeName_as_ComplexEnum(name)
t.Cleanup(func() { delete(typeRegistryComplexEnum, name) })

defined := &idltype.Defined{Name: name}

assert.True(t, complexEnumGuard(defined), "bare Defined")
assert.True(t, complexEnumGuard(&idltype.Option{Option: defined}), "Option<ComplexEnum>")
assert.True(t, complexEnumGuard(&idltype.COption{COption: defined}), "COption<ComplexEnum>")
}

// TestComplexEnumGuard_rejectsNonComplexOptionals ensures the guard does NOT
// fire for Option/COption wrapping a non-complex Defined or a primitive.
// A false positive here would cause the switch to enter the Option/COption case
// where .Option.(*idltype.Defined) would panic on a non-Defined inner type.
func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) {
const complexName = "Outcome"
register_TypeName_as_ComplexEnum(complexName)
t.Cleanup(func() { delete(typeRegistryComplexEnum, complexName) })

nonComplex := &idltype.Defined{Name: "PlainStruct"}

assert.False(t, complexEnumGuard(&idltype.Option{Option: nonComplex}),
"Option<NonComplexDefined> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.COption{COption: nonComplex}),
"COption<NonComplexDefined> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.U64{}}),
"Option<U64> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.COption{COption: &idltype.U8{}}),
"COption<U8> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}),
"Option<Vec<ComplexEnum>> — nested containers not supported, must not match")
}

// TestComplexEnumCodegen_optionalComplexEnum runs the actual marshal/unmarshal
// generator with Option<ComplexEnum> and COption<ComplexEnum> fields and
// verifies the generated Go source uses the specialized enum encoder/parser
// instead of the generic Encode/Decode.
func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) {
const enumName = "Outcome"
register_TypeName_as_ComplexEnum(enumName)
t.Cleanup(func() { delete(typeRegistryComplexEnum, enumName) })

fields := idl.IdlDefinedFieldsNamed{
{Name: "id", Ty: &idltype.U64{}},
{Name: "verdict", Ty: &idltype.Option{Option: &idltype.Defined{Name: enumName}}},
{Name: "alt_verdict", Ty: &idltype.COption{COption: &idltype.Defined{Name: enumName}}},
{Name: "checksum", Ty: &idltype.U64{}},
}

marshalCode := gen_MarshalWithEncoder_struct(
&idl.Idl{}, false, "Report", "", fields, true,
)
unmarshalCode := gen_UnmarshalWithDecoder_struct(
&idl.Idl{}, false, "Report", "", fields,
)

f := jen.NewFile("fixture")
f.Add(marshalCode)
f.Add(unmarshalCode)
src := f.GoString()

// Specialized enum encoder/parser must appear.
assert.Contains(t, src, "EncodeOutcome",
"Option/COption<ComplexEnum> fields must call the specialized enum encoder")
assert.Contains(t, src, "DecodeOutcome",
"Option/COption<ComplexEnum> fields must call the specialized enum parser")

// Option flags must still be written/read.
assert.Contains(t, src, "WriteOption")
assert.Contains(t, src, "WriteCOption")
assert.Contains(t, src, "ReadOption")
assert.Contains(t, src, "ReadCOption")

// Only the two plain U64 fields (Id, Checksum) should use the generic
// encoder/decoder. If the enum fields also fall through, the count is 4.
assert.Equal(t, 2, strings.Count(src, ".Encode("),
"generic Encode must only be used for non-enum fields")
assert.Equal(t, 2, strings.Count(src, ".Decode("),
"generic Decode must only be used for non-enum fields")
}
63 changes: 62 additions & 1 deletion cmd/generate-bindings/solana/anchor-go/generator/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func gen_marshal_DefinedFieldsNamed(
body.Commentf("Serialize `%s`:", exportedArgName)
}

if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) {
if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) {
switch field.Ty.(type) {
case *idltype.Defined:
enumTypeName := field.Ty.(*idltype.Defined).Name
Expand Down Expand Up @@ -260,6 +260,12 @@ func gen_marshal_DefinedFieldsNamed(
)
})
})
case *idltype.Option:
enumTypeName := field.Ty.(*idltype.Option).Option.(*idltype.Defined).Name
gen_marshal_optionalComplexEnum(body, "WriteOption", enumTypeName, field, checkNil, nameFormatter, encoderVariableName, returnNilErr, exportedArgName)
case *idltype.COption:
enumTypeName := field.Ty.(*idltype.COption).COption.(*idltype.Defined).Name
gen_marshal_optionalComplexEnum(body, "WriteCOption", enumTypeName, field, checkNil, nameFormatter, encoderVariableName, returnNilErr, exportedArgName)
}
} else {
if IsOption(field.Ty) || IsCOption(field.Ty) {
Expand Down Expand Up @@ -380,3 +386,58 @@ func gen_marshal_DefinedFieldsNamed(
}
}
}

func gen_marshal_optionalComplexEnum(
body *Group,
optionalityWriterName string,
enumTypeName string,
field idl.IdlField,
checkNil bool,
nameFormatter func(field idl.IdlField) *Statement,
encoderVariableName string,
returnNilErr bool,
exportedArgName string,
) {
errReturn := func(wrapped Code) *Statement {
return ReturnFunc(func(returnBody *Group) {
if returnNilErr {
returnBody.Nil()
}
returnBody.Add(wrapped)
})
}
optionalityErr := func() *Statement {
return errReturn(
Qual(PkgAnchorGoErrors, "NewOption").Call(
Lit(exportedArgName),
Qual("fmt", "Errorf").Call(Lit("error while encoding optionality: %w"), Err()),
),
)
}
fieldErr := func() *Statement {
return errReturn(
Qual(PkgAnchorGoErrors, "NewField").Call(Lit(exportedArgName), Err()),
)
}

if checkNil {
body.BlockFunc(func(optGroup *Group) {
optGroup.If(nameFormatter(field).Op("==").Nil()).Block(
Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(False()),
If(Err().Op("!=").Nil()).Block(optionalityErr()),
).Else().Block(
Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()),
If(Err().Op("!=").Nil()).Block(optionalityErr()),
Err().Op("=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)),
If(Err().Op("!=").Nil()).Block(fieldErr()),
)
})
} else {
body.BlockFunc(func(optGroup *Group) {
optGroup.Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True())
optGroup.If(Err().Op("!=").Nil()).Block(optionalityErr())
optGroup.Err().Op("=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field))
optGroup.If(Err().Op("!=").Nil()).Block(fieldErr())
})
}
}
46 changes: 43 additions & 3 deletions cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ func gen_unmarshal_DefinedFieldsNamed(
body.Commentf("Deserialize `%s`:", exportedArgName)
}

if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) {
// TODO: this assumes this cannot be an option;
// - check whether this is an option?
if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) {
switch field.Ty.(type) {
case *idltype.Defined:
enumName := field.Ty.(*idltype.Defined).Name
Expand Down Expand Up @@ -325,6 +323,12 @@ func gen_unmarshal_DefinedFieldsNamed(
)
})
})
case *idltype.Option:
enumTypeName := field.Ty.(*idltype.Option).Option.(*idltype.Defined).Name
gen_unmarshal_optionalComplexEnum(body, "ReadOption", enumTypeName, exportedArgName)
case *idltype.COption:
enumTypeName := field.Ty.(*idltype.COption).COption.(*idltype.Defined).Name
gen_unmarshal_optionalComplexEnum(body, "ReadCOption", enumTypeName, exportedArgName)
}
} else {
if IsOption(field.Ty) || IsCOption(field.Ty) {
Expand Down Expand Up @@ -376,3 +380,39 @@ func gen_unmarshal_DefinedFieldsNamed(
}
}
}

func gen_unmarshal_optionalComplexEnum(
body *Group,
optionalityReaderName string,
enumTypeName string,
exportedArgName string,
) {
body.BlockFunc(func(optGroup *Group) {
optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call()
optGroup.If(Err().Op("!=").Nil()).Block(
Return(
Qual(PkgAnchorGoErrors, "NewOption").Call(
Lit(exportedArgName),
Qual("fmt", "Errorf").Call(
Lit("error while reading optionality: %w"),
Err(),
),
),
),
)
optGroup.If(Id("ok")).Block(
List(
Id("obj").Dot(exportedArgName),
Err(),
).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")),
If(Err().Op("!=").Nil()).Block(
Return(
Qual(PkgAnchorGoErrors, "NewField").Call(
Lit(exportedArgName),
Err(),
),
),
),
)
})
}