Skip to content

Commit

Permalink
Move MTL tests and add a few (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd authored Dec 10, 2024
1 parent e6abcc6 commit d14426c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
2 changes: 1 addition & 1 deletion lib/mtl/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ MTLDevice(i::Integer) = devices()[i]

export supports_family, is_m3, is_m2, is_m1

@cenum MTLGPUFamily::NSUInteger begin
@cenum MTLGPUFamily::NSInteger begin
MTLGPUFamilyMetal3 = 5001 # Metal 3 support

MTLGPUFamilyApple9 = 1009 # M3, M4 & A17
Expand Down
2 changes: 1 addition & 1 deletion lib/mtl/heap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# heap enums
#

@cenum MTLHeapType::NSUInteger begin
@cenum MTLHeapType::NSInteger begin
MTLHeapTypeAutomatic = 0
MTLHeapTypePlacement = 1
end
Expand Down
2 changes: 1 addition & 1 deletion lib/mtl/size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ struct MTLRegion
origin::MTLOrigin # The top-left corner of the region
size::MTLSize # The size of the region

MTLRegion(x=0, y=0, z=0) = new(x, y, z)
MTLRegion(origin=MTLOrigin(), size=MTLSize()) = new(origin, size)
end
10 changes: 5 additions & 5 deletions test/metal.jl → test/mtl/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ let lib = MTLLibrary(dev, "", opts)
@test isempty(lib.functionNames)
end

metal_code = read(joinpath(@__DIR__, "dummy.metal"), String)
metal_code = read(joinpath(@__DIR__, "..", "dummy.metal"), String)
let lib = MTLLibrary(dev, metal_code, opts)
@test lib.device == dev
@test lib.label === nothing
Expand All @@ -77,7 +77,7 @@ let lib = MTLLibrary(dev, metal_code, opts)
@test "kernel_2" in fns
end

binary_path = joinpath(@__DIR__, "dummy.metallib")
binary_path = joinpath(@__DIR__, "..", "dummy.metallib")
let lib = MTLLibraryFromFile(dev, binary_path)
@test lib.device == dev
@test lib.label === nothing
Expand Down Expand Up @@ -119,7 +119,7 @@ desc.specializedName = "MySpecializedKernel"


dev = first(devices())
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "..", "dummy.metallib"))
fun = MTLFunction(lib, "kernel_1")

compact_str = sprint(io->show(io, fun))
Expand Down Expand Up @@ -347,7 +347,7 @@ end
@testset "compute pipeline" begin

dev = first(devices())
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "..", "dummy.metallib"))
fun = MTLFunction(lib, "kernel_1")

pipeline = MTLComputePipelineState(dev, fun)
Expand Down Expand Up @@ -390,7 +390,7 @@ end
@testset "binary archive" begin

dev = first(devices())
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "..", "dummy.metallib"))
fun = MTLFunction(lib, "kernel_1")

desc = MTLBinaryArchiveDescriptor()
Expand Down
28 changes: 28 additions & 0 deletions test/mtl/size.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@testset "size.jl" begin
@testset "size" begin
dim1 = rand(UInt64)
dim2 = rand(UInt64)
dim3 = rand(UInt64)

@test MTL.MTLSize(dim1) == MTL.MTLSize((dim1,))
@test MTL.MTLSize(dim1,dim2) == MTL.MTLSize((dim1,dim2))
@test MTL.MTLSize(dim1,dim2,dim3) == MTL.MTLSize((dim1,dim2,dim3))
end

@testset "origin" begin
dim1 = rand(UInt64)
dim2 = rand(UInt64)
dim3 = rand(UInt64)

orig = MTL.MTLOrigin(dim1,dim2,dim3)
@test orig.x == dim1
@test orig.y == dim2
@test orig.z == dim3
end

@testset "region" begin
reg = MTL.MTLRegion()
@test reg.origin isa MTL.MTLOrigin
@test reg.size isa MTL.MTLSize
end
end

0 comments on commit d14426c

Please sign in to comment.