Skip to content
Closed
4 changes: 2 additions & 2 deletions Tools/PIDML/KaonPidTask.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct KaonPidTask {
SliceCache cache;
Preslice<aod::Tracks> perCol = aod::track::collisionId;

std::shared_ptr<PidONNXModel> pidModel; // creates a shared pointer to a new instance 'pidmodel'.
std::shared_ptr<PidONNXModel<o2::aod::MyTracks>> pidModel; // creates a shared pointer to a new instance 'pidmodel'.
HistogramRegistry histos{"Histos", {}, OutputObjHandlingPolicy::AnalysisObject};

Configurable<float> cfgZvtxCut{"cfgZvtxCut", 10, "Z vtx cut"};
Expand Down Expand Up @@ -84,7 +84,7 @@ struct KaonPidTask {
if (cfgUseCCDB) {
ccdbApi.init(cfgCCDBURL); // Initializes ccdbApi when cfgUseCCDB is set to 'true'
}
pidModel = std::make_shared<PidONNXModel>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, cfgTimestamp.value, cfgPid.value, cfgCertainty.value);
pidModel = std::make_shared<PidONNXModel<o2::aod::MyTracks>>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, cfgTimestamp.value, cfgPid.value, cfgCertainty.value);

histos.add("hChargePos", ";z;", kTH1F, {{3, -1.5, 1.5}});
histos.add("hChargeNeg", ";z;", kTH1F, {{3, -1.5, 1.5}});
Expand Down
2 changes: 2 additions & 0 deletions Tools/PIDML/pidML.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#include "Framework/AnalysisDataModel.h"
#include "Common/DataModel/PIDResponse.h"
#include "Common/DataModel/Centrality.h"
#include "Common/DataModel/Multiplicity.h"

