diff --git a/example/example_containers_bitset.cpp b/example/example_containers_bitset.cpp index 3e0c86590..ff022d024 100644 --- a/example/example_containers_bitset.cpp +++ b/example/example_containers_bitset.cpp @@ -51,6 +51,11 @@ int main() { std::cout << color_bitset.test(Color::RED) << std::endl; // false std::cout << color_bitset.test(Color::GREEN) << std::endl; // true std::cout << color_bitset.test(Color::BLUE) << std::endl; // true + std::cout << (color_bitset.find(Color::RED) != color_bitset.end()) << std::endl; // false + std::cout << (color_bitset.find(Color::GREEN) != color_bitset.end()) << std::endl; // true + std::cout << (color_bitset.find(Color::BLUE) != color_bitset.end()) << std::endl; // true + for (Color color : color_bitset) { std::cout << magic_enum::enum_name(color) << " "; } // GREEN BLUE + std::cout << std::endl; return 0; } diff --git a/include/magic_enum/magic_enum_containers.hpp b/include/magic_enum/magic_enum_containers.hpp index 9a0208a0d..6f3f34df6 100644 --- a/include/magic_enum/magic_enum_containers.hpp +++ b/include/magic_enum/magic_enum_containers.hpp @@ -35,6 +35,20 @@ #include "magic_enum.hpp" +#if __has_include() +# include +#endif + +#if !defined(__cpp_lib_bitops) || (__cpp_lib_bitops < 201907L) +# if __has_include() +# include +# pragma intrinsic(_BitScanForward) +# pragma intrinsic(_BitScanForward64) +# pragma intrinsic(_BitScanReverse) +# pragma intrinsic(_BitScanReverse64) +# endif +#endif + #if !defined(MAGIC_ENUM_NO_EXCEPTION) && (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) #ifndef MAGIC_ENUM_USE_STD_MODULE # include @@ -310,6 +324,56 @@ struct FilteredIterator { [[nodiscard]] friend constexpr bool operator!=(const FilteredIterator& lhs, const FilteredIterator& rhs) { return lhs.current != rhs.current; } }; +template constexpr int countr_zero(T x) noexcept { +#if __cpp_lib_bitops >= 201907L + return std::countr_zero(x); +#elif __has_include() + unsigned long index; + if constexpr (sizeof(T) <= sizeof(unsigned long)) { + return _BitScanForward(&index, x) ? index : (sizeof(T) * 8); + } else { + return _BitScanForward64(&index, x) ? index : (sizeof(T) * 8); + } +#else + if constexpr (sizeof(T) <= sizeof(unsigned int)) { + return x ? __builtin_ctz(x) : (sizeof(T) * 8); + } else if constexpr (sizeof(T) <= sizeof(unsigned long)) { + return x ? __builtin_ctzl(x) : (sizeof(T) * 8); + } else { + return x ? __builtin_ctzll(x) : (sizeof(T) * 8); + } +#endif +} + +template constexpr int countl_zero(T x) noexcept { +#if __cpp_lib_bitops >= 201907L + return std::countl_zero(x); +#elif __has_include() + unsigned long index; + if constexpr (sizeof(T) <= sizeof(unsigned long)) { + return _BitScanReverse(&index, x) ? ((sizeof(T) * 8) - index - 1) : (sizeof(T) * 8); + } else { + return _BitScanReverse64(&index, x) ? ((sizeof(T) * 8) - index - 1) : (sizeof(T) * 8); + } +#else + if constexpr (sizeof(T) <= sizeof(unsigned int)) { + return x ? __builtin_clz(x) : (sizeof(T) * 8); + } else if constexpr (sizeof(T) <= sizeof(unsigned long)) { + return x ? __builtin_clzl(x) : (sizeof(T) * 8); + } else if constexpr (sizeof(T) <= sizeof(unsigned long long)) { + return x ? __builtin_clzll(x) : (sizeof(T) * 8); + } +#endif +} + +template constexpr int bit_width(T x) noexcept { +#if __cpp_lib_int_pow2 >= 202002L + return std::bit_width(x); +#else + return std::numeric_limits::digits - countl_zero(x); +#endif +} + } // namespace detail template @@ -552,11 +616,113 @@ class bitset { return res; } + template + class iterator_impl { + friend class bitset; + + parent_t parent = nullptr; + std::size_t num_index = 0; + base_type bit_index = 0; + public: + using iterator_category = std::bidirectional_iterator_tag; + using value_type = const E; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + constexpr iterator_impl() noexcept = default; + constexpr iterator_impl(const iterator_impl&) noexcept = default; + constexpr iterator_impl& operator=(const iterator_impl&) noexcept = default; + constexpr iterator_impl(iterator_impl&&) noexcept = default; + constexpr iterator_impl& operator=(iterator_impl&&) noexcept = default; + private: + constexpr iterator_impl(parent_t p, std::size_t i) noexcept : iterator_impl(p, std::pair{i / bits_per_base, base_type{1} << (i % bits_per_base)}) {} + + constexpr iterator_impl(parent_t p, std::pair i) noexcept : parent(p), num_index(std::get<0>(i)), bit_index(std::get<1>(i)) {} + + [[nodiscard]] static constexpr iterator_impl begin(parent_t p) noexcept{ + for (std::size_t num_index = 0; num_index < base_type_count; ++num_index) { + if (p->a[num_index] > 0) { + base_type bit_index = p->a[num_index] & -p->a[num_index]; + return iterator_impl(p, std::pair{num_index, bit_index}); + } + } + return end(p); + } + [[nodiscard]] static constexpr iterator_impl end(parent_t p) noexcept { + return iterator_impl(p, enum_count()); + } + + public: + [[nodiscard]] constexpr reference operator*() const noexcept { return *Index::it(num_index * bits_per_base + static_cast(detail::countr_zero(bit_index))); } + + [[nodiscard]] constexpr pointer operator->() const noexcept { return std::addressof(**this); } + + constexpr iterator_impl& operator++() noexcept { + if (num_index >= base_type_count || (num_index == base_type_count - 1 && bit_index > last_value_max)) { + if ((bit_index <<= 1) == 0) { + ++num_index; + bit_index = base_type{1}; + } + return *this; + } + base_type remaining_bits = parent->a[num_index] & ~((bit_index << 1) - 1); + while (remaining_bits == 0 && ++num_index < base_type_count) { + remaining_bits = parent->a[num_index]; + } + if (num_index >= base_type_count) { + return *this = end(parent); + } + bit_index = remaining_bits & -remaining_bits; + return *this; + } + + [[nodiscard]] constexpr iterator_impl operator++(int) noexcept { + iterator_impl cp = *this; + ++*this; + return cp; + } + + constexpr iterator_impl& operator--() noexcept { + if (num_index >= base_type_count || (num_index == base_type_count - 1 && bit_index > last_value_max)) { + if ((bit_index >>= 1) == 0) { + --num_index; + bit_index = base_type{1} << (bits_per_base - 1); + } + return *this; + } + + base_type remaining_bits = parent->a[num_index] & (bit_index - 1); + while (remaining_bits == 0 && num_index != 0) { + remaining_bits = parent->a[--num_index]; + } + if (remaining_bits == 0) { + num_index = std::numeric_limits::max(); + bit_index = base_type{1} << (bits_per_base - 1); + return *this; + } + bit_index = base_type{1} << (detail::bit_width(remaining_bits) - 1); + return *this; + } + + [[nodiscard]] constexpr iterator_impl operator--(int) noexcept { + iterator_impl cp = *this; + --*this; + return cp; + } + + [[nodiscard]] friend constexpr bool operator==(const iterator_impl& lhs, const iterator_impl& rhs) { return lhs.parent == rhs.parent && lhs.num_index == rhs.num_index && lhs.bit_index == rhs.bit_index; } + + [[nodiscard]] friend constexpr bool operator!=(const iterator_impl& lhs, const iterator_impl& rhs) { return !(lhs == rhs); } + }; + public: using index_type = Index; using container_type = std::array; using reference = reference_impl<>; using const_reference = reference_impl; + using iterator = iterator_impl<>; + using const_iterator = iterator; constexpr explicit bitset(detail::raw_access_t = raw_access) noexcept : a{{}} {} @@ -649,6 +815,32 @@ class bitset { return MAGIC_ENUM_ASSERT(i), reference{this, *i}; } + [[nodiscard]] constexpr iterator begin() noexcept { return iterator::begin(this); } + + [[nodiscard]] constexpr const_iterator begin() const noexcept { return const_iterator::begin(this); } + + [[nodiscard]] constexpr const_iterator cbegin() const noexcept { return const_iterator::begin(this); } + + [[nodiscard]] constexpr iterator end() noexcept { return iterator::end(this); } + + [[nodiscard]] constexpr const_iterator end() const noexcept { return const_iterator::end(this); } + + [[nodiscard]] constexpr const_iterator cend() const noexcept { return const_iterator::end(this); } + + [[nodiscard]] constexpr const_iterator find(E pos) const noexcept { + if (auto i = index_type::at(pos); i && static_cast(const_reference(this, *i))) { + return const_iterator(this, *i); + } + return end(); + } + + [[nodiscard]] constexpr iterator find(E pos) noexcept { + if (auto i = index_type::at(pos); i && static_cast(const_reference(this, *i))) { + return iterator(this, *i); + } + return end(); + } + constexpr bool test(E pos) const { if (auto i = index_type::at(pos)) { return static_cast(const_reference(this, *i));