From a7d06441734b2813371c9bbf96302d9da4bec292 Mon Sep 17 00:00:00 2001 From: gandalfr-KY Date: Fri, 13 Dec 2024 02:52:32 +0000 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E9=83=A8=E3=81=AEget=5Ffrom=5Fjson?= =?UTF-8?q?=E3=82=92=E5=AE=9F=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- exe/main.cpp | 15 +++- include/scaluq/gate/gate.hpp | 35 +++++++- include/scaluq/gate/gate_standard.hpp | 125 ++++++++++++++++++++++++++ 3 files changed, 168 insertions(+), 7 deletions(-) diff --git a/exe/main.cpp b/exe/main.cpp index fd8fd536..6247f33a 100644 --- a/exe/main.cpp +++ b/exe/main.cpp @@ -103,9 +103,18 @@ int main() { std::cout << Json(paramprobgate) << std::endl; } { - Gate g; - XGate x; - g = x; + auto x = gate::X(1, {2}); + Json j = x; + std::cout << j << std::endl; + Gate gate = j; + std::cout << gate << std::endl; + } + { + auto x = gate::RX(1, 0.5, {2}); + Json j = x; + std::cout << j << std::endl; + Gate gate = j; + std::cout << gate << std::endl; } Kokkos::finalize(); diff --git a/include/scaluq/gate/gate.hpp b/include/scaluq/gate/gate.hpp index d3288726..86d4f5d3 100644 --- a/include/scaluq/gate/gate.hpp +++ b/include/scaluq/gate/gate.hpp @@ -223,6 +223,9 @@ class GateBase : public std::enable_shared_from_this> { template concept GateImpl = std::derived_from>; +template +inline std::shared_ptr get_from_json(const Json&); + template class GatePtr { friend class GateFactory; @@ -289,10 +292,34 @@ class GatePtr { friend void to_json(Json& j, const GatePtr& gate) { gate->get_as_json(j); } friend void from_json(const Json& j, GatePtr& gate) { std::string type = j.at("type"); - if (type == "X") { - auto target = j.at("target"); - auto control = j.at("control"); - gate = std::shared_ptr(); + if (type == "I") { + gate = get_from_json>(j); + } else if (type == "GlobalPhase") { + gate = get_from_json>(j); + } else if (type == "X") { + gate = get_from_json>(j); + } else if (type == "Y") { + gate = get_from_json>(j); + } else if (type == "Z") { + gate = get_from_json>(j); + } else if (type == "H") { + gate = get_from_json>(j); + } else if (type == "S") { + gate = get_from_json>(j); + } else if (type == "Sdag") { + gate = get_from_json>(j); + } else if (type == "T") { + gate = get_from_json>(j); + } else if (type == "Tdag") { + gate = get_from_json>(j); + } + + else if (type == "RX") { + gate = get_from_json>(j); + } else if (type == "RY") { + gate = get_from_json>(j); + } else if (type == "RZ") { + gate = get_from_json>(j); } } }; diff --git a/include/scaluq/gate/gate_standard.hpp b/include/scaluq/gate/gate_standard.hpp index 73909323..05c5f484 100644 --- a/include/scaluq/gate/gate_standard.hpp +++ b/include/scaluq/gate/gate_standard.hpp @@ -561,6 +561,131 @@ class SwapGateImpl : public GateBase { } // namespace internal +template +using IGate = internal::GatePtr>; +template +using GlobalPhaseGate = internal::GatePtr>; +template +using XGate = internal::GatePtr>; +template +using YGate = internal::GatePtr>; +template +using ZGate = internal::GatePtr>; +template +using HGate = internal::GatePtr>; +template +using SGate = internal::GatePtr>; +template +using SdagGate = internal::GatePtr>; +template +using TGate = internal::GatePtr>; +template +using TdagGate = internal::GatePtr>; +template +using SqrtXGate = internal::GatePtr>; +template +using SqrtXdagGate = internal::GatePtr>; +template +using SqrtYGate = internal::GatePtr>; +template +using SqrtYdagGate = internal::GatePtr>; +template +using P0Gate = internal::GatePtr>; +template +using P1Gate = internal::GatePtr>; +template +using RXGate = internal::GatePtr>; +template +using RYGate = internal::GatePtr>; +template +using RZGate = internal::GatePtr>; +template +using U1Gate = internal::GatePtr>; +template +using U2Gate = internal::GatePtr>; +template +using U3Gate = internal::GatePtr>; +template +using SwapGate = internal::GatePtr>; + +namespace internal { // for json implemention +template <> +inline std::shared_ptr> get_from_json(const Json&) { + return std::make_shared>(); +} +template <> +inline std::shared_ptr> get_from_json(const Json&) { + return std::make_shared>(); +} + +template <> +inline std::shared_ptr> get_from_json(const Json& j) { + auto controls = j.at("control").get>(); + double phase = j.at("phase").get(); + return std::make_shared>(vector_to_mask(controls), phase); +} +template <> +inline std::shared_ptr> get_from_json(const Json& j) { + auto controls = j.at("control").get>(); + float phase = j.at("phase").get(); + return std::make_shared>(vector_to_mask(controls), phase); +} + +#define DECLARE_GET_FROM_JSON(Impl) \ + template <> \ + inline std::shared_ptr> get_from_json(const Json& j) { \ + auto targets = j.at("target").get>(); \ + auto controls = j.at("control").get>(); \ + return std::make_shared>(vector_to_mask(targets), \ + vector_to_mask(controls)); \ + } \ + template <> \ + inline std::shared_ptr> get_from_json(const Json& j) { \ + auto targets = j.at("target").get>(); \ + auto controls = j.at("control").get>(); \ + return std::make_shared>(vector_to_mask(targets), \ + vector_to_mask(controls)); \ + } + +DECLARE_GET_FROM_JSON(XGateImpl); +DECLARE_GET_FROM_JSON(YGateImpl); +DECLARE_GET_FROM_JSON(ZGateImpl); +DECLARE_GET_FROM_JSON(HGateImpl); +DECLARE_GET_FROM_JSON(SGateImpl); +DECLARE_GET_FROM_JSON(SdagGateImpl); +DECLARE_GET_FROM_JSON(TGateImpl); +DECLARE_GET_FROM_JSON(TdagGateImpl); +DECLARE_GET_FROM_JSON(SqrtXGateImpl); +DECLARE_GET_FROM_JSON(SqrtXdagGateImpl); +DECLARE_GET_FROM_JSON(SqrtYGateImpl); +DECLARE_GET_FROM_JSON(SqrtYdagGateImpl); +DECLARE_GET_FROM_JSON(P0GateImpl); +DECLARE_GET_FROM_JSON(P1GateImpl); + +#define DECLARE_GET_FROM_JSON_RGATE(Impl) \ + template <> \ + inline std::shared_ptr> get_from_json(const Json& j) { \ + auto targets = j.at("target").get>(); \ + auto controls = j.at("control").get>(); \ + double angle = j.at("angle").get(); \ + return std::make_shared>( \ + vector_to_mask(targets), vector_to_mask(controls), angle); \ + } \ + template <> \ + inline std::shared_ptr> get_from_json(const Json& j) { \ + auto targets = j.at("target").get>(); \ + auto controls = j.at("control").get>(); \ + float angle = j.at("angle").get(); \ + return std::make_shared>( \ + vector_to_mask(targets), vector_to_mask(controls), angle); \ + } + +DECLARE_GET_FROM_JSON_RGATE(RXGateImpl); +DECLARE_GET_FROM_JSON_RGATE(RYGateImpl); +DECLARE_GET_FROM_JSON_RGATE(RZGateImpl); + +} // namespace internal + #ifdef SCALUQ_USE_NANOBIND namespace internal { void bind_gate_gate_standard_hpp(nb::module_& m) {