namespace o2::aod
{
Expand Down
6 changes: 3 additions & 3 deletions Tools/PIDML/pidMLBatchEffAndPurProducer.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ struct PidMlBatchEffAndPurProducer {
std::array<std::shared_ptr<TH1>, kNPids> hMCPositive;

o2::ccdb::CcdbApi ccdbApi;
std::vector<PidONNXModel> models;

Configurable<std::vector<int32_t>> cfgPids{"pids", std::vector<int32_t>(kPids, kPids + kNPids), "PIDs to predict"};
Configurable<std::array<double, kNDetectors>> cfgDetectorsPLimits{"detectors-p-limits", std::array<double, kNDetectors>(pidml_pt_cuts::defaultModelPLimits), "\"use {detector} when p >= y_{detector}\": array of 3 doubles [y_TPC, y_TOF, y_TRD]"};
Expand All @@ -82,6 +81,7 @@ struct PidMlBatchEffAndPurProducer {
using BigTracks = soa::Filtered<soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTOFbeta, aod::TrackSelection, aod::TOFSignal, aod::McTrackLabels,
aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullEl, aod::pidTPCFullMu,
aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullEl, aod::pidTOFFullMu>>;
std::vector<PidONNXModel<BigTracks>> models;

void initHistos()
{
Expand Down Expand Up @@ -203,11 +203,11 @@ struct PidMlBatchEffAndPurProducer {
if (cfgUseCCDB && bc.runNumber() != currentRunNumber) {
uint64_t timestamp = cfgUseFixedTimestamp ? cfgTimestamp.value : bc.timestamp();
for (const int32_t& pid : cfgPids.value)
models.emplace_back(PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value,
models.emplace_back(PidONNXModel<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value,
ccdbApi, timestamp, pid, 1.1, &cfgDetectorsPLimits.value[0]));
} else {
for (int32_t& pid : cfgPids.value)
models.emplace_back(PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value,
models.emplace_back(PidONNXModel<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value,
ccdbApi, -1, pid, 1.1, &cfgDetectorsPLimits.value[0]));
}

Expand Down
7 changes: 3 additions & 4 deletions Tools/PIDML/pidMLEffAndPurProducer.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "Common/DataModel/TrackSelectionTables.h"
#include "Common/DataModel/PIDResponse.h"
#include "Tools/PIDML/pidOnnxModel.h"
#include "pidOnnxModel.h"
#include "Tools/PIDML/pidUtils.h"

using namespace o2;
Expand All @@ -35,7 +34,6 @@ using namespace pidml::pidutils;
struct PidMlEffAndPurProducer {
HistogramRegistry histos{"histos", {}, OutputObjHandlingPolicy::AnalysisObject};

PidONNXModel pidModel;
Configurable<int> cfgPid{"pid", 211, "PID to predict"};
Configurable<double> cfgNSigmaCut{"n-sigma-cut", 3.0f, "TPC and TOF PID nSigma cut"};
Configurable<std::array<double, kNDetectors>> cfgDetectorsPLimits{"detectors-p-limits", std::array<double, kNDetectors>(pidml_pt_cuts::defaultModelPLimits), "\"use {detector} when p >= y_{detector}\": array of 3 doubles [y_TPC, y_TOF, y_TRD]"};
Expand All @@ -57,6 +55,7 @@ struct PidMlEffAndPurProducer {
using BigTracks = soa::Filtered<soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTOFbeta, aod::TrackSelection, aod::TOFSignal, aod::McTrackLabels,
aod::pidTPCFullPi, aod::pidTPCFullKa, aod::pidTPCFullPr, aod::pidTPCFullEl, aod::pidTPCFullMu,
aod::pidTOFFullPi, aod::pidTOFFullKa, aod::pidTOFFullPr, aod::pidTOFFullEl, aod::pidTOFFullMu>>;
PidONNXModel<BigTracks> pidModel;

typedef struct nSigma_t {
double tpc, tof;
Expand Down Expand Up @@ -116,7 +115,7 @@ struct PidMlEffAndPurProducer {
if (cfgUseCCDB) {
ccdbApi.init(cfgCCDBURL);
} else {
pidModel = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1,
pidModel = PidONNXModel<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1,
cfgPid.value, cfgCertainty.value, &cfgDetectorsPLimits.value[0]);
}

Expand Down Expand Up @@ -153,7 +152,7 @@ struct PidMlEffAndPurProducer {
auto bc = collisions.iteratorAt(0).bc_as<aod::BCsWithTimestamps>();
if (cfgUseCCDB && bc.runNumber() != currentRunNumber) {
uint64_t timestamp = cfgUseFixedTimestamp ? cfgTimestamp.value : bc.timestamp();
pidModel = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, timestamp,
pidModel = PidONNXModel<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, timestamp,
cfgPid.value, cfgCertainty.value, &cfgDetectorsPLimits.value[0]);
}

Expand Down
9 changes: 4 additions & 5 deletions Tools/PIDML/pidOnnxInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static const std::vector<std::string> cutVarLabels = {

} // namespace pidml_pt_cuts

template <typename T>
struct PidONNXInterface {
PidONNXInterface(std::string& localPath, std::string& ccdbPath, bool useCCDB, o2::ccdb::CcdbApi& ccdbApi, uint64_t timestamp, std::vector<int> const& pids, o2::framework::LabeledArray<double> const& pLimits, std::vector<double> const& minCertainties, bool autoMode) : mNPids{pids.size()}, mPLimits{pLimits}
{
Expand Down Expand Up @@ -78,8 +79,7 @@ struct PidONNXInterface {
PidONNXInterface& operator=(const PidONNXInterface&) = delete;
~PidONNXInterface() = default;

template <typename T>
float applyModel(const T& track, int pid)
float applyModel(const T::iterator& track, int pid)
{
for (std::size_t i = 0; i < mNPids; i++) {
if (mModels[i].mPid == pid) {
Expand All @@ -90,8 +90,7 @@ struct PidONNXInterface {
return -1.0f;
}

template <typename T>
bool applyModelBoolean(const T& track, int pid)
bool applyModelBoolean(const T::iterator& track, int pid)
{
for (std::size_t i = 0; i < mNPids; i++) {
if (mModels[i].mPid == pid) {
Expand All @@ -110,7 +109,7 @@ struct PidONNXInterface {
minCertainties = std::vector<double>(mNPids, 0.5);
}

std::vector<PidONNXModel> mModels;
std::vector<PidONNXModel<T>> mModels;
std::size_t mNPids;
o2::framework::LabeledArray<double> mPLimits;
};
Expand Down
89 changes: 46 additions & 43 deletions Tools/PIDML/pidOnnxModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
#ifndef TOOLS_PIDML_PIDONNXMODEL_H_
#define TOOLS_PIDML_PIDONNXMODEL_H_

#include <Framework/ASoA.h>
#include <array>
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <limits>
#include <optional>
#include <string>
#include <algorithm>
#include <map>
#include <type_traits>
#include <utility>
#include <memory>
#include <vector>
Expand Down Expand Up @@ -73,6 +77,7 @@ bool readJsonFile(const std::string& config, rapidjson::Document& d)
}
} // namespace

template <typename T>
struct PidONNXModel {
public:
PidONNXModel(std::string& localPath, std::string& ccdbPath, bool useCCDB, o2::ccdb::CcdbApi& ccdbApi, uint64_t timestamp,
Expand Down Expand Up @@ -135,14 +140,12 @@ struct PidONNXModel {
PidONNXModel& operator=(const PidONNXModel&) = delete;
~PidONNXModel() = default;

template <typename T>
float applyModel(const T& track)
float applyModel(const typename T::iterator& track)
{
return getModelOutput(track);
}

template <typename T>
bool applyModelBoolean(const T& track)
bool applyModelBoolean(const typename T::iterator& track)
{
return getModelOutput(track) >= mMinCertainty;
}
Expand Down Expand Up @@ -203,7 +206,9 @@ struct PidONNXModel {
LOG(info) << "Using configuration files: " << localTrainColumnsPath << ", " << localScalingParamsPath;
if (readJsonFile(localTrainColumnsPath, trainColumnsDoc)) {
for (auto& param : trainColumnsDoc["columns_for_training"].GetArray()) {
mTrainColumns.emplace_back(param.GetString());
auto columnLabel = param.GetString();
mTrainColumns.emplace_back(columnLabel);
mGetters.emplace_back(o2::soa::row_helpers::getColumnGetterByLabel<float, T>(columnLabel));
}
}
if (readJsonFile(localScalingParamsPath, scalingParamsDoc)) {
Expand All @@ -213,52 +218,49 @@ struct PidONNXModel {
}
}

template <typename T>
std::vector<float> createInputsSingle(const T& track)
static float scale(float value, const std::pair<float, float>& scalingParams)
{
// TODO: Hardcoded for now. Planning to implement RowView extension to get runtime access to selected columns
// sign is short, trackType and tpcNClsShared uint8_t
return (value - scalingParams.first) / scalingParams.second;
}

float scaledTPCSignal = (track.tpcSignal() - mScalingParams.at("fTPCSignal").first) / mScalingParams.at("fTPCSignal").second;
std::vector<float> getValues(const typename T::iterator& track)
{
std::vector<float> output;
output.reserve(mTrainColumns.size());

std::vector<float> inputValues{scaledTPCSignal};
bool useTOF = !tofMissing(track) && inPLimit(track, mPLimits[kTPCTOF]);
bool useTRD = !trdMissing(track) && inPLimit(track, mPLimits[kTPCTOFTRD]);

// When TRD Signal shouldn't be used we pass quiet_NaNs to the network
if (!inPLimit(track, mPLimits[kTPCTOFTRD]) || trdMissing(track)) {
inputValues.push_back(std::numeric_limits<float>::quiet_NaN());
inputValues.push_back(std::numeric_limits<float>::quiet_NaN());
} else {
float scaledTRDSignal = (track.trdSignal() - mScalingParams.at("fTRDSignal").first) / mScalingParams.at("fTRDSignal").second;
inputValues.push_back(scaledTRDSignal);
inputValues.push_back(track.trdPattern());
}
for (uint32_t i = 0; i < mTrainColumns.size(); ++i) {
auto& columnLabel = mTrainColumns[i];

// When TOF Signal shouldn't be used we pass quiet_NaNs to the network
if (!inPLimit(track, mPLimits[kTPCTOF]) || tofMissing(track)) {
inputValues.push_back(std::numeric_limits<float>::quiet_NaN());
inputValues.push_back(std::numeric_limits<float>::quiet_NaN());
} else {
float scaledTOFSignal = (track.tofSignal() - mScalingParams.at("fTOFSignal").first) / mScalingParams.at("fTOFSignal").second;
float scaledBeta = (track.beta() - mScalingParams.at("fBeta").first) / mScalingParams.at("fBeta").second;
inputValues.push_back(scaledTOFSignal);
inputValues.push_back(scaledBeta);
}
if (
((columnLabel == "fTRDSignal" || columnLabel == "fTRDPattern") && !useTRD) ||
((columnLabel == "fTOFSignal" || columnLabel == "fBeta") && !useTOF)) {
output.push_back(std::numeric_limits<float>::quiet_NaN());
continue;
}

std::optional<std::pair<float, float>> scalingParams = std::nullopt;

auto scalingParamsEntry = mScalingParams.find(columnLabel);
if (scalingParamsEntry != mScalingParams.end()) {
scalingParams = scalingParamsEntry->second;
}

float scaledX = (track.x() - mScalingParams.at("fX").first) / mScalingParams.at("fX").second;
float scaledY = (track.y() - mScalingParams.at("fY").first) / mScalingParams.at("fY").second;
float scaledZ = (track.z() - mScalingParams.at("fZ").first) / mScalingParams.at("fZ").second;
float scaledAlpha = (track.alpha() - mScalingParams.at("fAlpha").first) / mScalingParams.at("fAlpha").second;
float scaledTPCNClsShared = (static_cast<float>(track.tpcNClsShared()) - mScalingParams.at("fTPCNClsShared").first) / mScalingParams.at("fTPCNClsShared").second;
float scaledDcaXY = (track.dcaXY() - mScalingParams.at("fDcaXY").first) / mScalingParams.at("fDcaXY").second;
float scaledDcaZ = (track.dcaZ() - mScalingParams.at("fDcaZ").first) / mScalingParams.at("fDcaZ").second;
float value = mGetters[i](track);

inputValues.insert(inputValues.end(), {track.p(), track.pt(), track.px(), track.py(), track.pz(), static_cast<float>(track.sign()), scaledX, scaledY, scaledZ, scaledAlpha, static_cast<float>(track.trackType()), scaledTPCNClsShared, scaledDcaXY, scaledDcaZ});
if (scalingParams) {
value = scale(value, scalingParams.value());
}

output.push_back(value);
}

return inputValues;
return output;
}

template <typename T>
float getModelOutput(const T& track)
float getModelOutput(const typename T::iterator& track)
{
// First rank of the expected model input is -1 which means that it is dynamic axis.
// Axis is exported as dynamic to make it possible to run model inference with the batch of
Expand All @@ -268,7 +270,7 @@ struct PidONNXModel {
auto input_shape = mInputShapes[0];
input_shape[0] = batch_size;

std::vector<float> inputTensorValues = createInputsSingle(track);
std::vector<float> inputTensorValues = getValues(track);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw that you are still converting our columns to rows. I though you said you managed to do it with bindings. Did I understand it wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e., I understood you found a way to setup an OrtValue directly from a signle IoBinding which is attached to the values of the column, without an intermediate representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Maja said previously in the discussion in this PR, we do not use IOBinding. We wanted to use it, but we still needed to make a copy of a value to standardize it. Anton suggested using gandiva to copy and standardize the whole column and then use IOBinding, so I implemented the first part of it: standardization using gandiva, but there was an issue. It copies the whole column, but even our example task used filtered iterator and just needed to iterate over only 1/10 of the column, because of the filtering.

  • So copying it and IOBinding would mean copying x (10 for my example) times more data than we need and then we can't use IOBinding, because it needs contiguous memory.
  • Another solution can be copying only filtered rows of column to the new column and then making a second copy to gandiva standardized column (we need continuous memory to use IOBinding), but for us, it still seems not ideal. (of course, we can skip gandiva stage and copy it once with standard arithmetic operations, but I am not sure which one would be faster)
  • We could also say that we ban Filtered<> from our PIDML API and the user would need to prepare the table contiguous, but this would be hard to use our API then
  • On top of that it wouldn't support dynamic columns unless we evaluate lambda and save it to a newly created column
    So we decided not to use IOBinding.

std::vector<Ort::Value> inputTensors;

#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
Expand Down Expand Up @@ -323,6 +325,7 @@ struct PidONNXModel {
}

std::vector<std::string> mTrainColumns;
std::vector<float(*)(const typename T::iterator&)> mGetters;
std::map<std::string, std::pair<float, float>> mScalingParams;

std::shared_ptr<Ort::Env> mEnv = nullptr;
Expand Down
27 changes: 14 additions & 13 deletions Tools/PIDML/qaPidML.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,6 @@ struct QaPidML {
}
}

// one model for one particle
PidONNXModel model211;
PidONNXModel model2212;
PidONNXModel model321;

Configurable<std::string> cfgPathCCDB{"ccdb-path", "Users/m/mkabus/PIDML", "base path to the CCDB directory with ONNX models"};
Configurable<std::string> cfgCCDBURL{"ccdb-url", "http://alice-ccdb.cern.ch", "URL of the CCDB repository"};
Configurable<bool> cfgUseCCDB{"useCCDB", true, "Whether to autofetch ML model from CCDB. If false, local file will be used."};
Expand All @@ -355,26 +350,32 @@ struct QaPidML {
o2::ccdb::CcdbApi ccdbApi;
int currentRunNumber = -1;

Filter trackFilter = requireGlobalTrackInFilter();
using pidTracks = soa::Filtered<soa::Join<aod::Tracks, aod::TracksExtra, aod::McTrackLabels, aod::TracksDCA, aod::TrackSelection, aod::pidTOFbeta, aod::TOFSignal>>;

// one model for one particle
PidONNXModel<pidTracks> model211;
PidONNXModel<pidTracks> model2212;
PidONNXModel<pidTracks> model321;

void init(InitContext const&)
{
if (cfgUseCCDB) {
ccdbApi.init(cfgCCDBURL);
} else {
model211 = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, 211, 0.5f, pSwitchValue[0]);
model2212 = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, 2211, 0.5f, pSwitchValue[1]);
model321 = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, 321, 0.5f, pSwitchValue[2]);
model211 = PidONNXModel<pidTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, 211, 0.5f, pSwitchValue[0]);
model2212 = PidONNXModel<pidTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, 2211, 0.5f, pSwitchValue[1]);
model321 = PidONNXModel<pidTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, 321, 0.5f, pSwitchValue[2]);
}
}

Filter trackFilter = requireGlobalTrackInFilter();
using pidTracks = soa::Filtered<soa::Join<aod::Tracks, aod::TracksExtra, aod::McTrackLabels, aod::TracksDCA, aod::TrackSelection, aod::pidTOFbeta, aod::TOFSignal>>;
void process(aod::Collisions const& collisions, pidTracks const& tracks, aod::McParticles const& /*mcParticles*/, aod::BCsWithTimestamps const&)
{
auto bc = collisions.iteratorAt(0).bc_as<aod::BCsWithTimestamps>();
if (cfgUseCCDB && bc.runNumber() != currentRunNumber) {
model211 = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, bc.timestamp(), 211, 0.5f, pSwitchValue[0]);
model2212 = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, bc.timestamp(), 2211, 0.5f, pSwitchValue[1]);
model321 = PidONNXModel(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, bc.timestamp(), 321, 0.5f, pSwitchValue[2]);
model211 = PidONNXModel<pidTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, bc.timestamp(), 211, 0.5f, pSwitchValue[0]);
model2212 = PidONNXModel<pidTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, bc.timestamp(), 2211, 0.5f, pSwitchValue[1]);
model321 = PidONNXModel<pidTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, bc.timestamp(), 321, 0.5f, pSwitchValue[2]);
}

for (auto& track : tracks) {
Expand Down
8 changes: 4 additions & 4 deletions Tools/PIDML/simpleApplyPidOnnxInterface.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ DECLARE_SOA_TABLE(MlPidResults, "AOD", "MLPIDRESULTS", o2::soa::Index<>, mlpidre
} // namespace o2::aod

struct SimpleApplyOnnxInterface {
PidONNXInterface pidInterface; // One instance to manage all needed ONNX models

Configurable<LabeledArray<double>> cfgPTCuts{"pT_cuts", {pidml_pt_cuts::cuts[0], pidml_pt_cuts::nPids, pidml_pt_cuts::nCutVars, pidml_pt_cuts::pidLabels, pidml_pt_cuts::cutVarLabels}, "pT cuts for each output pid and each detector configuration"};
Configurable<std::vector<int>> cfgPids{"pids", std::vector<int>{pidml_pt_cuts::pids_v}, "PIDs to predict"};
Configurable<std::vector<double>> cfgCertainties{"certainties", std::vector<double>{pidml_pt_cuts::certainties_v}, "Min certainties of the models to accept given particle to be of given kind"};
Expand All @@ -65,12 +63,14 @@ struct SimpleApplyOnnxInterface {
// Filter on isGlobalTrack (TracksSelection)
using BigTracks = soa::Filtered<soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTOFbeta, aod::TrackSelection, aod::TOFSignal>>;

PidONNXInterface<BigTracks> pidInterface; // One instance to manage all needed ONNX models

void init(InitContext const&)
{
if (cfgUseCCDB) {
ccdbApi.init(cfgCCDBURL);
} else {
pidInterface = PidONNXInterface(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, cfgPids.value, cfgPTCuts.value, cfgCertainties.value, cfgAutoMode.value);
pidInterface = PidONNXInterface<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, cfgPids.value, cfgPTCuts.value, cfgCertainties.value, cfgAutoMode.value);
}
}

Expand All @@ -79,7 +79,7 @@ struct SimpleApplyOnnxInterface {
auto bc = collisions.iteratorAt(0).bc_as<aod::BCsWithTimestamps>();
if (cfgUseCCDB && bc.runNumber() != currentRunNumber) {
uint64_t timestamp = cfgUseFixedTimestamp ? cfgTimestamp.value : bc.timestamp();
pidInterface = PidONNXInterface(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, timestamp, cfgPids.value, cfgPTCuts.value, cfgCertainties.value, cfgAutoMode.value);
pidInterface = PidONNXInterface<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, timestamp, cfgPids.value, cfgPTCuts.value, cfgCertainties.value, cfgAutoMode.value);
}

for (auto& track : tracks) {
Expand Down
Loading