diff --git a/CMakeLists.txt b/CMakeLists.txt index f06220661..3a23dc9c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,7 @@ include(HandleLLVMOptions) set(IMEX_INCLUDE_TESTS 1 CACHE BOOL "Include targets for IMEX tests") set(IMEX_ENABLE_SYCL_RUNTIME 0 CACHE BOOL "Enable the Sycl Runtime") +set(IMEX_ENABLE_OPENCL_RUNTIME 0 CACHE BOOL "Enable the OpenCL Runtime") set(IMEX_ENABLE_L0_RUNTIME 0 CACHE BOOL "Enable the Level Zero Runtime") set(IMEX_ENABLE_BENCHMARK 0 CACHE BOOL "Enable the IMEX Benchmark (Depending on SYCL Runtime)") # Useful when building IMEX as an LLVM external project. @@ -192,6 +193,12 @@ else () set(IMEX_ENABLE_SYCL_RUNTIME 0) endif() +if (IMEX_ENABLE_OPENCL_RUNTIME) + set(IMEX_ENABLE_OPENCL_RUNTIME 1) +else () + set(IMEX_ENABLE_OPENCL_RUNTIME 0) +endif() + if (IMEX_ENABLE_L0_RUNTIME) set(IMEX_ENABLE_L0_RUNTIME 1) else () diff --git a/README.md b/README.md index 4d4855968..fedd4221d 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Intel® Extension for MLIR (IMEX) is a collection of MLIR dialects and passes from Intel for supporting MLIR lowering to Intel silicon (CPU, GPU, …). Goal of this project is to support development of MLIR enhancements for upstream contribution, and to provide a sandbox for validation independent of front end frameworks. Current project scope includes: * Dialects and passes needed to lower and execute MLIR entry dialect (linalg, CFG, and etc) on Intel GPU. -* Wrapper libraries to inteface with level zero runtime and sycl runtime supporting Intel GPU. +* Wrapper libraries to inteface with level zero runtime, sycl runtime and OpenCL runtime supporting Intel GPU. * Other experimental dialects: NDArray, Dist ## Requirements for building and development @@ -44,6 +44,13 @@ cmake --build build --target install ``` * Binary package for system-wide install: https://github.com/oneapi-src/level-zero/releases +#### Getting OpenCL headers and libraries (For OpenCL runtime) + * Install using OS-provided package (Ubuntu 22.04) +```sh +sudo apt install -y intel-opencl-icd opencl-c-headers +``` + * Or, download and install package from: https://github.com/intel/compute-runtime/releases + ### Example: Setting up requirements using Conda ```sh conda create -n imex-dev -c intel -c defaults -c conda-forge pip">=21.2.4" pre-commit cmake clang-format lit doxygen @@ -79,6 +86,9 @@ cmake -G Ninja -B build -S llvm \ # For GPU support pass thes cmake variables to enable the required runtime libraries # -DIMEX_ENABLE_L0_RUNTIME=1 # -DIMEX_ENABLE_SYCL_RUNTIME=1 +# -DIMEX_ENABLE_OPENCL_RUNTIME=1 +# Additional if OpenCL library is not found by CMake +# -DOpenCL_LIBRARY=/PATH_TO/libOpenCL.so.1 ## usually at /usr/lib/x86_64-linux-gnu/libOpenCL.so.1 # Additional if using a non system wide Level Zero Loader built from source # -DLEVEL_ZERO_DIR=/PATH_TO/level-zero-install @@ -107,6 +117,9 @@ cmake -G Ninja -B build -S . \ # For GPU support pass thes cmake variables to enable the required runtime libraries # -DIMEX_ENABLE_L0_RUNTIME=1 # -DIMEX_ENABLE_SYCL_RUNTIME=1 +# -DIMEX_ENABLE_OPENCL_RUNTIME=1 +# Additional if OpenCL library is not found by CMake +# -DOpenCL_LIBRARY=/PATH_TO/libOpenCL.so.1 ## usually at /usr/lib/x86_64-linux-gnu/libOpenCL.so.1 # Additional if using a non system wide Level Zero Loader built from source # -DLEVEL_ZERO_DIR=/PATH_TO/level-zero-install @@ -127,6 +140,9 @@ cmake -G Ninja -B build -S . \ # For GPU support pass thes cmake variables to enable the required runtime libraries # -DIMEX_ENABLE_L0_RUNTIME=1 # -DIMEX_ENABLE_SYCL_RUNTIME=1 +# -DIMEX_ENABLE_OPENCL_RUNTIME=1 +# Additional if OpenCL library is not found by CMake +# -DOpenCL_LIBRARY=/PATH_TO/libOpenCL.so.1 ## usually at /usr/lib/x86_64-linux-gnu/libOpenCL.so.1 # Additional if using a non system wide Level Zero Loader built from source # -DLEVEL_ZERO_DIR=/PATH_TO/level-zero-install diff --git a/lib/ExecutionEngine/CMakeLists.txt b/lib/ExecutionEngine/CMakeLists.txt index c6b51ea40..06d90c4c6 100644 --- a/lib/ExecutionEngine/CMakeLists.txt +++ b/lib/ExecutionEngine/CMakeLists.txt @@ -6,6 +6,10 @@ if(IMEX_ENABLE_SYCL_RUNTIME) add_subdirectory(SYCLRUNTIME) endif() +if(IMEX_ENABLE_OPENCL_RUNTIME) + add_subdirectory(OPENCLRUNTIME) +endif() + add_mlir_library(imex_runner_utils SHARED ImexRunnerUtils.cpp diff --git a/lib/ExecutionEngine/OPENCLRUNTIME/CMakeLists.txt b/lib/ExecutionEngine/OPENCLRUNTIME/CMakeLists.txt new file mode 100644 index 000000000..be672e48c --- /dev/null +++ b/lib/ExecutionEngine/OPENCLRUNTIME/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +find_package(OpenCL REQUIRED) + +if(NOT OpenCL_FOUND) + message(FATAL_ERROR "OpenCL not found.") +endif() + +add_mlir_library(opencl-runtime + SHARED + OpenCLRuntimeWrappers.cpp + + EXCLUDE_FROM_LIBMLIR + ) + +check_cxx_compiler_flag("-frtti" CXX_HAS_FRTTI_FLAG) +if(NOT CXX_HAS_FRTTI_FLAG) + message(FATAL_ERROR "CXX compiler does not accept flag -frtti") +endif() +target_compile_options (opencl-runtime PUBLIC -fexceptions -frtti) + +target_include_directories(opencl-runtime PRIVATE + ${MLIR_INCLUDE_DIRS} + ${OpenCL_INCLUDE_DIRS} + ) + +message(STATUS "OpenCL Libraries: ${OpenCL_LIBRARIES}") +target_link_libraries(opencl-runtime PRIVATE ${OpenCL_LIBRARIES}) diff --git a/lib/ExecutionEngine/OPENCLRUNTIME/OpenCLRuntimeWrappers.cpp b/lib/ExecutionEngine/OPENCLRUNTIME/OpenCLRuntimeWrappers.cpp new file mode 100644 index 000000000..5347bbebb --- /dev/null +++ b/lib/ExecutionEngine/OPENCLRUNTIME/OpenCLRuntimeWrappers.cpp @@ -0,0 +1,428 @@ +// Copyright 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define OCL_RUNTIME_EXPORT __declspec(dllexport) +#else +#define OCL_RUNTIME_EXPORT +#endif // _WIN32 + +namespace { + +#define CL_SAFE_CALL2(a) \ + do { \ + (a); \ + if (err != CL_SUCCESS) { \ + fprintf(stderr, "FAIL: err=%d @ line=%d (%s)\n", err, __LINE__, (#a)); \ + abort(); \ + } \ + } while (0) + +#define CL_SAFE_CALL(call) \ + { \ + auto status = (call); \ + if (status != CL_SUCCESS) { \ + fprintf(stderr, "CL error %d @ line=%d (%s)\n", status, __LINE__, \ + (#call)); \ + abort(); \ + } \ + } + +constexpr char DeviceMemAllocName[] = "clDeviceMemAllocINTEL"; +constexpr char SharedMemAllocName[] = "clSharedMemAllocINTEL"; +constexpr char MemBlockingFreeName[] = "clMemBlockingFreeINTEL"; +constexpr char SetKernelArgMemPointerName[] = "clSetKernelArgMemPointerINTEL"; +static constexpr char EnqueueMemcpyName[] = "clEnqueueMemcpyINTEL"; + +void *queryCLExtFunc(cl_platform_id CurPlatform, const char *FuncName) { + void *ret = clGetExtensionFunctionAddressForPlatform(CurPlatform, FuncName); + + if (!ret) { + fflush(stderr); + abort(); + } + return ret; +} + +void *queryCLExtFunc(cl_device_id dev, const char *FuncName) { + cl_platform_id CurPlatform; + CL_SAFE_CALL(clGetDeviceInfo(dev, CL_DEVICE_PLATFORM, sizeof(cl_platform_id), + &CurPlatform, nullptr)); + return queryCLExtFunc(CurPlatform, FuncName); +} + +struct CLExtTable { + clDeviceMemAllocINTEL_fn allocDev; + clSharedMemAllocINTEL_fn allocShared; + clMemBlockingFreeINTEL_fn blockingFree; + clSetKernelArgMemPointerINTEL_fn setKernelArgMemPtr; + clEnqueueMemcpyINTEL_fn enqueneMemcpy; + CLExtTable() = default; + CLExtTable(cl_device_id dev) { + cl_platform_id plat; + CL_SAFE_CALL(clGetDeviceInfo(dev, CL_DEVICE_PLATFORM, + sizeof(cl_platform_id), &plat, nullptr)); + allocDev = + (clDeviceMemAllocINTEL_fn)queryCLExtFunc(plat, DeviceMemAllocName); + allocShared = + (clSharedMemAllocINTEL_fn)queryCLExtFunc(plat, SharedMemAllocName); + blockingFree = + (clMemBlockingFreeINTEL_fn)queryCLExtFunc(plat, MemBlockingFreeName); + setKernelArgMemPtr = (clSetKernelArgMemPointerINTEL_fn)queryCLExtFunc( + plat, SetKernelArgMemPointerName); + enqueneMemcpy = + (clEnqueueMemcpyINTEL_fn)queryCLExtFunc(plat, EnqueueMemcpyName); + } +}; + +// an "almost" lock-free cache for cl_device_id mapping to CL extention function +// table reading from the table is lock-free. And writing to it (when +// cache-miss) requires locking +struct CLExtTableCache { + static constexpr int numExtCache = 16; + std::array, numExtCache> devices; + std::array tables; + std::mutex lock; + CLExtTableCache() { std::fill(devices.begin(), devices.end(), nullptr); } + static CLExtTableCache &get() { + static CLExtTableCache v; + return v; + } + CLExtTable *query(cl_device_id dev) { + bool found = false; + int firstSearch = search(dev, 0, found); + if (found) { + return &tables[firstSearch]; + } + if (firstSearch == numExtCache) { + return nullptr; + } + { + std::lock_guard guard{lock}; + int secondSearch = search(dev, firstSearch, found); + if (found) { + return &tables[secondSearch]; + } + if (secondSearch == numExtCache) { + return nullptr; + } + tables[secondSearch] = CLExtTable(dev); + devices[secondSearch].store(dev, std::memory_order_release); + return &tables[secondSearch]; + } + } + +private: + int search(cl_device_id dev, int startIdx, bool &found) { + for (int i = startIdx; i < numExtCache; i++) { + auto val = devices[i].load(std::memory_order_acquire); + if (!val) { + found = false; + return i; + } + if (val == dev) { + found = true; + return i; + } + } + found = false; + return numExtCache; + } +}; + +} // namespace + +struct ParamDesc { + void *data; + size_t size; + + bool operator==(const ParamDesc &rhs) const { + return data == rhs.data && size == rhs.size; + } + + bool operator!=(const ParamDesc &rhs) const { return !(*this == rhs); } +}; + +template size_t countUntil(T *ptr, T &&elem) { + assert(ptr); + auto curr = ptr; + while (*curr != elem) { + ++curr; + } + return static_cast(curr - ptr); +} + +static cl_device_id getDevice(cl_device_type *devtype) { + cl_platform_id platform; // OpenCL platform + cl_device_id device; // device ID + CL_SAFE_CALL(clGetPlatformIDs(1, &platform, NULL)); + CL_SAFE_CALL(clGetDeviceIDs(platform, *devtype, 1, &device, NULL)); + return device; +} + +struct GPUCLQUEUE { + + cl_device_id device_ = nullptr; + cl_context context_ = nullptr; + cl_command_queue queue_ = nullptr; + bool context_owned_ = false; + bool queue_owned_ = false; + CLExtTable *ext_table_ = nullptr; + std::vector programs_; + std::vector kernels_; + + GPUCLQUEUE(cl_device_type *device, cl_context context, + cl_command_queue queue) { + cl_device_type defaultdev = CL_DEVICE_TYPE_GPU; + if (!device) { + device = &defaultdev; + } + device_ = getDevice(device); + init_context(context, queue, device_); + ext_table_ = CLExtTableCache::get().query(device_); + } + GPUCLQUEUE(cl_device_id device, cl_context context, cl_command_queue queue) { + if (!device) { + cl_device_type defaultdev = CL_DEVICE_TYPE_GPU; + device = getDevice(&defaultdev); + } + device_ = device; + init_context(context, queue, device_); + ext_table_ = CLExtTableCache::get().query(device_); + } + ~GPUCLQUEUE() { + for (auto p : kernels_) { + clReleaseKernel(p); + } + for (auto p : programs_) { + clReleaseProgram(p); + } + if (queue_ && queue_owned_) + clReleaseCommandQueue(queue_); + if (context_ && context_owned_) + clReleaseContext(context_); + } + +private: + void init_context(cl_context context, cl_command_queue queue, + cl_device_id device) { + if (queue) { + if (!context) { + throw std::runtime_error( + "Cannot create QUEUE wrapper with queue and without context"); + } + queue_ = queue; + queue_owned_ = true; + context_ = context; + context_owned_ = true; + return; + } + cl_int err; + if (!context) { + CL_SAFE_CALL2(context_ = + clCreateContext(NULL, 1, &device, NULL, NULL, &err)); + context_owned_ = true; + } else { + context_ = context; + } + CL_SAFE_CALL2( + queue_ = clCreateCommandQueueWithProperties(context_, device, 0, &err)); + queue_owned_ = true; + } +}; // end of GPUCLQUEUE + +static void *allocDeviceMemory(GPUCLQUEUE *queue, size_t size, size_t alignment, + bool isShared) { + void *memPtr = nullptr; + cl_int err; + if (isShared) { + auto func = queue->ext_table_ ? queue->ext_table_->allocShared + : (clSharedMemAllocINTEL_fn)queryCLExtFunc( + queue->device_, SharedMemAllocName); + CL_SAFE_CALL2(memPtr = func(queue->context_, queue->device_, nullptr, size, + alignment, &err)); + } else { + auto func = queue->ext_table_ ? queue->ext_table_->allocDev + : (clDeviceMemAllocINTEL_fn)queryCLExtFunc( + queue->device_, DeviceMemAllocName); + CL_SAFE_CALL2(memPtr = func(queue->context_, queue->device_, nullptr, size, + alignment, &err)); + } + return memPtr; +} + +static void deallocDeviceMemory(GPUCLQUEUE *queue, void *ptr) { + auto func = queue->ext_table_ ? queue->ext_table_->blockingFree + : (clMemBlockingFreeINTEL_fn)queryCLExtFunc( + queue->device_, MemBlockingFreeName); + CL_SAFE_CALL(func(queue->context_, ptr)); +} + +static cl_program loadModule(GPUCLQUEUE *queue, const unsigned char *data, + size_t dataSize) { + assert(data); + cl_int errNum = 0; + const unsigned char *codes[1] = {data}; + size_t sizes[1] = {dataSize}; + cl_program program; + cl_int err; + CL_SAFE_CALL2(program = clCreateProgramWithBinary(queue->context_, 1, + &queue->device_, sizes, + codes, &err, &errNum)); + const char *build_flags = "-cl-kernel-arg-info -x spir"; + // enable large register file if needed + if (getenv("IMEX_ENABLE_LARGE_REG_FILE")) { + build_flags = + "-vc-codegen -doubleGRF -Xfinalizer -noLocalSplit -Xfinalizer " + "-DPASTokenReduction -Xfinalizer -SWSBDepReduction -Xfinalizer " + "'-printregusage -enableBCR' -cl-kernel-arg-info -x spir"; + } + CL_SAFE_CALL(clBuildProgram(program, 0, NULL, build_flags, NULL, NULL)); + queue->programs_.push_back(program); + return program; +} + +static cl_kernel getKernel(GPUCLQUEUE *queue, cl_program program, + const char *name) { + cl_kernel kernel; + cl_int err; + CL_SAFE_CALL2(kernel = clCreateKernel(program, name, &err)); + cl_bool TrueVal = CL_TRUE; + CL_SAFE_CALL(clSetKernelExecInfo( + kernel, CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL, sizeof(cl_bool), + &TrueVal)); + CL_SAFE_CALL(clSetKernelExecInfo( + kernel, CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL, sizeof(cl_bool), + &TrueVal)); + CL_SAFE_CALL(clSetKernelExecInfo( + kernel, CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL, sizeof(cl_bool), + &TrueVal)); + queue->kernels_.push_back(kernel); + return kernel; +} + +static void launchKernel(GPUCLQUEUE *queue, cl_kernel kernel, size_t gridX, + size_t gridY, size_t gridZ, size_t blockX, + size_t blockY, size_t blockZ, size_t sharedMemBytes, + ParamDesc *params) { + auto func = queue->ext_table_ + ? queue->ext_table_->setKernelArgMemPtr + : (clSetKernelArgMemPointerINTEL_fn)queryCLExtFunc( + queue->device_, SetKernelArgMemPointerName); + auto paramsCount = countUntil(params, ParamDesc{nullptr, 0}); + // The assumption is, if there is a param for the shared local memory, + // then that will always be the last argument. + if (sharedMemBytes) { + paramsCount = paramsCount - 1; + } + for (size_t i = 0; i < paramsCount; i++) { + cl_kernel_arg_address_qualifier name; + size_t nameSize = sizeof(name); + CL_SAFE_CALL(clGetKernelArgInfo(kernel, i, CL_KERNEL_ARG_ADDRESS_QUALIFIER, + sizeof(name), &name, &nameSize)); + auto param = params[i]; + if (param.size == sizeof(void *) && name == CL_KERNEL_ARG_ADDRESS_GLOBAL) { + CL_SAFE_CALL(func(kernel, i, *(void **)param.data)); + } else { + CL_SAFE_CALL(clSetKernelArg(kernel, i, param.size, param.data)); + } + } + if (sharedMemBytes) { + CL_SAFE_CALL(clSetKernelArg(kernel, paramsCount, sharedMemBytes, nullptr)); + } + size_t globalSize[3] = {gridX * blockX, gridY * blockY, gridZ * blockZ}; + size_t localSize[3] = {blockX, blockY, blockZ}; + CL_SAFE_CALL(clEnqueueNDRangeKernel(queue->queue_, kernel, 3, NULL, + globalSize, localSize, 0, NULL, NULL)); +} + +static GPUCLQUEUE *getDefaultQueue() { + static GPUCLQUEUE defaultq(static_cast(nullptr), nullptr, + nullptr); + return &defaultq; +} + +// Wrappers + +extern "C" OCL_RUNTIME_EXPORT GPUCLQUEUE *gpuCreateStream(void *device, + void *context) { + // todo: this is a workaround of issue of gpux generating multiple streams + if (!device && !context) { + return getDefaultQueue(); + } + return new GPUCLQUEUE(reinterpret_cast(device), + reinterpret_cast(context), nullptr); +} + +extern "C" OCL_RUNTIME_EXPORT void gpuStreamDestroy(GPUCLQUEUE *queue) { + // todo: this is a workaround of issue of gpux generating multiple streams + // should uncomment the below line to release the queue + // delete queue; +} + +extern "C" OCL_RUNTIME_EXPORT void * +gpuMemAlloc(GPUCLQUEUE *queue, size_t size, size_t alignment, bool isShared) { + if (queue) { + return allocDeviceMemory(queue, size, alignment, isShared); + } + return nullptr; +} + +extern "C" OCL_RUNTIME_EXPORT void gpuMemFree(GPUCLQUEUE *queue, void *ptr) { + if (queue && ptr) { + deallocDeviceMemory(queue, ptr); + } +} + +extern "C" OCL_RUNTIME_EXPORT cl_program +gpuModuleLoad(GPUCLQUEUE *queue, const unsigned char *data, size_t dataSize) { + if (queue) { + return loadModule(queue, data, dataSize); + } + return nullptr; +} + +extern "C" OCL_RUNTIME_EXPORT cl_kernel gpuKernelGet(GPUCLQUEUE *queue, + cl_program module, + const char *name) { + if (queue) { + return getKernel(queue, module, name); + } + return nullptr; +} + +extern "C" OCL_RUNTIME_EXPORT void +gpuLaunchKernel(GPUCLQUEUE *queue, cl_kernel kernel, size_t gridX, size_t gridY, + size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, + size_t sharedMemBytes, void *params) { + if (queue) { + launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ, + sharedMemBytes, static_cast(params)); + } +} + +extern "C" OCL_RUNTIME_EXPORT void gpuWait(GPUCLQUEUE *queue) { + if (queue) { + CL_SAFE_CALL(clFinish(queue->queue_)); + } +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6f2c3f48d..bf9572773 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -52,13 +52,19 @@ if(IMEX_ENABLE_SYCL_RUNTIME) ) endif() +if(IMEX_ENABLE_OPENCL_RUNTIME) + list(APPEND IMEX_TEST_DEPENDS + opencl-runtime + ) +endif() + if(IMEX_ENABLE_L0_RUNTIME) list(APPEND IMEX_TEST_DEPENDS level-zero-runtime ) endif() -if(IMEX_ENABLE_L0_RUNTIME OR IMEX_ENABLE_SYCL_RUNTIME) +if(IMEX_ENABLE_L0_RUNTIME OR IMEX_ENABLE_SYCL_RUNTIME OR IMEX_ENABLE_OPENCL_RUNTIME) list(APPEND IMEX_TEST_DEPENDS l0-fp64-checker ) diff --git a/test/Conversion/GPUToSPIRV/printf_with_runner.mlir b/test/Conversion/GPUToSPIRV/printf_with_runner.mlir index 4d40f1def..afe229a64 100644 --- a/test/Conversion/GPUToSPIRV/printf_with_runner.mlir +++ b/test/Conversion/GPUToSPIRV/printf_with_runner.mlir @@ -2,6 +2,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module attributes { gpu.container_module }{ diff --git a/test/Gen/PlaidML/CppEdsl.Convolution.mlir.in b/test/Gen/PlaidML/CppEdsl.Convolution.mlir.in index bdbbdaa2c..36959199c 100644 --- a/test/Gen/PlaidML/CppEdsl.Convolution.mlir.in +++ b/test/Gen/PlaidML/CppEdsl.Convolution.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> diff --git a/test/Gen/PlaidML/CppEdsl.DotF16_AccF32.mlir.in b/test/Gen/PlaidML/CppEdsl.DotF16_AccF32.mlir.in index a88c02842..5c6e7cc54 100644 --- a/test/Gen/PlaidML/CppEdsl.DotF16_AccF32.mlir.in +++ b/test/Gen/PlaidML/CppEdsl.DotF16_AccF32.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> diff --git a/test/Gen/PlaidML/OpTest.Argmax.mlir.in b/test/Gen/PlaidML/OpTest.Argmax.mlir.in index 2acb30705..2e6581e1a 100644 --- a/test/Gen/PlaidML/OpTest.Argmax.mlir.in +++ b/test/Gen/PlaidML/OpTest.Argmax.mlir.in @@ -13,6 +13,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/Gen/PlaidML/OpTest.BroadcastNonNumpy.mlir.in b/test/Gen/PlaidML/OpTest.BroadcastNonNumpy.mlir.in index cf599405f..609fe0cc7 100644 --- a/test/Gen/PlaidML/OpTest.BroadcastNonNumpy.mlir.in +++ b/test/Gen/PlaidML/OpTest.BroadcastNonNumpy.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> module @broadcast_non_numpy { diff --git a/test/Gen/PlaidML/OpTest.ComplexConv2D.mlir.in b/test/Gen/PlaidML/OpTest.ComplexConv2D.mlir.in index c803be274..adf4f6774 100644 --- a/test/Gen/PlaidML/OpTest.ComplexConv2D.mlir.in +++ b/test/Gen/PlaidML/OpTest.ComplexConv2D.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 * 2 + d5 * 3, d2 * 2 + d6 * 3, d3, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d3, d7, d4)> diff --git a/test/Gen/PlaidML/OpTest.EltwiseAdd.mlir.in b/test/Gen/PlaidML/OpTest.EltwiseAdd.mlir.in index 4cc12391c..d9a276e69 100644 --- a/test/Gen/PlaidML/OpTest.EltwiseAdd.mlir.in +++ b/test/Gen/PlaidML/OpTest.EltwiseAdd.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @eltwise_add { func.func @test(%arg0: tensor<10x20x@DTYPE@>, %arg1: tensor<10x20x@DTYPE@>) -> tensor<10x20x@DTYPE@> { diff --git a/test/Gen/PlaidML/OpTest.EltwiseAddInt.mlir.in b/test/Gen/PlaidML/OpTest.EltwiseAddInt.mlir.in index c5185d2c8..3bad5cd13 100644 --- a/test/Gen/PlaidML/OpTest.EltwiseAddInt.mlir.in +++ b/test/Gen/PlaidML/OpTest.EltwiseAddInt.mlir.in @@ -13,6 +13,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @eltwise_add { func.func @test(%arg0: tensor<10x20x@DTYPE@>, %arg1: tensor<10x20x@DTYPE@>) -> tensor<10x20x@DTYPE@> { diff --git a/test/Gen/PlaidML/OpTest.ExplicitPadding.mlir.in b/test/Gen/PlaidML/OpTest.ExplicitPadding.mlir.in index 36f503b8e..65e9863b9 100644 --- a/test/Gen/PlaidML/OpTest.ExplicitPadding.mlir.in +++ b/test/Gen/PlaidML/OpTest.ExplicitPadding.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0 + 2, d1 + 1)> module @explicit_padding { diff --git a/test/Gen/PlaidML/OpTest.LogicalAnd_mixed.mlir.in b/test/Gen/PlaidML/OpTest.LogicalAnd_mixed.mlir.in index 4f2f50cad..c3c6773ed 100644 --- a/test/Gen/PlaidML/OpTest.LogicalAnd_mixed.mlir.in +++ b/test/Gen/PlaidML/OpTest.LogicalAnd_mixed.mlir.in @@ -12,6 +12,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_and { func.func @main() { diff --git a/test/Gen/PlaidML/OpTest.MaxPool1D.mlir.in b/test/Gen/PlaidML/OpTest.MaxPool1D.mlir.in index 028678447..dafcc7166 100644 --- a/test/Gen/PlaidML/OpTest.MaxPool1D.mlir.in +++ b/test/Gen/PlaidML/OpTest.MaxPool1D.mlir.in @@ -13,6 +13,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1)> #map1 = affine_map<(d0, d1) -> (d0)> module @max_pool_1d { diff --git a/test/Gen/PlaidML/OpTest.Quantize.mlir.in b/test/Gen/PlaidML/OpTest.Quantize.mlir.in index 9efadc565..535d3d8fb 100644 --- a/test/Gen/PlaidML/OpTest.Quantize.mlir.in +++ b/test/Gen/PlaidML/OpTest.Quantize.mlir.in @@ -12,6 +12,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> ()> module @quantize { diff --git a/test/Gen/PlaidML/OpTest.Relu.mlir.in b/test/Gen/PlaidML/OpTest.Relu.mlir.in index f7dbe99fe..d1dd5dc9b 100644 --- a/test/Gen/PlaidML/OpTest.Relu.mlir.in +++ b/test/Gen/PlaidML/OpTest.Relu.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @relu { diff --git a/test/Gen/PlaidML/OpTest.Softmax.mlir.in b/test/Gen/PlaidML/OpTest.Softmax.mlir.in index cf10324b3..857221057 100644 --- a/test/Gen/PlaidML/OpTest.Softmax.mlir.in +++ b/test/Gen/PlaidML/OpTest.Softmax.mlir.in @@ -15,6 +15,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0, 0)> module @softmax { diff --git a/test/Gen/PlaidML/OpTest.Sum.mlir.in b/test/Gen/PlaidML/OpTest.Sum.mlir.in index 3c6ff8c28..cd913ac25 100644 --- a/test/Gen/PlaidML/OpTest.Sum.mlir.in +++ b/test/Gen/PlaidML/OpTest.Sum.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @sum { diff --git a/test/Gen/PlaidML/OpTest.Transpose.mlir.in b/test/Gen/PlaidML/OpTest.Transpose.mlir.in index 6bc65e224..ee6844879 100644 --- a/test/Gen/PlaidML/OpTest.Transpose.mlir.in +++ b/test/Gen/PlaidML/OpTest.Transpose.mlir.in @@ -14,6 +14,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> module @transpose { diff --git a/test/Integration/Dialect/Func/gemm-with-func-and-intel-intrinsic.mlir b/test/Integration/Dialect/Func/gemm-with-func-and-intel-intrinsic.mlir index cd2f0e2dd..fd6771189 100644 --- a/test/Integration/Dialect/Func/gemm-with-func-and-intel-intrinsic.mlir +++ b/test/Integration/Dialect/Func/gemm-with-func-and-intel-intrinsic.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/func-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { diff --git a/test/Integration/Dialect/Func/load2d_dpas_store2d_with_intrinsic.mlir b/test/Integration/Dialect/Func/load2d_dpas_store2d_with_intrinsic.mlir index 6c79f0a48..7ab954a5a 100644 --- a/test/Integration/Dialect/Func/load2d_dpas_store2d_with_intrinsic.mlir +++ b/test/Integration/Dialect/Func/load2d_dpas_store2d_with_intrinsic.mlir @@ -7,6 +7,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/func-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_8x16xf16 : memref<8x16xf16> = dense<1.000000e+00> diff --git a/test/Integration/Dialect/XeGPU/dynamic_memref.vc.mlir b/test/Integration/Dialect/XeGPU/dynamic_memref.vc.mlir index a70a08c76..d4fa44a85 100644 --- a/test/Integration/Dialect/XeGPU/dynamic_memref.vc.mlir +++ b/test/Integration/Dialect/XeGPU/dynamic_memref.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A : memref<8x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/exp_f32.vc.mlir b/test/Integration/Dialect/XeGPU/exp_f32.vc.mlir index 45036fbbf..de6f2dd53 100644 --- a/test/Integration/Dialect/XeGPU/exp_f32.vc.mlir +++ b/test/Integration/Dialect/XeGPU/exp_f32.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16> ) -> (memref<8x16xf32>, memref<8x16xf32>) attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/flash_attention_fwd.mlir b/test/Integration/Dialect/XeGPU/flash_attention_fwd.mlir index 65ca640e8..ce0f77700 100644 --- a/test/Integration/Dialect/XeGPU/flash_attention_fwd.mlir +++ b/test/Integration/Dialect/XeGPU/flash_attention_fwd.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @flash_attention attributes {gpu.container_module} { gpu.module @flash_attention_fwd attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { gpu.func @flash_attention_fwd( diff --git a/test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir b/test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir index b2bb6829a..8e4b21a62 100644 --- a/test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir +++ b/test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x32xf16>, %B: memref<16x32xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir index 28f2a3210..13084ed15 100644 --- a/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_1024x1016xf16 : memref<1024x1016xf16> = dense<1.0> memref.global "private" @__constant_1016x1016xf16_ : memref<1016x1016xf16> = dense<1.0> diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1024xbf16.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1024xbf16.mlir index c96bcc20d..0f8e710b2 100644 --- a/test/Integration/Dialect/XeGPU/gemm_1024x1024xbf16.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1024xbf16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_1024x1024xbf16 : memref<1024x1024xbf16> = dense<0.0> memref.global "private" @__constant_1024x1024xbf16_ : memref<1024x1024xbf16> = dense<0.0> diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.mlir index f056d75e0..30644a4d7 100644 --- a/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir index 0807fa5de..a3ec8bf93 100644 --- a/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> diff --git a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir index e0ee10a29..bcc2e5307 100644 --- a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16.mlir b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16.mlir index ae7954564..800b1a952 100644 --- a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir index a65503775..dd917c452 100644 --- a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir index b65ed81c3..e2d49e4e0 100644 --- a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/gemm_SIMT_1024x1024x1024xf16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/gemm_SIMT_1024x1024x1024xf16_f16_f32.mlir index 53f036197..43051899d 100644 --- a/test/Integration/Dialect/XeGPU/gemm_SIMT_1024x1024x1024xf16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_SIMT_1024x1024x1024xf16_f16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm-joint-matrix.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTE: This test case provides an end-to-end example of XeGPU SIMT mode ops to SPIR-V JointMatrix ops lowering diff --git a/test/Integration/Dialect/XeGPU/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir index c0d5efb5f..422e12de7 100644 --- a/test/Integration/Dialect/XeGPU/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> diff --git a/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir b/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir index 4fc95bd4d..6d411852f 100644 --- a/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir +++ b/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf16>) -> memref<8x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir b/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir index 9f5639279..1b4377492 100644 --- a/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir +++ b/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf16>) -> memref<8x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir b/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir index fe0cff169..72ccf1a18 100644 --- a/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir +++ b/test/Integration/Dialect/XeGPU/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<8x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir b/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir index ef6db8b5e..1c2bbd7f1 100644 --- a/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir +++ b/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.0> func.func @test(%arg0: memref<8x16xf32>,%arg1:index)attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeGPU/load2d-padding.mlir b/test/Integration/Dialect/XeGPU/load2d-padding.mlir index 22a97f496..84a081a06 100644 --- a/test/Integration/Dialect/XeGPU/load2d-padding.mlir +++ b/test/Integration/Dialect/XeGPU/load2d-padding.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { // memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> diff --git a/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir b/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir index 98de82596..bdcf4952a 100644 --- a/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir +++ b/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> memref.global "private" @__constant_16x16xf16 : memref<16x16xf16> = dense<1.0> diff --git a/test/Integration/Dialect/XeGPU/load_with_block_array_16_16_2.vc.mlir b/test/Integration/Dialect/XeGPU/load_with_block_array_16_16_2.vc.mlir index ad318ecd3..3791e330b 100644 --- a/test/Integration/Dialect/XeGPU/load_with_block_array_16_16_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/load_with_block_array_16_16_2.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_16x32xf16 : memref<16x32xf16> = dense<5.000000e-01> func.func @test(%arg0: memref<16x32xf16>) -> memref<16x32xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeGPU/load_with_block_array_32_16_2.vc.mlir b/test/Integration/Dialect/XeGPU/load_with_block_array_32_16_2.vc.mlir index b672f457c..517f1a261 100644 --- a/test/Integration/Dialect/XeGPU/load_with_block_array_32_16_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/load_with_block_array_32_16_2.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/load_with_block_array_8_16_2.vc.mlir b/test/Integration/Dialect/XeGPU/load_with_block_array_8_16_2.vc.mlir index 72a06b2e6..9cfb73fa0 100644 --- a/test/Integration/Dialect/XeGPU/load_with_block_array_8_16_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/load_with_block_array_8_16_2.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_8x32xf16 : memref<8x32xf16> = dense<5.000000e-01> func.func @test(%arg0: memref<8x32xf16>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeGPU/preop_dpas.mlir b/test/Integration/Dialect/XeGPU/preop_dpas.mlir index c3bef77ca..199fdbcf7 100644 --- a/test/Integration/Dialect/XeGPU/preop_dpas.mlir +++ b/test/Integration/Dialect/XeGPU/preop_dpas.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" @__constant_8x16xf32 : memref<8x16xf32> = dense<0.0> func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeGPU/transpose_8x16xf16.mlir b/test/Integration/Dialect/XeGPU/transpose_8x16xf16.mlir index 67b5ddbb7..ac5cdb0e0 100644 --- a/test/Integration/Dialect/XeGPU/transpose_8x16xf16.mlir +++ b/test/Integration/Dialect/XeGPU/transpose_8x16xf16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index diff --git a/test/Integration/Dialect/XeGPU/transpose_8x16xf32.mlir b/test/Integration/Dialect/XeGPU/transpose_8x16xf32.mlir index ee5b0a2d8..1afda09ca 100644 --- a/test/Integration/Dialect/XeGPU/transpose_8x16xf32.mlir +++ b/test/Integration/Dialect/XeGPU/transpose_8x16xf32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index diff --git a/test/Integration/Dialect/XeGPU/transpose_8x8xf16.mlir b/test/Integration/Dialect/XeGPU/transpose_8x8xf16.mlir index 53786b92e..9aece81ad 100644 --- a/test/Integration/Dialect/XeGPU/transpose_8x8xf16.mlir +++ b/test/Integration/Dialect/XeGPU/transpose_8x8xf16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index diff --git a/test/Integration/Dialect/XeGPU/transpose_8x8xf32.mlir b/test/Integration/Dialect/XeGPU/transpose_8x8xf32.mlir index 51579ad83..da20916b5 100644 --- a/test/Integration/Dialect/XeGPU/transpose_8x8xf32.mlir +++ b/test/Integration/Dialect/XeGPU/transpose_8x8xf32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index diff --git a/test/Integration/Dialect/XeGPU/vector_broadcast_1.mlir b/test/Integration/Dialect/XeGPU/vector_broadcast_1.mlir index 4cbfab6bc..6263c7bf3 100644 --- a/test/Integration/Dialect/XeGPU/vector_broadcast_1.mlir +++ b/test/Integration/Dialect/XeGPU/vector_broadcast_1.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %bcast : memref<1x32xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/vector_broadcast_2.mlir b/test/Integration/Dialect/XeGPU/vector_broadcast_2.mlir index e523f0c9e..dfb33796d 100644 --- a/test/Integration/Dialect/XeGPU/vector_broadcast_2.mlir +++ b/test/Integration/Dialect/XeGPU/vector_broadcast_2.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %bcast : memref<1x32xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_1.vc.mlir b/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_1.vc.mlir index 53494e0cf..774bda3f2 100644 --- a/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_1.vc.mlir +++ b/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_1.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_2.vc.mlir b/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_2.vc.mlir index 5803b748c..29b769288 100644 --- a/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/vector_extract_strided_slice_2.vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<32x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/vector_insert_1.mlir b/test/Integration/Dialect/XeGPU/vector_insert_1.mlir index 53b57afc8..966a9948b 100644 --- a/test/Integration/Dialect/XeGPU/vector_insert_1.mlir +++ b/test/Integration/Dialect/XeGPU/vector_insert_1.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x16xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/vector_insert_2.mlir b/test/Integration/Dialect/XeGPU/vector_insert_2.mlir index bcca52e98..1f109159f 100644 --- a/test/Integration/Dialect/XeGPU/vector_insert_2.mlir +++ b/test/Integration/Dialect/XeGPU/vector_insert_2.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<8x16xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeGPU/xegpu-to-vc.mlir b/test/Integration/Dialect/XeGPU/xegpu-to-vc.mlir index fe13bcc47..4d45db375 100644 --- a/test/Integration/Dialect/XeGPU/xegpu-to-vc.mlir +++ b/test/Integration/Dialect/XeGPU/xegpu-to-vc.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01> diff --git a/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir b/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir index 5ee8c44d5..a23998c33 100644 --- a/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @softmax attributes {gpu.container_module} { func.func @broadcast_test() -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir b/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir index c7491fc0d..ccd69ae72 100644 --- a/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @softmax attributes {gpu.container_module} { func.func @broadcast_test() -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir b/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir index 0d5076100..d93008947 100644 --- a/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @softmax attributes {gpu.container_module} { func.func @reduce_test(%a: memref<16x1024xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir b/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir index 3183930d8..72b437a46 100644 --- a/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @softmax attributes {gpu.container_module} { func.func @reduce_test(%a: memref<1024x32xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir b/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir index 6613172d3..f807c31ff 100644 --- a/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @softmax attributes {gpu.container_module} { func.func @block_softmax_test(%a: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir b/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir index 317d42576..2667eac86 100644 --- a/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @block_softmax attributes {gpu.container_module} { func.func @block_softmax_test(%a: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir index 1c26bab1f..a182b126c 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir index 5ef63af49..af115d8a9 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir index a3bd39dd1..0f03f17e4 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir index 6622e4223..cef16d1d6 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir index 14a52dbb1..e796e40fd 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir index 21d9e4664..bfb264cc4 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir index eb0406c55..e238b8c1f 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir index 0b20ac34a..ca0e57eff 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir b/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir index 36d794c36..d5c5e0ed6 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir index 965c2e65d..aa50d10e9 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %bcast: memref<1x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir index ba730b0b4..0fb02d40f 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %bcast: memref<1x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir b/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir index a83e8dcfb..cb1da255b 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir index 79fb4efde..6ff19624a 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir index 011c8ce9b..78bb3e1e1 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { diff --git a/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir b/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir index 5901acefb..3e4697675 100644 --- a/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation diff --git a/test/Jax/gordon/jit__logsm_from_logmhalo_jax_kern_0_before_linalg.mlir b/test/Jax/gordon/jit__logsm_from_logmhalo_jax_kern_0_before_linalg.mlir index a8ca09e52..de3acae24 100644 --- a/test/Jax/gordon/jit__logsm_from_logmhalo_jax_kern_0_before_linalg.mlir +++ b/test/Jax/gordon/jit__logsm_from_logmhalo_jax_kern_0_before_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<() -> ()> #map1 = affine_map<(d0) -> ()> #map2 = affine_map<(d0) -> (d0)> diff --git a/test/Jax/janet/jit__get_age_weights_from_tables.8_linalg.mlir b/test/Jax/janet/jit__get_age_weights_from_tables.8_linalg.mlir index d61fb7506..8c1e08a74 100644 --- a/test/Jax/janet/jit__get_age_weights_from_tables.8_linalg.mlir +++ b/test/Jax/janet/jit__get_age_weights_from_tables.8_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/Jax/janet/jit__get_lgt_birth.7_linalg.mlir b/test/Jax/janet/jit__get_lgt_birth.7_linalg.mlir index 8aabfa3d8..18543f9ff 100644 --- a/test/Jax/janet/jit__get_lgt_birth.7_linalg.mlir +++ b/test/Jax/janet/jit__get_lgt_birth.7_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> ()> #map1 = affine_map<(d0) -> (d0)> #map2 = affine_map<() -> ()> diff --git a/test/Jax/janet/jit__get_met_weights_singlegal.43_linalg.mlir b/test/Jax/janet/jit__get_met_weights_singlegal.43_linalg.mlir index 35c129aef..fb8de1889 100644 --- a/test/Jax/janet/jit__get_met_weights_singlegal.43_linalg.mlir +++ b/test/Jax/janet/jit__get_met_weights_singlegal.43_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/Jax/janet/jit__net_loss.91_linalg.mlir b/test/Jax/janet/jit__net_loss.91_linalg.mlir index c33e58b07..4bce3abe5 100644 --- a/test/Jax/janet/jit__net_loss.91_linalg.mlir +++ b/test/Jax/janet/jit__net_loss.91_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1)> #map1 = affine_map<(d0, d1) -> (d0, d1)> #map2 = affine_map<(d0, d1) -> ()> diff --git a/test/Jax/janet/jit__unit_scale_traindata.47_linalg.mlir b/test/Jax/janet/jit__unit_scale_traindata.47_linalg.mlir index 7f4da94eb..16e589819 100644 --- a/test/Jax/janet/jit__unit_scale_traindata.47_linalg.mlir +++ b/test/Jax/janet/jit__unit_scale_traindata.47_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/test/Jax/janet/jit_prim_fun.50_linalg.mlir b/test/Jax/janet/jit_prim_fun.50_linalg.mlir index 05514c173..8bbbc45aa 100644 --- a/test/Jax/janet/jit_prim_fun.50_linalg.mlir +++ b/test/Jax/janet/jit_prim_fun.50_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @jit_prim_fun.50 { diff --git a/test/Jax/jax_qmc/jit__linspace.39_linalg.mlir b/test/Jax/jax_qmc/jit__linspace.39_linalg.mlir index 2c6cc8273..6eea5c2b2 100644 --- a/test/Jax/jax_qmc/jit__linspace.39_linalg.mlir +++ b/test/Jax/jax_qmc/jit__linspace.39_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime,igpu-fp64 %igpu_fp64 -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<() -> ()> #map1 = affine_map<(d0) -> ()> #map2 = affine_map<(d0) -> (d0)> diff --git a/test/Jax/jax_qmc/jit__mean.46_linalg.mlir b/test/Jax/jax_qmc/jit__mean.46_linalg.mlir index 7c22f5931..3c0e2d753 100644 --- a/test/Jax/jax_qmc/jit__mean.46_linalg.mlir +++ b/test/Jax/jax_qmc/jit__mean.46_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime,igpu-fp64 %igpu_fp64 -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> #map2 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/test/Jax/jax_qmc/jit__mean.51_linalg.mlir b/test/Jax/jax_qmc/jit__mean.51_linalg.mlir index 77b814fa2..53d12620a 100644 --- a/test/Jax/jax_qmc/jit__mean.51_linalg.mlir +++ b/test/Jax/jax_qmc/jit__mean.51_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime,igpu-fp64 %igpu_fp64 -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/Jax/jax_qmc/jit_pionless_2b_lo.41_linalg.mlir b/test/Jax/jax_qmc/jit_pionless_2b_lo.41_linalg.mlir index a43b567a1..219079953 100644 --- a/test/Jax/jax_qmc/jit_pionless_2b_lo.41_linalg.mlir +++ b/test/Jax/jax_qmc/jit_pionless_2b_lo.41_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime,igpu-fp64 %igpu_fp64 -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<() -> ()> #map1 = affine_map<(d0) -> (d0)> #map2 = affine_map<(d0) -> (0)> diff --git a/test/Jax/jax_qmc/jit_prim_fun.13_linalg.mlir b/test/Jax/jax_qmc/jit_prim_fun.13_linalg.mlir index 94f1a3ea7..7f53ad096 100644 --- a/test/Jax/jax_qmc/jit_prim_fun.13_linalg.mlir +++ b/test/Jax/jax_qmc/jit_prim_fun.13_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0) -> (d0)> module @jit_prim_fun.13 { diff --git a/test/Jax/jax_qmc/jit_v_em.42_linalg.mlir b/test/Jax/jax_qmc/jit_v_em.42_linalg.mlir index 3e430c779..e6d6c8b87 100644 --- a/test/Jax/jax_qmc/jit_v_em.42_linalg.mlir +++ b/test/Jax/jax_qmc/jit_v_em.42_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime,igpu-fp64 %igpu_fp64 -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<() -> ()> module @jit_v_em.42 { diff --git a/test/Jax/qoc/jit__diag.11_linalg.mlir b/test/Jax/qoc/jit__diag.11_linalg.mlir index 9b87b1200..ad8fd8aed 100644 --- a/test/Jax/qoc/jit__diag.11_linalg.mlir +++ b/test/Jax/qoc/jit__diag.11_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/test/Jax/qoc/jit__reduce_sum.357_linalg.mlir b/test/Jax/qoc/jit__reduce_sum.357_linalg.mlir index 1895693e4..050d7a2e6 100644 --- a/test/Jax/qoc/jit__reduce_sum.357_linalg.mlir +++ b/test/Jax/qoc/jit__reduce_sum.357_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> ()> module @jit__reduce_sum.357 { diff --git a/test/Jax/qoc/jit_absolute.341_linalg.mlir b/test/Jax/qoc/jit_absolute.341_linalg.mlir index c295e7b5b..636bc1baf 100644 --- a/test/Jax/qoc/jit_absolute.341_linalg.mlir +++ b/test/Jax/qoc/jit_absolute.341_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<() -> ()> module @jit_absolute.341 { diff --git a/test/Jax/qoc/jit_conjugate.307_linalg.mlir b/test/Jax/qoc/jit_conjugate.307_linalg.mlir index 2d7b6a4c1..4c3f9cc78 100644 --- a/test/Jax/qoc/jit_conjugate.307_linalg.mlir +++ b/test/Jax/qoc/jit_conjugate.307_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @jit_conjugate.307 { func.func private @callee(%arg0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { diff --git a/test/Jax/qoc/jit_matmul.338_linalg.mlir b/test/Jax/qoc/jit_matmul.338_linalg.mlir index 783253e5b..a8d164d32 100644 --- a/test/Jax/qoc/jit_matmul.338_linalg.mlir +++ b/test/Jax/qoc/jit_matmul.338_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d1, d3, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> diff --git a/test/Jax/qoc/jit_prim_fun.335_linalg.mlir b/test/Jax/qoc/jit_prim_fun.335_linalg.mlir index 693714511..f3bb44dd6 100644 --- a/test/Jax/qoc/jit_prim_fun.335_linalg.mlir +++ b/test/Jax/qoc/jit_prim_fun.335_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0) -> (d0)> module @jit_prim_fun.335 { diff --git a/test/Jax/qoc/jit_real.364_linalg.mlir b/test/Jax/qoc/jit_real.364_linalg.mlir index 07753a7af..9849504c2 100644 --- a/test/Jax/qoc/jit_real.364_linalg.mlir +++ b/test/Jax/qoc/jit_real.364_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0) -> (d0)> module @jit_real.364 { diff --git a/test/Jax/qoc/jit_swapaxes.304_linalg.mlir b/test/Jax/qoc/jit_swapaxes.304_linalg.mlir index ae7688096..0e2b92732 100644 --- a/test/Jax/qoc/jit_swapaxes.304_linalg.mlir +++ b/test/Jax/qoc/jit_swapaxes.304_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> module @jit_swapaxes.304 { diff --git a/test/Jax/qoc/jit_trace.340_linalg.mlir b/test/Jax/qoc/jit_trace.340_linalg.mlir index e864f850c..049be9730 100644 --- a/test/Jax/qoc/jit_trace.340_linalg.mlir +++ b/test/Jax/qoc/jit_trace.340_linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/test/Jax/qoc/jit_true_divide.316_linalg.mlir b/test/Jax/qoc/jit_true_divide.316_linalg.mlir index 22bac4239..7ad288d68 100644 --- a/test/Jax/qoc/jit_true_divide.316_linalg.mlir +++ b/test/Jax/qoc/jit_true_divide.316_linalg.mlir @@ -10,6 +10,10 @@ // RUN-GPU: --runner imex-cpu-runner -e main \ // RUN-GPU: --entry-point-result=void \ // RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN-GPU: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN-GPU: --runner imex-cpu-runner -e main \ +// RUN-GPU: --entry-point-result=void \ +// RUN-GPU: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<() -> ()> module @jit_true_divide.316 { func.func private @callee(%arg0: tensor>, %arg1: tensor) -> tensor> { diff --git a/test/Models/Mobilenet-v3/mobilenetv3-linalg-without-tensor-pad.mlir b/test/Models/Mobilenet-v3/mobilenetv3-linalg-without-tensor-pad.mlir index 8303e33f9..579d9001a 100644 --- a/test/Models/Mobilenet-v3/mobilenetv3-linalg-without-tensor-pad.mlir +++ b/test/Models/Mobilenet-v3/mobilenetv3-linalg-without-tensor-pad.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> diff --git a/test/Models/Mobilenet-v3/mobilenetv3-linalg.mlir b/test/Models/Mobilenet-v3/mobilenetv3-linalg.mlir index ed5e01d36..71eabcb10 100644 --- a/test/Models/Mobilenet-v3/mobilenetv3-linalg.mlir +++ b/test/Models/Mobilenet-v3/mobilenetv3-linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> diff --git a/test/Models/Resnet-50/resnet-50-linalg-without-tensor-pad.mlir b/test/Models/Resnet-50/resnet-50-linalg-without-tensor-pad.mlir index 97d16de0c..d41ca5cc5 100644 --- a/test/Models/Resnet-50/resnet-50-linalg-without-tensor-pad.mlir +++ b/test/Models/Resnet-50/resnet-50-linalg-without-tensor-pad.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> diff --git a/test/Models/Resnet-50/resnet-50-linalg.mlir b/test/Models/Resnet-50/resnet-50-linalg.mlir index e89ce9f5a..b0a263120 100644 --- a/test/Models/Resnet-50/resnet-50-linalg.mlir +++ b/test/Models/Resnet-50/resnet-50-linalg.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> diff --git a/test/PlaidML/CppEdsl.Add.mlir b/test/PlaidML/CppEdsl.Add.mlir index 3da63bb55..2089b34e9 100644 --- a/test/PlaidML/CppEdsl.Add.mlir +++ b/test/PlaidML/CppEdsl.Add.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm-caching.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @add { func.func @main() { diff --git a/test/PlaidML/CppEdsl.Atan.mlir b/test/PlaidML/CppEdsl.Atan.mlir index fe29fb7ae..acee0198d 100644 --- a/test/PlaidML/CppEdsl.Atan.mlir +++ b/test/PlaidML/CppEdsl.Atan.mlir @@ -9,7 +9,11 @@ // RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @atan { func.func @test(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { diff --git a/test/PlaidML/CppEdsl.BigDot.mlir b/test/PlaidML/CppEdsl.BigDot.mlir index 711b323e8..f33f1baf0 100644 --- a/test/PlaidML/CppEdsl.BigDot.mlir +++ b/test/PlaidML/CppEdsl.BigDot.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/CppEdsl.BitAnd.mlir b/test/PlaidML/CppEdsl.BitAnd.mlir index 5f9d88a52..ce48c7f58 100644 --- a/test/PlaidML/CppEdsl.BitAnd.mlir +++ b/test/PlaidML/CppEdsl.BitAnd.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @bit_and { func.func @main() { diff --git a/test/PlaidML/CppEdsl.BitAndScalar.mlir b/test/PlaidML/CppEdsl.BitAndScalar.mlir index bfd8f26e8..003ef4f0e 100644 --- a/test/PlaidML/CppEdsl.BitAndScalar.mlir +++ b/test/PlaidML/CppEdsl.BitAndScalar.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @bit_and { diff --git a/test/PlaidML/CppEdsl.BitLeft.mlir b/test/PlaidML/CppEdsl.BitLeft.mlir index 556d038f8..0edddc7be 100644 --- a/test/PlaidML/CppEdsl.BitLeft.mlir +++ b/test/PlaidML/CppEdsl.BitLeft.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @bit_left { func.func @main() { diff --git a/test/PlaidML/CppEdsl.BitNot.mlir b/test/PlaidML/CppEdsl.BitNot.mlir index 1a264f42b..73124c476 100644 --- a/test/PlaidML/CppEdsl.BitNot.mlir +++ b/test/PlaidML/CppEdsl.BitNot.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @bit_not { func.func @main() { diff --git a/test/PlaidML/CppEdsl.BitOr.mlir b/test/PlaidML/CppEdsl.BitOr.mlir index c7e990476..393ab32e7 100644 --- a/test/PlaidML/CppEdsl.BitOr.mlir +++ b/test/PlaidML/CppEdsl.BitOr.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @bit_or { func.func @main() { diff --git a/test/PlaidML/CppEdsl.BitRightScalar.mlir b/test/PlaidML/CppEdsl.BitRightScalar.mlir index 60eb23a01..212f456f1 100644 --- a/test/PlaidML/CppEdsl.BitRightScalar.mlir +++ b/test/PlaidML/CppEdsl.BitRightScalar.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @bit_right_scalar { diff --git a/test/PlaidML/CppEdsl.BitRightTensor.mlir b/test/PlaidML/CppEdsl.BitRightTensor.mlir index bb8bece50..003e4d477 100644 --- a/test/PlaidML/CppEdsl.BitRightTensor.mlir +++ b/test/PlaidML/CppEdsl.BitRightTensor.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @bit_right_tensor { func.func @main() { diff --git a/test/PlaidML/CppEdsl.BitXor.mlir b/test/PlaidML/CppEdsl.BitXor.mlir index 73cf52db4..3a2f627be 100644 --- a/test/PlaidML/CppEdsl.BitXor.mlir +++ b/test/PlaidML/CppEdsl.BitXor.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @bit_xor { func.func @main() { diff --git a/test/PlaidML/CppEdsl.BroadcastCmp.mlir b/test/PlaidML/CppEdsl.BroadcastCmp.mlir index 6af103c13..d19e60a81 100644 --- a/test/PlaidML/CppEdsl.BroadcastCmp.mlir +++ b/test/PlaidML/CppEdsl.BroadcastCmp.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0, 0)> module @broadcast_cmp { diff --git a/test/PlaidML/CppEdsl.Cast.mlir b/test/PlaidML/CppEdsl.Cast.mlir index 27d55e8d4..70d8f3634 100644 --- a/test/PlaidML/CppEdsl.Cast.mlir +++ b/test/PlaidML/CppEdsl.Cast.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @cast { func.func @test(%arg0: tensor<3x3xi64>) -> tensor<3x3xi32> { diff --git a/test/PlaidML/CppEdsl.ConstAdd.mlir b/test/PlaidML/CppEdsl.ConstAdd.mlir index 0499abb84..67332b25c 100644 --- a/test/PlaidML/CppEdsl.ConstAdd.mlir +++ b/test/PlaidML/CppEdsl.ConstAdd.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0) -> (d0)> module @const_add { func.func @test(%arg0: tensor<4xi32> {stdx.const}, %arg1: tensor<4xi32> {stdx.const}) -> tensor<4xi32> { diff --git a/test/PlaidML/CppEdsl.ConvI8.mlir b/test/PlaidML/CppEdsl.ConvI8.mlir index 7b6efd06a..65b9b0e11 100644 --- a/test/PlaidML/CppEdsl.ConvI8.mlir +++ b/test/PlaidML/CppEdsl.ConvI8.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> diff --git a/test/PlaidML/CppEdsl.Convolution.mlir b/test/PlaidML/CppEdsl.Convolution.mlir index af40c89ae..379f87ecc 100644 --- a/test/PlaidML/CppEdsl.Convolution.mlir +++ b/test/PlaidML/CppEdsl.Convolution.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> diff --git a/test/PlaidML/CppEdsl.Cos.mlir b/test/PlaidML/CppEdsl.Cos.mlir index 9fcf963f9..8d5811ad0 100644 --- a/test/PlaidML/CppEdsl.Cos.mlir +++ b/test/PlaidML/CppEdsl.Cos.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @cos { func.func @test(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { diff --git a/test/PlaidML/CppEdsl.CumSum.mlir b/test/PlaidML/CppEdsl.CumSum.mlir index 860db8140..1556918b6 100644 --- a/test/PlaidML/CppEdsl.CumSum.mlir +++ b/test/PlaidML/CppEdsl.CumSum.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1)> #map1 = affine_map<(d0, d1) -> (d0)> #set = affine_set<(d0, d1) : (d0 - d1 >= 0)> diff --git a/test/PlaidML/CppEdsl.DefractLong.mlir b/test/PlaidML/CppEdsl.DefractLong.mlir index 6625fcd2b..d21051960 100644 --- a/test/PlaidML/CppEdsl.DefractLong.mlir +++ b/test/PlaidML/CppEdsl.DefractLong.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, -d0 + d3 + d8 + 1, -d0 + d4 + d5 + d9, d10)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d0 * 4 + d2 * 2 + d7 - d8 * 2, d0 * 4 + d2 - d4 + d7 - d9 * 2 + 3, d6, d10)> diff --git a/test/PlaidML/CppEdsl.Dot.mlir b/test/PlaidML/CppEdsl.Dot.mlir index 2f5f7e222..b04e95e42 100644 --- a/test/PlaidML/CppEdsl.Dot.mlir +++ b/test/PlaidML/CppEdsl.Dot.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/CppEdsl.DotF16.mlir b/test/PlaidML/CppEdsl.DotF16.mlir index e55d62e0e..80c81993d 100644 --- a/test/PlaidML/CppEdsl.DotF16.mlir +++ b/test/PlaidML/CppEdsl.DotF16.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/CppEdsl.DotF16_AccF32.mlir b/test/PlaidML/CppEdsl.DotF16_AccF32.mlir index e831aeadf..d24bbdbe9 100644 --- a/test/PlaidML/CppEdsl.DotF16_AccF32.mlir +++ b/test/PlaidML/CppEdsl.DotF16_AccF32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> diff --git a/test/PlaidML/CppEdsl.DoubleDot.mlir b/test/PlaidML/CppEdsl.DoubleDot.mlir index 700b77358..79168112d 100644 --- a/test/PlaidML/CppEdsl.DoubleDot.mlir +++ b/test/PlaidML/CppEdsl.DoubleDot.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/CppEdsl.DupOut.mlir b/test/PlaidML/CppEdsl.DupOut.mlir index 558810ae8..2d719c051 100644 --- a/test/PlaidML/CppEdsl.DupOut.mlir +++ b/test/PlaidML/CppEdsl.DupOut.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/CppEdsl.EltwiseMod.mlir b/test/PlaidML/CppEdsl.EltwiseMod.mlir index 58e7419f7..c4c057e55 100644 --- a/test/PlaidML/CppEdsl.EltwiseMod.mlir +++ b/test/PlaidML/CppEdsl.EltwiseMod.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @mod { func.func @main() { diff --git a/test/PlaidML/CppEdsl.Erf.mlir b/test/PlaidML/CppEdsl.Erf.mlir index 35d33fe47..ecb2ae421 100644 --- a/test/PlaidML/CppEdsl.Erf.mlir +++ b/test/PlaidML/CppEdsl.Erf.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @erf { func.func @main() { diff --git a/test/PlaidML/OpTest.Abs.mlir b/test/PlaidML/OpTest.Abs.mlir index 308ca8621..861e2d910 100644 --- a/test/PlaidML/OpTest.Abs.mlir +++ b/test/PlaidML/OpTest.Abs.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> module @abs { diff --git a/test/PlaidML/OpTest.Argmax.mlir b/test/PlaidML/OpTest.Argmax.mlir index c5c786985..b6ffd6a0d 100644 --- a/test/PlaidML/OpTest.Argmax.mlir +++ b/test/PlaidML/OpTest.Argmax.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/PlaidML/OpTest.BinaryCrossentropy.mlir b/test/PlaidML/OpTest.BinaryCrossentropy.mlir index 8fb3e9481..3641c5b40 100644 --- a/test/PlaidML/OpTest.BinaryCrossentropy.mlir +++ b/test/PlaidML/OpTest.BinaryCrossentropy.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> module @binary_crossentropy { diff --git a/test/PlaidML/OpTest.BroadcastNonNumpy.mlir b/test/PlaidML/OpTest.BroadcastNonNumpy.mlir index 70d5871d4..b64966910 100644 --- a/test/PlaidML/OpTest.BroadcastNonNumpy.mlir +++ b/test/PlaidML/OpTest.BroadcastNonNumpy.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> module @broadcast_non_numpy { diff --git a/test/PlaidML/OpTest.ComplexConv2D.mlir b/test/PlaidML/OpTest.ComplexConv2D.mlir index 1a14bca12..4dd965cf9 100644 --- a/test/PlaidML/OpTest.ComplexConv2D.mlir +++ b/test/PlaidML/OpTest.ComplexConv2D.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 * 2 + d5 * 3, d2 * 2 + d6 * 3, d3, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d3, d7, d4)> diff --git a/test/PlaidML/OpTest.Conv1D.mlir b/test/PlaidML/OpTest.Conv1D.mlir index a2839f423..001f3ce5b 100644 --- a/test/PlaidML/OpTest.Conv1D.mlir +++ b/test/PlaidML/OpTest.Conv1D.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> diff --git a/test/PlaidML/OpTest.Conv2DDilated.mlir b/test/PlaidML/OpTest.Conv2DDilated.mlir index 27aecce68..b9ad8ae91 100644 --- a/test/PlaidML/OpTest.Conv2DDilated.mlir +++ b/test/PlaidML/OpTest.Conv2DDilated.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4 * 2, d2 + d5 * 3, d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> diff --git a/test/PlaidML/OpTest.Dot.mlir b/test/PlaidML/OpTest.Dot.mlir index 1141bccd9..77c228f9e 100644 --- a/test/PlaidML/OpTest.Dot.mlir +++ b/test/PlaidML/OpTest.Dot.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/OpTest.DotF16.mlir b/test/PlaidML/OpTest.DotF16.mlir index 3084b21ac..51f681815 100644 --- a/test/PlaidML/OpTest.DotF16.mlir +++ b/test/PlaidML/OpTest.DotF16.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module @dot_f16 { diff --git a/test/PlaidML/OpTest.EltwiseAdd.mlir b/test/PlaidML/OpTest.EltwiseAdd.mlir index f4ba71d32..62d3e9559 100644 --- a/test/PlaidML/OpTest.EltwiseAdd.mlir +++ b/test/PlaidML/OpTest.EltwiseAdd.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @eltwise_add { func.func @test(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { diff --git a/test/PlaidML/OpTest.ExplicitPadding.mlir b/test/PlaidML/OpTest.ExplicitPadding.mlir index ff7bf50ae..75ae17b9c 100644 --- a/test/PlaidML/OpTest.ExplicitPadding.mlir +++ b/test/PlaidML/OpTest.ExplicitPadding.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0 + 2, d1 + 1)> module @explicit_padding { diff --git a/test/PlaidML/OpTest.ExplicitPaddingInf.mlir b/test/PlaidML/OpTest.ExplicitPaddingInf.mlir index b1ec820b2..9fe99da7c 100644 --- a/test/PlaidML/OpTest.ExplicitPaddingInf.mlir +++ b/test/PlaidML/OpTest.ExplicitPaddingInf.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0 + 2, d1 + 1)> module @explicit_padding { diff --git a/test/PlaidML/OpTest.ExplicitPaddingNegInf.mlir b/test/PlaidML/OpTest.ExplicitPaddingNegInf.mlir index 89ecb6df6..b06c16947 100644 --- a/test/PlaidML/OpTest.ExplicitPaddingNegInf.mlir +++ b/test/PlaidML/OpTest.ExplicitPaddingNegInf.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0 + 2, d1 + 1)> module @explicit_padding { diff --git a/test/PlaidML/OpTest.Floor.mlir b/test/PlaidML/OpTest.Floor.mlir index 667d19900..1075a696d 100644 --- a/test/PlaidML/OpTest.Floor.mlir +++ b/test/PlaidML/OpTest.Floor.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @floor { func.func @main() { diff --git a/test/PlaidML/OpTest.GEMM_FLOAT32.mlir b/test/PlaidML/OpTest.GEMM_FLOAT32.mlir index bfebd5c21..c6842bb67 100644 --- a/test/PlaidML/OpTest.GEMM_FLOAT32.mlir +++ b/test/PlaidML/OpTest.GEMM_FLOAT32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm-caching.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/OpTest.GEMM_INT32.mlir b/test/PlaidML/OpTest.GEMM_INT32.mlir index aca89d3d3..121a087e6 100644 --- a/test/PlaidML/OpTest.GEMM_INT32.mlir +++ b/test/PlaidML/OpTest.GEMM_INT32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/OpTest.GEMM_INT64.mlir b/test/PlaidML/OpTest.GEMM_INT64.mlir index 43717033d..0c956c258 100644 --- a/test/PlaidML/OpTest.GEMM_INT64.mlir +++ b/test/PlaidML/OpTest.GEMM_INT64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/OpTest.GEMM_INT8.mlir b/test/PlaidML/OpTest.GEMM_INT8.mlir index cc52e278f..4b059f055 100644 --- a/test/PlaidML/OpTest.GEMM_INT8.mlir +++ b/test/PlaidML/OpTest.GEMM_INT8.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/OpTest.GEMM_UINT64.mlir b/test/PlaidML/OpTest.GEMM_UINT64.mlir index 475521cab..a25522d35 100644 --- a/test/PlaidML/OpTest.GEMM_UINT64.mlir +++ b/test/PlaidML/OpTest.GEMM_UINT64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/test/PlaidML/OpTest.GEMVC_INT64.mlir b/test/PlaidML/OpTest.GEMVC_INT64.mlir index 5c80df9f3..dc696c35d 100644 --- a/test/PlaidML/OpTest.GEMVC_INT64.mlir +++ b/test/PlaidML/OpTest.GEMVC_INT64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> #map2 = affine_map<(d0, d1) -> (d1)> diff --git a/test/PlaidML/OpTest.GEMV_FLOAT32.mlir b/test/PlaidML/OpTest.GEMV_FLOAT32.mlir index ffcfd8748..1b65e77a0 100644 --- a/test/PlaidML/OpTest.GEMV_FLOAT32.mlir +++ b/test/PlaidML/OpTest.GEMV_FLOAT32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1)> #map2 = affine_map<(d0, d1) -> (d0)> diff --git a/test/PlaidML/OpTest.GEMV_INT32.mlir b/test/PlaidML/OpTest.GEMV_INT32.mlir index 66d0e7f58..983a91333 100644 --- a/test/PlaidML/OpTest.GEMV_INT32.mlir +++ b/test/PlaidML/OpTest.GEMV_INT32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1)> #map2 = affine_map<(d0, d1) -> (d0)> diff --git a/test/PlaidML/OpTest.GEMV_INT64.mlir b/test/PlaidML/OpTest.GEMV_INT64.mlir index 6d90ca9d9..89769f5eb 100644 --- a/test/PlaidML/OpTest.GEMV_INT64.mlir +++ b/test/PlaidML/OpTest.GEMV_INT64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1)> #map2 = affine_map<(d0, d1) -> (d0)> diff --git a/test/PlaidML/OpTest.GEMV_INT8.mlir b/test/PlaidML/OpTest.GEMV_INT8.mlir index 557b56dbf..1f95a52ff 100644 --- a/test/PlaidML/OpTest.GEMV_INT8.mlir +++ b/test/PlaidML/OpTest.GEMV_INT8.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1)> #map2 = affine_map<(d0, d1) -> (d0)> diff --git a/test/PlaidML/OpTest.GEMV_UINT64.mlir b/test/PlaidML/OpTest.GEMV_UINT64.mlir index 6d90ca9d9..89769f5eb 100644 --- a/test/PlaidML/OpTest.GEMV_UINT64.mlir +++ b/test/PlaidML/OpTest.GEMV_UINT64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1)> #map2 = affine_map<(d0, d1) -> (d0)> diff --git a/test/PlaidML/OpTest.GlobalMin.mlir b/test/PlaidML/OpTest.GlobalMin.mlir index 8f663f1e9..c1492410a 100644 --- a/test/PlaidML/OpTest.GlobalMin.mlir +++ b/test/PlaidML/OpTest.GlobalMin.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/PlaidML/OpTest.HigherPrecisioConstants.mlir b/test/PlaidML/OpTest.HigherPrecisioConstants.mlir index 3eff7140a..3be0d85b9 100644 --- a/test/PlaidML/OpTest.HigherPrecisioConstants.mlir +++ b/test/PlaidML/OpTest.HigherPrecisioConstants.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime,igpu-fp64 %igpu_fp64 -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @higher_precision_constants { diff --git a/test/PlaidML/OpTest.LarsMomentum4d.mlir b/test/PlaidML/OpTest.LarsMomentum4d.mlir index 7e36c095b..e36422394 100644 --- a/test/PlaidML/OpTest.LarsMomentum4d.mlir +++ b/test/PlaidML/OpTest.LarsMomentum4d.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> #map2 = affine_map<() -> ()> diff --git a/test/PlaidML/OpTest.Layer.mlir b/test/PlaidML/OpTest.Layer.mlir index 2be73a8cd..dfcbe4a1b 100644 --- a/test/PlaidML/OpTest.Layer.mlir +++ b/test/PlaidML/OpTest.Layer.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @relu { diff --git a/test/PlaidML/OpTest.LayerEmbeddedConst.mlir b/test/PlaidML/OpTest.LayerEmbeddedConst.mlir index 027217f95..481f86140 100644 --- a/test/PlaidML/OpTest.LayerEmbeddedConst.mlir +++ b/test/PlaidML/OpTest.LayerEmbeddedConst.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @LayerEmbeddedConst { func.func @main() { diff --git a/test/PlaidML/OpTest.LayerException.mlir b/test/PlaidML/OpTest.LayerException.mlir index dd9ecfb5e..1ab457201 100644 --- a/test/PlaidML/OpTest.LayerException.mlir +++ b/test/PlaidML/OpTest.LayerException.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @LayerException { func.func @main() { diff --git a/test/PlaidML/OpTest.LayerMulti.mlir b/test/PlaidML/OpTest.LayerMulti.mlir index 22e58f232..7d5d3c266 100644 --- a/test/PlaidML/OpTest.LayerMulti.mlir +++ b/test/PlaidML/OpTest.LayerMulti.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @LayerMulti { func.func @main() { diff --git a/test/PlaidML/OpTest.LayerOperandOrder.mlir b/test/PlaidML/OpTest.LayerOperandOrder.mlir index 8968701ee..cdf8cb302 100644 --- a/test/PlaidML/OpTest.LayerOperandOrder.mlir +++ b/test/PlaidML/OpTest.LayerOperandOrder.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @LayerOperandOrder { func.func @main() { diff --git a/test/PlaidML/OpTest.LayerUnusedOperand.mlir b/test/PlaidML/OpTest.LayerUnusedOperand.mlir index 866bca0cd..a85c7b084 100644 --- a/test/PlaidML/OpTest.LayerUnusedOperand.mlir +++ b/test/PlaidML/OpTest.LayerUnusedOperand.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @LayerUnusedOperand { func.func @main() { diff --git a/test/PlaidML/OpTest.Lens.mlir b/test/PlaidML/OpTest.Lens.mlir index b1bab2f1d..9cab5e55c 100644 --- a/test/PlaidML/OpTest.Lens.mlir +++ b/test/PlaidML/OpTest.Lens.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> module @transpose_nm { diff --git a/test/PlaidML/OpTest.LogicalAnd_mixed.mlir b/test/PlaidML/OpTest.LogicalAnd_mixed.mlir index aafec5a6e..9aa6d5b6e 100644 --- a/test/PlaidML/OpTest.LogicalAnd_mixed.mlir +++ b/test/PlaidML/OpTest.LogicalAnd_mixed.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_and { func.func @main() { diff --git a/test/PlaidML/OpTest.LogicalAnd_uint64.mlir b/test/PlaidML/OpTest.LogicalAnd_uint64.mlir index c289f8c1d..51335eddf 100644 --- a/test/PlaidML/OpTest.LogicalAnd_uint64.mlir +++ b/test/PlaidML/OpTest.LogicalAnd_uint64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_and { func.func @main() { diff --git a/test/PlaidML/OpTest.LogicalNot_float.mlir b/test/PlaidML/OpTest.LogicalNot_float.mlir index abda4b077..1db25e2c7 100644 --- a/test/PlaidML/OpTest.LogicalNot_float.mlir +++ b/test/PlaidML/OpTest.LogicalNot_float.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_not { func.func @main() { diff --git a/test/PlaidML/OpTest.LogicalNot_int32.mlir b/test/PlaidML/OpTest.LogicalNot_int32.mlir index 7c7a85c28..5df5039a4 100644 --- a/test/PlaidML/OpTest.LogicalNot_int32.mlir +++ b/test/PlaidML/OpTest.LogicalNot_int32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_not { func.func @main() { diff --git a/test/PlaidML/OpTest.LogicalOr_float.mlir b/test/PlaidML/OpTest.LogicalOr_float.mlir index f2bf9c7ae..8b5743ecb 100644 --- a/test/PlaidML/OpTest.LogicalOr_float.mlir +++ b/test/PlaidML/OpTest.LogicalOr_float.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_or { func.func @main() { diff --git a/test/PlaidML/OpTest.LogicalOr_int32.mlir b/test/PlaidML/OpTest.LogicalOr_int32.mlir index 3f5c56f8f..fe75c1920 100644 --- a/test/PlaidML/OpTest.LogicalOr_int32.mlir +++ b/test/PlaidML/OpTest.LogicalOr_int32.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_or { func.func @main() { diff --git a/test/PlaidML/OpTest.LogicalOr_uint64.mlir b/test/PlaidML/OpTest.LogicalOr_uint64.mlir index f86c2fe2c..06f826556 100644 --- a/test/PlaidML/OpTest.LogicalOr_uint64.mlir +++ b/test/PlaidML/OpTest.LogicalOr_uint64.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @logical_or { func.func @main() { diff --git a/test/PlaidML/OpTest.Matmul.mlir b/test/PlaidML/OpTest.Matmul.mlir index e81d850ed..c8b24b109 100644 --- a/test/PlaidML/OpTest.Matmul.mlir +++ b/test/PlaidML/OpTest.Matmul.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @linalg_matmul{ func.func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<5x2xf32>) { %cst = arith.constant 0.0 : f32 diff --git a/test/PlaidML/OpTest.Max.mlir b/test/PlaidML/OpTest.Max.mlir index 3d1e997a8..16dc4996a 100644 --- a/test/PlaidML/OpTest.Max.mlir +++ b/test/PlaidML/OpTest.Max.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> module @max { diff --git a/test/PlaidML/OpTest.MaxPool1D.mlir b/test/PlaidML/OpTest.MaxPool1D.mlir index b2a40022d..a698584bb 100644 --- a/test/PlaidML/OpTest.MaxPool1D.mlir +++ b/test/PlaidML/OpTest.MaxPool1D.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1)> #map1 = affine_map<(d0, d1) -> (d0)> module @max_pool_1d { diff --git a/test/PlaidML/OpTest.MnistCnn.mlir b/test/PlaidML/OpTest.MnistCnn.mlir index a667b5942..a54705451 100644 --- a/test/PlaidML/OpTest.MnistCnn.mlir +++ b/test/PlaidML/OpTest.MnistCnn.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> diff --git a/test/PlaidML/OpTest.MnistMlp.mlir b/test/PlaidML/OpTest.MnistMlp.mlir index 5e5b19895..846a7d1e7 100644 --- a/test/PlaidML/OpTest.MnistMlp.mlir +++ b/test/PlaidML/OpTest.MnistMlp.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1)> #map1 = affine_map<(d0, d1) -> (d0, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> diff --git a/test/PlaidML/OpTest.Pow.mlir b/test/PlaidML/OpTest.Pow.mlir index 1a4b44af9..9f090518a 100644 --- a/test/PlaidML/OpTest.Pow.mlir +++ b/test/PlaidML/OpTest.Pow.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @pow { func.func @main() { diff --git a/test/PlaidML/OpTest.Quantize.mlir b/test/PlaidML/OpTest.Quantize.mlir index d2191271d..514ccbd4f 100644 --- a/test/PlaidML/OpTest.Quantize.mlir +++ b/test/PlaidML/OpTest.Quantize.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> ()> module @quantize { diff --git a/test/PlaidML/OpTest.Reciprocal.mlir b/test/PlaidML/OpTest.Reciprocal.mlir index b66446c5d..ef6bc8138 100644 --- a/test/PlaidML/OpTest.Reciprocal.mlir +++ b/test/PlaidML/OpTest.Reciprocal.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0) -> ()> #map1 = affine_map<(d0) -> (d0)> module @reciprocal { diff --git a/test/PlaidML/OpTest.Relu.mlir b/test/PlaidML/OpTest.Relu.mlir index 25a91053b..ec4dea6f4 100644 --- a/test/PlaidML/OpTest.Relu.mlir +++ b/test/PlaidML/OpTest.Relu.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @relu { diff --git a/test/PlaidML/OpTest.RepeatElements.mlir b/test/PlaidML/OpTest.RepeatElements.mlir index 8a7102290..7bf329d03 100644 --- a/test/PlaidML/OpTest.RepeatElements.mlir +++ b/test/PlaidML/OpTest.RepeatElements.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 3, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module @repeat_elts { diff --git a/test/PlaidML/OpTest.ReshapeFold.mlir b/test/PlaidML/OpTest.ReshapeFold.mlir index 888a5eb23..0b99cf10a 100644 --- a/test/PlaidML/OpTest.ReshapeFold.mlir +++ b/test/PlaidML/OpTest.ReshapeFold.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @reshape_fold { func.func @main() { diff --git a/test/PlaidML/OpTest.ReshapeIntoScalar.mlir b/test/PlaidML/OpTest.ReshapeIntoScalar.mlir index 2ee68cb96..ee14b2343 100644 --- a/test/PlaidML/OpTest.ReshapeIntoScalar.mlir +++ b/test/PlaidML/OpTest.ReshapeIntoScalar.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<() -> ()> module @reshape_into_scalar { func.func @main() { diff --git a/test/PlaidML/OpTest.Select.mlir b/test/PlaidML/OpTest.Select.mlir index 15ae46c60..02bb11b80 100644 --- a/test/PlaidML/OpTest.Select.mlir +++ b/test/PlaidML/OpTest.Select.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @select { diff --git a/test/PlaidML/OpTest.Shape.mlir b/test/PlaidML/OpTest.Shape.mlir index bcc07e478..4288a61c8 100644 --- a/test/PlaidML/OpTest.Shape.mlir +++ b/test/PlaidML/OpTest.Shape.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @shape { func.func @main() { %0= arith.constant dense<[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]>:tensor<2x3xf32> diff --git a/test/PlaidML/OpTest.SimpleAdd.mlir b/test/PlaidML/OpTest.SimpleAdd.mlir index 2aa9b9213..855d823c2 100644 --- a/test/PlaidML/OpTest.SimpleAdd.mlir +++ b/test/PlaidML/OpTest.SimpleAdd.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @simple_add { func.func @main() { diff --git a/test/PlaidML/OpTest.Sin.mlir b/test/PlaidML/OpTest.Sin.mlir index 74198f5c3..df4b47502 100644 --- a/test/PlaidML/OpTest.Sin.mlir +++ b/test/PlaidML/OpTest.Sin.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @sin { func.func @main() { diff --git a/test/PlaidML/OpTest.SinH.mlir b/test/PlaidML/OpTest.SinH.mlir index 4a7225572..a80196627 100644 --- a/test/PlaidML/OpTest.SinH.mlir +++ b/test/PlaidML/OpTest.SinH.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @sinh { func.func @main() { diff --git a/test/PlaidML/OpTest.Softmax.mlir b/test/PlaidML/OpTest.Softmax.mlir index 4081a5a02..b55f1b65b 100644 --- a/test/PlaidML/OpTest.Softmax.mlir +++ b/test/PlaidML/OpTest.Softmax.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0, 0)> module @softmax { diff --git a/test/PlaidML/OpTest.Sum.dynamic.mlir b/test/PlaidML/OpTest.Sum.dynamic.mlir index d21cb0c90..0c490779d 100644 --- a/test/PlaidML/OpTest.Sum.dynamic.mlir +++ b/test/PlaidML/OpTest.Sum.dynamic.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @sum { diff --git a/test/PlaidML/OpTest.Sum.mlir b/test/PlaidML/OpTest.Sum.mlir index c17359ce8..48ecfa8c3 100644 --- a/test/PlaidML/OpTest.Sum.mlir +++ b/test/PlaidML/OpTest.Sum.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> ()> module @sum { diff --git a/test/PlaidML/OpTest.Tan.mlir b/test/PlaidML/OpTest.Tan.mlir index 12189ac1d..efd76387d 100644 --- a/test/PlaidML/OpTest.Tan.mlir +++ b/test/PlaidML/OpTest.Tan.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0, d1) -> (d0, d1)> module @tan { func.func @main() { diff --git a/test/PlaidML/OpTest.Transpose.mlir b/test/PlaidML/OpTest.Transpose.mlir index 7db708c75..4f7b8c8ab 100644 --- a/test/PlaidML/OpTest.Transpose.mlir +++ b/test/PlaidML/OpTest.Transpose.mlir @@ -10,6 +10,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> module @transpose { diff --git a/test/PlaidML/OpTest.UniqueNames.mlir b/test/PlaidML/OpTest.UniqueNames.mlir index 82bdce5b8..9433f86a1 100644 --- a/test/PlaidML/OpTest.UniqueNames.mlir +++ b/test/PlaidML/OpTest.UniqueNames.mlir @@ -9,7 +9,11 @@ // RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck #map = affine_map<(d0) -> (d0)> module @unique_names { func.func @main() { diff --git a/test/SPIRV/CppEdsl.Convolution_BF16.mlir b/test/SPIRV/CppEdsl.Convolution_BF16.mlir index a08ed05cd..34c50307a 100644 --- a/test/SPIRV/CppEdsl.Convolution_BF16.mlir +++ b/test/SPIRV/CppEdsl.Convolution_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @convolution attributes {gpu.container_module} { memref.global "private" constant @__constant_3x3x64x64xbf16 : memref<3x3x64x64xbf16> = dense<5.000000e-01> diff --git a/test/SPIRV/IntelVectorExtension/DPAS_Dynamic_Size_BF16.mlir b/test/SPIRV/IntelVectorExtension/DPAS_Dynamic_Size_BF16.mlir index 5bd009831..5216cfdc1 100644 --- a/test/SPIRV/IntelVectorExtension/DPAS_Dynamic_Size_BF16.mlir +++ b/test/SPIRV/IntelVectorExtension/DPAS_Dynamic_Size_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck /// A simple Matrix Multiplication using DPAS instruction diff --git a/test/SPIRV/IntelVectorExtension/DPAS_Static_Size_BF16.mlir b/test/SPIRV/IntelVectorExtension/DPAS_Static_Size_BF16.mlir index a0d15aa93..598573993 100644 --- a/test/SPIRV/IntelVectorExtension/DPAS_Static_Size_BF16.mlir +++ b/test/SPIRV/IntelVectorExtension/DPAS_Static_Size_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck /// A simple Matrix Multiplication using DPAS instruction /// A and B are in bf16, while the result C is f32 diff --git a/test/SPIRV/IntelVectorExtension/DPAS_raw_send.mlir b/test/SPIRV/IntelVectorExtension/DPAS_raw_send.mlir index a02705a56..6e988b9f5 100644 --- a/test/SPIRV/IntelVectorExtension/DPAS_raw_send.mlir +++ b/test/SPIRV/IntelVectorExtension/DPAS_raw_send.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck /// A simple Matrix Multiplication using DPAS instruction diff --git a/test/SPIRV/IntelVectorExtension/GEMM_4kx4kx4k_BF16.mlir b/test/SPIRV/IntelVectorExtension/GEMM_4kx4kx4k_BF16.mlir index 5582f6ab4..2523dd4b6 100644 --- a/test/SPIRV/IntelVectorExtension/GEMM_4kx4kx4k_BF16.mlir +++ b/test/SPIRV/IntelVectorExtension/GEMM_4kx4kx4k_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module attributes {gpu.container_module} { diff --git a/test/SPIRV/IntelVectorExtension/Load_1d_raw_send.mlir b/test/SPIRV/IntelVectorExtension/Load_1d_raw_send.mlir index eb27e5744..ca28c8782 100644 --- a/test/SPIRV/IntelVectorExtension/Load_1d_raw_send.mlir +++ b/test/SPIRV/IntelVectorExtension/Load_1d_raw_send.mlir @@ -6,7 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // A simple test case showing how to use raw_send2 VC intrinsics for doing a load1d module attributes {gpu.container_module} { diff --git a/test/SPIRV/IntelVectorExtension/Load_1d_slm.mlir b/test/SPIRV/IntelVectorExtension/Load_1d_slm.mlir index f4f9a7627..163ad65cb 100644 --- a/test/SPIRV/IntelVectorExtension/Load_1d_slm.mlir +++ b/test/SPIRV/IntelVectorExtension/Load_1d_slm.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module attributes {gpu.container_module} { diff --git a/test/SPIRV/IntelVectorExtension/Load_2d_raw_send.mlir b/test/SPIRV/IntelVectorExtension/Load_2d_raw_send.mlir index 9f44d7fcf..77dec0cb5 100644 --- a/test/SPIRV/IntelVectorExtension/Load_2d_raw_send.mlir +++ b/test/SPIRV/IntelVectorExtension/Load_2d_raw_send.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck /// A simple load2d/store2d example /// This example loads and stores 16x16xf32 elements using raw_send2/store2d diff --git a/test/SPIRV/IntelVectorExtension/Store2d_raw_send.mlir b/test/SPIRV/IntelVectorExtension/Store2d_raw_send.mlir index e06d80ae0..e07c13314 100644 --- a/test/SPIRV/IntelVectorExtension/Store2d_raw_send.mlir +++ b/test/SPIRV/IntelVectorExtension/Store2d_raw_send.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%irunner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck /// A simple load2d/store2d example /// This example loads and stores 16x16xf32 elements using load2d/raw_sends2 diff --git a/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir b/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir index 288db45e1..9bfdfc848 100644 --- a/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir +++ b/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck // CHECK-COUNT-4194304: 4970.23 module @gemm_using_jointmatrix_module attributes {gpu.container_module} { memref.global "private" constant @__constant_A_2048x2048xbf16 : memref<2048x2048xbf16> = dense<1.100000e+00> diff --git a/test/SPIRV/OpTest.ArgMax_BF16.mlir b/test/SPIRV/OpTest.ArgMax_BF16.mlir index 3e634bcfd..c22985981 100644 --- a/test/SPIRV/OpTest.ArgMax_BF16.mlir +++ b/test/SPIRV/OpTest.ArgMax_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @argmax attributes {gpu.container_module} { memref.global "private" constant @__constant_1x4x4x3xbf16 : memref<1x4x4x3xbf16> = dense<[[[[9.000000e+00, 8.000000e+00, 0.000000e+00], [1.000000e+00, 5.000000e+00, 0.000000e+00], [1.000000e+00, 1.000000e+00, 7.000000e+00], [8.000000e+00, 2.000000e+00, 2.000000e+00]], [[8.000000e+00, 0.000000e+00, 4.000000e+00], [7.000000e+00, 5.000000e+00, 5.000000e+00], [8.000000e+00, 2.000000e+00, 0.000000e+00], [0.000000e+00, 9.000000e+00, 5.000000e+00]], [[4.000000e+00, 7.000000e+00, 2.000000e+00], [4.000000e+00, 5.000000e+00, 1.000000e+00], [3.000000e+00, 3.000000e+00, 6.000000e+00], [8.000000e+00, 0.000000e+00, 1.000000e+00]], [[2.000000e+00, 8.000000e+00, 4.000000e+00], [0.000000e+00, 5.000000e+00, 5.000000e+00], [6.000000e+00, 1.000000e+00, 1.000000e+00], [3.000000e+00, 3.000000e+00, 1.000000e+00]]]]> diff --git a/test/SPIRV/OpTest.Argmax_FLOAT32.mlir b/test/SPIRV/OpTest.Argmax_FLOAT32.mlir index 5555853d1..6af9f9a8d 100644 --- a/test/SPIRV/OpTest.Argmax_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Argmax_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @argmax attributes {gpu.container_module} { memref.global "private" constant @__constant_1x4x4x3xf32 : memref<1x4x4x3xf32> = dense<[[[[9.000000e+00, 8.000000e+00, 0.000000e+00], [1.000000e+00, 5.000000e+00, 0.000000e+00], [1.000000e+00, 1.000000e+00, 7.000000e+00], [8.000000e+00, 2.000000e+00, 2.000000e+00]], [[8.000000e+00, 0.000000e+00, 4.000000e+00], [7.000000e+00, 5.000000e+00, 5.000000e+00], [8.000000e+00, 2.000000e+00, 0.000000e+00], [0.000000e+00, 9.000000e+00, 5.000000e+00]], [[4.000000e+00, 7.000000e+00, 2.000000e+00], [4.000000e+00, 5.000000e+00, 1.000000e+00], [3.000000e+00, 3.000000e+00, 6.000000e+00], [8.000000e+00, 0.000000e+00, 1.000000e+00]], [[2.000000e+00, 8.000000e+00, 4.000000e+00], [0.000000e+00, 5.000000e+00, 5.000000e+00], [6.000000e+00, 1.000000e+00, 1.000000e+00], [3.000000e+00, 3.000000e+00, 1.000000e+00]]]]> diff --git a/test/SPIRV/OpTest.BroadcastNonNumpy_BF16.mlir b/test/SPIRV/OpTest.BroadcastNonNumpy_BF16.mlir index ac7c92939..c31119d2a 100644 --- a/test/SPIRV/OpTest.BroadcastNonNumpy_BF16.mlir +++ b/test/SPIRV/OpTest.BroadcastNonNumpy_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @broadcast_non_numpy attributes {gpu.container_module} { memref.global "private" constant @__constant_3xbf16 : memref<3xbf16> = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> diff --git a/test/SPIRV/OpTest.BroadcastNonNumpy_FLOAT32.mlir b/test/SPIRV/OpTest.BroadcastNonNumpy_FLOAT32.mlir index 8e8490994..8df9effff 100644 --- a/test/SPIRV/OpTest.BroadcastNonNumpy_FLOAT32.mlir +++ b/test/SPIRV/OpTest.BroadcastNonNumpy_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @broadcast_non_numpy attributes {gpu.container_module} { memref.global "private" constant @__constant_3xf32 : memref<3xf32> = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> diff --git a/test/SPIRV/OpTest.Conv2D_FLOAT32.mlir b/test/SPIRV/OpTest.Conv2D_FLOAT32.mlir index 2879b769e..618ec5d84 100644 --- a/test/SPIRV/OpTest.Conv2D_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Conv2D_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @complex_conv_2d attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.EltwiseAdd_BF16.mlir b/test/SPIRV/OpTest.EltwiseAdd_BF16.mlir index 035c1d1b8..becb4eefc 100644 --- a/test/SPIRV/OpTest.EltwiseAdd_BF16.mlir +++ b/test/SPIRV/OpTest.EltwiseAdd_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @eltwise_add attributes {gpu.container_module} { memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], diff --git a/test/SPIRV/OpTest.EltwiseAdd_FLOAT32.mlir b/test/SPIRV/OpTest.EltwiseAdd_FLOAT32.mlir index dd2f97998..051acdd40 100644 --- a/test/SPIRV/OpTest.EltwiseAdd_FLOAT32.mlir +++ b/test/SPIRV/OpTest.EltwiseAdd_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @eltwise_add attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.ExplicitPadding_FLOAT32.mlir b/test/SPIRV/OpTest.ExplicitPadding_FLOAT32.mlir index 3f01dc794..5d44cde8a 100644 --- a/test/SPIRV/OpTest.ExplicitPadding_FLOAT32.mlir +++ b/test/SPIRV/OpTest.ExplicitPadding_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @explicit_padding attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.GEMM_BF16.mlir b/test/SPIRV/OpTest.GEMM_BF16.mlir index 54eaf982a..884f796c3 100644 --- a/test/SPIRV/OpTest.GEMM_BF16.mlir +++ b/test/SPIRV/OpTest.GEMM_BF16.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_3x3xbf16_result : memref<3x3xbf16> = dense<1.000000e+00> diff --git a/test/SPIRV/OpTest.GEMM_BF16_ACC_F32.mlir b/test/SPIRV/OpTest.GEMM_BF16_ACC_F32.mlir index 1347ff0ea..3eae1b633 100644 --- a/test/SPIRV/OpTest.GEMM_BF16_ACC_F32.mlir +++ b/test/SPIRV/OpTest.GEMM_BF16_ACC_F32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_3x3xf32 : memref<3x3xf32> = dense<1.000000e+00> diff --git a/test/SPIRV/OpTest.GEMM_F16_ACC_F32.mlir b/test/SPIRV/OpTest.GEMM_F16_ACC_F32.mlir index 4d9675596..bf32c16f8 100644 --- a/test/SPIRV/OpTest.GEMM_F16_ACC_F32.mlir +++ b/test/SPIRV/OpTest.GEMM_F16_ACC_F32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @gemm attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.MaxPool1D_INT64.mlir b/test/SPIRV/OpTest.MaxPool1D_INT64.mlir index c84cf00ea..9303564a5 100644 --- a/test/SPIRV/OpTest.MaxPool1D_INT64.mlir +++ b/test/SPIRV/OpTest.MaxPool1D_INT64.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @max_pool_1d attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.Quantize_FLOAT32.mlir b/test/SPIRV/OpTest.Quantize_FLOAT32.mlir index 1dfb018bc..b1db7a019 100644 --- a/test/SPIRV/OpTest.Quantize_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Quantize_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @quantize attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.Relu_FLOAT32.mlir b/test/SPIRV/OpTest.Relu_FLOAT32.mlir index 942f67dc0..8cb3117ee 100644 --- a/test/SPIRV/OpTest.Relu_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Relu_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @relu attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.SlmDynamic.mlir b/test/SPIRV/OpTest.SlmDynamic.mlir index a96336d3e..e148cfd5b 100644 --- a/test/SPIRV/OpTest.SlmDynamic.mlir +++ b/test/SPIRV/OpTest.SlmDynamic.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @slm attributes {gpu.container_module} { memref.global "private" constant @__constant_4x128xf32 : memref<4x128xf32> = dense<[ diff --git a/test/SPIRV/OpTest.Softmax_FLOAT32.mlir b/test/SPIRV/OpTest.Softmax_FLOAT32.mlir index bf8a8133a..11285a4ba 100644 --- a/test/SPIRV/OpTest.Softmax_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Softmax_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @softmax attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.Sum_FLOAT32.mlir b/test/SPIRV/OpTest.Sum_FLOAT32.mlir index 8b11d4e89..0f57e77eb 100644 --- a/test/SPIRV/OpTest.Sum_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Sum_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @sum attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.Transpose_FLOAT32.mlir b/test/SPIRV/OpTest.Transpose_FLOAT32.mlir index 73ef05c1e..3883487df 100644 --- a/test/SPIRV/OpTest.Transpose_FLOAT32.mlir +++ b/test/SPIRV/OpTest.Transpose_FLOAT32.mlir @@ -6,6 +6,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @transpose attributes {gpu.container_module} { diff --git a/test/SPIRV/OpTest.spirv.CL.printf.mlir b/test/SPIRV/OpTest.spirv.CL.printf.mlir index a7be187f6..9464792d3 100644 --- a/test/SPIRV/OpTest.spirv.CL.printf.mlir +++ b/test/SPIRV/OpTest.spirv.CL.printf.mlir @@ -2,6 +2,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @print_simple attributes {gpu.container_module} { diff --git a/test/SPIRV/relu.slm.static.8x32.mlir b/test/SPIRV/relu.slm.static.8x32.mlir index e73eda5fd..a1e07d3bb 100644 --- a/test/SPIRV/relu.slm.static.8x32.mlir +++ b/test/SPIRV/relu.slm.static.8x32.mlir @@ -2,6 +2,10 @@ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=opencl-runtime -i %s --pass-pipeline-file=%p/spirv-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime --filecheck module @test attributes {gpu.container_module} { memref.global "private" constant @__constant_8x32xf32 : memref<8x32xf32> = dense<[ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 4d7e9890f..fa19b9d04 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -41,6 +41,8 @@ config.substitutions.append(('%python_executable', config.python_executable)) if config.imex_enable_sycl_runtime: config.substitutions.append(('%sycl_runtime', config.sycl_runtime)) +if config.imex_enable_opencl_runtime: + config.substitutions.append(('%opencl_runtime', config.opencl_runtime)) if config.imex_enable_l0_runtime: config.substitutions.append(('%levelzero_runtime', config.levelzero_runtime)) if config.imex_enable_igpu: diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 829e46573..a71dc657a 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -37,6 +37,7 @@ config.imex_tools_dir = "@IMEX_TOOLS_DIR@" config.imex_lib_dir = "@IMEX_LIB_DIR@" config.imex_enable_l0_runtime = @IMEX_ENABLE_L0_RUNTIME@ config.imex_enable_sycl_runtime = @IMEX_ENABLE_SYCL_RUNTIME@ +config.imex_enable_opencl_runtime = @IMEX_ENABLE_OPENCL_RUNTIME@ config.imex_enable_bf16_tests = @IMEX_ENABLE_BF16_TESTS@ config.imex_enable_excluded_tests = @IMEX_ENABLE_EXCLUDED_TESTS@ config.imex_enable_ats_target = @IMEX_ENABLE_ATS_TARGET@ @@ -53,9 +54,11 @@ if config.enable_vulkan_runner: config.vulkan_runtime_wrappers = os.path.normpath(os.path.join(config.mlir_runner_utils_dir, config.shlib_prefix + "vulkan-runtime-wrappers" + config.llvm_shlib_ext)) if config.imex_enable_sycl_runtime: config.sycl_runtime = os.path.normpath(os.path.join(config.imex_lib_dir, config.shlib_prefix + "sycl-runtime" + config.llvm_shlib_ext)) +if config.imex_enable_opencl_runtime: + config.opencl_runtime = os.path.normpath(os.path.join(config.imex_lib_dir, config.shlib_prefix + "opencl-runtime" + config.llvm_shlib_ext)) if config.imex_enable_l0_runtime: config.levelzero_runtime = os.path.normpath(os.path.join(config.imex_lib_dir, config.shlib_prefix + "level-zero-runtime" + config.llvm_shlib_ext)) -config.imex_enable_igpu = config.imex_enable_l0_runtime or config.imex_enable_sycl_runtime +config.imex_enable_igpu = config.imex_enable_l0_runtime or config.imex_enable_sycl_runtime or config.imex_enable_opencl_runtime if config.imex_enable_igpu: config.l0_fp64_checker = os.path.normpath(os.path.join(config.imex_tools_dir, "l0-fp64-checker")) try: diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 41a1f23f0..b9cf66b55 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(imex-runner) add_subdirectory(imex-opt) -if(IMEX_ENABLE_L0_RUNTIME OR IMEX_ENABLE_SYCL_RUNTIME) +if((IMEX_ENABLE_L0_RUNTIME OR IMEX_ENABLE_SYCL_RUNTIME) OR IMEX_ENABLE_OPENCL_RUNTIME) add_subdirectory(l0-fp64-checker) endif() add_subdirectory(imex-cpu-runner) diff --git a/tools/imex-runner/imex-runner.py.in b/tools/imex-runner/imex-runner.py.in index a94a9a8b7..48af07abc 100644 --- a/tools/imex-runner/imex-runner.py.in +++ b/tools/imex-runner/imex-runner.py.in @@ -53,10 +53,11 @@ import subprocess imex_enable_vulkan_runner = @IMEX_ENABLE_VULKAN_RUNNER@ imex_enable_l0_runtime = @IMEX_ENABLE_L0_RUNTIME@ imex_enable_sycl_runtime = @IMEX_ENABLE_SYCL_RUNTIME@ +imex_enable_opencl_runtime = @IMEX_ENABLE_OPENCL_RUNTIME@ runner_choices = ['imex-cpu-runner'] enabled_features = [] -all_features = ['vulkan-runner', 'l0-runtime', 'sycl-runtime', 'igpu-fp64'] +all_features = ['vulkan-runner', 'l0-runtime', 'sycl-runtime', 'opencl-runtime', 'igpu-fp64'] if imex_enable_vulkan_runner: runner_choices.append('mlir-vulkan-runner') enabled_features.append('vulkan-runner') @@ -64,6 +65,8 @@ if imex_enable_l0_runtime: enabled_features.append('l0-runtime') if imex_enable_sycl_runtime: enabled_features.append('sycl-runtime') +if imex_enable_opencl_runtime: + enabled_features.append('opencl-runtime') class SplitArgs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None):