diff --git a/cmake/support_pltensorcutn.cmake b/cmake/support_pltensorcutn.cmake index 33f9f166d0..1fca27ebc5 100644 --- a/cmake/support_pltensorcutn.cmake +++ b/cmake/support_pltensorcutn.cmake @@ -47,7 +47,7 @@ macro(findCutensornet external_libs) get_filename_component(CUTENSORNET_INC_DIR ${CUTENSORNET_INC} DIRECTORY) target_include_directories(cutensornet INTERFACE ${CUTENSORNET_INC_DIR}) - set_target_properties( cutensornet PROPERTIES IMPORTED_LOCATION ${CUTENSORNET_LIB}) + set_target_properties(cutensornet PROPERTIES IMPORTED_LOCATION ${CUTENSORNET_LIB}) target_link_libraries(${external_libs} INTERFACE cutensornet) endif() diff --git a/pennylane_lightning/core/src/simulators/lightning_tensor/cutn/CutnBase.hpp b/pennylane_lightning/core/src/simulators/lightning_tensor/cutn/CutnBase.hpp index 8a07373994..d4c5c96560 100644 --- a/pennylane_lightning/core/src/simulators/lightning_tensor/cutn/CutnBase.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_tensor/cutn/CutnBase.hpp @@ -14,7 +14,7 @@ /** * @file CutnBase.hpp - * Base class for classes backed by the cuTensorNet library. + * Base class for cuTensorNet-backed tensor networks. */ #pragma once @@ -111,7 +111,7 @@ class CutnBase : public TensornetBase { protected: /** - * @brief Get the memory size used for a work space + * @brief Returns the workspace size. * * @return std::size_t */ @@ -127,12 +127,15 @@ class CutnBase : public TensornetBase { /* cutensornetMemspace_t*/ CUTENSORNET_MEMSPACE_DEVICE, /* cutensornetWorkspaceKind_t */ CUTENSORNET_WORKSPACE_SCRATCH, /* int64_t * */ &worksize)); + + // Ensure data is aligned by 256 bytes + worksize += int64_t{256} - worksize % int64_t{256}; return static_cast(worksize); } /** - * @brief Set the memory for a work space + * @brief Set memory for a workspace. * * @param workDesc cutensornet work space descriptor * @param scratchPtr Pointer to scratch memory @@ -174,9 +177,6 @@ class CutnBase : public TensornetBase { std::size_t worksize = getWorkSpaceMemorySize(workDesc); - // Ensure data is aligned by 256 bytes - worksize += std::size_t{256} - worksize % std::size_t{256}; - PL_ABORT_IF(std::size_t(worksize) > scratchSize, "Insufficient workspace size on Device!");