Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Glacialte committed Dec 13, 2024
1 parent 7fe0406 commit 84a3603
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions tests/gate/batched_gate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ void run_random_batched_gate_apply_pauli(std::uint64_t n_qubits) {

for (std::uint64_t batch_id = 0; batch_id < states.batch_size(); batch_id++) {
for (std::uint64_t i = 0; i < dim; i++) {
check_near((StdComplex<Fp>)states_cp[batch_id][i], test_state[i]);
check_near((StdComplex<Fp>)states_cp[batch_id][i],
(StdComplex<Fp>)states_bef_cp[batch_id][i]);
}
}
}
Expand Down Expand Up @@ -379,7 +380,8 @@ void run_random_batched_gate_apply_pauli(std::uint64_t n_qubits) {

for (std::uint64_t batch_id = 0; batch_id < states.batch_size(); batch_id++) {
for (std::uint64_t i = 0; i < dim; i++) {
check_near((StdComplex<Fp>)states_cp[batch_id][i], test_state[i]);
check_near((StdComplex<Fp>)states_cp[batch_id][i],
(StdComplex<Fp>)states_bef_cp[batch_id][i]);
}
}
}
Expand Down Expand Up @@ -713,10 +715,10 @@ TEST(BatchedGateTest, ApplyDenseMatrixGate) {
run_random_batched_gate_apply_general_dense<float>(6);
run_random_batched_gate_apply_general_dense<double>(6);
}
// TEST(BatchedGateTest, ApplyPauliGate) {
// run_random_batched_gate_apply_pauli<float>(5);
// run_random_batched_gate_apply_pauli<double>(5);
// }
TEST(BatchedGateTest, ApplyPauliGate) {
run_random_batched_gate_apply_pauli<float>(5);
run_random_batched_gate_apply_pauli<double>(5);
}

TEST(BatchedGateTest, ApplyProbablisticGate) {
{
Expand Down Expand Up @@ -778,24 +780,24 @@ void test_batched_gate(Gate<Fp> gate_control,
StateVectorBatched<Fp> states =
StateVectorBatched<Fp>::Haar_random_state(BATCH_SIZE, n_qubits, true);
auto amplitudes = states.get_amplitudes();
StateVectorBatched<Fp> state_controlled(BATCH_SIZE, n_qubits - std::popcount(control_mask));
StateVectorBatched<Fp> states_controlled(BATCH_SIZE, n_qubits - std::popcount(control_mask));
std::vector<std::vector<Complex<Fp>>> amplitudes_controlled(
BATCH_SIZE, std::vector<Complex<Fp>>(state_controlled.dim()));
BATCH_SIZE, std::vector<Complex<Fp>>(states_controlled.dim()));
for (std::size_t i = 0; i < BATCH_SIZE; i++) {
for (std::uint64_t j = 0; j < state_controlled.dim(); j++) {
for (std::uint64_t j = 0; j < states_controlled.dim(); j++) {
amplitudes_controlled[i][j] =
amplitudes[i]
[internal::insert_zero_at_mask_positions(j, control_mask) | control_mask];
}
}
state_controlled.load(amplitudes_controlled);
states_controlled.load(amplitudes_controlled);
gate_control->update_quantum_state(states);
gate_simple->update_quantum_state(state_controlled);
gate_simple->update_quantum_state(states_controlled);
amplitudes = states.get_amplitudes();
amplitudes_controlled = state_controlled.get_amplitudes();
for (std::size_t i = 0; i < BATCH_SIZE; i++) {
for (std::uint64_t j = 0; j < state_controlled.dim(); j++) {
check_near((StdComplex<Fp>)amplitudes[i][j],
amplitudes_controlled = states_controlled.get_amplitudes();
for (std::uint64_t i = 0; i < BATCH_SIZE; i++) {
for (std::uint64_t j : std::views::iota(0ULL, states_controlled.dim())) {
check_near((StdComplex<Fp>)amplitudes_controlled[i][j],
(StdComplex<Fp>)
amplitudes[i][internal::insert_zero_at_mask_positions(j, control_mask) |
control_mask]);
Expand All @@ -809,12 +811,7 @@ template <std::floating_point Fp,
typename Factory>
void test_batched_standard_gate_control(Factory factory, std::uint64_t n) {
Random random;
std::vector<std::uint64_t> shuffled(n);
std::iota(shuffled.begin(), shuffled.end(), 0ULL);
for (std::uint64_t i : std::views::iota(0ULL, n) | std::views::reverse) {
std::uint64_t j = random.int32() % (i + 1);
if (i != j) std::swap(shuffled[i], shuffled[j]);
}
std::vector<std::uint64_t> shuffled = random.permutation(n);
std::vector<std::uint64_t> targets(num_target);
for (std::uint64_t i : std::views::iota(0ULL, num_target)) {
targets[i] = shuffled[i];
Expand Down Expand Up @@ -1018,8 +1015,8 @@ TEST(BatchGateTest, Control) {
test_batched_standard_gate_control<float, 1, 2>(gate::U2<float>, n);
test_batched_standard_gate_control<float, 1, 3>(gate::U3<float>, n);
test_batched_standard_gate_control<float, 2, 0>(gate::Swap<float>, n);
// test_batched_pauli_control<float, false>(n);
// test_batched_pauli_control<float, true>(n);
test_batched_pauli_control<float, false>(n);
test_batched_pauli_control<float, true>(n);
test_batched_matrix_control<float, 0>(n);
test_batched_matrix_control<float, 1>(n);
test_batched_matrix_control<float, 2>(n);
Expand All @@ -1046,8 +1043,8 @@ TEST(BatchGateTest, Control) {
test_batched_standard_gate_control<double, 1, 2>(gate::U2<double>, n);
test_batched_standard_gate_control<double, 1, 3>(gate::U3<double>, n);
test_batched_standard_gate_control<double, 2, 0>(gate::Swap<double>, n);
// test_batched_pauli_control<double, false>(n);
// test_batched_pauli_control<double, true>(n);
test_batched_pauli_control<double, false>(n);
test_batched_pauli_control<double, true>(n);
test_batched_matrix_control<double, 0>(n);
test_batched_matrix_control<double, 1>(n);
test_batched_matrix_control<double, 2>(n);
Expand Down

0 comments on commit 84a3603

Please sign in to comment.