diff --git a/src/gallium/state_trackers/clover/api/memory.cpp b/src/gallium/state_trackers/clover/api/memory.cpp index 1efb95b5ce7..ffd1d9d7e3d 100644 --- a/src/gallium/state_trackers/clover/api/memory.cpp +++ b/src/gallium/state_trackers/clover/api/memory.cpp @@ -28,37 +28,53 @@ using namespace clover; namespace { - const cl_mem_flags dev_access_flags = - CL_MEM_READ_WRITE | CL_MEM_WRITE_ONLY | CL_MEM_READ_ONLY; - const cl_mem_flags host_ptr_flags = - CL_MEM_USE_HOST_PTR | CL_MEM_ALLOC_HOST_PTR | CL_MEM_COPY_HOST_PTR; - const cl_mem_flags host_access_flags = - CL_MEM_HOST_WRITE_ONLY | CL_MEM_HOST_READ_ONLY | CL_MEM_HOST_NO_ACCESS; - const cl_mem_flags all_mem_flags = - dev_access_flags | host_ptr_flags | host_access_flags; + cl_mem_flags + validate_flags(cl_mem d_parent, cl_mem_flags d_flags) { + const cl_mem_flags dev_access_flags = + CL_MEM_READ_WRITE | CL_MEM_WRITE_ONLY | CL_MEM_READ_ONLY; + const cl_mem_flags host_ptr_flags = + CL_MEM_USE_HOST_PTR | CL_MEM_ALLOC_HOST_PTR | CL_MEM_COPY_HOST_PTR; + const cl_mem_flags host_access_flags = + CL_MEM_HOST_WRITE_ONLY | CL_MEM_HOST_READ_ONLY | CL_MEM_HOST_NO_ACCESS; + const cl_mem_flags valid_flags = + dev_access_flags | host_access_flags | (d_parent ? 0 : host_ptr_flags); - void - validate_flags(cl_mem_flags flags, cl_mem_flags valid) { - if ((flags & ~valid) || - util_bitcount(flags & dev_access_flags) > 1 || - util_bitcount(flags & host_access_flags) > 1) + if ((d_flags & ~valid_flags) || + util_bitcount(d_flags & dev_access_flags) > 1 || + util_bitcount(d_flags & host_access_flags) > 1) throw error(CL_INVALID_VALUE); - if ((flags & CL_MEM_USE_HOST_PTR) && - (flags & (CL_MEM_COPY_HOST_PTR | CL_MEM_ALLOC_HOST_PTR))) + if ((d_flags & CL_MEM_USE_HOST_PTR) && + (d_flags & (CL_MEM_COPY_HOST_PTR | CL_MEM_ALLOC_HOST_PTR))) throw error(CL_INVALID_VALUE); + + if (d_parent) { + const auto &parent = obj(d_parent); + const cl_mem_flags flags = (d_flags | + (d_flags & dev_access_flags ? 0 : + parent.flags() & dev_access_flags) | + (d_flags & host_access_flags ? 0 : + parent.flags() & host_access_flags) | + (parent.flags() & host_ptr_flags)); + + if (~flags & parent.flags() & + ((dev_access_flags & ~CL_MEM_READ_WRITE) | host_access_flags)) + throw error(CL_INVALID_VALUE); + + return flags; + + } else { + return d_flags | (d_flags & dev_access_flags ? 0 : CL_MEM_READ_WRITE); + } } } CLOVER_API cl_mem clCreateBuffer(cl_context d_ctx, cl_mem_flags d_flags, size_t size, void *host_ptr, cl_int *r_errcode) try { - const cl_mem_flags flags = d_flags | - (d_flags & dev_access_flags ? 0 : CL_MEM_READ_WRITE); + const cl_mem_flags flags = validate_flags(NULL, d_flags); auto &ctx = obj(d_ctx); - validate_flags(d_flags, all_mem_flags); - if (bool(host_ptr) != bool(flags & (CL_MEM_USE_HOST_PTR | CL_MEM_COPY_HOST_PTR))) throw error(CL_INVALID_HOST_PTR); @@ -82,16 +98,7 @@ clCreateSubBuffer(cl_mem d_mem, cl_mem_flags d_flags, cl_buffer_create_type op, const void *op_info, cl_int *r_errcode) try { auto &parent = obj(d_mem); - const cl_mem_flags flags = d_flags | - (d_flags & dev_access_flags ? 0 : parent.flags() & dev_access_flags) | - (d_flags & host_access_flags ? 0 : parent.flags() & host_access_flags) | - (parent.flags() & host_ptr_flags); - - validate_flags(d_flags, dev_access_flags | host_access_flags); - - if (~flags & parent.flags() & - ((dev_access_flags & ~CL_MEM_READ_WRITE) | host_access_flags)) - throw error(CL_INVALID_VALUE); + const cl_mem_flags flags = validate_flags(d_mem, d_flags); if (op == CL_BUFFER_CREATE_TYPE_REGION) { auto reg = reinterpret_cast(op_info); @@ -121,12 +128,9 @@ clCreateImage2D(cl_context d_ctx, cl_mem_flags d_flags, const cl_image_format *format, size_t width, size_t height, size_t row_pitch, void *host_ptr, cl_int *r_errcode) try { - const cl_mem_flags flags = d_flags | - (d_flags & dev_access_flags ? 0 : CL_MEM_READ_WRITE); + const cl_mem_flags flags = validate_flags(NULL, d_flags); auto &ctx = obj(d_ctx); - validate_flags(d_flags, all_mem_flags); - if (!any_of(std::mem_fn(&device::image_support), ctx.devices())) throw error(CL_INVALID_OPERATION); @@ -158,12 +162,9 @@ clCreateImage3D(cl_context d_ctx, cl_mem_flags d_flags, size_t width, size_t height, size_t depth, size_t row_pitch, size_t slice_pitch, void *host_ptr, cl_int *r_errcode) try { - const cl_mem_flags flags = d_flags | - (d_flags & dev_access_flags ? 0 : CL_MEM_READ_WRITE); + const cl_mem_flags flags = validate_flags(NULL, d_flags); auto &ctx = obj(d_ctx); - validate_flags(d_flags, all_mem_flags); - if (!any_of(std::mem_fn(&device::image_support), ctx.devices())) throw error(CL_INVALID_OPERATION); @@ -196,7 +197,7 @@ clGetSupportedImageFormats(cl_context d_ctx, cl_mem_flags flags, auto &ctx = obj(d_ctx); auto formats = supported_formats(ctx, type); - validate_flags(flags, all_mem_flags); + validate_flags(NULL, flags); if (r_buf && !r_count) throw error(CL_INVALID_VALUE);