diff --git a/Detectors/TRD/pid/include/TRDPID/ML.h b/Detectors/TRD/pid/include/TRDPID/ML.h index c43de8921630a..210dfb56d80ff 100644 --- a/Detectors/TRD/pid/include/TRDPID/ML.h +++ b/Detectors/TRD/pid/include/TRDPID/ML.h @@ -70,10 +70,14 @@ class ML : public PIDBase [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) { LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]"); }, - (void*)3}; ///< ONNX enviroment - const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api + (void*)3}; ///< ONNX enviroment + const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api +#if __has_include() std::unique_ptr mSession; ///< ONNX session - Ort::SessionOptions mSessionOptions; ///< ONNX session options +#else + std::unique_ptr mSession; ///< ONNX session +#endif + Ort::SessionOptions mSessionOptions; ///< ONNX session options Ort::AllocatorWithDefaultOptions mAllocator; // Input/Output diff --git a/Detectors/TRD/pid/src/ML.cxx b/Detectors/TRD/pid/src/ML.cxx index 785f3b05cc112..bee46b27767a8 100644 --- a/Detectors/TRD/pid/src/ML.cxx +++ b/Detectors/TRD/pid/src/ML.cxx @@ -68,7 +68,11 @@ void ML::init(o2::framework::ProcessingContext& pc) LOG(info) << "Set GraphOptimizationLevel to " << mParams.graphOptimizationLevel; // create actual session +#if __has_include() mSession = std::make_unique(mEnv, reinterpret_cast(model_data.data()), model_data.size(), mSessionOptions); +#else + mSession = std::make_unique(mEnv, reinterpret_cast(model_data.data()), model_data.size(), mSessionOptions); +#endif LOG(info) << "ONNX runtime session created"; // print name/shape of inputs @@ -104,8 +108,15 @@ float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& try { auto input = prepareModelInput(trk, inputTracks); // create memory mapping to vector above +#if __has_include() auto inputTensor = Ort::Experimental::Value::CreateTensor(input.data(), input.size(), {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); +#else + Ort::MemoryInfo mem_info = + Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + auto inputTensor = Ort::Value::CreateTensor(mem_info, input.data(), input.size(), + {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); +#endif std::vector ortTensor; ortTensor.push_back(std::move(inputTensor)); auto outTensor = mSession->Run(mInputNames, ortTensor, mOutputNames); diff --git a/dependencies/FindONNXRuntime.cmake b/dependencies/FindONNXRuntime.cmake index 0983d44644a6c..b783c2e1c7bf3 100644 --- a/dependencies/FindONNXRuntime.cmake +++ b/dependencies/FindONNXRuntime.cmake @@ -17,4 +17,7 @@ endif() if (NOT ONNXRuntime::ONNXRuntime_FOUND) find_package(onnxruntime CONFIG) + if (onnxruntime_FOUND) + add_library(ONNXRuntime::ONNXRuntime ALIAS onnxruntime::onnxruntime) + endif() endif()