@@ -9,6 +9,7 @@ public sealed class NativeImportGenerator : IIncrementalGenerator
99
1010 static NativeImportGenerator ( )
1111 {
12+ // NOTE: This set is unnecessary (but just in case)...
1213 ContextualKeywordsThatNeedEscaping = new HashSet < string > ( StringComparer . Ordinal )
1314 {
1415 // Query keywords
@@ -18,7 +19,7 @@ static NativeImportGenerator()
1819 "async" , "await" ,
1920 // Other contextual keywords
2021 "when" , "yield" , "partial" , "file" , "required" , "init" , "set" , "get" , "add" , "remove" ,
21- "nameof" , "var" , "dynamic"
22+ "nameof" , "var" , "dynamic" , "field"
2223 } ;
2324 }
2425
@@ -38,8 +39,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
3839 var ( compilation , propDecls ) = source ;
3940 if ( propDecls . IsDefaultOrEmpty ) return ;
4041
41- var nativeImportAttr = compilation . GetTypeByMetadataName ( typeof ( NativeInvoke . NativeImportAttribute ) . FullName ! ) ;
42- var nativeImportMethodAttr = compilation . GetTypeByMetadataName ( typeof ( NativeInvoke . NativeImportMethodAttribute ) . FullName ! ) ;
42+ var nativeImportAttr = compilation . GetTypeByMetadataName ( typeof ( NativeImportAttribute ) . FullName ! ) ;
43+ var nativeImportMethodAttr = compilation . GetTypeByMetadataName ( typeof ( NativeImportMethodAttribute ) . FullName ! ) ;
4344 if ( nativeImportAttr is null || nativeImportMethodAttr is null ) return ; // Ensure we have our attributes
4445
4546 foreach ( var propDecl in propDecls )
@@ -123,30 +124,14 @@ private static void GenerateForProperty(
123124 NativeImportAttribute nativeImportAttr ;
124125 {
125126 var temp = new NativeImportAttribute ( string . Empty ) ; // Temporary instance to get the default values
126- var enforceBlittable = ( bool ) ( pAttr . NamedArguments
127- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . EnforceBlittable ) )
128- . Value . Value ?? temp . EnforceBlittable ) ;
129- var explicitOnly = ( bool ) ( pAttr . NamedArguments
130- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . ExplicitOnly ) )
131- . Value . Value ?? temp . ExplicitOnly ) ;
132- var inherited = ( bool ) ( pAttr . NamedArguments
133- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . Inherited ) )
134- . Value . Value ?? temp . Inherited ) ;
135- var lazy = ( bool ) ( pAttr . NamedArguments
136- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . Lazy ) )
137- . Value . Value ?? temp . Lazy ) ;
138- var defaultCc = ( CallingConvention ) ( pAttr . NamedArguments
139- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . CallingConvention ) )
140- . Value . Value ?? temp . CallingConvention ) ; // Fallback to platform-default
141- var suppressGCTransition = ( bool ) ( pAttr . NamedArguments
142- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . SuppressGCTransition ) )
143- . Value . Value ?? temp . SuppressGCTransition ) ;
144- var symbolPrefix = ( pAttr . NamedArguments
145- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . SymbolPrefix ) )
146- . Value . Value ?? temp . SymbolPrefix ) as string ;
147- var symbolSuffix = ( pAttr . NamedArguments
148- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportAttribute . SymbolSuffix ) )
149- . Value . Value ?? temp . SymbolSuffix ) as string ;
127+ var enforceBlittable = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . EnforceBlittable ) , spc , prop . Locations [ 0 ] , temp . EnforceBlittable ) ;
128+ var explicitOnly = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . ExplicitOnly ) , spc , prop . Locations [ 0 ] , temp . ExplicitOnly ) ;
129+ var inherited = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . Inherited ) , spc , prop . Locations [ 0 ] , temp . Inherited ) ;
130+ var lazy = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . Lazy ) , spc , prop . Locations [ 0 ] , temp . Lazy ) ;
131+ var defaultCc = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . CallingConvention ) , spc , prop . Locations [ 0 ] , temp . CallingConvention ) ;
132+ var suppressGCTransition = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . SuppressGCTransition ) , spc , prop . Locations [ 0 ] , temp . SuppressGCTransition ) ;
133+ var symbolPrefix = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . SymbolPrefix ) , spc , prop . Locations [ 0 ] , temp . SymbolPrefix ) ;
134+ var symbolSuffix = SafeGetNamedArgument ( pAttr , nameof ( NativeImportAttribute . SymbolSuffix ) , spc , prop . Locations [ 0 ] , temp . SymbolSuffix ) ;
150135 nativeImportAttr = new NativeImportAttribute ( libraryName ! )
151136 {
152137 EnforceBlittable = enforceBlittable ,
@@ -205,9 +190,9 @@ private static void GenerateForProperty(
205190 }
206191 else
207192 {
208- // Treat both null and empty/whitespace strings as explicit exclusion
193+ // Treat both null and empty strings as explicit exclusion
209194 entryPoint = argValue as string ; // may be null
210- shouldInclude = ! string . IsNullOrWhiteSpace ( entryPoint ) ;
195+ shouldInclude = ! string . IsNullOrEmpty ( entryPoint ) ;
211196 }
212197 }
213198 }
@@ -219,15 +204,10 @@ private static void GenerateForProperty(
219204 // Reconstruct the method attribute and create method data
220205 var name = $ "{ method . Name } _{ Guid . NewGuid ( ) : N} "; // Append a Guid to prevent name collisions for overloaded functions
221206 entryPoint = shouldInclude ? ResolveMethodEntryPoint ( entryPoint , method . Name , nativeImportAttr . SymbolPrefix , nativeImportAttr . SymbolSuffix ) : string . Empty ;
222- var cc = ( mAttr ? . NamedArguments
223- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportMethodAttribute . CallingConvention ) )
224- . Value . Value ?? null ) as CallingConvention ? ;
225- var suppressGCTransition = ( mAttr ? . NamedArguments
226- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportMethodAttribute . SuppressGCTransition ) )
227- . Value . Value ?? null ) as bool ? ;
228- var enforceBlittable = ( mAttr ? . NamedArguments
229- . FirstOrDefault ( static kv => kv . Key == nameof ( NativeImportMethodAttribute . EnforceBlittable ) )
230- . Value . Value ?? null ) as bool ? ;
207+ // Reminder: nativeImportAttr.* are the fallback defaults (resolved later), but here we use nullable to mean "not specified"
208+ var cc = SafeGetNamedArgument < CallingConvention ? > ( mAttr , nameof ( NativeImportMethodAttribute . CallingConvention ) , spc , method . Locations [ 0 ] , null ) ;
209+ var suppressGCTransition = SafeGetNamedArgument < bool ? > ( mAttr , nameof ( NativeImportMethodAttribute . SuppressGCTransition ) , spc , method . Locations [ 0 ] , null ) ;
210+ var enforceBlittable = SafeGetNamedArgument < bool ? > ( mAttr , nameof ( NativeImportMethodAttribute . EnforceBlittable ) , spc , method . Locations [ 0 ] , null ) ;
231211 var attr = ordinal . HasValue
232212 ? new NativeImportMethodAttribute ( ordinal . Value )
233213 : new NativeImportMethodAttribute ( entryPoint ) ;
@@ -681,6 +661,90 @@ private static bool ShouldEscapeIdentifier(string identifier)
681661 // These are keywords that are only reserved in specific contexts
682662 return ContextualKeywordsThatNeedEscaping . Contains ( identifier ) ;
683663 }
664+
665+ private static string GetFriendlyTypeName ( Type type )
666+ {
667+ // Nullable`1 -> T?
668+ if ( type . IsGenericType && type . GetGenericTypeDefinition ( ) == typeof ( Nullable < > ) )
669+ {
670+ var underlyingType = type . GetGenericArguments ( ) [ 0 ] ;
671+ return $ "{ underlyingType . Name } ?";
672+ }
673+ return type . Name ;
674+ }
675+
676+ private static T ? SafeGetNamedArgument < T > ( AttributeData ? attribute , string parameterName , SourceProductionContext context , Location location , T ? defaultValue = default )
677+ {
678+ // Safe type casting with validation
679+ if ( attribute ? . NamedArguments == null ) return defaultValue ;
680+
681+ var namedArg = attribute . NamedArguments . FirstOrDefault ( kv => kv . Key == parameterName ) ;
682+ if ( namedArg . Key == null ) return defaultValue ;
683+
684+ if ( namedArg . Value . Value == null ) return defaultValue ;
685+
686+ // Handle string
687+ if ( typeof ( T ) == typeof ( string ) && namedArg . Value . Value is string stringValue )
688+ {
689+ return ( T ? ) ( object ? ) stringValue ;
690+ }
691+ // Handle boolean
692+ if ( typeof ( T ) == typeof ( bool ) && namedArg . Value . Value is bool boolValue )
693+ {
694+ return ( T ? ) ( object ? ) boolValue ;
695+ }
696+ // Handle nullable boolean
697+ if ( typeof ( T ) == typeof ( bool ? ) && namedArg . Value . Value is bool boolValueNullable )
698+ {
699+ return ( T ? ) ( object ? ) boolValueNullable ;
700+ }
701+ // Handle primitive types (int, uint, long, ulong, short, ushort, byte, sbyte, float, double, decimal)
702+ if ( typeof ( T ) . IsPrimitive && namedArg . Value . Value . GetType ( ) . IsPrimitive )
703+ {
704+ return ( T ? ) ( object ? ) namedArg . Value . Value ;
705+ }
706+ // Handle nullable primitive types
707+ if ( typeof ( T ) . IsGenericType && typeof ( T ) . GetGenericTypeDefinition ( ) == typeof ( Nullable < > ) &&
708+ typeof ( T ) . GetGenericArguments ( ) [ 0 ] . IsPrimitive && namedArg . Value . Value . GetType ( ) . IsPrimitive )
709+ {
710+ return ( T ? ) ( object ? ) namedArg . Value . Value ;
711+ }
712+ // Handle enum types (including CallingConvention)
713+ if ( typeof ( T ) . IsEnum && namedArg . Value . Value . GetType ( ) . IsEnum )
714+ {
715+ return ( T ? ) ( object ? ) namedArg . Value . Value ;
716+ }
717+ // Handle nullable enum types
718+ if ( typeof ( T ) . IsGenericType && typeof ( T ) . GetGenericTypeDefinition ( ) == typeof ( Nullable < > ) &&
719+ typeof ( T ) . GetGenericArguments ( ) [ 0 ] . IsEnum && namedArg . Value . Value . GetType ( ) . IsEnum )
720+ {
721+ return ( T ? ) ( object ? ) namedArg . Value . Value ;
722+ }
723+ // Handle enum from integer conversion
724+ if ( typeof ( T ) . IsEnum && namedArg . Value . Value is int enumIntValue && Enum . IsDefined ( typeof ( T ) , enumIntValue ) )
725+ {
726+ return ( T ? ) Enum . ToObject ( typeof ( T ) , enumIntValue ) ;
727+ }
728+ // Handle nullable enum from integer conversion
729+ if ( typeof ( T ) . IsGenericType && typeof ( T ) . GetGenericTypeDefinition ( ) == typeof ( Nullable < > ) &&
730+ typeof ( T ) . GetGenericArguments ( ) [ 0 ] . IsEnum && namedArg . Value . Value is int enumIntValueNullable &&
731+ Enum . IsDefined ( typeof ( T ) . GetGenericArguments ( ) [ 0 ] , enumIntValueNullable ) )
732+ {
733+ var enumType = typeof ( T ) . GetGenericArguments ( ) [ 0 ] ;
734+ var enumValue = Enum . ToObject ( enumType , enumIntValueNullable ) ;
735+ return ( T ? ) ( object ? ) enumValue ;
736+ }
737+
738+ // Report warning for invalid type
739+ var expectedTypeName = GetFriendlyTypeName ( typeof ( T ) ) ;
740+ context . ReportDiagnostic ( Diagnostic . Create (
741+ Diagnostics . InvalidAttributeArgument ,
742+ location ,
743+ parameterName ,
744+ expectedTypeName ) ) ;
745+
746+ return defaultValue ;
747+ }
684748}
685749
686750internal static partial class Extensions
0 commit comments