Skip to content

Commit 308ed00

Browse files
committed
Propagate reduction method to subscribers for all reduction methods + test
1 parent 3cce4e5 commit 308ed00

File tree

6 files changed

+41
-5
lines changed

6 files changed

+41
-5
lines changed

include/dtlmod/CompressionReductionMethod.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class CompressionReductionMethod : public ReductionMethod {
8282

8383
/// @brief Copy a publisher variable's parameterization to a subscriber variable.
8484
/// Called by Stream::inquire_variable so that Engine::get() can compute decompression costs.
85-
void propagate_for_subscriber(const Variable& publisher_var, const Variable& subscriber_var);
85+
void propagate_for_subscriber(const Variable& publisher_var, const Variable& subscriber_var) override;
8686
};
8787
/// \endcond
8888
} // namespace dtlmod

include/dtlmod/DecimationReductionMethod.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ class DecimationReductionMethod : public ReductionMethod {
9898

9999
public:
100100
using ReductionMethod::ReductionMethod;
101+
102+
/// @brief Copy a publisher variable's parameterization to a subscriber variable.
103+
/// Called by Stream::inquire_variable so that subscribers can query reduced sizes.
104+
void propagate_for_subscriber(const Variable& publisher_var, const Variable& subscriber_var) override;
101105
};
102106
///\endcond
103107
} // namespace dtlmod

include/dtlmod/ReductionMethod.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ReductionMethod {
3838
get_reduced_start_and_count_for(const Variable& var, simgrid::s4u::ActorPtr publisher) const = 0;
3939
virtual double get_flop_amount_to_reduce_variable(const Variable& var) const = 0;
4040
virtual double get_flop_amount_to_decompress_variable(const Variable& /*var*/) const { return 0.0; }
41+
virtual void propagate_for_subscriber(const Variable& /*publisher_var*/, const Variable& /*subscriber_var*/) {}
4142

4243
/// @brief Helper function to print out the name of the ReductionMethod.
4344
/// @return The corresponding string

src/DecimationReductionMethod.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ void DecimationReductionMethod::parameterize_for_variable(
123123
}
124124
}
125125

126+
void DecimationReductionMethod::propagate_for_subscriber(const Variable& publisher_var, const Variable& subscriber_var)
127+
{
128+
auto it = per_variable_parameterizations_.find(&publisher_var);
129+
if (it == per_variable_parameterizations_.end())
130+
return;
131+
132+
auto pub_param = it->second;
133+
auto sub_param = std::make_shared<ParameterizedDecimation>(subscriber_var, pub_param->get_stride(),
134+
pub_param->get_interpolation_method(),
135+
pub_param->get_cost_per_element());
136+
sub_param->set_reduced_shape(pub_param->get_reduced_shape());
137+
138+
// The subscriber receives the full reduced variable, so its local region is the entire reduced shape.
139+
const auto& reduced_shape = pub_param->get_reduced_shape();
140+
std::vector<size_t> reduced_start(reduced_shape.size(), 0);
141+
sub_param->set_reduced_local_start_and_count(sg4::Actor::self(), reduced_start, reduced_shape);
142+
143+
per_variable_parameterizations_[&subscriber_var] = std::move(sub_param);
144+
}
145+
126146
void DecimationReductionMethod::reduce_variable(const Variable& var)
127147
{
128148
auto parameterization = per_variable_parameterizations_[&var];

src/Stream.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,9 @@ std::shared_ptr<Variable> Stream::inquire_variable(std::string_view name) const
410410
if (var->second->is_reduced()) {
411411
new_var->is_reduced_with_ = var->second->get_reduction_method();
412412
new_var->reduction_origin_ = var->second->reduction_origin_;
413-
// Register the subscriber's variable in the compression method's map so that
414-
// Engine::get() can compute decompression costs via get_flop_amount_to_decompress_variable.
415-
if (auto compressor = std::dynamic_pointer_cast<CompressionReductionMethod>(new_var->is_reduced_with_))
416-
compressor->propagate_for_subscriber(*var->second, *new_var);
413+
// Register the subscriber's variable in the reduction method's map so that
414+
// Engine::get() can compute costs, and subscribers can query reduced sizes.
415+
new_var->is_reduced_with_->propagate_for_subscriber(*var->second, *new_var);
417416
}
418417

419418
return new_var;

test/dtl_reduction.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,21 @@ TEST_F(DTLReductionTest, DecimationStagingEngine)
501501
auto dtl = dtlmod::DTL::connect();
502502
auto stream = dtl->add_stream("my-output");
503503
auto engine = stream->open("my-output", dtlmod::Stream::Mode::Subscribe);
504+
XBT_INFO("Wait for the publisher to have set the decimation reduction operation");
505+
sg4::this_actor::sleep_for(0.5);
504506
auto var = stream->inquire_variable("var");
505507

508+
ASSERT_TRUE(var->is_reduced());
509+
ASSERT_TRUE(var->is_reduced_by_publisher());
510+
511+
XBT_INFO("Verify that the subscriber can access the reduction method set by the publisher");
512+
auto reduction = var->get_reduction_method();
513+
ASSERT_TRUE(reduction != nullptr);
514+
XBT_INFO("Verify that the subscriber can get the reduced local size");
515+
auto reduced_size = reduction->get_reduced_variable_local_size(*var);
516+
ASSERT_DOUBLE_EQ(reduced_size, 5000 * 5000 * 8.0);
506517
XBT_INFO("Get the decimated variable");
518+
507519
engine->begin_transaction();
508520
ASSERT_NO_THROW(engine->get(var));
509521
engine->end_transaction();

0 commit comments

Comments
 (0)