Add user-facing TreeMetadata
object returned by `PyTreeCheckpointHa…
#12288
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: build | |
on: | |
push: | |
branches: | |
- main | |
- 'test_*' | |
pull_request: | |
branches: | |
- main | |
jobs: | |
build-checkpoint: | |
name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
runs-on: ubuntu-latest | |
defaults: | |
run: | |
working-directory: checkpoint | |
strategy: | |
matrix: | |
python-version: ["3.10", "3.11"] | |
jax-version: ["newest"] | |
include: | |
- python-version: "3.10" | |
jax-version: "0.4.34" # keep in sync with minimum version in checkpoint/pyproject.toml | |
steps: | |
- name: Cancel previous | |
uses: styfle/[email protected] | |
with: | |
access_token: ${{ github.token }} | |
- uses: actions/checkout@v2 | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v1 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Install dependencies | |
# TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. | |
# Currently in place to override remote orbax import due to flax dependency. | |
run: | | |
pip install -e . | |
pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
pip uninstall -y orbax | |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
pip install -U jax jaxlib | |
else | |
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" | |
fi | |
- name: Test with pytest | |
run: | | |
python -m pytest | |
# The below step just reports the success or failure of tests as a "commit status". | |
# This is needed for copybara integration. | |
- name: Report success or failure as github status | |
if: always() | |
shell: bash | |
run: | | |
status="${{ job.status }}" | |
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
curl -sS --request POST \ | |
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
--header 'content-type: application/json' \ | |
--data '{ | |
"state": "'$lowercase_status'", | |
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
"description": "'$status'", | |
"context": "github-actions/build" | |
}' | |
build-export: | |
name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
runs-on: ubuntu-latest | |
defaults: | |
run: | |
working-directory: export | |
strategy: | |
matrix: | |
python-version: ["3.10"] | |
jax-version: ["newest", "0.4.30"] # keep in sync with minimum version in export/pyproject.toml | |
steps: | |
- name: Cancel previous | |
uses: styfle/[email protected] | |
with: | |
access_token: ${{ github.token }} | |
- uses: actions/checkout@v2 | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v1 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Extract branch name | |
shell: bash | |
run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT | |
id: extract_branch | |
- name: Install dependencies | |
run: | | |
pip install . | |
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
pip install -U jax jaxlib | |
else | |
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" | |
fi | |
- name: Test with pytest | |
run: | | |
test_dir=$(mktemp -d) | |
cp orbax/export/conftest.py ${test_dir} | |
for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do | |
cp ${t} ${test_dir} | |
XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) | |
done | |
# The below step just reports the success or failure of tests as a "commit status". | |
# This is needed for copybara integration. | |
- name: Report success or failure as github status | |
if: always() | |
shell: bash | |
run: | | |
status="${{ job.status }}" | |
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
curl -sS --request POST \ | |
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
--header 'content-type: application/json' \ | |
--data '{ | |
"state": "'$lowercase_status'", | |
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
"description": "'$status'", | |
"context": "github-actions/build" | |
}' |