Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sycl/include/sycl/accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <sycl/detail/type_traits.hpp> // for const_if_const_AS
#include <sycl/exception.hpp> // for make_error_code
#include <sycl/ext/oneapi/accessor_property_list.hpp> // for accessor_prope...
#include <sycl/ext/oneapi/weak_object_base.hpp> // for getSyclWeakObj...
#include <sycl/id.hpp> // for id
#include <sycl/multi_ptr.hpp> // for multi_ptr
#include <sycl/pointers.hpp> // for local_ptr, glo...
Expand Down
13 changes: 11 additions & 2 deletions sycl/include/sycl/detail/owner_less_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@

#pragma once

#include <sycl/detail/impl_utils.hpp> // for getSyclObjImpl
#include <sycl/ext/oneapi/weak_object_base.hpp> // for getSyclWeakObjImpl
#include <sycl/detail/impl_utils.hpp> // for getSyclObjImpl

namespace sycl {
inline namespace _V1 {

namespace ext::oneapi::detail {
template <typename SYCLObjT> class weak_object_base;

// Helper function for getting the underlying weak_ptr from a weak_object.
template <typename SYCLObjT>
decltype(weak_object_base<SYCLObjT>::MObjWeakPtr)
getSyclWeakObjImpl(const weak_object_base<SYCLObjT> &WeakObj);
} // namespace ext::oneapi::detail

namespace detail {

// Common CRTP base class supplying a common definition of owner-before ordering
Expand Down
181 changes: 108 additions & 73 deletions sycl/include/sycl/ext/oneapi/weak_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,97 @@

#pragma once

#include <sycl/access/access.hpp> // for target, mode
#include <sycl/accessor.hpp> // for accessor
#include <sycl/buffer.hpp> // for buffer
#include <sycl/detail/impl_utils.hpp> // for createSyc...
#include <sycl/detail/memcpy.hpp> // for detail
#include <sycl/exception.hpp> // for make_erro...
#include <sycl/ext/oneapi/weak_object_base.hpp> // for weak_obje...
#include <sycl/range.hpp> // for range
#include <sycl/stream.hpp> // for stream

#include <memory> // for shared_ptr
#include <optional> // for optional
#include <stddef.h> // for size_t
#include <memory>
#include <optional>
#include <stddef.h>
#include <sycl/access/access.hpp>
#include <sycl/accessor.hpp>
#include <sycl/buffer.hpp>
#include <sycl/detail/impl_utils.hpp>
#include <sycl/detail/memcpy.hpp>
#include <sycl/exception.hpp>
#include <sycl/range.hpp>
#include <sycl/stream.hpp>

namespace sycl {
inline namespace _V1 {
namespace ext::oneapi {
template <typename SYCLObjT> class weak_object;
namespace detail {
// Import from detail:: into ext::oneapi::detail:: to improve readability later
using namespace ::sycl::detail;
using namespace sycl::detail;
template <typename SYCLObjT> class weak_object_base;

// Helper function for getting the underlying weak_ptr from a weak_object.
template <typename SYCLObjT>
decltype(weak_object_base<SYCLObjT>::MObjWeakPtr)
getSyclWeakObjImpl(const weak_object_base<SYCLObjT> &WeakObj) {
return WeakObj.MObjWeakPtr;
}

// Common base class for weak_object.
template <typename SYCLObjT> class weak_object_base {
public:
using object_type = SYCLObjT;

constexpr weak_object_base() noexcept : MObjWeakPtr() {}
weak_object_base(const SYCLObjT &SYCLObj) noexcept
#ifndef __SYCL_DEVICE_ONLY__
: MObjWeakPtr(getSyclObjImpl(SYCLObj))
#endif
{
(void)SYCLObj;
}
weak_object_base(const weak_object_base &Other) noexcept = default;
weak_object_base(weak_object_base &&Other) noexcept = default;

weak_object_base &operator=(const weak_object_base &Other) noexcept = default;
weak_object_base &operator=(weak_object_base &&Other) noexcept = default;

void reset() noexcept { MObjWeakPtr.reset(); }
void swap(weak_object_base &Other) noexcept {
MObjWeakPtr.swap(Other.MObjWeakPtr);
}

bool expired() const noexcept { return MObjWeakPtr.expired(); }

#ifndef __SYCL_DEVICE_ONLY__
bool owner_before(const SYCLObjT &Other) const noexcept {
return MObjWeakPtr.owner_before(getSyclObjImpl(Other));
}
bool owner_before(const weak_object_base &Other) const noexcept {
return MObjWeakPtr.owner_before(Other.MObjWeakPtr);
}
SYCLObjT lock() const {
std::optional<SYCLObjT> OptionalObj =
static_cast<const weak_object<SYCLObjT> *>(this)->try_lock();
if (!OptionalObj)
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
"Referenced object has expired.");
return *std::move(OptionalObj);
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
bool owner_before(const SYCLObjT &Other) const noexcept;
bool owner_before(const weak_object_base &Other) const noexcept;
std::optional<SYCLObjT> try_lock() const noexcept;
SYCLObjT lock() const;
#endif // __SYCL_DEVICE_ONLY__

protected:
#ifndef __SYCL_DEVICE_ONLY__
// Store a weak variant of the impl in the SYCLObjT.
typename std::remove_reference_t<decltype(getSyclObjImpl(
std::declval<SYCLObjT>()))>::weak_type MObjWeakPtr;
#else
// On device we may not have an impl, so we pad with an unused void pointer.
std::weak_ptr<void> MObjWeakPtr;
#endif // __SYCL_DEVICE_ONLY__

template <class Obj>
friend decltype(weak_object_base<Obj>::MObjWeakPtr)
getSyclWeakObjImpl(const weak_object_base<Obj> &WeakObj);
};

// Helper for creating ranges for empty weak_objects.
template <int Dims> static range<Dims> createDummyRange() {
Expand All @@ -50,19 +121,18 @@ template <int Dims> static range<Dims> createDummyRange() {
// weak_object_base class.
template <typename SYCLObjT>
class weak_object : public detail::weak_object_base<SYCLObjT> {
using weak_object_base = detail::weak_object_base<SYCLObjT>;

public:
using object_type = typename detail::weak_object_base<SYCLObjT>::object_type;
using object_type = typename weak_object_base::object_type;

constexpr weak_object() noexcept = default;
weak_object(const SYCLObjT &SYCLObj) noexcept
: detail::weak_object_base<SYCLObjT>(SYCLObj) {}
weak_object(const SYCLObjT &SYCLObj) noexcept : weak_object_base(SYCLObj) {}
weak_object(const weak_object &Other) noexcept = default;
weak_object(weak_object &&Other) noexcept = default;

weak_object &operator=(const SYCLObjT &SYCLObj) noexcept {
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
this->MObjWeakPtr =
detail::weak_object_base<SYCLObjT>::GetWeakImpl(SYCLObj);
weak_object_base::operator=(SYCLObj);
return *this;
}
weak_object &operator=(const weak_object &Other) noexcept = default;
Expand All @@ -73,20 +143,8 @@ class weak_object : public detail::weak_object_base<SYCLObjT> {
auto MObjImplPtr = this->MObjWeakPtr.lock();
if (!MObjImplPtr)
return std::nullopt;
return sycl::detail::createSyclObjFromImpl<SYCLObjT>(MObjImplPtr);
}
SYCLObjT lock() const {
std::optional<SYCLObjT> OptionalObj = try_lock();
if (!OptionalObj)
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
"Referenced object has expired.");
return *OptionalObj;
return detail::createSyclObjFromImpl<SYCLObjT>(std::move(MObjImplPtr));
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
std::optional<SYCLObjT> try_lock() const noexcept;
SYCLObjT lock() const;
#endif // __SYCL_DEVICE_ONLY__
};

Expand All @@ -95,28 +153,26 @@ class weak_object : public detail::weak_object_base<SYCLObjT> {
template <typename T, int Dimensions, typename AllocatorT>
class weak_object<buffer<T, Dimensions, AllocatorT>>
: public detail::weak_object_base<buffer<T, Dimensions, AllocatorT>> {
private:
using weak_object_base =
detail::weak_object_base<buffer<T, Dimensions, AllocatorT>>;
using buffer_type = buffer<T, Dimensions, AllocatorT>;

public:
using object_type =
typename detail::weak_object_base<buffer_type>::object_type;
using object_type = typename weak_object_base::object_type;

constexpr weak_object() noexcept
: detail::weak_object_base<buffer_type>(),
MRange{detail::createDummyRange<Dimensions>()}, MOffsetInBytes{0},
: MRange{detail::createDummyRange<Dimensions>()}, MOffsetInBytes{0},
MIsSubBuffer{false} {}
weak_object(const buffer_type &SYCLObj) noexcept
: detail::weak_object_base<buffer_type>(SYCLObj), MRange{SYCLObj.Range},
: weak_object_base(SYCLObj), MRange{SYCLObj.Range},
MOffsetInBytes{SYCLObj.OffsetInBytes},
MIsSubBuffer{SYCLObj.IsSubBuffer} {}
weak_object(const weak_object &Other) noexcept = default;
weak_object(weak_object &&Other) noexcept = default;

weak_object &operator=(const buffer_type &SYCLObj) noexcept {
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
this->MObjWeakPtr = detail::weak_object_base<
buffer<T, Dimensions, AllocatorT>>::GetWeakImpl(SYCLObj);
weak_object_base::operator=(SYCLObj);
this->MRange = SYCLObj.Range;
this->MOffsetInBytes = SYCLObj.OffsetInBytes;
this->MIsSubBuffer = SYCLObj.IsSubBuffer;
Expand All @@ -126,7 +182,7 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
weak_object &operator=(weak_object &&Other) noexcept = default;

void swap(weak_object &Other) noexcept {
this->MObjWeakPtr.swap(Other.MObjWeakPtr);
weak_object_base::swap(Other);
std::swap(MRange, Other.MRange);
std::swap(MOffsetInBytes, Other.MOffsetInBytes);
std::swap(MIsSubBuffer, Other.MIsSubBuffer);
Expand All @@ -138,20 +194,9 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
if (!MObjImplPtr)
return std::nullopt;
// To reconstruct the buffer we use the reinterpret constructor.
return buffer_type{MObjImplPtr, MRange, MOffsetInBytes, MIsSubBuffer};
return buffer_type{std::move(MObjImplPtr), MRange, MOffsetInBytes,
MIsSubBuffer};
}
buffer_type lock() const {
std::optional<buffer_type> OptionalObj = try_lock();
if (!OptionalObj)
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
"Referenced object has expired.");
return *OptionalObj;
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
std::optional<buffer_type> try_lock() const noexcept;
buffer_type lock() const;
#endif // __SYCL_DEVICE_ONLY__

private:
Expand All @@ -165,8 +210,10 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
// to reconstruct the original stream.
template <>
class weak_object<stream> : public detail::weak_object_base<stream> {
using weak_object_base = detail::weak_object_base<stream>;

public:
using object_type = typename detail::weak_object_base<stream>::object_type;
using object_type = typename weak_object_base::object_type;

constexpr weak_object() noexcept : detail::weak_object_base<stream>() {}
weak_object(const stream &SYCLObj) noexcept
Expand All @@ -178,8 +225,7 @@ class weak_object<stream> : public detail::weak_object_base<stream> {
weak_object(weak_object &&Other) noexcept = default;

weak_object &operator=(const stream &SYCLObj) noexcept {
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
this->MObjWeakPtr = detail::weak_object_base<stream>::GetWeakImpl(SYCLObj);
weak_object_base::operator=(SYCLObj);
MWeakGlobalBuf = SYCLObj.GlobalBuf;
MWeakGlobalOffset = SYCLObj.GlobalOffset;
MWeakGlobalFlushBuf = SYCLObj.GlobalFlushBuf;
Expand All @@ -189,7 +235,7 @@ class weak_object<stream> : public detail::weak_object_base<stream> {
weak_object &operator=(weak_object &&Other) noexcept = default;

void swap(weak_object &Other) noexcept {
this->MObjWeakPtr.swap(Other.MObjWeakPtr);
weak_object_base::swap(Other);
MWeakGlobalBuf.swap(Other.MWeakGlobalBuf);
MWeakGlobalOffset.swap(Other.MWeakGlobalOffset);
MWeakGlobalFlushBuf.swap(Other.MWeakGlobalFlushBuf);
Expand All @@ -210,20 +256,9 @@ class weak_object<stream> : public detail::weak_object_base<stream> {
auto GlobalFlushBuf = MWeakGlobalFlushBuf.try_lock();
if (!ObjImplPtr || !GlobalBuf || !GlobalOffset || !GlobalFlushBuf)
return std::nullopt;
return stream{ObjImplPtr, *GlobalBuf, *GlobalOffset, *GlobalFlushBuf};
return stream{std::move(ObjImplPtr), *std::move(GlobalBuf),
*std::move(GlobalOffset), *std::move(GlobalFlushBuf)};
}
stream lock() const {
std::optional<stream> OptionalObj = try_lock();
if (!OptionalObj)
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
"Referenced object has expired.");
return *OptionalObj;
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
std::optional<stream> try_lock() const noexcept;
stream lock() const;
#endif // __SYCL_DEVICE_ONLY__

private:
Expand Down
85 changes: 0 additions & 85 deletions sycl/include/sycl/ext/oneapi/weak_object_base.hpp

This file was deleted.

Loading
Loading