From cfe7ab383be3f395f759e806cc19f67ded0e3ddd Mon Sep 17 00:00:00 2001 From: Peter Hristov Date: Wed, 24 Apr 2024 15:57:44 +0300 Subject: [PATCH 1/2] Additional changes for newer versions of ONNXRuntime --- Detectors/TRD/pid/include/TRDPID/ML.h | 4 ++++ Detectors/TRD/pid/src/ML.cxx | 11 +++++++++++ dependencies/FindONNXRuntime.cmake | 3 +++ 3 files changed, 18 insertions(+) diff --git a/Detectors/TRD/pid/include/TRDPID/ML.h b/Detectors/TRD/pid/include/TRDPID/ML.h index c43de8921630a..e51a758a79b81 100644 --- a/Detectors/TRD/pid/include/TRDPID/ML.h +++ b/Detectors/TRD/pid/include/TRDPID/ML.h @@ -72,7 +72,11 @@ class ML : public PIDBase }, (void*)3}; ///< ONNX enviroment const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api +#if __has_include() std::unique_ptr mSession; ///< ONNX session +#else + std::unique_ptr mSession; ///< ONNX session +#endif Ort::SessionOptions mSessionOptions; ///< ONNX session options Ort::AllocatorWithDefaultOptions mAllocator; diff --git a/Detectors/TRD/pid/src/ML.cxx b/Detectors/TRD/pid/src/ML.cxx index 785f3b05cc112..f553f1200bcb9 100644 --- a/Detectors/TRD/pid/src/ML.cxx +++ b/Detectors/TRD/pid/src/ML.cxx @@ -68,7 +68,11 @@ void ML::init(o2::framework::ProcessingContext& pc) LOG(info) << "Set GraphOptimizationLevel to " << mParams.graphOptimizationLevel; // create actual session +#if __has_include() mSession = std::make_unique(mEnv, reinterpret_cast(model_data.data()), model_data.size(), mSessionOptions); +#else + mSession = std::make_unique(mEnv, reinterpret_cast(model_data.data()), model_data.size(), mSessionOptions); +#endif LOG(info) << "ONNX runtime session created"; // print name/shape of inputs @@ -104,8 +108,15 @@ float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& try { auto input = prepareModelInput(trk, inputTracks); // create memory mapping to vector above +#if __has_include() auto inputTensor = Ort::Experimental::Value::CreateTensor(input.data(), input.size(), {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); +#else + Ort::MemoryInfo mem_info = + Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + auto inputTensor = Ort::Value::CreateTensor(mem_info, input.data(), input.size(), + {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); +#endif std::vector ortTensor; ortTensor.push_back(std::move(inputTensor)); auto outTensor = mSession->Run(mInputNames, ortTensor, mOutputNames); diff --git a/dependencies/FindONNXRuntime.cmake b/dependencies/FindONNXRuntime.cmake index 0983d44644a6c..b783c2e1c7bf3 100644 --- a/dependencies/FindONNXRuntime.cmake +++ b/dependencies/FindONNXRuntime.cmake @@ -17,4 +17,7 @@ endif() if (NOT ONNXRuntime::ONNXRuntime_FOUND) find_package(onnxruntime CONFIG) + if (onnxruntime_FOUND) + add_library(ONNXRuntime::ONNXRuntime ALIAS onnxruntime::onnxruntime) + endif() endif() From aac755eb37e756020fd3f6097ee9e802386a4abb Mon Sep 17 00:00:00 2001 From: Peter Hristov Date: Wed, 24 Apr 2024 16:20:18 +0300 Subject: [PATCH 2/2] Clang-format --- Detectors/TRD/pid/include/TRDPID/ML.h | 6 +++--- Detectors/TRD/pid/src/ML.cxx | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Detectors/TRD/pid/include/TRDPID/ML.h b/Detectors/TRD/pid/include/TRDPID/ML.h index e51a758a79b81..210dfb56d80ff 100644 --- a/Detectors/TRD/pid/include/TRDPID/ML.h +++ b/Detectors/TRD/pid/include/TRDPID/ML.h @@ -70,14 +70,14 @@ class ML : public PIDBase [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) { LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]"); }, - (void*)3}; ///< ONNX enviroment - const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api + (void*)3}; ///< ONNX enviroment + const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api #if __has_include() std::unique_ptr mSession; ///< ONNX session #else std::unique_ptr mSession; ///< ONNX session #endif - Ort::SessionOptions mSessionOptions; ///< ONNX session options + Ort::SessionOptions mSessionOptions; ///< ONNX session options Ort::AllocatorWithDefaultOptions mAllocator; // Input/Output diff --git a/Detectors/TRD/pid/src/ML.cxx b/Detectors/TRD/pid/src/ML.cxx index f553f1200bcb9..bee46b27767a8 100644 --- a/Detectors/TRD/pid/src/ML.cxx +++ b/Detectors/TRD/pid/src/ML.cxx @@ -115,7 +115,7 @@ float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); auto inputTensor = Ort::Value::CreateTensor(mem_info, input.data(), input.size(), - {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); + {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); #endif std::vector ortTensor; ortTensor.push_back(std::move(inputTensor));