Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 102 additions & 62 deletions dpnp/backend/include/dpnp4pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,37 @@ class dpnp_capi
public:
PyTypeObject *PyUSMArrayType_;

char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
const py::ssize_t *,
int,
Py_MemoryObject *,
py::ssize_t,
char);
PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
int,
DPCTLSyclUSMRef,
DPCTLSyclQueueRef,
PyObject *);
PyObject *(*UsmNDArray_MakeFromPtr_)(int,
const py::ssize_t *,
int,
const py::ssize_t *,
DPCTLSyclUSMRef,
DPCTLSyclQueueRef,
py::ssize_t,
PyObject *);

int USM_ARRAY_C_CONTIGUOUS_;
int USM_ARRAY_F_CONTIGUOUS_;
int USM_ARRAY_WRITABLE_;
Expand Down Expand Up @@ -119,7 +150,15 @@ class dpnp_capi
std::shared_ptr<py::object> default_usm_ndarray_;

dpnp_capi()
: PyUSMArrayType_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
: PyUSMArrayType_(nullptr), UsmNDArray_GetData_(nullptr),
UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
UsmNDArray_MakeSimpleFromMemory_(nullptr),
UsmNDArray_MakeSimpleFromPtr_(nullptr),
UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
Expand All @@ -135,6 +174,23 @@ class dpnp_capi

this->PyUSMArrayType_ = &PyUSMArrayType;

// dpnp.tensor.usm_ndarray API
this->UsmNDArray_GetData_ = UsmNDArray_GetData;
this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
this->UsmNDArray_MakeSimpleFromMemory_ =
UsmNDArray_MakeSimpleFromMemory;
this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;

// constants
this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
Expand Down Expand Up @@ -269,7 +325,9 @@ class usm_ndarray : public py::object
char *get_data() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
return raw_ar->data_;

auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetData_(raw_ar);
}

template <typename T>
Expand All @@ -281,13 +339,17 @@ class usm_ndarray : public py::object
int get_ndim() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
return raw_ar->nd_;

auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetNDim_(raw_ar);
}

const py::ssize_t *get_shape_raw() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
return raw_ar->shape_;

auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetShape_(raw_ar);
}

std::vector<py::ssize_t> get_shape_vector() const
Expand All @@ -308,7 +370,9 @@ class usm_ndarray : public py::object
const py::ssize_t *get_strides_raw() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
return raw_ar->strides_;

auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetStrides_(raw_ar);
}

std::vector<py::ssize_t> get_strides_vector() const
Expand Down Expand Up @@ -343,8 +407,9 @@ class usm_ndarray : public py::object
{
PyUSMArrayObject *raw_ar = usm_array_ptr();

int ndim = raw_ar->nd_;
const py::ssize_t *shape = raw_ar->shape_;
auto const &api = detail::dpnp_capi::get();
int ndim = api.UsmNDArray_GetNDim_(raw_ar);
const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);

py::ssize_t nelems = 1;
for (int i = 0; i < ndim; ++i) {
Expand All @@ -359,9 +424,10 @@ class usm_ndarray : public py::object
{
PyUSMArrayObject *raw_ar = usm_array_ptr();

int nd = raw_ar->nd_;
const py::ssize_t *shape = raw_ar->shape_;
const py::ssize_t *strides = raw_ar->strides_;
auto const &api = detail::dpnp_capi::get();
int nd = api.UsmNDArray_GetNDim_(raw_ar);
const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);

py::ssize_t offset_min = 0;
py::ssize_t offset_max = 0;
Expand Down Expand Up @@ -389,77 +455,43 @@ class usm_ndarray : public py::object
sycl::queue get_queue() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
Py_MemoryObject *mem_obj =
reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);

auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
DPCTLSyclQueueRef QRef = dpctl_api.Memory_GetQueueRef_(mem_obj);
auto const &api = detail::dpnp_capi::get();
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
return *(reinterpret_cast<sycl::queue *>(QRef));
}

sycl::device get_device() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
Py_MemoryObject *mem_obj =
reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);

auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
DPCTLSyclQueueRef QRef = dpctl_api.Memory_GetQueueRef_(mem_obj);
auto const &api = detail::dpnp_capi::get();
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
return reinterpret_cast<sycl::queue *>(QRef)->get_device();
}

int get_typenum() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
return raw_ar->typenum_;

auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetTypenum_(raw_ar);
}

int get_flags() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
return raw_ar->flags_;

auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetFlags_(raw_ar);
}

int get_elemsize() const
{
int typenum = get_typenum();
auto const &api = detail::dpnp_capi::get();
PyUSMArrayObject *raw_ar = usm_array_ptr();

// Lookup table for element sizes based on typenum
if (typenum == api.UAR_BOOL_)
return 1;
if (typenum == api.UAR_BYTE_)
return 1;
if (typenum == api.UAR_UBYTE_)
return 1;
if (typenum == api.UAR_SHORT_)
return 2;
if (typenum == api.UAR_USHORT_)
return 2;
if (typenum == api.UAR_INT_)
return 4;
if (typenum == api.UAR_UINT_)
return 4;
if (typenum == api.UAR_LONG_)
return sizeof(long);
if (typenum == api.UAR_ULONG_)
return sizeof(unsigned long);
if (typenum == api.UAR_LONGLONG_)
return 8;
if (typenum == api.UAR_ULONGLONG_)
return 8;
if (typenum == api.UAR_FLOAT_)
return 4;
if (typenum == api.UAR_DOUBLE_)
return 8;
if (typenum == api.UAR_CFLOAT_)
return 8;
if (typenum == api.UAR_CDOUBLE_)
return 16;
if (typenum == api.UAR_HALF_)
return 2;

return 0; // Unknown type
auto const &api = detail::dpnp_capi::get();
return api.UsmNDArray_GetElementSize_(raw_ar);
}

bool is_c_contiguous() const
Expand Down Expand Up @@ -487,9 +519,10 @@ class usm_ndarray : public py::object
py::object get_usm_data() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = detail::dpnp_capi::get();
// base_ is the Memory object - return new reference
PyObject *usm_data = raw_ar->base_;
Py_XINCREF(usm_data);
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);

// pass reference ownership to py::object
return py::reinterpret_steal<py::object>(usm_data);
Expand All @@ -498,28 +531,34 @@ class usm_ndarray : public py::object
bool is_managed_by_smart_ptr() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
PyObject *usm_data = raw_ar->base_;

auto const &api = detail::dpnp_capi::get();
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);

auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
Py_DECREF(usm_data);
return false;
}

Py_MemoryObject *mem_obj =
reinterpret_cast<Py_MemoryObject *>(usm_data);
const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);

Py_DECREF(usm_data);
return bool(opaque_ptr);
}

const std::shared_ptr<void> &get_smart_ptr_owner() const
{
PyUSMArrayObject *raw_ar = usm_array_ptr();
PyObject *usm_data = raw_ar->base_;

auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
auto const &api = detail::dpnp_capi::get();
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);

auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
Py_DECREF(usm_data);
throw std::runtime_error(
"usm_ndarray object does not have Memory object "
"managing lifetime of USM allocation");
Expand All @@ -528,6 +567,7 @@ class usm_ndarray : public py::object
Py_MemoryObject *mem_obj =
reinterpret_cast<Py_MemoryObject *>(usm_data);
void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
Py_DECREF(usm_data);

if (opaque_ptr) {
auto shptr_ptr =
Expand Down
Loading
Loading