Skip to content

Commit 34a1e7a

Browse files
committed
reduce method complexity
1 parent a570653 commit 34a1e7a

4 files changed

Lines changed: 150 additions & 159 deletions

File tree

include/dtlmod/CompressionReductionMethod.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ class CompressionReductionMethod : public ReductionMethod {
6363
/// @brief Derive the compression ratio from accuracy and compressor profile.
6464
static double derive_compression_ratio(double accuracy, std::string_view profile, double data_smoothness);
6565

66+
/// @brief Validate compressor profile string. Throws if invalid.
67+
static void validate_compressor_profile(std::string_view profile);
68+
69+
/// @brief Validate and resolve the compression ratio from parsed parameters.
70+
static double resolve_compression_ratio(double ratio, bool ratio_explicitly_set, bool is_new,
71+
std::string_view profile, double accuracy, double data_smoothness);
72+
6673
public:
6774
using ReductionMethod::ReductionMethod;
6875
void parameterize_for_variable(const Variable& var,

include/dtlmod/DecimationReductionMethod.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ class DecimationReductionMethod : public ReductionMethod {
5858

5959
std::map<const Variable*, std::shared_ptr<ParameterizedDecimation>> per_variable_parameterizations_;
6060

61+
/// @brief Parse and validate a comma-separated stride string against a variable's shape.
62+
static std::vector<size_t> parse_stride(std::string_view value, const Variable& var);
63+
64+
/// @brief Validate that an interpolation method is compatible with the variable's dimensionality.
65+
static void validate_interpolation(std::string_view method, const Variable& var);
66+
6167
protected:
6268
void parameterize_for_variable(const Variable& var,
6369
const std::map<std::string, std::string, std::less<>>& parameters) override;

src/CompressionReductionMethod.cpp

Lines changed: 72 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ double CompressionReductionMethod::derive_compression_ratio(double accuracy, std
5555
{
5656
if (profile == "sz") {
5757
// SZ-like prediction-based compressor: empirical fit from published benchmarks on scientific data.
58-
// Higher smoothness better prediction higher ratio.
58+
// Higher smoothness -> better prediction -> higher ratio.
5959
double alpha = 3.0;
6060
double beta = 0.8;
6161
return std::max(1.0, alpha * std::pow(-std::log10(accuracy), beta) * (0.5 + data_smoothness));
62-
} else if (profile == "zfp") {
62+
}
63+
if (profile == "zfp") {
6364
// ZFP-like transform-based compressor: rate = bits-per-value derived from accuracy.
6465
// 64 bits (double) / rate gives the compression ratio.
6566
double rate = std::max(1.0, -std::log2(accuracy) + 1.0);
@@ -69,104 +70,91 @@ double CompressionReductionMethod::derive_compression_ratio(double accuracy, std
6970
return 1.0;
7071
}
7172

73+
void CompressionReductionMethod::validate_compressor_profile(std::string_view profile)
74+
{
75+
if (profile != "fixed" && profile != "sz" && profile != "zfp")
76+
throw UnknownCompressionOptionException(XBT_THROW_POINT, "Unknown compressor profile: " + std::string(profile) +
77+
" (options are: fixed, sz, or zfp).");
78+
}
79+
80+
double CompressionReductionMethod::resolve_compression_ratio(double ratio, bool ratio_explicitly_set, bool is_new,
81+
std::string_view profile, double accuracy,
82+
double data_smoothness)
83+
{
84+
if (ratio_explicitly_set) {
85+
if (ratio < 1.0)
86+
throw InconsistentCompressionRatioException(XBT_THROW_POINT, "Compression ratio must be >= 1.0");
87+
return ratio;
88+
}
89+
if (is_new) {
90+
if (profile == "fixed")
91+
throw InconsistentCompressionRatioException(
92+
XBT_THROW_POINT, "Compressor profile 'fixed' requires an explicit 'compression_ratio' parameter.");
93+
return derive_compression_ratio(accuracy, profile, data_smoothness);
94+
}
95+
return ratio; // Keep existing ratio for partial updates without explicit ratio
96+
}
97+
7298
void CompressionReductionMethod::parameterize_for_variable(
7399
const Variable& var, const std::map<std::string, std::string, std::less<>>& parameters)
74100
{
75-
double new_accuracy = 1e-3;
76-
double new_compression_cost_per_element = 1.0;
77-
double new_decompression_cost_per_element = 1.0;
78-
double new_compression_ratio = 0.0; // 0 means "not specified, must be derived"
79-
std::string new_compressor_profile = "fixed";
80-
double new_data_smoothness = 0.5;
81-
double new_ratio_variability = 0.0;
82-
83-
// Detect existing parameterization (if any).
84-
auto it = per_variable_parameterizations_.find(&var);
85-
const bool exists = (it != per_variable_parameterizations_.end());
86-
87-
// Initialize from existing values (if present) to support partial updates.
88-
if (exists) {
101+
// Start from existing values (if any) to support partial updates.
102+
auto it = per_variable_parameterizations_.find(&var);
103+
bool is_new = (it == per_variable_parameterizations_.end());
104+
105+
double accuracy = 1e-3;
106+
double compression_cost_per_element = 1.0;
107+
double decompression_cost_per_element = 1.0;
108+
double compression_ratio = 0.0;
109+
std::string compressor_profile = "fixed";
110+
double data_smoothness = 0.5;
111+
double ratio_variability = 0.0;
112+
113+
if (!is_new) {
89114
const auto& existing = it->second;
90-
new_accuracy = existing->get_accuracy();
91-
new_compression_cost_per_element = existing->get_compression_cost_per_element();
92-
new_decompression_cost_per_element = existing->get_decompression_cost_per_element();
93-
new_compression_ratio = existing->get_compression_ratio();
94-
new_compressor_profile = existing->get_compressor_profile();
95-
new_data_smoothness = existing->get_data_smoothness();
96-
new_ratio_variability = existing->get_ratio_variability();
115+
accuracy = existing->get_accuracy();
116+
compression_cost_per_element = existing->get_compression_cost_per_element();
117+
decompression_cost_per_element = existing->get_decompression_cost_per_element();
118+
compression_ratio = existing->get_compression_ratio();
119+
compressor_profile = existing->get_compressor_profile();
120+
data_smoothness = existing->get_data_smoothness();
121+
ratio_variability = existing->get_ratio_variability();
97122
}
98123

99124
bool ratio_explicitly_set = false;
100125

101126
for (const auto& [key, value] : parameters) {
102-
if (key == "accuracy") {
103-
new_accuracy = std::stod(value);
104-
} else if (key == "compression_cost_per_element") {
105-
new_compression_cost_per_element = std::stod(value);
106-
} else if (key == "decompression_cost_per_element") {
107-
new_decompression_cost_per_element = std::stod(value);
108-
} else if (key == "compression_ratio") {
109-
new_compression_ratio = std::stod(value);
110-
ratio_explicitly_set = true;
127+
if (key == "accuracy")
128+
accuracy = std::stod(value);
129+
else if (key == "compression_cost_per_element")
130+
compression_cost_per_element = std::stod(value);
131+
else if (key == "decompression_cost_per_element")
132+
decompression_cost_per_element = std::stod(value);
133+
else if (key == "compression_ratio") {
134+
compression_ratio = std::stod(value);
135+
ratio_explicitly_set = true;
111136
} else if (key == "compressor") {
112-
if (value == "fixed" || value == "sz" || value == "zfp")
113-
new_compressor_profile = value;
114-
else
115-
throw UnknownCompressionOptionException(XBT_THROW_POINT, "Unknown compressor profile: " + value +
116-
" (options are: fixed, sz, or zfp).");
117-
} else if (key == "data_smoothness") {
118-
new_data_smoothness = std::stod(value);
119-
} else if (key == "ratio_variability") {
120-
new_ratio_variability = std::stod(value);
121-
} else {
137+
validate_compressor_profile(value);
138+
compressor_profile = value;
139+
} else if (key == "data_smoothness")
140+
data_smoothness = std::stod(value);
141+
else if (key == "ratio_variability")
142+
ratio_variability = std::stod(value);
143+
else
122144
throw UnknownCompressionOptionException(XBT_THROW_POINT, key);
123-
}
124145
}
125146

126-
// Derive compression ratio if not explicitly specified
127-
if (!ratio_explicitly_set && !exists) {
128-
if (new_compressor_profile == "fixed")
129-
throw InconsistentCompressionRatioException(
130-
XBT_THROW_POINT, "Compressor profile 'fixed' requires an explicit 'compression_ratio' parameter.");
131-
new_compression_ratio = derive_compression_ratio(new_accuracy, new_compressor_profile, new_data_smoothness);
132-
} else if (ratio_explicitly_set && new_compression_ratio < 1.0) {
133-
throw InconsistentCompressionRatioException(XBT_THROW_POINT, "Compression ratio must be >= 1.0");
134-
}
147+
compression_ratio = resolve_compression_ratio(compression_ratio, ratio_explicitly_set, is_new, compressor_profile,
148+
accuracy, data_smoothness);
135149

136150
XBT_DEBUG("Compression parameterization for Variable %s: profile=%s, accuracy=%.2e, ratio=%.2f, "
137151
"compression_cost=%.2f, decompression_cost=%.2f, smoothness=%.2f, variability=%.2f",
138-
var.get_cname(), new_compressor_profile.c_str(), new_accuracy, new_compression_ratio,
139-
new_compression_cost_per_element, new_decompression_cost_per_element, new_data_smoothness,
140-
new_ratio_variability);
141-
142-
if (!exists) {
143-
per_variable_parameterizations_.try_emplace(
144-
&var, std::make_shared<ParameterizedCompression>(
145-
var, new_accuracy, new_compression_cost_per_element, new_decompression_cost_per_element,
146-
new_compression_ratio, new_compressor_profile, new_data_smoothness, new_ratio_variability));
147-
return;
148-
}
152+
var.get_cname(), compressor_profile.c_str(), accuracy, compression_ratio, compression_cost_per_element,
153+
decompression_cost_per_element, data_smoothness, ratio_variability);
149154

150-
// If already exists, update only if changed.
151-
const auto& existing = it->second;
152-
153-
if (existing->get_accuracy() != new_accuracy)
154-
existing->set_accuracy(new_accuracy);
155-
if (existing->get_compression_cost_per_element() != new_compression_cost_per_element)
156-
existing->set_compression_cost_per_element(new_compression_cost_per_element);
157-
if (existing->get_decompression_cost_per_element() != new_decompression_cost_per_element)
158-
existing->set_decompression_cost_per_element(new_decompression_cost_per_element);
159-
if (ratio_explicitly_set || new_compressor_profile != existing->get_compressor_profile()) {
160-
double updated_ratio = ratio_explicitly_set
161-
? new_compression_ratio
162-
: derive_compression_ratio(new_accuracy, new_compressor_profile, new_data_smoothness);
163-
existing->set_compression_ratio(updated_ratio);
164-
}
165-
if (existing->get_compressor_profile() != new_compressor_profile)
166-
existing->set_compressor_profile(new_compressor_profile);
167-
if (existing->get_data_smoothness() != new_data_smoothness)
168-
existing->set_data_smoothness(new_data_smoothness);
169-
if (existing->get_ratio_variability() != new_ratio_variability)
170-
existing->set_ratio_variability(new_ratio_variability);
155+
// Always (re)create the parameterization — avoids field-by-field update complexity.
156+
per_variable_parameterizations_[&var] = std::make_shared<ParameterizedCompression>(
157+
var, accuracy, compression_cost_per_element, decompression_cost_per_element, compression_ratio,
158+
compressor_profile, data_smoothness, ratio_variability);
171159
}
172160
} // namespace dtlmod

src/DecimationReductionMethod.cpp

Lines changed: 65 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -36,95 +36,85 @@ double DecimationReductionMethod::ParameterizedDecimation::get_flop_amount_to_de
3636
{
3737
XBT_DEBUG("Compute decimation cost with: cost_per_element = %.2f and interpolation_method = %s", cost_per_element_,
3838
interpolation_method_.c_str());
39-
double amount = cost_per_element_;
39+
double amount = cost_per_element_;
4040
auto local_size = static_cast<double>(var_->get_local_size());
41-
if (interpolation_method_.empty()) {
42-
amount *= local_size;
43-
} else if (interpolation_method_ == "linear") {
44-
amount = 2 * amount * local_size;
45-
} else if (interpolation_method_ == "quadratic") {
46-
amount = 4 * amount * local_size;
47-
} else if (interpolation_method_ == "cubic") {
48-
amount = 8 * amount * local_size;
49-
} // Sanity check done when parameterizing the reduction method for this variable
50-
return amount;
41+
int multiplier = 1;
42+
43+
if (interpolation_method_ == "linear")
44+
multiplier = 2;
45+
else if (interpolation_method_ == "quadratic")
46+
multiplier = 4;
47+
else if (interpolation_method_ == "cubic")
48+
multiplier = 8;
49+
50+
return multiplier * amount * local_size;
51+
}
52+
53+
std::vector<size_t> DecimationReductionMethod::parse_stride(std::string_view value, const Variable& var)
54+
{
55+
std::vector<std::string> tokens;
56+
std::string value_str(value);
57+
boost::split(tokens, value_str, boost::is_any_of(","), boost::token_compress_on);
58+
59+
if (var.get_shape().size() != tokens.size())
60+
throw InconsistentDecimationStrideException(
61+
XBT_THROW_POINT, "Decimation Stride and Variable Shape vectors must have the same size. Stride: " +
62+
std::to_string(tokens.size()) + ", Shape: " + std::to_string(var.get_shape().size()));
63+
64+
std::vector<size_t> stride;
65+
stride.reserve(tokens.size());
66+
for (const auto& t : tokens) {
67+
auto dim_stride = std::stoul(t);
68+
if (t[0] == '-' || dim_stride == 0)
69+
throw InconsistentDecimationStrideException(XBT_THROW_POINT, "Stride values must be strictly positive");
70+
stride.push_back(dim_stride);
71+
}
72+
return stride;
73+
}
74+
75+
void DecimationReductionMethod::validate_interpolation(std::string_view method, const Variable& var)
76+
{
77+
if (method != "linear" && method != "quadratic" && method != "cubic")
78+
throw UnknownDecimationInterpolationException(XBT_THROW_POINT, std::string("Unknown interpolation method: ") +
79+
std::string(method) +
80+
" (options are: linear, cubic, or quadratic).");
81+
82+
if ((method == "quadratic" && var.get_shape().size() < 2) || (method == "cubic" && var.get_shape().size() < 3))
83+
throw InconsistentDecimationInterpolationException(
84+
XBT_THROW_POINT, "Variable has not enough dimensions to apply this interpolation method");
5185
}
5286

5387
void DecimationReductionMethod::parameterize_for_variable(
5488
const Variable& var, const std::map<std::string, std::string, std::less<>>& parameters)
5589
{
56-
std::vector<size_t> new_stride;
57-
std::string new_interpolation_method;
58-
double new_cost_per_element = 1.0;
59-
60-
// Detect existing parameterization (if any).
61-
auto it = per_variable_parameterizations_.find(&var);
62-
const bool exists = (it != per_variable_parameterizations_.end());
63-
64-
// Initialize from existing values (if present) to support partial updates.
65-
if (exists) {
66-
const auto& existing = it->second;
67-
// Replace these getters with your actual API:
68-
new_stride = existing->get_stride();
69-
new_interpolation_method = existing->get_interpolation_method();
70-
new_cost_per_element = existing->get_cost_per_element();
90+
// Start from existing values (if any) to support partial updates.
91+
auto it = per_variable_parameterizations_.find(&var);
92+
93+
std::vector<size_t> stride;
94+
std::string interpolation_method;
95+
double cost_per_element = 1.0;
96+
97+
if (it != per_variable_parameterizations_.end()) {
98+
stride = it->second->get_stride();
99+
interpolation_method = it->second->get_interpolation_method();
100+
cost_per_element = it->second->get_cost_per_element();
71101
}
72102

73103
for (const auto& [key, value] : parameters) {
74-
if (key == "stride") {
75-
std::vector<std::string> tokens;
76-
boost::split(tokens, value, boost::is_any_of(","), boost::token_compress_on);
77-
78-
if (var.get_shape().size() != tokens.size())
79-
throw InconsistentDecimationStrideException(
80-
XBT_THROW_POINT, "Decimation Stride and Variable Shape vectors must have the same size. Stride: " +
81-
std::to_string(tokens.size()) + ", Shape: " + std::to_string(var.get_shape().size()));
82-
83-
std::vector<size_t> parsed_stride;
84-
parsed_stride.reserve(tokens.size());
85-
for (const auto& t : tokens) {
86-
auto dim_stride = std::stoul(t);
87-
if (t[0] == '-' || dim_stride == 0)
88-
throw InconsistentDecimationStrideException(XBT_THROW_POINT, "Stride values must be strictly positive");
89-
parsed_stride.push_back(dim_stride);
90-
}
91-
new_stride = std::move(parsed_stride);
92-
93-
} else if (key == "interpolation") {
94-
if (value == "linear" || value == "quadratic" || value == "cubic")
95-
new_interpolation_method = value;
96-
else
97-
throw UnknownDecimationInterpolationException(XBT_THROW_POINT,
98-
std::string("Unknown interpolation method: ") + value +
99-
" (options are: linear, cubic, or quadratic).");
100-
101-
if ((value == "quadratic" && var.get_shape().size() < 2) || (value == "cubic" && var.get_shape().size() < 3))
102-
throw InconsistentDecimationInterpolationException(
103-
XBT_THROW_POINT, "Variable has not enough dimensions to apply this interpolation method");
104+
if (key == "stride")
105+
stride = parse_stride(value, var);
106+
else if (key == "interpolation") {
107+
validate_interpolation(value, var);
108+
interpolation_method = value;
104109
} else if (key == "cost_per_element")
105-
new_cost_per_element = std::stod(value);
110+
cost_per_element = std::stod(value);
106111
else
107112
throw UnknownDecimationOptionException(XBT_THROW_POINT, key);
108113
}
109114

110-
if (!exists) {
111-
// First-time parameterization
112-
per_variable_parameterizations_.try_emplace(
113-
&var,
114-
std::make_shared<ParameterizedDecimation>(var, new_stride, new_interpolation_method, new_cost_per_element));
115-
return;
116-
}
117-
118-
// If already exists, update only if changed.
119-
const auto& existing = it->second;
120-
121-
// Compare with existing to avoid unnecessary churn
122-
if (existing->get_stride() != new_stride)
123-
existing->set_stride(new_stride);
124-
if (existing->get_interpolation_method() != new_interpolation_method)
125-
existing->set_interpolation_method(new_interpolation_method);
126-
if (existing->get_cost_per_element() != new_cost_per_element)
127-
existing->set_cost_per_element(new_cost_per_element);
115+
// Always (re)create the parameterization — avoids field-by-field update complexity.
116+
per_variable_parameterizations_[&var] =
117+
std::make_shared<ParameterizedDecimation>(var, stride, interpolation_method, cost_per_element);
128118
}
129119

130120
void DecimationReductionMethod::reduce_variable(const Variable& var)

0 commit comments

Comments
 (0)