Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added TrkDiag/data/TrkQual_BDT1_v2.0.ubj
Binary file not shown.
3 changes: 2 additions & 1 deletion TrkDiag/src/SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ helper.make_plugins( [
'TMVA',
'ROOTTMVASofie',
'pthread',
'onnxruntime'
'onnxruntime',
'xgboost'
] )

helper.make_dict_and_map( [
Expand Down
77 changes: 66 additions & 11 deletions TrkDiag/src/TrackQuality_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "Offline/RecoDataProducts/inc/MVAResult.hh"
// ONNXRuntime
#include "onnxruntime/onnxruntime_cxx_api.h"
// XGBoost
#include <xgboost/c_api.h>
// C++
#include <iostream>
#include <fstream>
Expand All @@ -47,6 +49,7 @@ namespace mu2e
fhicl::Atom<art::InputTag> kalSeedPtrTag{Name("KalSeedPtrCollection"), Comment("Input tag for KalSeedPtrCollection")};
fhicl::Atom<bool> printMVA{Name("PrintMVA"), Comment("Print the MVA used"), false};
fhicl::Atom<std::string> onnxFilename{Name("onnxFilename"), Comment("Filename for the .onnx file to use")};
fhicl::Atom<std::string> xgbFileName{Name("xgbFilename"), Comment("Path to XGBoost .ubj model file")};
fhicl::Atom<int> debug{Name("debugLevel"), Comment("Debug printout level"), 0};
};

Expand All @@ -63,6 +66,7 @@ namespace mu2e

ConfigFileLookupPolicy _configFileLookup;

// for the ANN
Ort::Env _env;
Ort::SessionOptions _session_options;
Ort::Session _session;
Expand All @@ -74,6 +78,11 @@ namespace mu2e
size_t _total_size;
Ort::MemoryInfo _memory_info;
Ort::AllocatedStringPtr _output_name;

// for the BDT
BoosterHandle _booster; // Booster object
// Number of features is fixed, must match training!
static constexpr size_t _nFeatures = 7;
};

TrackQuality::TrackQuality(const Parameters& conf) :
Expand All @@ -91,7 +100,8 @@ namespace mu2e
_memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)),
_output_name(_session.GetOutputNameAllocated(0, _allocator))
{
produces<MVAResultCollection>();
produces<MVAResultCollection>("ANN");
produces<MVAResultCollection>("BDT");

// Handle dynamic dimensions if needed
for (auto& dim : _input_shape) {
Expand All @@ -103,11 +113,29 @@ namespace mu2e
for (auto dim : _input_shape) {
_total_size *= dim;
}

// Load XGBoost model
if (XGBoosterCreate(nullptr, 0, &_booster) != 0) {
throw std::runtime_error(std::string("XGBoosterCreate failed: ") + XGBGetLastError());
}

std::string modelPath = ConfigFileLookupPolicy()(conf().xgbFileName());
if (XGBoosterLoadModel(_booster, modelPath.c_str()) != 0) {
throw std::runtime_error(std::string("XGBoosterLoadModel failed: ") + XGBGetLastError());
}

// verify the loaded model matches the expected feature count
bst_ulong nFeaturesModel = 0;
if (XGBoosterGetNumFeature(_booster, &nFeaturesModel) != 0) {
throw std::runtime_error(std::string("XGBoosterGetNumFeature failed: ") + XGBGetLastError());
}

}

void TrackQuality::produce(art::Event& event ) {
// create output
unique_ptr<MVAResultCollection> mvacol(new MVAResultCollection());
unique_ptr<MVAResultCollection> anncol(new MVAResultCollection());
unique_ptr<MVAResultCollection> bdtcol(new MVAResultCollection());

// get the KalSeedPtrs
art::Handle<KalSeedPtrCollection> kalSeedPtrHandle;
Expand Down Expand Up @@ -186,7 +214,7 @@ namespace mu2e
_input_shape.data(),
_input_shape.size()
);
// Run inference
// Run ANN inference
const char* input_names[] = {_input_name.get()};
const char* output_names[] = {_output_name.get()};
auto output_tensors = _session.Run(
Expand All @@ -198,28 +226,55 @@ namespace mu2e
1
);
// Get output
float* mvaout = output_tensors[0].GetTensorMutableData<float>();
float* annout = output_tensors[0].GetTensorMutableData<float>();

if (!entrance_found) {
mvaout[0] = 0; // this is not a good track
annout[0] = 0; // this is not a good track
}

// run XGBoost inference
// See reference in preamble comments...
float bdt_score = 0.0f;
DMatrixHandle dmat;
if (XGDMatrixCreateFromMat(features.data(), 1, _nFeatures, NAN, &dmat) != 0) { // use same features vector
throw std::runtime_error(std::string("XGDMatrixCreateFromMat failed: ") + XGBGetLastError());
}

bst_ulong out_len = 0;
const float* out_result = nullptr;
if (XGBoosterPredict(_booster, dmat, 0, 0, 0, &out_len, &out_result) != 0) {
XGDMatrixFree(dmat);
throw std::runtime_error(std::string("XGBoosterPredict failed: ") + XGBGetLastError());
}
if (out_len < 1 || out_result == nullptr) {
XGDMatrixFree(dmat);
throw std::runtime_error("XGBoosterPredict returned no result");
}

bdt_score = out_result[0];
if (XGDMatrixFree(dmat) != 0) {
throw std::runtime_error(std::string("XGDMatrixFree failed: ") + XGBGetLastError());
}


if(_debug > 0) {
printf("[TrackQuality::%s::%s] Inputs = %.0f, %.4f, %.4f, %.4f, %.4f, %.4f %.4f --> output = %.4f\n",
printf("[TrackQuality::%s::%s] Inputs = %.0f, %.4f, %.4f, %.4f, %.4f, %.4f %.4f --> ANN output = %.4fm BDT output = %.4fm\n",
__func__, moduleDescription().moduleLabel().c_str(),
features[0], features[1], features[2], features[3], features[4], features[5], features[6], mvaout[0]);
features[0], features[1], features[2], features[3], features[4], features[5], features[6], annout[0], bdt_score);
}

mvacol->push_back(MVAResult(mvaout[0]));
anncol->push_back(MVAResult(annout[0]));
bdtcol->push_back(MVAResult(bdt_score));
}

if ( (mvacol->size() != kalSeedPtrs.size()) ) {
throw cet::exception("TrackQuality") << "KalSeedPtr and MVAResult sizes are inconsistent (" << kalSeedPtrs.size() << ", " << mvacol->size();
if ( (anncol->size() != kalSeedPtrs.size()) ) {
throw cet::exception("TrackQuality") << "KalSeedPtr and MVAResult sizes are inconsistent (" << kalSeedPtrs.size() << ", " << anncol->size();
}


// put the output products into the event
event.put(move(mvacol));
event.put(move(anncol), "ANN");
event.put(move(bdtcol), "BDT");
}
}// mu2e

Expand Down