Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD committed Apr 30, 2024
1 parent 6889198 commit a71c2e4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace Pennylane::LightningTensor::Cutn {
*/

template <class PrecisionT>
class CudaTensor final: public TensorBase<PrecisionT, CudaTensor<PrecisionT>> {
class CudaTensor final : public TensorBase<PrecisionT, CudaTensor<PrecisionT>> {
public:
using BaseType = TensorBase<PrecisionT, CudaTensor>;
using CFP_t = decltype(cuUtil::getCudaType(PrecisionT{}));
Expand All @@ -52,7 +52,7 @@ class CudaTensor final: public TensorBase<PrecisionT, CudaTensor<PrecisionT>> {
: TensorBase<PrecisionT, CudaTensor<PrecisionT>>(rank, modes, extents),
data_buffer_{std::make_shared<DataBuffer<CFP_t>>(
BaseType::getLength(), dev_tag, device_alloc)} {}

CudaTensor() = delete;

~CudaTensor() final = default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#include <catch2/catch.hpp>
#include <cmath>

#include "DevTag.hpp"
#include "CudaTensor.hpp"
#include "DevTag.hpp"

#include "TestHelpers.hpp"

Expand All @@ -42,11 +42,12 @@ TEMPLATE_PRODUCT_TEST_CASE("CudaTensor::Constructibility",
using TensorT = TestType;

SECTION("TensorT<TestType>") { REQUIRE(!std::is_constructible_v<TensorT>); }
SECTION("TensorT<TestType> {const size_t, const std::vector<size_t> &, const "
"std::vector<size_t>&, DevTag<int> &}") {
REQUIRE(std::is_constructible_v<TensorT, const size_t, const std::vector<size_t> &,
const std::vector<size_t> &,
DevTag<int> &>);
SECTION(
"TensorT<TestType> {const size_t, const std::vector<size_t> &, const "
"std::vector<size_t>&, DevTag<int> &}") {
REQUIRE(std::is_constructible_v<
TensorT, const size_t, const std::vector<size_t> &,
const std::vector<size_t> &, DevTag<int> &>);
}
}

Expand All @@ -59,9 +60,7 @@ TEMPLATE_TEST_CASE("CudaTensor::baseMethods", "[CudaTensor]", float, double) {

CudaTensor<TestType> tensor{rank, modes, extents, dev_tag};

SECTION("getRank()") {
CHECK(tensor.getRank() == rank);
}
SECTION("getRank()") { CHECK(tensor.getRank() == rank); }

SECTION("getModes()") {
CHECK(tensor.getModes() == Pennylane::Util::approx(modes));
Expand All @@ -71,8 +70,5 @@ TEMPLATE_TEST_CASE("CudaTensor::baseMethods", "[CudaTensor]", float, double) {
CHECK(tensor.getExtents() == Pennylane::Util::approx(extents));
}

SECTION("getLength()") {
CHECK(tensor.getLength() == length);
}

SECTION("getLength()") { CHECK(tensor.getLength() == length); }
}

0 comments on commit a71c2e4

Please sign in to comment.