Skip to content

Commit

Permalink
Add conditional splitting support when using WAE
Browse files Browse the repository at this point in the history
  • Loading branch information
mlxd committed Sep 13, 2024
1 parent 07b51b7 commit 749fa01
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
57 changes: 51 additions & 6 deletions .github/workflows/interface-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ jobs:
"default": ["3.10"]
}
EOF
elif [ "${{ inputs.python_warning_level }}" == "error" ];
then
cat >python_versions.json <<-EOF
{
"default": ["3.12"],
"torch-tests": ["3.12"],
"tf-tests": ["3.12"],
"jax-tests": ["3.12"],
"all-interfaces-tests": ["3.12"],
"external-libraries-tests": ["3.12"],
"qcut-tests": ["3.12"],
"qchem-tests": ["3.12"],
"gradients-tests": ["3.12"],
"data-tests": ["3.12"],
"device-tests": ["3.12"] }
EOF
else
cat >python_versions.json <<-EOF
{
Expand Down Expand Up @@ -107,6 +123,18 @@ jobs:
"device-tests": 2
}
EOF
elif [ "${{ inputs.python_warning_level }}" == "error" ];
then
cat >matrix_max_parallel.json <<-EOF
{
"default": 1,
"core-tests": 1,
"gradients-tests": 1,
"jax-tests": 1,
"tf-tests": 1,
"device-tests": 1
}
EOF
else
cat >matrix_max_parallel.json <<-EOF
{
Expand All @@ -123,6 +151,19 @@ jobs:
jq . matrix_max_parallel.json
echo "matrix_max_parallel=$(jq -r tostring matrix_max_parallel.json)" >> $GITHUB_OUTPUT
- name: Enable splitting of tests
id: job_split
run: |
if [ "${{ inputs.python_warning_level }}" == "error" ];
then
ENABLE_SPLIT=0
else
ENABLE_SPLIT=1
fi
jq . matrix_max_parallel.json
echo "enable_split=$ENABLE_SPLIT" >> $GITHUB_OUTPUT
- name: Setup Job to Skip
id: jobs_to_skip
env:
Expand All @@ -141,6 +182,7 @@ jobs:
matrix-max-parallel: ${{ steps.max_parallel.outputs.matrix_max_parallel }}
python-version: ${{ steps.python_versions.outputs.python_versions }}
jobs-to-skip: ${{ steps.jobs_to_skip.outputs.jobs_to_skip }}
enable_split: ${{ steps.jobs_to_skip.outputs.enable_split }}

torch-tests:
needs:
Expand Down Expand Up @@ -219,7 +261,8 @@ jobs:
|| fromJSON(needs.setup-ci-load.outputs.matrix-max-parallel).default
}}
matrix:
group: [1, 2, 3]
group: >-
${{ fromJSON(needs.setup-ci-load.outputs.enable-split) && [1, 2, 3] || [1] }}
python-version: >-
${{
fromJSON(needs.setup-ci-load.outputs.python-version).tf-tests
Expand All @@ -239,7 +282,7 @@ jobs:
install_pennylane_lightning_master: true
pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }}
pytest_markers: tf and not qcut and not finite-diff and not param-shift
pytest_additional_args: --splits 3 --group ${{ matrix.group }}
pytest_additional_args: ${{ fromJSON(needs.setup-ci-load.outputs.enable-split) == "1" && format('--splits {0} --group {1}', '3', matrix.group) || format('--splits {0} --group {1}', '1', matrix.group) }}
pytest_durations_file_path: '.github/workflows/tf_tests_durations.json'
pytest_store_durations: ${{ inputs.pytest_store_durations }}
additional_pip_packages: pytest-split
Expand All @@ -258,7 +301,7 @@ jobs:
|| fromJSON(needs.setup-ci-load.outputs.matrix-max-parallel).default
}}
matrix:
group: [1, 2, 3, 4, 5]
group: ${{ fromJSON(needs.setup-ci-load.outputs.enable-split) && [1, 2, 3, 4, 5] || [1] }}
python-version: >-
${{
fromJSON(needs.setup-ci-load.outputs.python-version).jax-tests
Expand All @@ -278,7 +321,9 @@ jobs:
install_pennylane_lightning_master: true
pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }}
pytest_markers: jax and not qcut and not finite-diff and not param-shift
pytest_additional_args: --splits 5 --group ${{ matrix.group }}
pytest_additional_args: ${{ fromJSON(needs.setup-ci-load.outputs.enable-split) == "1" && format('--splits {0} --group {1}', '5', matrix.group) || format('--splits {0} --group {1}', '1', matrix.group) }}


pytest_durations_file_path: '.github/workflows/jax_tests_durations.json'
pytest_store_durations: ${{ inputs.pytest_store_durations }}
additional_pip_packages: pytest-split
Expand All @@ -297,7 +342,7 @@ jobs:
|| fromJSON(needs.setup-ci-load.outputs.matrix-max-parallel).default
}}
matrix:
group: [1, 2, 3, 4, 5]
group: ${{ fromJSON(needs.setup-ci-load.outputs.enable-split) && [1, 2, 3, 4, 5] || [1] }}
python-version: >-
${{
fromJSON(needs.setup-ci-load.outputs.python-version).core-tests
Expand All @@ -317,7 +362,7 @@ jobs:
install_pennylane_lightning_master: true
pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }}
pytest_markers: core and not qcut and not finite-diff and not param-shift
pytest_additional_args: --splits 5 --group ${{ matrix.group }}
pytest_additional_args: ${{ fromJSON(needs.setup-ci-load.outputs.enable-split) == "1" && format('--splits {0} --group {1}', '5', matrix.group) || format('--splits {0} --group {1}', '1', matrix.group) }}
pytest_durations_file_path: '.github/workflows/core_tests_durations.json'
pytest_store_durations: ${{ inputs.pytest_store_durations }}
additional_pip_packages: pytest-split
Expand Down
1 change: 0 additions & 1 deletion tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ addopts = --benchmark-disable
xfail_strict=true
filterwarnings =
error::pennylane.PennyLaneDeprecationWarning
ignore:ImportWarning

0 comments on commit 749fa01

Please sign in to comment.