Skip to content

Commit

Permalink
modify code for cuda 16bit
Browse files Browse the repository at this point in the history
  • Loading branch information
KowerKoint authored and KowerKoint committed Dec 13, 2024
1 parent b721756 commit 6c2de1e
Show file tree
Hide file tree
Showing 42 changed files with 1,478 additions and 470 deletions.
12 changes: 12 additions & 0 deletions exe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,17 @@ int main() {
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(ed - st).count()
<< std::endl;
}
{
using Fp = scaluq::BF16;
scaluq::StateVector<Fp> state(n_qubits);
auto st = std::chrono::system_clock::now();
for (int i = 0; i < 10000; i++) {
auto x_gate = scaluq::gate::X<Fp>(dist(mt));
x_gate->update_quantum_state(state);
}
auto ed = std::chrono::system_clock::now();
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(ed - st).count()
<< std::endl;
}
Kokkos::finalize();
}
4 changes: 2 additions & 2 deletions include/scaluq/circuit/circuit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace scaluq {

template <std::floating_point Fp>
template <FloatingPoint Fp>
class Circuit {
public:
using GateWithKey = std::variant<Gate<Fp>, std::pair<ParamGate<Fp>, std::string>>;
Expand Down Expand Up @@ -76,7 +76,7 @@ class Circuit {

#ifdef SCALUQ_USE_NANOBIND
namespace internal {
template <std::floating_point Fp>
template <FloatingPoint Fp>
void bind_circuit_circuit_hpp(nb::module_& m) {
nb::class_<Circuit<Fp>>(m, "Circuit", "Quantum circuit represented as gate array")
.def(nb::init<std::uint64_t>(), "Initialize empty circuit of specified qubits.")
Expand Down
48 changes: 24 additions & 24 deletions include/scaluq/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ KOKKOS_INLINE_FUNCTION
constexpr double SINPI8() { return 0.382683432365090; }

//! identity matrix
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> I_GATE() {
return {1, 0, 0, 1};
}
//! Pauli matrix X
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> X_GATE() {
return {0, 1, 1, 0};
}
//! Pauli matrix Y
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> Y_GATE() {
return {0, Complex<Fp>(0, -1), Complex<Fp>(0, 1), 0};
}
//! Pauli matrix Z
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> Z_GATE() {
return {1, 0, 0, -1};
}
Expand All @@ -47,84 +47,84 @@ KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> Z_GATE() {
// Y_GATE, Z_GATE};

//! S-gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> S_GATE_MATRIX() {
return {1, 0, 0, Complex<Fp>(0, 1)};
}
//! Sdag-gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> S_DAG_GATE_MATRIX() {
return {1, 0, 0, Complex<Fp>(0, -1)};
}
//! T-gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> T_GATE_MATRIX() {
return {1, 0, 0, Complex<Fp>(INVERSE_SQRT2(), INVERSE_SQRT2())};
}
//! Tdag-gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> T_DAG_GATE_MATRIX() {
return {1, 0, 0, Complex<Fp>(INVERSE_SQRT2(), -INVERSE_SQRT2())};
}
//! Hadamard gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> HADAMARD_MATRIX() {
constexpr Fp ISQRT2 = static_cast<Fp>(INVERSE_SQRT2());
Fp ISQRT2 = static_cast<Fp>(INVERSE_SQRT2());
return {ISQRT2, ISQRT2, ISQRT2, -ISQRT2};
}
//! square root of X gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> SQRT_X_GATE_MATRIX() {
constexpr Fp HALF = static_cast<Fp>(0.5);
Fp HALF = static_cast<Fp>(0.5);
return {Complex<Fp>(HALF, HALF),
Complex<Fp>(HALF, -HALF),
Complex<Fp>(HALF, -HALF),
Complex<Fp>(HALF, HALF)};
}
//! square root of Y gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> SQRT_Y_GATE_MATRIX() {
constexpr Fp HALF = static_cast<Fp>(0.5);
Fp HALF = static_cast<Fp>(0.5);
return {Complex<Fp>(HALF, HALF),
Complex<Fp>(-HALF, -HALF),
Complex<Fp>(HALF, HALF),
Complex<Fp>(HALF, HALF)};
}
//! square root dagger of X gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> SQRT_X_DAG_GATE_MATRIX() {
constexpr Fp HALF = static_cast<Fp>(0.5);
Fp HALF = static_cast<Fp>(0.5);
return {Complex<Fp>(HALF, -HALF),
Complex<Fp>(HALF, HALF),
Complex<Fp>(HALF, HALF),
Complex<Fp>(HALF, -HALF)};
}
//! square root dagger of Y gate
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> SQRT_Y_DAG_GATE_MATRIX() {
constexpr Fp HALF = static_cast<Fp>(0.5);
Fp HALF = static_cast<Fp>(0.5);
return {Complex<Fp>(HALF, -HALF),
Complex<Fp>(HALF, -HALF),
Complex<Fp>(-HALF, HALF),
Complex<Fp>(HALF, -HALF)};
}
//! Projection to 0
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> PROJ_0_MATRIX() {
return {1, 0, 0, 0};
return {Fp{1}, Fp{0}, Fp{0}, Fp{0}};
}
//! Projection to 1
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Matrix2x2<Fp> PROJ_1_MATRIX() {
return {0, 0, 0, 1};
return {Fp{0}, Fp{0}, Fp{0}, Fp{1}};
}
//! complex values for exp(j * i*pi/4 )
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Kokkos::Array<Complex<Fp>, 4> PHASE_90ROT() {
return {Fp{1}, Complex<Fp>(0, 1), Fp{-1}, Complex<Fp>(0, -1)};
}
//! complex values for exp(-j * i*pi/4 )
template <std::floating_point Fp>
template <FloatingPoint Fp>
KOKKOS_INLINE_FUNCTION Kokkos::Array<Complex<Fp>, 4> PHASE_M90ROT() {
return {Fp{1}, Complex<Fp>(0, -1), Fp{-1}, Complex<Fp>(0, 1)};
}
Expand Down
72 changes: 36 additions & 36 deletions include/scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,68 +8,68 @@ namespace scaluq {
namespace internal {
// forward declarations

template <std::floating_point Fp>
template <FloatingPoint Fp>
class GateBase;

template <std::floating_point Fp>
template <FloatingPoint Fp>
class IGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class GlobalPhaseGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class XGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class YGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class ZGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class HGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SdagGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class TGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class TdagGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SqrtXGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SqrtXdagGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SqrtYGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SqrtYdagGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class P0GateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class P1GateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class RXGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class RYGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class RZGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class U1GateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class U2GateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class U3GateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class OneTargetMatrixGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SwapGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class TwoTargetMatrixGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class PauliGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class PauliRotationGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class ProbablisticGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class SparseMatrixGateImpl;
template <std::floating_point Fp>
template <FloatingPoint Fp>
class DenseMatrixGateImpl;

} // namespace internal
Expand Down Expand Up @@ -108,7 +108,7 @@ enum class GateType {
Probablistic
};

template <typename T, std::floating_point S>
template <typename T, FloatingPoint S>
constexpr GateType get_gate_type() {
using TWithoutConst = std::remove_cv_t<T>;
if constexpr (std::is_same_v<TWithoutConst, internal::GateBase<S>>)
Expand Down Expand Up @@ -179,7 +179,7 @@ constexpr GateType get_gate_type() {

namespace internal {
// GateBase テンプレートクラス
template <std::floating_point _FloatType>
template <FloatingPoint _FloatType>
class GateBase : public std::enable_shared_from_this<GateBase<_FloatType>> {
public:
using Fp = _FloatType;
Expand Down Expand Up @@ -287,7 +287,7 @@ class GatePtr {

} // namespace internal

template <std::floating_point Fp>
template <FloatingPoint Fp>
using Gate = internal::GatePtr<internal::GateBase<Fp>>;

#ifdef SCALUQ_USE_NANOBIND
Expand Down Expand Up @@ -342,7 +342,7 @@ namespace internal {
[](const GATE_TYPE<FLOAT>& gate) { return gate->to_string(""); }, \
"Get string representation of the gate.")

template <std::floating_point Fp>
template <FloatingPoint Fp>
nb::class_<Gate<Fp>> gate_base_def;

#define DEF_GATE(GATE_TYPE, FLOAT, DESCRIPTION) \
Expand Down Expand Up @@ -386,7 +386,7 @@ void bind_gate_gate_hpp_without_precision(nb::module_& m) {
.value("PauliRotation", GateType::PauliRotation);
}

template <std::floating_point Fp>
template <FloatingPoint Fp>
void bind_gate_gate_hpp(nb::module_& m) {
gate_base_def<Fp> =
DEF_GATE_BASE(Gate,
Expand Down
Loading

0 comments on commit 6c2de1e

Please sign in to comment.