diff --git a/TrkDiag/data/TrkQual_BDT1_v2.0.ubj b/TrkDiag/data/TrkQual_BDT1_v2.0.ubj new file mode 100644 index 0000000..699871a Binary files /dev/null and b/TrkDiag/data/TrkQual_BDT1_v2.0.ubj differ diff --git a/TrkDiag/src/SConscript b/TrkDiag/src/SConscript index efb348f..7d02d04 100644 --- a/TrkDiag/src/SConscript +++ b/TrkDiag/src/SConscript @@ -103,7 +103,8 @@ helper.make_plugins( [ 'TMVA', 'ROOTTMVASofie', 'pthread', - 'onnxruntime' + 'onnxruntime', + 'xgboost' ] ) helper.make_dict_and_map( [ diff --git a/TrkDiag/src/TrackQuality_module.cc b/TrkDiag/src/TrackQuality_module.cc index b1cf2eb..6fafdce 100644 --- a/TrkDiag/src/TrackQuality_module.cc +++ b/TrkDiag/src/TrackQuality_module.cc @@ -23,6 +23,8 @@ #include "Offline/RecoDataProducts/inc/MVAResult.hh" // ONNXRuntime #include "onnxruntime/onnxruntime_cxx_api.h" +// XGBoost +#include // C++ #include #include @@ -47,6 +49,7 @@ namespace mu2e fhicl::Atom kalSeedPtrTag{Name("KalSeedPtrCollection"), Comment("Input tag for KalSeedPtrCollection")}; fhicl::Atom printMVA{Name("PrintMVA"), Comment("Print the MVA used"), false}; fhicl::Atom onnxFilename{Name("onnxFilename"), Comment("Filename for the .onnx file to use")}; + fhicl::Atom xgbFileName{Name("xgbFilename"), Comment("Path to XGBoost .ubj model file")}; fhicl::Atom debug{Name("debugLevel"), Comment("Debug printout level"), 0}; }; @@ -63,6 +66,7 @@ namespace mu2e ConfigFileLookupPolicy _configFileLookup; + // for the ANN Ort::Env _env; Ort::SessionOptions _session_options; Ort::Session _session; @@ -74,6 +78,11 @@ namespace mu2e size_t _total_size; Ort::MemoryInfo _memory_info; Ort::AllocatedStringPtr _output_name; + + // for the BDT + BoosterHandle _booster; // Booster object + // Number of features is fixed, must match training! + static constexpr size_t _nFeatures = 7; }; TrackQuality::TrackQuality(const Parameters& conf) : @@ -91,7 +100,8 @@ namespace mu2e _memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)), _output_name(_session.GetOutputNameAllocated(0, _allocator)) { - produces(); + produces("ANN"); + produces("BDT"); // Handle dynamic dimensions if needed for (auto& dim : _input_shape) { @@ -103,11 +113,29 @@ namespace mu2e for (auto dim : _input_shape) { _total_size *= dim; } + + // Load XGBoost model + if (XGBoosterCreate(nullptr, 0, &_booster) != 0) { + throw std::runtime_error(std::string("XGBoosterCreate failed: ") + XGBGetLastError()); + } + + std::string modelPath = ConfigFileLookupPolicy()(conf().xgbFileName()); + if (XGBoosterLoadModel(_booster, modelPath.c_str()) != 0) { + throw std::runtime_error(std::string("XGBoosterLoadModel failed: ") + XGBGetLastError()); + } + + // verify the loaded model matches the expected feature count + bst_ulong nFeaturesModel = 0; + if (XGBoosterGetNumFeature(_booster, &nFeaturesModel) != 0) { + throw std::runtime_error(std::string("XGBoosterGetNumFeature failed: ") + XGBGetLastError()); + } + } void TrackQuality::produce(art::Event& event ) { // create output - unique_ptr mvacol(new MVAResultCollection()); + unique_ptr anncol(new MVAResultCollection()); + unique_ptr bdtcol(new MVAResultCollection()); // get the KalSeedPtrs art::Handle kalSeedPtrHandle; @@ -186,7 +214,7 @@ namespace mu2e _input_shape.data(), _input_shape.size() ); - // Run inference + // Run ANN inference const char* input_names[] = {_input_name.get()}; const char* output_names[] = {_output_name.get()}; auto output_tensors = _session.Run( @@ -198,28 +226,55 @@ namespace mu2e 1 ); // Get output - float* mvaout = output_tensors[0].GetTensorMutableData(); + float* annout = output_tensors[0].GetTensorMutableData(); if (!entrance_found) { - mvaout[0] = 0; // this is not a good track + annout[0] = 0; // this is not a good track } + // run XGBoost inference + // See reference in preamble comments... + float bdt_score = 0.0f; + DMatrixHandle dmat; + if (XGDMatrixCreateFromMat(features.data(), 1, _nFeatures, NAN, &dmat) != 0) { // use same features vector + throw std::runtime_error(std::string("XGDMatrixCreateFromMat failed: ") + XGBGetLastError()); + } + + bst_ulong out_len = 0; + const float* out_result = nullptr; + if (XGBoosterPredict(_booster, dmat, 0, 0, 0, &out_len, &out_result) != 0) { + XGDMatrixFree(dmat); + throw std::runtime_error(std::string("XGBoosterPredict failed: ") + XGBGetLastError()); + } + if (out_len < 1 || out_result == nullptr) { + XGDMatrixFree(dmat); + throw std::runtime_error("XGBoosterPredict returned no result"); + } + + bdt_score = out_result[0]; + if (XGDMatrixFree(dmat) != 0) { + throw std::runtime_error(std::string("XGDMatrixFree failed: ") + XGBGetLastError()); + } + + if(_debug > 0) { - printf("[TrackQuality::%s::%s] Inputs = %.0f, %.4f, %.4f, %.4f, %.4f, %.4f %.4f --> output = %.4f\n", + printf("[TrackQuality::%s::%s] Inputs = %.0f, %.4f, %.4f, %.4f, %.4f, %.4f %.4f --> ANN output = %.4fm BDT output = %.4fm\n", __func__, moduleDescription().moduleLabel().c_str(), - features[0], features[1], features[2], features[3], features[4], features[5], features[6], mvaout[0]); + features[0], features[1], features[2], features[3], features[4], features[5], features[6], annout[0], bdt_score); } - mvacol->push_back(MVAResult(mvaout[0])); + anncol->push_back(MVAResult(annout[0])); + bdtcol->push_back(MVAResult(bdt_score)); } - if ( (mvacol->size() != kalSeedPtrs.size()) ) { - throw cet::exception("TrackQuality") << "KalSeedPtr and MVAResult sizes are inconsistent (" << kalSeedPtrs.size() << ", " << mvacol->size(); + if ( (anncol->size() != kalSeedPtrs.size()) ) { + throw cet::exception("TrackQuality") << "KalSeedPtr and MVAResult sizes are inconsistent (" << kalSeedPtrs.size() << ", " << anncol->size(); } // put the output products into the event - event.put(move(mvacol)); + event.put(move(anncol), "ANN"); + event.put(move(bdtcol), "BDT"); } }// mu2e