Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MPSNDArrayDescriptor wrapper #502

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import GPUArrays

const MtlFloat = Union{Float32, Float16}

const MPSShape = NSArray#{NSNumber}
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

include("size.jl")
Expand Down
2 changes: 1 addition & 1 deletion lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ end
"""
matmul!(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
A `MPSMatrixMultiplication` kernel thay computes:
A `MPSMatrixMultiplication` kernel that computes:
`c = alpha * op(a) * beta * op(b) + beta * C`

This function should not typically be used. Rather, use the normal `LinearAlgebra` interface
Expand Down
37 changes: 28 additions & 9 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes
end

function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}}
revshape = collect(reverse(shape))
obj = GC.@preserve revshape begin
shapeptr = pointer(revshape)
MPSNDArrayDescriptor(dataType, length(revshape), shapeptr)
obj = GC.@preserve shape begin
shapeptr = pointer(shape)
MPSNDArrayDescriptor(dataType, length(shape), shapeptr)
end
return obj
end
Expand Down Expand Up @@ -75,6 +74,11 @@ else
end
end

function Base.size(ndarr::MPSNDArray)
ndims = Int(ndarr.numberOfDimensions)
Tuple([Int(lengthOfDimension(ndarr,i)) for i in 0:ndims-1])
end

@objcwrapper immutable=false MPSTemporaryNDArray <: MPSNDArray

@objcproperties MPSTemporaryNDArray begin
Expand Down Expand Up @@ -130,20 +134,23 @@ end

function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
arrsize = size(arr)
@assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16"
@assert arrsize[1]*sizeof(T) % 16 == 0 "First dimension of arr must have a byte size divisible by 16"
desc = MPSNDArrayDescriptor(T, arrsize)
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
ndims = Int(ndarr.numberOfDimensions)
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
arrsize = size(ndarr)
T = convert(DataType, ndarr.dataType)
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
arr = MtlArray{T,length(arrsize),storage}(undef, (arrsize)...)
return exportToMtlArray!(arr, ndarr; async)
end

function exportToMtlArray!(arr::MtlArray{T}, ndarr::MPSNDArray; async=false) where T
dev = device(arr)

cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, arr.offset)
end

async || wait_completed(cmdBuf)
Expand All @@ -157,6 +164,12 @@ exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffe
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffer, destinationDataType, offset) =
@objc [ndarr::MPSNDArray exportDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
toBuffer:toBuffer::id{MTLBuffer}
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:nil::id{ObjectiveC.Object}]::Nothing

# rowStrides in Bytes
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset, rowStrides) =
Expand All @@ -165,6 +178,12 @@ importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBu
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset) =
@objc [ndarr::MPSNDArray importDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
fromBuffer:fromBuffer::id{MTLBuffer}
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:nil::id{ObjectiveC.Object}]::Nothing

# TODO
# exportDataWithCommandBuffer(toImages, offset)
Expand Down
9 changes: 5 additions & 4 deletions test/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
T = Float32
DT = convert(MPSDataType, T)

desc1 = MPSNDArrayDescriptor(T, 5,4,3,2,1)
desc1 = MPSNDArrayDescriptor(T,1,2,3,4,5)
@test desc1 isa MPSNDArrayDescriptor
@test desc1.dataType == DT
@test desc1.preferPackedRows == false
Expand All @@ -19,7 +19,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
@test lengthOfDimension(desc1,4) == 4
@test lengthOfDimension(desc1,3) == 5

desc2 = MPSNDArrayDescriptor(T, (4,3,2,1))
desc2 = MPSNDArrayDescriptor(T, (1,2,3,4))
@test desc2 isa MPSNDArrayDescriptor
@test desc2.dataType == DT
@test desc2.numberOfDimensions == 4
Expand Down Expand Up @@ -51,6 +51,7 @@ using .MPS: MPSNDArray
@test ndarr1.label == "Test1"
@test ndarr1.numberOfDimensions == 5
@test ndarr1.parent === nothing
@test size(ndarr1) == (5,4,3,2,1)

ndarr2 = MPSNDArray(dev, 4)
@test ndarr2 isa MPSNDArray
Expand All @@ -63,9 +64,9 @@ using .MPS: MPSNDArray
@test ndarr2.parent === nothing

arr3 = MtlArray(ones(Float16, 2,3,4))
@test_throws "Final dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)
@test_throws "First dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)

arr4 = MtlArray(ones(Float16, 2,3,8))
arr4 = MtlArray(ones(Float16, 8,3,2))

@static if Metal.macos_version() >= v"15"
@test ndarr1.descriptor isa MPSNDArrayDescriptor
Expand Down
Loading