diff --git a/tree/ntuple/inc/ROOT/RNTupleJoinTable.hxx b/tree/ntuple/inc/ROOT/RNTupleJoinTable.hxx index edb1e1b6c9070..2a049626ba11b 100644 --- a/tree/ntuple/inc/ROOT/RNTupleJoinTable.hxx +++ b/tree/ntuple/inc/ROOT/RNTupleJoinTable.hxx @@ -209,47 +209,13 @@ public: /// \brief Get an entry index (if it exists) for the given join field value(s), from any partition. /// /// \param[in] valuePtrs A vector of pointers to the join field values to look up. - /// - /// \note If one or more corresponding entries exist for the given value(s), the first entry index found in the join - /// table is returned. To get *all* the entry indexes, use GetEntryIndexes. + /// \param[in] throwOnMultipleMatches If set to `true` an RException will be thrown when multiple corresponding entry + /// indexes are found. When set to `false`, the first entry index that is found is returned, even if there are + /// multple corresponding entries. /// /// \return An entry number that corresponds to `valuePtrs`. When there are no corresponding entries, /// `kInvalidNTupleIndex` is returned. - ROOT::NTupleSize_t GetEntryIndex(const std::vector &valuePtrs) const; - - ///////////////////////////////////////////////////////////////////////////// - /// \brief Get all entry indexes for the given join field value(s) within a partition. - /// - /// \param[in] valuePtrs A vector of pointers to the join field values to look up. - /// \param[in] partitionKey The partition key to use for the lookup. If not provided, it will use the default - /// partition key. - /// - /// \return The entry numbers that correspond to `valuePtrs`. When there are no corresponding entries, an empty - /// vector is returned. - std::vector - GetEntryIndexes(const std::vector &valuePtrs, PartitionKey_t partitionKey = kDefaultPartitionKey) const; - - ///////////////////////////////////////////////////////////////////////////// - /// \brief Get all entry indexes for the given join field value(s) for a specific set of partitions. - /// - /// \param[in] valuePtrs A vector of pointers to the join field values to look up. - /// \param[in] partitionKeys The partition keys to use for the lookup. - /// - /// \return The entry numbers that correspond to `valuePtrs`, grouped by partition. When there are no corresponding - /// entries, an empty map is returned. - std::unordered_map> - GetPartitionedEntryIndexes(const std::vector &valuePtrs, - const std::vector &partitionKeys) const; - - ///////////////////////////////////////////////////////////////////////////// - /// \brief Get all entry indexes for the given join field value(s) for all partitions. - /// - /// \param[in] valuePtrs A vector of pointers to the join field values to look up. - /// - /// \return The entry numbers that correspond to `valuePtrs`, grouped by partition. When there are no corresponding - /// entries, an empty map is returned. - std::unordered_map> - GetPartitionedEntryIndexes(const std::vector &valuePtrs) const; + ROOT::NTupleSize_t GetEntryIndex(const std::vector &valuePtrs, bool throwOnMultipleMatches = true) const; }; } // namespace Internal } // namespace Experimental diff --git a/tree/ntuple/src/RNTupleJoinTable.cxx b/tree/ntuple/src/RNTupleJoinTable.cxx index d664539b47954..bbd7da5b6f3e7 100644 --- a/tree/ntuple/src/RNTupleJoinTable.cxx +++ b/tree/ntuple/src/RNTupleJoinTable.cxx @@ -125,14 +125,19 @@ ROOT::Experimental::Internal::RNTupleJoinTable::Add(ROOT::Internal::RPageSource return *this; } -ROOT::NTupleSize_t -ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector &valuePtrs) const +ROOT::NTupleSize_t ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector &valuePtrs, + bool throwOnMultipleMatches) const { for (const auto &partition : fPartitions) { for (const auto &joinMapping : partition.second) { auto entriesForMapping = joinMapping->GetEntryIndexes(valuePtrs); if (entriesForMapping) { + if (throwOnMultipleMatches && entriesForMapping->size() > 1) { + // TODO (fdegeus): make error message more informative + throw RException(R__FAIL("found more than one corresponding entry index")); + } + return (*entriesForMapping)[0]; } } @@ -140,57 +145,3 @@ ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector< return kInvalidNTupleIndex; } - -std::vector -ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndexes(const std::vector &valuePtrs, - PartitionKey_t partitionKey) const -{ - auto partition = fPartitions.find(partitionKey); - if (partition == fPartitions.end()) - return {}; - - std::vector entryIdxs{}; - - for (const auto &joinMapping : partition->second) { - auto entriesForMapping = joinMapping->GetEntryIndexes(valuePtrs); - if (entriesForMapping) - entryIdxs.insert(entryIdxs.end(), entriesForMapping->begin(), entriesForMapping->end()); - } - - return entryIdxs; -} - -std::unordered_map> -ROOT::Experimental::Internal::RNTupleJoinTable::GetPartitionedEntryIndexes( - const std::vector &valuePtrs, const std::vector &partitionKeys) const -{ - std::unordered_map> entryIdxs{}; - - for (const auto &partitionKey : partitionKeys) { - auto entriesForPartition = GetEntryIndexes(valuePtrs, partitionKey); - if (!entriesForPartition.empty()) { - entryIdxs[partitionKey].insert(entryIdxs[partitionKey].end(), entriesForPartition.begin(), - entriesForPartition.end()); - } - } - - return entryIdxs; -} - -std::unordered_map> -ROOT::Experimental::Internal::RNTupleJoinTable::GetPartitionedEntryIndexes(const std::vector &valuePtrs) const -{ - std::unordered_map> entryIdxs{}; - - for (const auto &partition : fPartitions) { - for (const auto &joinMapping : partition.second) { - auto entriesForMapping = joinMapping->GetEntryIndexes(valuePtrs); - if (entriesForMapping) { - entryIdxs[partition.first].insert(entryIdxs[partition.first].end(), entriesForMapping->begin(), - entriesForMapping->end()); - } - } - } - - return entryIdxs; -} diff --git a/tree/ntuple/test/ntuple_join_table.cxx b/tree/ntuple/test/ntuple_join_table.cxx index b5672932a9eb9..49c462768c4d4 100644 --- a/tree/ntuple/test/ntuple_join_table.cxx +++ b/tree/ntuple/test/ntuple_join_table.cxx @@ -22,7 +22,7 @@ TEST(RNTupleJoinTable, Basic) std::uint64_t fldValue = 0; // No entry mappings have been added to the join table yet - EXPECT_EQ(std::vector{}, joinTable->GetEntryIndexes({&fldValue})); + EXPECT_EQ(ROOT::kInvalidNTupleIndex, joinTable->GetEntryIndex({&fldValue})); // Now add the entry mapping for the page source joinTable->Add(*pageSource); @@ -33,7 +33,7 @@ TEST(RNTupleJoinTable, Basic) for (unsigned i = 0; i < ntuple->GetNEntries(); ++i) { fldValue = fld(i); EXPECT_EQ(fldValue, i * 2); - EXPECT_EQ(joinTable->GetEntryIndexes({&fldValue}), std::vector{static_cast(i)}); + EXPECT_EQ(joinTable->GetEntryIndex({&fldValue}), i); } } @@ -127,13 +127,12 @@ TEST(RNTupleJoinTable, SparseSecondary) auto event = fldEvent(i); if (i % 2 == 1) { - EXPECT_EQ(joinTable->GetEntryIndexes({&event}), std::vector{}) + EXPECT_EQ(joinTable->GetEntryIndex({&event}), ROOT::kInvalidNTupleIndex) << "entry should not be present in the join table"; } else { - auto entryIdxs = joinTable->GetEntryIndexes({&event}); - ASSERT_EQ(1ul, entryIdxs.size()); - EXPECT_EQ(entryIdxs[0], i / 2); - EXPECT_FLOAT_EQ(fldX(entryIdxs[0]), static_cast(entryIdxs[0]) / 3.14); + auto entryIdx = joinTable->GetEntryIndex({&event}); + EXPECT_EQ(entryIdx, i / 2); + EXPECT_FLOAT_EQ(fldX(entryIdx), static_cast(entryIdx) / 3.14); } } } @@ -171,26 +170,25 @@ TEST(RNTupleJoinTable, MultipleFields) for (std::uint64_t i = 0; i < pageSource->GetNEntries(); ++i) { run = i / 5; event = i % 5; - auto entryIdxs = joinTable->GetEntryIndexes({&run, &event}); - ASSERT_EQ(1ul, entryIdxs.size()); - EXPECT_EQ(fld(entryIdxs[0]), fld(i)); + auto entryIdx = joinTable->GetEntryIndex({&run, &event}); + EXPECT_EQ(fld(entryIdx), fld(i)); } run = 1; event = 2; - auto idx1 = joinTable->GetEntryIndexes({&run, &event}); - auto idx2 = joinTable->GetEntryIndexes({&event, &run}); + auto idx1 = joinTable->GetEntryIndex({&run, &event}); + auto idx2 = joinTable->GetEntryIndex({&event, &run}); EXPECT_NE(idx1, idx2); try { - joinTable->GetEntryIndexes({&run, &event, &event}); + joinTable->GetEntryIndex({&run, &event, &event}); FAIL() << "querying the join table with more values than join field values should not be possible"; } catch (const ROOT::RException &err) { EXPECT_THAT(err.what(), testing::HasSubstr("number of value pointers must match number of join fields")); } try { - joinTable->GetEntryIndexes({&run}); + joinTable->GetEntryIndex({&run}); FAIL() << "querying the join table with fewer values than join field values should not be possible"; } catch (const ROOT::RException &err) { EXPECT_THAT(err.what(), testing::HasSubstr("number of value pointers must match number of join fields")); @@ -210,8 +208,6 @@ TEST(RNTupleJoinTable, MultipleMatches) for (int i = 0; i < 10; ++i) { if (i > 4) *fldRun = 2; - if (i > 7) - *fldRun = 3; ntuple->Fill(); } } @@ -221,102 +217,15 @@ TEST(RNTupleJoinTable, MultipleMatches) joinTable->Add(*pageSource); std::uint64_t run = 1; - auto entryIdxs = joinTable->GetEntryIndexes({&run}); - auto expected = std::vector{0, 1, 2, 3, 4}; - EXPECT_EQ(expected, entryIdxs); - entryIdxs = joinTable->GetEntryIndexes({&(++run)}); - expected = {5, 6, 7}; - EXPECT_EQ(expected, entryIdxs); - entryIdxs = joinTable->GetEntryIndexes({&(++run)}); - expected = {8, 9}; - EXPECT_EQ(expected, entryIdxs); - entryIdxs = joinTable->GetEntryIndexes({&(++run)}); - EXPECT_EQ(std::vector{}, entryIdxs); -} - -TEST(RNTupleJoinTable, Partitions) -{ - auto fnWriteNTuple = [](const FileRaii &fileGuard, std::uint16_t run) { - auto model = RNTupleModel::Create(); - *model->MakeField("run") = run; - auto fldI = model->MakeField("i"); - - auto ntuple = RNTupleWriter::Recreate(std::move(model), "ntuple", fileGuard.GetPath()); - - for (int i = 0; i < 5; ++i) { - *fldI = i; - ntuple->Fill(); - } - }; - - auto joinTable = RNTupleJoinTable::Create({"i"}); - - std::vector fileGuards; - fileGuards.emplace_back("test_ntuple_join_partition1.root"); - fileGuards.emplace_back("test_ntuple_join_partition2.root"); - fileGuards.emplace_back("test_ntuple_join_partition3.root"); - fileGuards.emplace_back("test_ntuple_join_partition4.root"); - std::int16_t runNumbers[4] = {1, 2, 3, 3}; - std::vector> pageSources; - - // Create four ntuples where with their corresponding run numbers and add them to the join table using this run - // number as the partition key. - for (unsigned i = 0; i < fileGuards.size(); ++i) { - fnWriteNTuple(fileGuards[i], runNumbers[i]); - pageSources.emplace_back(RPageSource::Create("ntuple", fileGuards[i].GetPath())); - joinTable->Add(*pageSources.back(), runNumbers[i]); - } - - std::vector openSpec; - for (const auto &fileGuard : fileGuards) { - openSpec.emplace_back("ntuple", fileGuard.GetPath()); - } - auto proc = RNTupleProcessor::CreateChain(openSpec); - - auto run = proc->RequestField("run"); - auto i = proc->RequestField("i"); - - std::uint32_t idx = 0; - - // When getting the entry indexes for all partitions, we expect multiple resulting entry indexes (i.e., one entry - // index for each ntuple in the chain). - std::unordered_map> expectedEntryIdxMap = { - {1, {0}}, - {2, {0}}, - {3, {0, 0}}, - }; - EXPECT_EQ(expectedEntryIdxMap, joinTable->GetPartitionedEntryIndexes({&idx})); - EXPECT_EQ(expectedEntryIdxMap, joinTable->GetPartitionedEntryIndexes({&idx}, {1, 2, 3})); - - expectedEntryIdxMap = { - {1, {0}}, - {3, {0, 0}}, - }; - EXPECT_EQ(expectedEntryIdxMap, joinTable->GetPartitionedEntryIndexes({&idx}, {1, 3})); - - // Calling GetEntryIndexes with a partition key not present in the join table shouldn't fail; it should just return - // an empty vector. - EXPECT_EQ(std::vector{}, joinTable->GetEntryIndexes({&idx}, 4)); - - // Similarly, calling GetEntryIndexes with a partition key that is present in the join table but a join value that - // isn't shouldn't fail; it should just return an empty vector. - idx = 99; - EXPECT_EQ(std::vector{}, joinTable->GetEntryIndexes({&idx}, 3)); - - expectedEntryIdxMap.clear(); - EXPECT_EQ(expectedEntryIdxMap, joinTable->GetPartitionedEntryIndexes({&idx})); - EXPECT_EQ(expectedEntryIdxMap, joinTable->GetPartitionedEntryIndexes({&idx}, {1, 2, 3})); - EXPECT_EQ(expectedEntryIdxMap, joinTable->GetPartitionedEntryIndexes({&idx}, {1, 3})); - - for (auto it = proc->begin(); it != proc->end(); it++) { - auto entryIdxs = joinTable->GetEntryIndexes({i.GetRawPtr()}, *run); - - // Because two ntuples store their events under run number 3 and their entries for `i` are identical, two entry - // indexes are expected. For the other case (run == 1 and run == 2), only one entry index is expected. - if (*run == 3) - EXPECT_EQ(entryIdxs.size(), 2ul); - else - EXPECT_EQ(entryIdxs.size(), 1ul); + try { + joinTable->GetEntryIndex({&run}, /*throwOnMultipleMatches=*/true); + FAIL() << "should throw on multiple entry matches"; + } catch (const ROOT::RException &err) { + EXPECT_THAT(err.what(), testing::HasSubstr("found more than one corresponding entry index")); } + auto entryIdx = joinTable->GetEntryIndex({&(++run)}, /*throwOnMultipleMatches=*/false); + EXPECT_EQ(5, entryIdx); + entryIdx = joinTable->GetEntryIndex({&(++run)}); + EXPECT_EQ(ROOT::kInvalidNTupleIndex, entryIdx); }