diff --git a/xls/common/attribute_data.cc b/xls/common/attribute_data.cc index 609b58c635..6405c5e750 100644 --- a/xls/common/attribute_data.cc +++ b/xls/common/attribute_data.cc @@ -40,6 +40,8 @@ std::string AttributeKindToString(AttributeKind kind) { return "channel_strictness"; case AttributeKind::kFuzzTest: return "fuzz_test"; + case AttributeKind::kFuzzDomain: + return "fuzz_domain"; } } diff --git a/xls/common/attribute_data.h b/xls/common/attribute_data.h index f931f27eab..4686615270 100644 --- a/xls/common/attribute_data.h +++ b/xls/common/attribute_data.h @@ -34,6 +34,7 @@ enum class AttributeKind : uint8_t { kQuickcheck, kChannelStrictness, kFuzzTest, + kFuzzDomain, }; // Converts an AttributeKind to a string, e.g., "fuzz_test" diff --git a/xls/dslx/bytecode/builtins.cc b/xls/dslx/bytecode/builtins.cc index c6ec7e0ae1..b16b223ddd 100644 --- a/xls/dslx/bytecode/builtins.cc +++ b/xls/dslx/bytecode/builtins.cc @@ -156,23 +156,7 @@ absl::Status BuiltinRangeInternal(InterpreterStack& stack) { const InterpValue& end) -> absl::StatusOr { XLS_RET_CHECK(start.IsBits()); XLS_RET_CHECK(end.IsBits()); - XLS_ASSIGN_OR_RETURN(InterpValue start_ge_end, start.Ge(end)); - if (start_ge_end.IsTrue()) { - return InterpValue::MakeRange({}); - } - - std::vector elements; - InterpValue cur = start; - XLS_ASSIGN_OR_RETURN(InterpValue done, cur.Ge(end)); - XLS_ASSIGN_OR_RETURN(int64_t cur_bits, cur.GetBitCount()); - InterpValue one(cur.IsSigned() ? InterpValue::MakeSBits(cur_bits, 1) - : InterpValue::MakeUBits(cur_bits, 1)); - while (done.IsFalse()) { - elements.push_back(cur); - XLS_ASSIGN_OR_RETURN(cur, cur.Add(one)); - XLS_ASSIGN_OR_RETURN(done, cur.Ge(end)); - } - return InterpValue::MakeRange(elements); + return InterpValue::MakeSymbolicRange(start, end); }, stack); } diff --git a/xls/dslx/bytecode/bytecode_interpreter_test.cc b/xls/dslx/bytecode/bytecode_interpreter_test.cc index 94b7a8e659..e44f04d619 100644 --- a/xls/dslx/bytecode/bytecode_interpreter_test.cc +++ b/xls/dslx/bytecode/bytecode_interpreter_test.cc @@ -2880,5 +2880,25 @@ fn main() -> u32[3] { InterpValue::MakeU32(15)})); } +TEST_F(BytecodeInterpreterTest, SymbolicRangeHuge) { + if (kDefaultTypeInferenceVersion == TypeInferenceVersion::kVersion2) { + // The range() builtin is deprecated and no longer supported in TIv2. + return; + } + constexpr std::string_view kProgram = R"( +fn main() -> u32[4294967295] { + range(u32:0, u32:4294967295) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(InterpValue value, Interpret(kProgram, "main")); + EXPECT_TRUE(value.is_range()); + XLS_ASSERT_OK_AND_ASSIGN(int64_t len, value.GetLength()); + EXPECT_EQ(len, 4294967295LL); + + XLS_ASSERT_OK_AND_ASSIGN(InterpValue elem, value.Index(123456789)); + XLS_ASSERT_OK_AND_ASSIGN(uint64_t val, elem.GetBitValueUnsigned()); + EXPECT_EQ(val, 123456789); +} + } // namespace } // namespace xls::dslx diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 62e63215e4..e502b0a3b9 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -3334,6 +3334,11 @@ class StructDefBase : public AstNode { const std::vector& members() const { return members_; } std::vector& mutable_members() { return struct_members_; } + void AddMember(StructMemberNode* member) { + members_.push_back(member); + struct_members_.push_back(member->ToStructMemberStruct()); + members_by_name_[member->name()] = member; + } bool is_public() const { return public_; } const Span& span() const { return span_; } @@ -3417,9 +3422,22 @@ class StructDef : public StructDefBase { return extern_type_name_; } + void set_is_domain_struct(bool v) { is_domain_struct_ = v; } + bool is_domain_struct() const { return is_domain_struct_; } + + void set_is_populated(bool v) { is_populated_ = v; } + bool is_populated() const { return is_populated_; } + + void set_original_struct(StructDef* s) { original_struct_ = s; } + StructDef* original_struct() const { return original_struct_; } + private: // The external verilog type name std::optional extern_type_name_; + + bool is_domain_struct_ = false; + bool is_populated_ = false; + StructDef* original_struct_ = nullptr; }; // Represents a proc declared with struct-like syntax, with the functions in an diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index a7aa721a8a..7528bb58d1 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -225,6 +226,7 @@ absl::StatusOr ParseAttributeKind(Token token, {"test_proc", AttributeKind::kTestProc}, {"quickcheck", AttributeKind::kQuickcheck}, {"fuzz_test", AttributeKind::kFuzzTest}, + {"fuzz_domain", AttributeKind::kFuzzDomain}, {"channel_strictness", AttributeKind::kChannelStrictness}}; const auto it = map->find(token.GetStringValue()); @@ -605,7 +607,7 @@ absl::StatusOr> Parser::ParseModule( TypeAlias * type_alias, ParseTypeAlias(*module_member_start_pos, is_public, *bindings)); XLS_RETURN_IF_ERROR( - ApplyTypeAttributes(type_alias, pending_attributes)); + ApplyTypeAttributes(type_alias, pending_attributes, *bindings)); XLS_RETURN_IF_ERROR(module_->AddTop(type_alias, make_collision_error)); break; } @@ -614,7 +616,7 @@ absl::StatusOr> Parser::ParseModule( StructDef * struct_def, ParseStruct(*module_member_start_pos, is_public, *bindings)); XLS_RETURN_IF_ERROR( - ApplyTypeAttributes(struct_def, pending_attributes)); + ApplyTypeAttributes(struct_def, pending_attributes, *bindings)); XLS_RETURN_IF_ERROR(module_->AddTop(struct_def, make_collision_error)); break; } @@ -622,7 +624,8 @@ absl::StatusOr> Parser::ParseModule( XLS_ASSIGN_OR_RETURN( EnumDef * enum_def, ParseEnumDef(*module_member_start_pos, is_public, *bindings)); - XLS_RETURN_IF_ERROR(ApplyTypeAttributes(enum_def, pending_attributes)); + XLS_RETURN_IF_ERROR( + ApplyTypeAttributes(enum_def, pending_attributes, *bindings)); XLS_RETURN_IF_ERROR(module_->AddTop(enum_def, make_collision_error)); break; } @@ -949,6 +952,7 @@ absl::Status Parser::UnsupportedAttributeError(const Attribute& attribute) { Span span = *attribute.GetSpan(); switch (attribute.attribute_kind()) { case AttributeKind::kDerive: + case AttributeKind::kFuzzDomain: return ParseErrorStatus( span, absl::StrCat(attribute.ToString(), " is only valid on a struct.")); @@ -1143,7 +1147,8 @@ absl::StatusOr Parser::ParseFuzzTestDomains( template absl::Status Parser::ApplyTypeAttributes(T* node, - std::vector attributes) { + std::vector attributes, + Bindings& bindings) { for (Attribute* next : attributes) { switch (next->attribute_kind()) { case AttributeKind::kDerive: { @@ -1180,6 +1185,45 @@ absl::Status Parser::ApplyTypeAttributes(T* node, break; } + case AttributeKind::kFuzzDomain: { + if constexpr (!std::is_same_v) { + return UnsupportedAttributeError(*next); + } else { + if (next->args().size() != 1 || + !std::holds_alternative( + next->args()[0])) { + return ParseErrorStatus( + *next->GetSpan(), + "fuzz_domain attribute requires a string argument."); + } + std::string domain_name = + std::get(next->args()[0]) + .text; + + StructDef* struct_def = node; + Span span = *next->GetSpan(); + NameDef* domain_name_def = + module_->Make(span, domain_name, /*definer=*/nullptr); + std::vector parametric_bindings; + std::vector members; + StructDef* domain_struct = module_->Make( + span, domain_name_def, std::move(parametric_bindings), + std::move(members), struct_def->is_public()); + domain_name_def->set_definer(domain_struct); + + domain_struct->set_is_domain_struct(true); + domain_struct->set_original_struct(struct_def); + + bindings.Add(domain_name, domain_struct); + + auto make_collision_error = + absl::bind_front(&MakeModuleTopCollisionError, file_table()); + XLS_RETURN_IF_ERROR( + module_->AddTop(domain_struct, make_collision_error)); + } + break; + } + default: return UnsupportedAttributeError(*next); } diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index 6fe45a0098..ad33d43739 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -674,7 +674,8 @@ class Parser : public TokenParser { const std::vector& attributes); template - absl::Status ApplyTypeAttributes(T* node, std::vector attributes); + absl::Status ApplyTypeAttributes(T* node, std::vector attributes, + Bindings& bindings); absl::StatusOr GetQuickCheckTestCases( const Attribute& attribute); diff --git a/xls/dslx/frontend/parser_test.cc b/xls/dslx/frontend/parser_test.cc index 504c71ea98..7fab7b251f 100644 --- a/xls/dslx/frontend/parser_test.cc +++ b/xls/dslx/frontend/parser_test.cc @@ -4964,4 +4964,27 @@ fn f(x: u32, y: u8) {} EXPECT_TRUE(second_member_tuple->members().empty()); } +TEST_F(ParserTest, FuzzDomainAttribute) { + std::string_view program = R"( +#[fuzz_domain("MyDomain")] +struct MyStruct { + x: u32, +} +)"; + auto module = ExpectParsesSuccessfully(program); + ASSERT_NE(module, nullptr); + + // Verify MyStruct exists. + XLS_ASSERT_OK_AND_ASSIGN(StructDef * my_struct, + module->GetMemberOrError("MyStruct")); + EXPECT_FALSE(my_struct->is_domain_struct()); + + // Verify MyDomain exists and is linked. + XLS_ASSERT_OK_AND_ASSIGN(StructDef * my_domain, + module->GetMemberOrError("MyDomain")); + EXPECT_TRUE(my_domain->is_domain_struct()); + EXPECT_FALSE(my_domain->is_populated()); + EXPECT_EQ(my_domain->original_struct(), my_struct); +} + } // namespace xls::dslx diff --git a/xls/dslx/interp_value.cc b/xls/dslx/interp_value.cc index 962fabfdf4..e6f7307b78 100644 --- a/xls/dslx/interp_value.cc +++ b/xls/dslx/interp_value.cc @@ -229,6 +229,14 @@ std::string InterpValue::ToStringInternal(bool humanize, return InterpValueBitsToString(*this, format, /*include_type_prefix=*/!humanize); case InterpValueTag::kArray: + if (is_range() && + std::holds_alternative>(payload_)) { + const RangeData& range = + *std::get>(payload_); + return absl::StrFormat( + "[%s..%s%s]", range.start.ToString(humanize, format), + range.inclusive ? "=" : "", range.end.ToString(humanize, format)); + } return absl::StrFormat("[%s]", make_guts()); case InterpValueTag::kTuple: return absl::StrFormat("(%s)", make_guts()); @@ -499,7 +507,36 @@ bool InterpValue::Eq(const InterpValue& other) const { if (!other.IsArray()) { return false; } - return values_equal(); + bool self_range = is_range(); + bool other_range = other.is_range(); + + if (self_range && other_range) { + bool self_symbolic = + std::holds_alternative>(payload_); + bool other_symbolic = + std::holds_alternative>(other.payload_); + if (self_symbolic && other_symbolic) { + return *std::get>(payload_) == + *std::get>(other.payload_); + } + } + + if (!self_range && !other_range) { + return values_equal(); + } + + // Element-by-element comparison fallback. + int64_t self_len = GetLength().value(); + int64_t other_len = other.GetLength().value(); + if (self_len != other_len) { + return false; + } + for (int64_t i = 0; i < self_len; ++i) { + if (Index(i).value() != other.Index(i).value()) { + return false; + } + } + return true; } case InterpValueTag::kTuple: { if (!other.IsTuple()) { @@ -741,6 +778,20 @@ absl::StatusOr InterpValue::Slice( } absl::StatusOr InterpValue::Index(int64_t index) const { + if (is_range() && + std::holds_alternative>(payload_)) { + const RangeData& range = *std::get>(payload_); + XLS_ASSIGN_OR_RETURN(int64_t len, GetLength()); + if (index >= len) { + return absl::InvalidArgumentError(absl::StrFormat( + "Index out of bounds; index: %d >= %d elements; lhs: %s", index, len, + ToString())); + } + XLS_ASSIGN_OR_RETURN(Bits start_bits, range.start.GetBits()); + InterpValue index_value = InterpValue::MakeBits( + range.start.IsSBits(), UBits(index, start_bits.bit_count())); + return range.start.Add(index_value); + } XLS_ASSIGN_OR_RETURN(const std::vector* lhs, GetValues()); if (lhs->size() <= index) { return absl::InvalidArgumentError(absl::StrFormat( @@ -752,15 +803,9 @@ absl::StatusOr InterpValue::Index(int64_t index) const { absl::StatusOr InterpValue::Index(const InterpValue& other) const { XLS_RET_CHECK(other.IsUBits()); - XLS_ASSIGN_OR_RETURN(const std::vector* lhs, GetValues()); XLS_ASSIGN_OR_RETURN(Bits rhs, other.GetBits()); XLS_ASSIGN_OR_RETURN(uint64_t index, rhs.ToUint64()); - if (lhs->size() <= index) { - return absl::InvalidArgumentError(absl::StrFormat( - "Index out of bounds; index: %d >= %d elements; lhs: %s", index, - lhs->size(), ToString())); - } - return (*lhs)[index]; + return Index(static_cast(index)); } absl::StatusOr InterpValue::Update( @@ -779,6 +824,9 @@ absl::StatusOr InterpValue::Update( return absl::InvalidArgumentError(absl::StrFormat( "Update of non-array element: %s", element->ToString())); } + if (element->is_range()) { + element->GetValuesOrDie(); // Forces expansion of the symbolic range + } std::vector& values = std::get>(element->payload_); XLS_ASSIGN_OR_RETURN(Bits index_bits, i.GetBits()); @@ -1218,4 +1266,88 @@ bool InterpValue::operator>=(const InterpValue& rhs) const { return !(*this < rhs); } +absl::StatusOr> InterpValue::ExpandRange() const { + XLS_RET_CHECK(is_range()); + const RangeData& range = *std::get>(payload_); + std::vector elements; + InterpValue cur = range.start; + XLS_ASSIGN_OR_RETURN(int64_t len, GetLength()); + elements.reserve(len); + + XLS_ASSIGN_OR_RETURN(int64_t cur_bits, cur.GetBitCount()); + InterpValue one(cur.IsSigned() ? InterpValue::MakeSBits(cur_bits, 1) + : InterpValue::MakeUBits(cur_bits, 1)); + + for (int64_t i = 0; i < len; ++i) { + elements.push_back(cur); + XLS_ASSIGN_OR_RETURN(cur, cur.Add(one)); + } + return elements; +} + +/* static */ InterpValue InterpValue::MakeSymbolicRange(InterpValue start, + InterpValue end, + bool inclusive) { + return InterpValue(InterpValueTag::kArray, + std::make_shared(RangeData{ + std::move(start), std::move(end), inclusive}), + /*is_range=*/true); +} + +absl::StatusOr InterpValue::GetLength() const { + if (is_range() && + std::holds_alternative>(payload_)) { + const RangeData& range = *std::get>(payload_); + XLS_ASSIGN_OR_RETURN(Bits start_bits, range.start.GetBits()); + XLS_ASSIGN_OR_RETURN(Bits end_bits, range.end.GetBits()); + + int64_t len = 0; + if (range.start.IsSigned()) { + XLS_ASSIGN_OR_RETURN(int64_t start_val, start_bits.ToInt64()); + XLS_ASSIGN_OR_RETURN(int64_t end_val, end_bits.ToInt64()); + if (start_val > end_val) { + return 0; + } + len = end_val - start_val; + } else { + XLS_ASSIGN_OR_RETURN(uint64_t start_val, start_bits.ToUint64()); + XLS_ASSIGN_OR_RETURN(uint64_t end_val, end_bits.ToUint64()); + if (start_val > end_val) { + return 0; + } + len = end_val - start_val; + } + + if (range.inclusive) { + len += 1; + } + return len; + } + if (IsTuple() || IsArray()) { + return GetValuesOrDie().size(); + } + return absl::InvalidArgumentError("Invalid tag for length query: " + + TagToString(tag_)); +} + +absl::StatusOr*> InterpValue::GetValues() const { + if (is_range() && + std::holds_alternative>(payload_)) { + XLS_ASSIGN_OR_RETURN(std::vector elements, ExpandRange()); + payload_ = std::move(elements); + } + if (!std::holds_alternative>(payload_)) { + return absl::InvalidArgumentError("Value does not hold element values"); + } + return &std::get>(payload_); +} + +const std::vector& InterpValue::GetValuesOrDie() const { + if (is_range() && + std::holds_alternative>(payload_)) { + payload_ = ExpandRange().value(); + } + return std::get>(payload_); +} + } // namespace xls::dslx diff --git a/xls/dslx/interp_value.h b/xls/dslx/interp_value.h index 9450b4f3f2..71afedfc62 100644 --- a/xls/dslx/interp_value.h +++ b/xls/dslx/interp_value.h @@ -31,6 +31,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xls/common/status/status_macros.h" #include "xls/dslx/channel_direction.h" #include "xls/dslx/dslx_builtins.h" #include "xls/dslx/frontend/ast.h" @@ -41,6 +42,8 @@ namespace xls::dslx { +struct RangeData; + // Tags a value to denote its payload. // // Note this goes beyond InterpValue::Payload annotating things like whether the @@ -279,6 +282,8 @@ class InterpValue { } static absl::StatusOr MakeRange( std::vector elements); + static InterpValue MakeSymbolicRange(InterpValue start, InterpValue end, + bool inclusive = false); static absl::StatusOr MakeBits(InterpValueTag tag, Bits bits) { if (tag != InterpValueTag::kUBits && tag != InterpValueTag::kSBits) { @@ -385,13 +390,7 @@ class InterpValue { bool FitsInNBitsUnsigned(int64_t n) const; bool FitsInNBitsSigned(int64_t n) const; - absl::StatusOr GetLength() const { - if (IsTuple() || IsArray()) { - return GetValuesOrDie().size(); - } - return absl::InvalidArgumentError("Invalid tag for length query: " + - TagToString(tag_)); - } + absl::StatusOr GetLength() const; absl::StatusOr ZeroExt(int64_t new_bit_count) const; absl::StatusOr SignExt(int64_t new_bit_count) const; @@ -476,15 +475,16 @@ class InterpValue { InterpValueTag tag() const { return tag_; } - absl::StatusOr*> GetValues() const { - if (!std::holds_alternative>(payload_)) { - return absl::InvalidArgumentError("Value does not hold element values"); + absl::StatusOr*> GetValues() const; + const std::vector& GetValuesOrDie() const; + std::optional> GetRangeData() const { + if (is_range() && + std::holds_alternative>(payload_)) { + return std::get>(payload_); } - return &std::get>(payload_); - } - const std::vector& GetValuesOrDie() const { - return std::get>(payload_); + return std::nullopt; } + absl::StatusOr GetFunction() const { if (!std::holds_alternative(payload_)) { return absl::InvalidArgumentError( @@ -598,6 +598,7 @@ class InterpValue { // // TODO(leary): 2020-02-10 When all Python bindings are eliminated we can more // easily make an interpreter scoped lifetime that InterpValues can live in. + struct TypeReference { const TypeAnnotation* annotation; std::string string; @@ -606,7 +607,8 @@ class InterpValue { using Payload = std::variant, FnData, std::shared_ptr, ChannelReference, - TypeReference, std::shared_ptr>; + TypeReference, std::shared_ptr, + std::shared_ptr>; InterpValue(InterpValueTag tag, Payload payload, bool is_range = false) : tag_(tag), payload_(std::move(payload)), is_range_(is_range) {} @@ -618,8 +620,10 @@ class InterpValue { const InterpValue& rhs, CompareF ucmp, CompareF scmp); + absl::StatusOr> ExpandRange() const; + InterpValueTag tag_; - Payload payload_; + mutable Payload payload_; bool is_range_; }; @@ -628,6 +632,17 @@ H AbslHashValue(H state, const InterpValue::UserFnData& v) { return H::combine(std::move(state), v.module, v.function); } +struct RangeData { + InterpValue start; + InterpValue end; + bool inclusive; + + bool operator==(const RangeData& other) const { + return start == other.start && end == other.end && + inclusive == other.inclusive; + } +}; + } // namespace xls::dslx #endif // XLS_DSLX_INTERP_VALUE_H_ diff --git a/xls/dslx/ir_convert/function_converter_fuzztest_test.cc b/xls/dslx/ir_convert/function_converter_fuzztest_test.cc index b68fb93e94..e656367312 100644 --- a/xls/dslx/ir_convert/function_converter_fuzztest_test.cc +++ b/xls/dslx/ir_convert/function_converter_fuzztest_test.cc @@ -947,7 +947,8 @@ fn f(x: u32) -> u32 { x } "test_module.x", "test_module", &import_data), absl_testing::StatusIs( absl::StatusCode::kInvalidArgument, - testing::HasSubstr("Unsupported fuzz test domain"))); + testing::HasSubstr("Expected range or set domain for scalar " + "parameter; got type ubits"))); } TEST(FunctionConverterFuzzTestTest, EmptyArrayDomain) { @@ -1397,5 +1398,65 @@ fn f(o: Outer) -> u32 { o.x } EXPECT_EQ(a_domain.range().max().bits().data(), std::string{'\004'}); } +TEST(FunctionConverterFuzzTestTest, DerivedStructDomainSuccess) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_domain("MyStructDomain")] +struct MyStruct { + x: u32, +} + +fn create_f_domain() -> MyStructDomain { + MyStructDomain { + x: u32:0..10, + } +} + +#[fuzz_test(domains=`create_f_domain()`)] +fn f(s: MyStruct) {} +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + absl::Span attributes = ir_fn->attributes(); + const AttributeData::Argument& arg = attributes[0].args()[0]; + const auto& skv = std::get(arg); + + xls::PackageInterfaceProto::Function function_proto; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(skv.second, &function_proto)); + ASSERT_EQ(function_proto.parameter_domains_size(), 1); + const auto& domain = function_proto.parameter_domains(0); + + ASSERT_TRUE(domain.has_tuple()); + ASSERT_EQ(domain.tuple().elements_size(), 1); + + const auto& x_domain = domain.tuple().elements(0); + ASSERT_TRUE(x_domain.has_range()); + EXPECT_EQ(x_domain.range().min().bits().data(), + std::string("\000\000\000\000", 4)); + EXPECT_EQ(x_domain.range().max().bits().data(), + std::string("\t\000\000\000", 4)); +} + } // namespace } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/fuzz_test_converter.cc b/xls/dslx/ir_convert/fuzz_test_converter.cc index e956d5f2bf..a747557c60 100644 --- a/xls/dslx/ir_convert/fuzz_test_converter.cc +++ b/xls/dslx/ir_convert/fuzz_test_converter.cc @@ -84,13 +84,15 @@ absl::Status FuzzTestConverter::LowerTuple( return absl::OkStatus(); } auto* tuple_proto = proto.mutable_tuple(); - const TupleType* tuple_type = nullptr; - if (param_type != nullptr && param_type->IsTuple()) { - tuple_type = ¶m_type->AsTuple(); - } for (size_t i = 0; i < elements.size(); ++i) { - const Type* member_type = - (tuple_type != nullptr) ? tuple_type->members()[i].get() : nullptr; + const Type* member_type = nullptr; + if (param_type != nullptr) { + if (param_type->IsTuple()) { + member_type = ¶m_type->AsTuple().GetMemberType(i); + } else if (param_type->IsStruct()) { + member_type = ¶m_type->AsStruct().GetMemberType(i); + } + } const InterpValue& element = elements[i]; XLS_RETURN_IF_ERROR( LowerConstant(member_type, element, *tuple_proto->add_elements())); @@ -166,9 +168,23 @@ absl::Status FuzzTestConverter::LowerConstant( const Type* param_type, const InterpValue& val, PackageInterfaceProto::FuzzTestDomain& proto) { if (val.is_range()) { - // InterpValues that originated as ranges are stored as an array of - // elements, so we need to get the first and last entries in the array - // (the min and max of the range). + std::optional> range_data = val.GetRangeData(); + if (range_data.has_value()) { + const RangeData& range = **range_data; + XLS_ASSIGN_OR_RETURN(int64_t len, val.GetLength()); + if (len == 0) { + return absl::InvalidArgumentError( + "Empty ranges are unsupported as fuzztest domains"); + } + InterpValue max_val = range.end; + if (!range.inclusive) { + std::optional dec = range.end.Decrement(); + XLS_RET_CHECK(dec.has_value()); + max_val = *dec; + } + return LowerRange(range.start, max_val, proto); + } + // Fallback for expanded ranges. XLS_ASSIGN_OR_RETURN(const std::vector* elements, val.GetValues()); if (elements->empty()) { @@ -293,7 +309,6 @@ FuzzTestConverter::LowerFuzzTestDomains(const Function* node) { google::protobuf::TextFormat::Printer printer; printer.SetSingleLineMode(true); XLS_RET_CHECK(printer.PrintToString(temp_func, &proto_str)); - std::vector args; args.push_back( AttributeData::StringKeyValueArgument{.first = "domains", diff --git a/xls/dslx/type_system_v2/BUILD b/xls/dslx/type_system_v2/BUILD index e012b10faa..672c8aa8d7 100644 --- a/xls/dslx/type_system_v2/BUILD +++ b/xls/dslx/type_system_v2/BUILD @@ -547,6 +547,7 @@ cc_library( "//xls/ir:format_strings", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/xls/dslx/type_system_v2/populate_table_visitor.cc b/xls/dslx/type_system_v2/populate_table_visitor.cc index a1d659743f..0ad95d60df 100644 --- a/xls/dslx/type_system_v2/populate_table_visitor.cc +++ b/xls/dslx/type_system_v2/populate_table_visitor.cc @@ -27,6 +27,7 @@ #include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -92,7 +93,17 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, typecheck_imported_module_(std::move(typecheck_imported_module)) {} absl::Status PopulateFromModule(const Module* module) override { - return module->Accept(this); + XLS_RETURN_IF_ERROR(module->Accept(this)); + for (const ModuleMember& member : module->top()) { + if (std::holds_alternative(member)) { + StructDef* struct_def = std::get(member); + if (struct_def->is_domain_struct() && !struct_def->is_populated()) { + XLS_RETURN_IF_ERROR(MaybePopulateDomainStruct(struct_def)); + XLS_RETURN_IF_ERROR(struct_def->Accept(this)); + } + } + } + return absl::OkStatus(); } absl::Status PopulateFromInvocation(const Invocation* invocation) override { @@ -1424,6 +1435,8 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, // annotation for each member and cache it for use by type unification of // instances of this struct. absl::Status HandleStructDef(const StructDef* node) override { + XLS_RETURN_IF_ERROR( + MaybePopulateDomainStruct(const_cast(node))); return HandleStructDefBaseInternal(node); } @@ -1974,6 +1987,87 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, } private: + absl::Status MaybePopulateDomainStruct(StructDef* struct_def) { + if (struct_def == nullptr || !struct_def->is_domain_struct() || + struct_def->is_populated()) { + return absl::OkStatus(); + } + + StructDef* original = struct_def->original_struct(); + XLS_RET_CHECK(original != nullptr); + + for (const StructMemberNode* member : original->members()) { + TypeAnnotation* member_type = member->type(); + XLS_ASSIGN_OR_RETURN(std::optional struct_ref, + GetStructOrProcRef(member_type, import_data_)); + + TypeAnnotation* domain_member_type = nullptr; + if (struct_ref.has_value()) { + const StructDef* nested_struct = + dynamic_cast(struct_ref->def); + if (nested_struct != nullptr) { + std::optional fuzz_domain_attr; + for (Attribute* attr : nested_struct->attributes()) { + if (attr->attribute_kind() == AttributeKind::kFuzzDomain) { + fuzz_domain_attr = attr; + break; + } + } + if (fuzz_domain_attr.has_value()) { + std::string domain_name = + std::get( + (*fuzz_domain_attr)->args()[0]) + .text; + Module* nested_module = nested_struct->owner(); + XLS_ASSIGN_OR_RETURN( + StructDef * nested_domain_struct, + nested_module->GetMemberOrError(domain_name)); + + XLS_RETURN_IF_ERROR( + MaybePopulateDomainStruct(nested_domain_struct)); + + if (auto* type_ref_type = + dynamic_cast(member_type)) { + TypeRef* type_ref = type_ref_type->type_ref(); + TypeDefinition def = type_ref->type_definition(); + if (std::holds_alternative(def)) { + const ColonRef* colon_ref = std::get(def); + ColonRef* domain_colon_ref = module_.Make( + colon_ref->span(), colon_ref->subject(), domain_name); + domain_member_type = module_.Make( + member->span(), + module_.Make(member->span(), + TypeDefinition(domain_colon_ref)), + std::vector(), std::nullopt); + } else if (std::holds_alternative(def)) { + domain_member_type = module_.Make( + member->span(), + module_.Make(member->span(), + TypeDefinition(nested_domain_struct)), + std::vector(), std::nullopt); + } + } + } + } + } + + if (domain_member_type == nullptr) { + domain_member_type = module_.Make( + member->span(), std::vector()); + } + + NameDef* member_name_def = module_.Make( + member->span(), member->name(), /*definer=*/nullptr); + StructMemberNode* domain_member = module_.Make( + member->span(), member_name_def, member->colon_span(), + domain_member_type); + member_name_def->set_definer(domain_member); + + struct_def->AddMember(domain_member); + } + struct_def->set_is_populated(true); + return absl::OkStatus(); + } // Determines the target of the given `ColonRef` that is already known to be // referencing a member with the name `attribute` of the given `struct_def`. // Associates the target node with the `ColonRef` in the `InferenceTable` for @@ -2272,6 +2366,19 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, } const StructDefBase* struct_def = struct_or_proc_ref->def; + + bool old_in_fuzz_test_domain = in_fuzz_test_domain_; + const StructDef* concrete_struct_def = + dynamic_cast(struct_def); + if (concrete_struct_def != nullptr && + concrete_struct_def->is_domain_struct()) { + in_fuzz_test_domain_ = true; + } + absl::Cleanup restore_in_fuzz_test_domain = [this, + old_in_fuzz_test_domain] { + in_fuzz_test_domain_ = old_in_fuzz_test_domain; + }; + const NameRef* type_variable = *table_.GetTypeVariable(node); if (source.has_value()) { XLS_RETURN_IF_ERROR(table_.SetTypeVariable(*source, type_variable)); diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_array_tuple_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_array_tuple_test.cc index ca952074dc..9efd1ab057 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_array_tuple_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_array_tuple_test.cc @@ -1579,8 +1579,8 @@ TEST(TypecheckV2Test, FuzzTestDomainArrayMismatch) { #[fuzz_test(domains=`[u32:0, 16384]`)] fn f(x: u8) {} )", - TypecheckFails(HasSubstr("Fuzz test domain `[u32:0, 16384]` is not " - "compatible with parameter `x: u8`"))); + TypecheckFails(HasSubstr("Fuzz test domain bit count (32) does not " + "match parameter bit count (8)"))); } TEST(TypecheckV2Test, FuzzTestDomainsEmptyTupleAlwaysMatches) { diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_function_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_function_test.cc index 7e8717f42f..915a65f013 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_function_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_function_test.cc @@ -1191,7 +1191,7 @@ TEST(TypecheckV2Test, FuzzTestDomainNotSupported) { #[fuzz_test(domains=`u8:0`)] fn f(x: u8) {} )", - TypecheckFails(HasSubstr("Unsupported fuzz test domain `u8:0`"))); + TypecheckFails(HasSubstr("Expected range or set domain"))); } TEST(TypecheckV2Test, FuzzTestConstRange) { diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_struct_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_struct_test.cc index 2d4b33b654..2e104f68d9 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_struct_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_struct_test.cc @@ -2437,5 +2437,50 @@ fn f(s: S) {} TypecheckSucceeds(::testing::_)); } +TEST(TypecheckV2StructTest, FuzzTestDerivedStructDomainSuccess) { + EXPECT_THAT(R"( +#[fuzz_domain("MyStructDomain")] +struct MyStruct { + x: u32, +} + +fn create_f_domain() -> MyStructDomain { + MyStructDomain { + x: u32:0..10, + } +} + +#[fuzz_test(domains=`create_f_domain()`)] +fn f(s: MyStruct) {} +)", + TypecheckSucceeds(::testing::_)); +} + +TEST(TypecheckV2StructTest, FuzzTestDerivedStructDomainNestedSuccess) { + EXPECT_THAT(R"( +#[fuzz_domain("InnerDomain")] +struct Inner { + y: u32, +} + +#[fuzz_domain("OuterDomain")] +struct Outer { + x: Inner, +} + +fn create_f_domain() -> OuterDomain { + OuterDomain { + x: InnerDomain { + y: u32:0..10, + }, + } +} + +#[fuzz_test(domains=`create_f_domain()`)] +fn f(s: Outer) {} +)", + TypecheckSucceeds(::testing::_)); +} + } // namespace } // namespace xls::dslx diff --git a/xls/dslx/type_system_v2/validate_concrete_type.cc b/xls/dslx/type_system_v2/validate_concrete_type.cc index 6f57726e89..43116d3724 100644 --- a/xls/dslx/type_system_v2/validate_concrete_type.cc +++ b/xls/dslx/type_system_v2/validate_concrete_type.cc @@ -454,126 +454,300 @@ class TypeValidator : public AstNodeVisitorWithDefault { } private: - // Validates that a fuzz test domain type is compatible with the corresponding - // parameter type, recursively for tuples. - absl::Status ValidateFuzzTestDomainType(const Type* domain_type, - const Type* param_type, - const Span& span, - std::string_view domain_str, - std::string_view param_str) { - if (domain_type->IsTuple()) { - const TupleType& domain_tuple = domain_type->AsTuple(); - if (domain_tuple.empty()) { - // Empty domain for this parameter; this is considered an "Arbitrary" - // domain and always matches. - return absl::OkStatus(); + // Evaluates the given expression to an InterpValue using ConstexprEvaluator. + absl::StatusOr Evaluate(const Expr* expr) { + const ParametricEnv env = table_.GetParametricEnv(parametric_context_); + return ConstexprEvaluator::EvaluateToValue( + const_cast(&import_data_), const_cast(&ti_), + &warning_collector_, env, const_cast(expr)); + } + + absl::StatusOr CompleteStructInstance(const StructInstance* node, + Module& module) { + XLS_ASSIGN_OR_RETURN(std::optional struct_ref, + GetStructOrProcRef(node->struct_ref(), import_data_)); + XLS_RET_CHECK(struct_ref.has_value()); + const StructDefBase* struct_def = struct_ref->def; + + std::vector> new_members; + absl::flat_hash_map existing_members; + for (const auto& [name, expr] : node->members()) { + existing_members[name] = expr; + } + + for (const StructMemberNode* formal_member : struct_def->members()) { + std::string name = formal_member->name(); + auto it = existing_members.find(name); + if (it != existing_members.end()) { + XLS_ASSIGN_OR_RETURN(const Expr* completed_member, + CompleteExpr(it->second, module)); + new_members.push_back({name, const_cast(completed_member)}); + } else { + XlsTuple* empty_tuple = + module.Make(node->span(), std::vector(), + /*has_multiline_elements=*/false); + new_members.push_back({name, empty_tuple}); + } + } + + return module.Make( + node->span(), const_cast(node->struct_ref()), + std::move(new_members)); + } + + absl::StatusOr CompleteExpr(const Expr* expr, Module& module) { + if (expr->kind() == AstNodeKind::kStructInstance) { + return CompleteStructInstance( + absl::down_cast(expr), module); + } + if (expr->kind() == AstNodeKind::kXlsTuple) { + const XlsTuple* tuple = absl::down_cast(expr); + std::vector new_members; + for (const Expr* member : tuple->members()) { + XLS_ASSIGN_OR_RETURN(const Expr* completed_member, + CompleteExpr(member, module)); + new_members.push_back(const_cast(completed_member)); } - if (!param_type->IsTuple()) { + return module.Make(tuple->span(), std::move(new_members), + tuple->has_trailing_comma()); + } + return expr; + } + + absl::Status RegisterTypes(const Expr* original, const Expr* completed, + TypeInfo& ti) { + std::optional type = ti.GetItem(original); + XLS_RET_CHECK(type.has_value()); + ti.SetItem(completed, **type); + + if (original->kind() == AstNodeKind::kStructInstance) { + const StructInstance* orig_struct = + absl::down_cast(original); + const StructInstance* comp_struct = + absl::down_cast(completed); + + std::optional ref_type = + ti.GetItem(orig_struct->struct_ref()); + if (ref_type.has_value()) { + ti.SetItem(comp_struct->struct_ref(), **ref_type); + } + + absl::flat_hash_map orig_members; + for (const auto& [name, expr] : orig_struct->members()) { + orig_members[name] = expr; + } + + for (const auto& [name, comp_expr] : comp_struct->members()) { + auto it = orig_members.find(name); + if (it != orig_members.end()) { + XLS_RETURN_IF_ERROR(RegisterTypes(it->second, comp_expr, ti)); + } else { + ti.SetItem(comp_expr, Type::MakeUnit()); + } + } + } else if (original->kind() == AstNodeKind::kXlsTuple) { + const XlsTuple* orig_tuple = absl::down_cast(original); + const XlsTuple* comp_tuple = absl::down_cast(completed); + XLS_RET_CHECK_EQ(orig_tuple->members().size(), + comp_tuple->members().size()); + for (int i = 0; i < orig_tuple->members().size(); ++i) { + XLS_RETURN_IF_ERROR(RegisterTypes(orig_tuple->members()[i], + comp_tuple->members()[i], ti)); + } + } + return absl::OkStatus(); + } + + // Validates that a single scalar domain value is compatible with the + // corresponding parameter type. + absl::Status ValidateScalarDomainValue(const InterpValue& value, + const Type* param_type, + const Span& span, + std::string_view param_str) { + if (auto* bits_type = dynamic_cast(param_type)) { + if (!value.IsBits()) { return TypeInferenceErrorStatus( span, param_type, - "Fuzz test domain implies a tuple type, but parameter is not a " - "tuple.", + absl::StrFormat("Expected bits domain value; got %s", + value.ToString()), file_table_); } - const TupleType& param_tuple = param_type->AsTuple(); - if (domain_tuple.size() != param_tuple.size()) { + XLS_ASSIGN_OR_RETURN(int64_t val_bits, value.GetBitCount()); + if (val_bits != bits_type->size().GetAsInt64().value()) { return TypeInferenceErrorStatus( span, param_type, - absl::Substitute("Fuzz test domain tuple size ($0) does not match " - "parameter tuple size ($1).", - domain_tuple.size(), param_tuple.size()), + absl::Substitute("Fuzz test domain bit count ($0) does not match " + "parameter bit count ($1).", + val_bits, bits_type->size().GetAsInt64().value()), file_table_); } - for (int i = 0; i < domain_tuple.size(); ++i) { - const Type& domain_member = domain_tuple.GetMemberType(i); - const Type& param_member = param_tuple.GetMemberType(i); - XLS_RETURN_IF_ERROR(ValidateFuzzTestDomainType( - &domain_member, ¶m_member, span, domain_member.ToString(), - param_member.ToString())); + if (value.IsSigned() != bits_type->is_signed()) { + return TypeInferenceErrorStatus( + span, param_type, + absl::Substitute("Fuzz test domain signedness ($0) does not match " + "parameter signedness ($1).", + value.IsSigned() ? "signed" : "unsigned", + bits_type->is_signed() ? "signed" : "unsigned"), + file_table_); } return absl::OkStatus(); } - - if (domain_type->IsArray()) { - const ArrayType& array_type = domain_type->AsArray(); - const Type& element_type = array_type.element_type(); - if (!param_type->CompatibleWith(element_type)) { + if (param_type->IsEnum()) { + if (!value.IsEnum()) { return TypeInferenceErrorStatus( - span, param_type, - absl::Substitute("Fuzz test domain `$0` is not compatible with " - "parameter `$1`.", - domain_str, param_str), - file_table_); + span, param_type, "Expected enum domain value", file_table_); + } + const EnumType* enum_type = absl::down_cast(param_type); + std::optional enum_data = value.GetEnumData(); + XLS_RET_CHECK(enum_data.has_value()); + if (enum_data->def != &enum_type->nominal_type()) { + return TypeInferenceErrorStatus(span, param_type, "Enum type mismatch", + file_table_); } return absl::OkStatus(); } - return TypeInferenceErrorStatus( span, param_type, - absl::Substitute("Unsupported fuzz test domain `$0` of type `$1`.", - domain_str, domain_type->ToString()), + absl::Substitute("Unsupported parameter type for scalar domain: $0", + param_type->ToString()), file_table_); } - absl::Status ValidateStructDomain(const StructInstance* domain, - const Type* param_type, - std::string_view param_str) { - if (!param_type->IsStruct()) { - return TypeInferenceErrorStatus( - domain->span(), param_type, - absl::Substitute("Fuzz test domain implies a struct type, but " - "parameter is of type `$0`.", - param_type->ToString()), - file_table_); + // Validates that a fuzz test domain value is compatible with the + // corresponding parameter type, recursively. + absl::Status ValidateFuzzTestDomainValue(const InterpValue& value, + const Type* param_type, + const Span& span, + std::string_view param_str) { + if (value.IsTuple()) { + XLS_ASSIGN_OR_RETURN(const std::vector* elements, + value.GetValues()); + if (elements->empty()) { + return absl::OkStatus(); + } + if (!param_type->IsTuple() && !param_type->IsStruct()) { + return TypeInferenceErrorStatus( + span, param_type, + "Fuzz test domain implies a tuple type, but parameter is not a " + "tuple.", + file_table_); + } } - const StructType* struct_type = - absl::down_cast(param_type); - absl::flat_hash_map formal_members; - for (int i = 0; i < struct_type->size(); ++i) { - formal_members[struct_type->GetMemberName(i)] = - &struct_type->GetMemberType(i); + if (param_type->IsTuple()) { + if (!value.IsTuple()) { + return TypeInferenceErrorStatus( + span, param_type, + "Fuzz test domain is not a tuple, but parameter is a tuple.", + file_table_); + } + XLS_ASSIGN_OR_RETURN(const std::vector* elements, + value.GetValues()); + const TupleType* tuple_type = + absl::down_cast(param_type); + if (elements->size() != tuple_type->size()) { + return TypeInferenceErrorStatus( + span, param_type, + absl::Substitute("Fuzz test domain tuple size ($0) does not match " + "parameter tuple size ($1).", + elements->size(), tuple_type->size()), + file_table_); + } + for (int i = 0; i < elements->size(); ++i) { + const Type& member_type = tuple_type->GetMemberType(i); + XLS_RETURN_IF_ERROR(ValidateFuzzTestDomainValue( + elements->at(i), &member_type, span, member_type.ToString())); + } + return absl::OkStatus(); } - for (const auto& [name, actual_member] : domain->members()) { - auto it = formal_members.find(name); - // Extraneous fields are already caught during type checking in - // ValidateStructInstanceMemberNames, so we can assume they exist here. - XLS_RET_CHECK(it != formal_members.end()) << "Extraneous member " << name; - const Type* formal_member_type = it->second; - XLS_RETURN_IF_ERROR(ValidateFuzzTestDomain( - actual_member, formal_member_type, formal_member_type->ToString())); + if (param_type->IsStruct()) { + if (!value.IsTuple()) { + return TypeInferenceErrorStatus( + span, param_type, + "Fuzz test domain is not a struct, but parameter is a struct.", + file_table_); + } + XLS_ASSIGN_OR_RETURN(const std::vector* elements, + value.GetValues()); + const StructType* struct_type = + absl::down_cast(param_type); + if (elements->size() != struct_type->size()) { + return TypeInferenceErrorStatus( + span, param_type, + absl::Substitute("Fuzz test domain struct size ($0) does not match " + "parameter struct size ($1).", + elements->size(), struct_type->size()), + file_table_); + } + for (int i = 0; i < elements->size(); ++i) { + const Type& member_type = struct_type->GetMemberType(i); + XLS_RETURN_IF_ERROR(ValidateFuzzTestDomainValue( + elements->at(i), &member_type, span, member_type.ToString())); + } + return absl::OkStatus(); } - return absl::OkStatus(); - } - absl::Status ValidateTupleDomain(const XlsTuple* domain, - const Type* param_type, - std::string_view param_str) { - if (domain->members().empty()) { - // Empty domain for this parameter; this is considered an "Arbitrary" - // domain and always matches. + if (param_type->IsArray()) { + const ArrayType* array_type = + absl::down_cast(param_type); + if (!value.IsArray() || value.is_range()) { + return TypeInferenceErrorStatus( + span, param_type, "Expected array of domains for array parameter", + file_table_); + } + XLS_ASSIGN_OR_RETURN(const std::vector* elements, + value.GetValues()); + if (elements->size() != array_type->size().GetAsInt64().value()) { + return TypeInferenceErrorStatus( + span, param_type, + absl::Substitute("Fuzz test domain array size ($0) does not match " + "parameter array size ($1).", + elements->size(), + array_type->size().GetAsInt64().value()), + file_table_); + } + const Type& element_type = array_type->element_type(); + for (int i = 0; i < elements->size(); ++i) { + XLS_RETURN_IF_ERROR(ValidateFuzzTestDomainValue( + elements->at(i), &element_type, span, element_type.ToString())); + } return absl::OkStatus(); } - if (!param_type->IsTuple()) { - return TypeInferenceErrorStatus(domain->span(), param_type, - "Fuzz test domain implies a tuple type, " - "but parameter is not a tuple.", - file_table_); - } - const TupleType* tuple_type = absl::down_cast(param_type); - if (domain->members().size() != tuple_type->size()) { + + if (!value.IsArray()) { return TypeInferenceErrorStatus( - domain->span(), param_type, - absl::Substitute("Fuzz test domain tuple size ($0) does not match " - "parameter tuple size ($1).", - domain->members().size(), tuple_type->size()), + span, param_type, + absl::Substitute("Expected range or set domain for scalar " + "parameter; got type $0", + TagToString(value.tag())), file_table_); } - for (int i = 0; i < domain->members().size(); ++i) { - const Type& member_type = tuple_type->GetMemberType(i); - XLS_RETURN_IF_ERROR(ValidateFuzzTestDomain( - domain->members()[i], &member_type, member_type.ToString())); + + if (value.is_range()) { + std::optional> range_data = + value.GetRangeData(); + if (range_data.has_value()) { + XLS_RETURN_IF_ERROR(ValidateScalarDomainValue( + (*range_data)->start, param_type, span, param_str)); + XLS_RETURN_IF_ERROR(ValidateScalarDomainValue( + (*range_data)->end, param_type, span, param_str)); + return absl::OkStatus(); + } + XLS_ASSIGN_OR_RETURN(const std::vector* elements, + value.GetValues()); + if (elements->empty()) { + return absl::OkStatus(); + } + return ValidateScalarDomainValue(elements->at(0), param_type, span, + param_str); + } + + XLS_ASSIGN_OR_RETURN(const std::vector* elements, + value.GetValues()); + for (const InterpValue& element : *elements) { + XLS_RETURN_IF_ERROR( + ValidateScalarDomainValue(element, param_type, span, param_str)); } return absl::OkStatus(); } @@ -583,22 +757,36 @@ class TypeValidator : public AstNodeVisitorWithDefault { absl::Status ValidateFuzzTestDomain(const Expr* domain, const Type* param_type, std::string_view param_str) { - if (domain->kind() == AstNodeKind::kStructInstance) { - return ValidateStructDomain( - absl::down_cast(domain), param_type, - param_str); + Module* module = domain->owner(); + XLS_RET_CHECK(module != nullptr); + XLS_ASSIGN_OR_RETURN(const Expr* completed_domain, + CompleteExpr(domain, *module)); + if (completed_domain != domain) { + XLS_RETURN_IF_ERROR( + RegisterTypes(domain, completed_domain, const_cast(ti_))); } - if (domain->kind() == AstNodeKind::kXlsTuple) { - return ValidateTupleDomain(absl::down_cast(domain), - param_type, param_str); + absl::StatusOr value_or = Evaluate(completed_domain); + if (!value_or.ok()) { + return TypeInferenceErrorStatus( + domain->span(), param_type, + absl::StrFormat("Fuzz test domain must be a constexpr expression; " + "evaluation failed: %s", + value_or.status().message()), + file_table_); } + const InterpValue& value = *value_or; - std::optional maybe_domain_type = ti_.GetItem(domain); - XLS_RET_CHECK(maybe_domain_type.has_value()); - const Type* domain_type = *maybe_domain_type; - - return ValidateFuzzTestDomainType(domain_type, param_type, domain->span(), - domain->ToString(), param_str); + absl::Status status = ValidateFuzzTestDomainValue( + value, param_type, domain->span(), param_str); + if (!status.ok()) { + return TypeInferenceErrorStatus( + domain->span(), param_type, + absl::Substitute("Fuzz test domain `$0` is not " + "compatible with parameter `$1`: $2", + domain->ToString(), param_str, status.message()), + file_table_); + } + return absl::OkStatus(); } absl::Status HandleStructDefBaseInternal(const StructDefBase* def) { diff --git a/xls/jit/jit_wrapper_generator.py b/xls/jit/jit_wrapper_generator.py index 65a8e236c4..5bfaf482a0 100644 --- a/xls/jit/jit_wrapper_generator.py +++ b/xls/jit/jit_wrapper_generator.py @@ -643,7 +643,8 @@ def wrapped_to_fuzztest( conversion_snippet = to_value_conversion(p.type_proto, p.name) if ( - p.type_proto.type_enum == type_pb2.TypeProto.TUPLE + len(wrapped.params) == 1 + and p.type_proto.type_enum == type_pb2.TypeProto.TUPLE and domain_snippet is not None ): domain_snippet = f"fuzztest::TupleOf({domain_snippet})" diff --git a/xls/tests/fuzz_test/BUILD b/xls/tests/fuzz_test/BUILD index 860ab80c62..64202d48ea 100644 --- a/xls/tests/fuzz_test/BUILD +++ b/xls/tests/fuzz_test/BUILD @@ -66,6 +66,12 @@ dslx_fuzz_test( test_function = "test_tuple", ) +dslx_fuzz_test( + name = "test_1_tuple_fuzz_test", + library = ":fuzz_test_dslx", + test_function = "test_1_tuple", +) + xls_dslx_library( name = "enum_tests_dslx", srcs = ["enum_tests.x"], @@ -117,3 +123,68 @@ dslx_fuzz_test( library = ":array_tests_dslx", test_function = "tuple_with_big_array", ) + +xls_dslx_library( + name = "imported_module_dslx", + srcs = ["imported_module.x"], +) + +xls_dslx_library( + name = "struct_tests_dslx", + srcs = ["struct_tests.x"], + deps = [":imported_module_dslx"], +) + +dslx_fuzz_test( + name = "arbitrary_struct_fuzz_test", + library = ":struct_tests_dslx", + test_function = "arbitrary_struct", +) + +dslx_fuzz_test( + name = "struct_range_fuzz_test", + library = ":struct_tests_dslx", + test_function = "struct_range", +) + +dslx_fuzz_test( + name = "struct_element_of_fuzz_test", + library = ":struct_tests_dslx", + test_function = "struct_element_of", +) + +dslx_fuzz_test( + name = "struct_arbitrary_field_fuzz_test", + library = ":struct_tests_dslx", + test_function = "struct_arbitrary_field", +) + +dslx_fuzz_test( + name = "flat_struct_domain_fuzz_test", + library = ":struct_tests_dslx", + test_function = "test_flat_struct_domain", +) + +dslx_fuzz_test( + name = "nested_struct_domain_fuzz_test", + library = ":struct_tests_dslx", + test_function = "test_nested_struct_domain", +) + +dslx_fuzz_test( + name = "imported_struct_domain_fn_fuzz_test", + library = ":struct_tests_dslx", + test_function = "test_imported_struct_domain", +) + +dslx_fuzz_test( + name = "nested_big_array_fuzz_test", + library = ":array_tests_dslx", + test_function = "nested_big_array", +) + +dslx_fuzz_test( + name = "inline_nested_struct_domain_fuzz_test", + library = ":struct_tests_dslx", + test_function = "test_inline_nested_struct_domain", +) diff --git a/xls/tests/fuzz_test/array_tests.x b/xls/tests/fuzz_test/array_tests.x index 09ea3b2104..7005dcf6d2 100644 --- a/xls/tests/fuzz_test/array_tests.x +++ b/xls/tests/fuzz_test/array_tests.x @@ -36,3 +36,8 @@ fn big_array(x: uN[128][3]) -> bool { fn tuple_with_big_array(x: (uN[128][2], u32)) -> bool { true } + +#[fuzz_test(domains=`()`)] +fn nested_big_array(x: uN[128][2][3]) -> bool { + true +} diff --git a/xls/tests/fuzz_test/fuzz_tests.x b/xls/tests/fuzz_test/fuzz_tests.x index de7bb50d4e..4b77560726 100644 --- a/xls/tests/fuzz_test/fuzz_tests.x +++ b/xls/tests/fuzz_test/fuzz_tests.x @@ -34,3 +34,8 @@ fn test_element_of(x: u32) -> bool { fn test_tuple(t: (u32, u8)) -> bool { t.0 <= u32:10 && (t.1 == u8:1 || t.1 == u8:2) } + +#[fuzz_test(domains=`(u32:0..10,)`)] +fn test_1_tuple(t: (u32,)) -> bool { + t.0 <= u32:10 +} diff --git a/xls/tests/fuzz_test/imported_module.x b/xls/tests/fuzz_test/imported_module.x new file mode 100644 index 0000000000..ba804eee73 --- /dev/null +++ b/xls/tests/fuzz_test/imported_module.x @@ -0,0 +1,18 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[fuzz_domain("SDomain")] +pub struct S { + x: u32, +} diff --git a/xls/tests/fuzz_test/struct_tests.x b/xls/tests/fuzz_test/struct_tests.x new file mode 100644 index 0000000000..aaef303877 --- /dev/null +++ b/xls/tests/fuzz_test/struct_tests.x @@ -0,0 +1,104 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import xls.tests.fuzz_test.imported_module; + +struct Point { + x: u32, + y: u8, +} + +#[fuzz_test] +fn arbitrary_struct(p: Point) -> bool { + p.x == p.x && p.y == p.y +} + +#[fuzz_test(domains=`Point { x: u32:0..10, y: u8:11..20 }`)] +fn struct_range(p: Point) -> bool { + (p.x as u8) < p.y +} + +#[fuzz_test(domains=`Point { x: [u32:0, 1, 2, 3], y: u8:11..20 }`)] +fn struct_element_of(p: Point) -> bool { + (p.x as u8) < p.y +} + +#[fuzz_test(domains=`Point { x: [u32:0, 1, 2, 3] }`)] +fn struct_arbitrary_field(p: Point) -> bool { + p.x <= u32:3 +} + +#[fuzz_domain("MyStructDomain")] +struct MyStruct { + x: u32, + y: bool, +} + +fn create_flat_domain() -> MyStructDomain { + MyStructDomain { + x: u32:0..10, + y: (), + } +} + +#[fuzz_test(domains=`create_flat_domain()`)] +fn test_flat_struct_domain(s: MyStruct) -> bool { + s.x >= u32:0 && s.x < u32:10 +} + +#[fuzz_domain("InnerDomain")] +struct Inner { + y: u32, +} + +#[fuzz_domain("OuterDomain")] +struct Outer { + x: Inner, +} + +fn create_nested_domain() -> OuterDomain { + OuterDomain { + x: InnerDomain { + y: u32:0..10, + }, + } +} + +#[fuzz_test(domains=`create_nested_domain()`)] +fn test_nested_struct_domain(o: Outer) -> bool { + o.x.y >= u32:0 && o.x.y < u32:10 +} + +fn get_domain() -> imported_module::SDomain { + imported_module::SDomain { + x: u32:0..10, + } +} + +#[fuzz_test(domains=`get_domain()`)] +fn test_imported_struct_domain(s: imported_module::S) -> bool { + s.x >= u32:0 && s.x < u32:10 +} + +struct InnerWithoutDomain { + y: u32, +} + +struct OuterWithInline { + x: InnerWithoutDomain, +} + +#[fuzz_test(domains=`OuterWithInline { x: InnerWithoutDomain { y: u32:0..99999 } }`)] +fn test_inline_nested_struct_domain(o: OuterWithInline) -> bool { + o.x.y < u32:99999 +}