Skip to content

Commit

Permalink
update memory set
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD committed May 3, 2024
1 parent 0a91c9e commit 1318aa5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cmake/support_pltensorcutn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ macro(findCutensornet external_libs)

get_filename_component(CUTENSORNET_INC_DIR ${CUTENSORNET_INC} DIRECTORY)
target_include_directories(cutensornet INTERFACE ${CUTENSORNET_INC_DIR})
set_target_properties( cutensornet PROPERTIES IMPORTED_LOCATION ${CUTENSORNET_LIB})
set_target_properties(cutensornet PROPERTIES IMPORTED_LOCATION ${CUTENSORNET_LIB})

target_link_libraries(${external_libs} INTERFACE cutensornet)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

/**
* @file CutnBase.hpp
* Base class for classes backed by the cuTensorNet library.
* Base class for cuTensorNet-backed tensor networks.
*/

#pragma once
Expand Down Expand Up @@ -111,7 +111,7 @@ class CutnBase : public TensornetBase<Precision, Derived> {

protected:
/**
* @brief Get the memory size used for a work space
* @brief Returns the workspace size.
*
* @return std::size_t
*/
Expand All @@ -127,12 +127,15 @@ class CutnBase : public TensornetBase<Precision, Derived> {
/* cutensornetMemspace_t*/ CUTENSORNET_MEMSPACE_DEVICE,
/* cutensornetWorkspaceKind_t */ CUTENSORNET_WORKSPACE_SCRATCH,
/* int64_t * */ &worksize));

// Ensure data is aligned by 256 bytes
worksize += int64_t{256} - worksize % int64_t{256};

return static_cast<std::size_t>(worksize);
}

/**
* @brief Set the memory for a work space
* @brief Set memory for a workspace.
*
* @param workDesc cutensornet work space descriptor
* @param scratchPtr Pointer to scratch memory
Expand Down Expand Up @@ -174,9 +177,6 @@ class CutnBase : public TensornetBase<Precision, Derived> {

std::size_t worksize = getWorkSpaceMemorySize(workDesc);

// Ensure data is aligned by 256 bytes
worksize += std::size_t{256} - worksize % std::size_t{256};

PL_ABORT_IF(std::size_t(worksize) > scratchSize,
"Insufficient workspace size on Device!");

Expand Down

0 comments on commit 1318aa5

Please sign in to comment.