Skip to content

Commit

Permalink
update base on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD committed May 2, 2024
1 parent 81e745e commit acc870f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class CutnBase : public TensornetBase<Precision, Derived> {
public:
CutnBase() = delete;

CutnBase(const std::size_t numQubits, DevTag<int> &dev_tag)
explicit CutnBase(const std::size_t numQubits, DevTag<int> &dev_tag)
: BaseType(numQubits), handle_(make_shared_cutn_handle()),
dev_tag_(dev_tag) {
initHelper_();
}

CutnBase(const std::size_t numQubits, int device_id = 0,
cudaStream_t stream_id = 0)
explicit CutnBase(const std::size_t numQubits, int device_id = 0,
cudaStream_t stream_id = 0)
: BaseType(numQubits), handle_(make_shared_cutn_handle()),
dev_tag_({device_id, stream_id}) {
initHelper_();
Expand Down Expand Up @@ -109,6 +109,7 @@ class CutnBase : public TensornetBase<Precision, Derived> {
return dev_tag_;
}

protected:
/**
* @brief Get the memory size used for a work space
*
Expand All @@ -127,7 +128,7 @@ class CutnBase : public TensornetBase<Precision, Derived> {
/* cutensornetWorkspaceKind_t */ CUTENSORNET_WORKSPACE_SCRATCH,
/* int64_t * */ &worksize));

return worksize;
return static_cast<std::size_t>(worksize);
}

/**
Expand All @@ -138,17 +139,16 @@ class CutnBase : public TensornetBase<Precision, Derived> {
* @param worksize Memory size of a work space
*/
void setWorkSpaceMemory(cutensornetWorkspaceDescriptor_t &workDesc,
void *scratchPtr, int64_t &worksize) {
void *scratchPtr, size_t worksize) {
PL_CUTENSORNET_IS_SUCCESS(cutensornetWorkspaceSetMemory(
/* const cutensornetHandle_t */ getCutnHandle(),
/* cutensornetWorkspaceDescriptor_t */ workDesc,
/* cutensornetMemspace_t*/ CUTENSORNET_MEMSPACE_DEVICE,
/* cutensornetWorkspaceKind_t */ CUTENSORNET_WORKSPACE_SCRATCH,
/* void *const */ scratchPtr,
/* int64_t */ worksize));
/* int64_t */ static_cast<int64_t>(worksize)));
}

protected:
/**
* @brief Save quantumState information to data provided by a user
*
Expand All @@ -172,12 +172,12 @@ class CutnBase : public TensornetBase<Precision, Derived> {
/* cutensornetWorkspaceDescriptor_t */ workDesc,
/* cudaStream_t unused in v24.03*/ 0x0));

int64_t worksize = getWorkSpaceMemorySize(workDesc);
std::size_t worksize = getWorkSpaceMemorySize(workDesc);

// Ensure data is aligned by 256 bytes
worksize += int64_t{256} - worksize % int64_t{256};
worksize += std::size_t{256} - worksize % std::size_t{256};

PL_ABORT_IF(static_cast<std::size_t>(worksize) > scratchSize,
PL_ABORT_IF(std::size_t(worksize) > scratchSize,
"Insufficient workspace size on Device!");

const std::size_t d_scratch_length = worksize / sizeof(std::size_t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ class MPSCutn final : public CutnBase<Precision, MPSCutn<Precision>> {
[[nodiscard]] auto getTensorsDataPtr() -> std::vector<uint64_t *> {
std::vector<uint64_t *> tensorsDataPtr(BaseType::getNumQubits());
for (std::size_t i = 0; i < BaseType::getNumQubits(); i++) {
tensorsDataPtr[i] =
reinterpret_cast<uint64_t *>(tensors_[i].getDataBuffer().getData());
tensorsDataPtr[i] = reinterpret_cast<uint64_t *>(
tensors_[i].getDataBuffer().getData());
}
return tensorsDataPtr;
}
Expand Down Expand Up @@ -335,7 +335,7 @@ class MPSCutn final : public CutnBase<Precision, MPSCutn<Precision>> {
CUTENSORNET_BOUNDARY_CONDITION_OPEN,
/*const int64_t *const*/ extentsIn,
/*const int64_t *const*/ nullptr,
/*void **/ reinterpret_cast<void**>(tensorsIn)));
/*void **/ reinterpret_cast<void **>(tensorsIn)));
}
};
} // namespace Pennylane::LightningTensor::Cutn

0 comments on commit acc870f

Please sign in to comment.