diff --git a/Framework/Core/include/Framework/ASoA.h b/Framework/Core/include/Framework/ASoA.h index 84c6e3ae705fb..cfece12767612 100644 --- a/Framework/Core/include/Framework/ASoA.h +++ b/Framework/Core/include/Framework/ASoA.h @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include @@ -2172,6 +2174,82 @@ std::tuple getRowData(arrow::Table* table, T rowIterator, { return std::make_tuple(getSingleRowData(table, rowIterator, ci, ai, globalIndex)...); } + +template +R getColumnValue(const T& rowIterator) +{ + return static_cast(static_cast(rowIterator).get()); +} + +template +using ColumnGetterFunction = R (*)(const T&); + +template +concept dynamic_with_common_getter = is_dynamic_column && + // lambda is callable without additional free args + framework::pack_size(typename T::bindings_t{}) == framework::pack_size(typename T::callable_t::args{}) && + requires(T t) { + { t.get() } -> std::convertible_to; + }; + +template +concept persistent_with_common_getter = is_persistent_v && requires(T t) { + { t.get() } -> std::convertible_to; +}; + +template C> +ColumnGetterFunction createGetterPtr(const std::string_view& targetColumnLabel) +{ + return targetColumnLabel == C::columnLabel() ? &getColumnValue : nullptr; +} + +template C> +ColumnGetterFunction createGetterPtr(const std::string_view& targetColumnLabel) +{ + std::string_view columnLabel(C::columnLabel()); + + // allows user to use consistent formatting (with prefix) of all column labels + // by default there isn't 'f' prefix for dynamic column labels + if (targetColumnLabel.starts_with("f") && targetColumnLabel.substr(1) == columnLabel) { + return &getColumnValue; + } + + // check also exact match if user is aware of prefix missing + if (targetColumnLabel == columnLabel) { + return &getColumnValue; + } + + return nullptr; +} + +template +ColumnGetterFunction getColumnGetterByLabel(o2::framework::pack, const std::string_view& targetColumnLabel) +{ + ColumnGetterFunction func; + + (void)((func = createGetterPtr(targetColumnLabel), func) || ...); + + if (!func) { + throw framework::runtime_error_f("Getter for \"%s\" not found", targetColumnLabel); + } + + return func; +} + +template +using with_common_getter_t = typename std::conditional || dynamic_with_common_getter, std::true_type, std::false_type>::type; + +template +ColumnGetterFunction getColumnGetterByLabel(const std::string_view& targetColumnLabel) +{ + using TypesWithCommonGetter = o2::framework::selected_pack_multicondition, typename T::columns_t>; + + if (targetColumnLabel.size() == 0) { + throw framework::runtime_error("columnLabel: must not be empty"); + } + + return getColumnGetterByLabel(TypesWithCommonGetter{}, targetColumnLabel); +} } // namespace row_helpers } // namespace o2::soa diff --git a/Framework/Core/include/Framework/BinningPolicy.h b/Framework/Core/include/Framework/BinningPolicy.h index ea04aa3b5a5b3..b5e9ba546c4d9 100644 --- a/Framework/Core/include/Framework/BinningPolicy.h +++ b/Framework/Core/include/Framework/BinningPolicy.h @@ -12,7 +12,6 @@ #ifndef FRAMEWORK_BINNINGPOLICY_H #define FRAMEWORK_BINNINGPOLICY_H -#include "Framework/ASoA.h" #include "Framework/HistogramSpec.h" // only for VARIABLE_WIDTH #include "Framework/Pack.h" diff --git a/Framework/Core/test/benchmark_ASoA.cxx b/Framework/Core/test/benchmark_ASoA.cxx index 8dfac9e735c0b..4001e2a725a15 100644 --- a/Framework/Core/test/benchmark_ASoA.cxx +++ b/Framework/Core/test/benchmark_ASoA.cxx @@ -29,6 +29,7 @@ DECLARE_SOA_COLUMN_FULL(X, x, float, "x"); DECLARE_SOA_COLUMN_FULL(Y, y, float, "y"); DECLARE_SOA_COLUMN_FULL(Z, z, float, "z"); DECLARE_SOA_DYNAMIC_COLUMN(Sum, sum, [](float x, float y) { return x + y; }); +DECLARE_SOA_DYNAMIC_COLUMN(SumFreeArgs, sumFreeArgs, [](float x, float y, float freeArg) { return x + y + freeArg; }); } // namespace test DECLARE_SOA_TABLE(TestTable, "AOD", "TESTTBL", test::X, test::Y, test::Z, test::Sum); @@ -290,6 +291,36 @@ static void BM_ASoADynamicColumnPresent(benchmark::State& state) BENCHMARK(BM_ASoADynamicColumnPresent)->Range(8, 8 << maxrange); +static void BM_ASoADynamicColumnPresentGetGetterByLabel(benchmark::State& state) +{ + // Seed with a real random value, if available + std::default_random_engine e1(1234567891); + std::uniform_real_distribution uniform_dist(0, 1); + + TableBuilder builder; + auto rowWriter = builder.persist({"x", "y", "z"}); + for (auto i = 0; i < state.range(0); ++i) { + rowWriter(0, uniform_dist(e1), uniform_dist(e1), uniform_dist(e1)); + } + auto table = builder.finalize(); + + using Test = o2::soa::InPlaceTable<"A/0"_h, test::X, test::Y, test::Z, test::Sum>; + + for (auto _ : state) { + Test tests{table}; + float sum = 0; + auto xGetter = o2::soa::row_helpers::getColumnGetterByLabel("x"); + auto yGetter = o2::soa::row_helpers::getColumnGetterByLabel("y"); + for (auto& test : tests) { + sum += xGetter(test) + yGetter(test); + } + benchmark::DoNotOptimize(sum); + } + state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float) * 2); +} + +BENCHMARK(BM_ASoADynamicColumnPresentGetGetterByLabel)->Range(8, 8 << maxrange); + static void BM_ASoADynamicColumnCall(benchmark::State& state) { // Seed with a real random value, if available @@ -317,4 +348,33 @@ static void BM_ASoADynamicColumnCall(benchmark::State& state) } BENCHMARK(BM_ASoADynamicColumnCall)->Range(8, 8 << maxrange); +static void BM_ASoADynamicColumnCallGetGetterByLabel(benchmark::State& state) +{ + // Seed with a real random value, if available + std::default_random_engine e1(1234567891); + std::uniform_real_distribution uniform_dist(0, 1); + + TableBuilder builder; + auto rowWriter = builder.persist({"x", "y", "z"}); + for (auto i = 0; i < state.range(0); ++i) { + rowWriter(0, uniform_dist(e1), uniform_dist(e1), uniform_dist(e1)); + } + auto table = builder.finalize(); + + // SumFreeArgs presence checks if dynamic columns get() is handled correctly during compilation + using Test = o2::soa::InPlaceTable<"A/0"_h, test::X, test::Y, test::Sum, test::SumFreeArgs>; + + Test tests{table}; + for (auto _ : state) { + float sum = 0; + auto sumGetter = o2::soa::row_helpers::getColumnGetterByLabel("Sum"); + for (auto& test : tests) { + sum += sumGetter(test); + } + benchmark::DoNotOptimize(sum); + } + state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float) * 2); +} +BENCHMARK(BM_ASoADynamicColumnCallGetGetterByLabel)->Range(8, 8 << maxrange); + BENCHMARK_MAIN();