Skip to content

Commit

Permalink
Merge branch 'cellfinder-to-keras-3' into it/keras3-pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorTatarnikov authored Apr 5, 2024
2 parents 64bde71 + ca80c6d commit 585e7da
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ jobs:
include:
- os: macos-latest
python-version: "3.10"
- os: windows-latest
python-version: "3.10"

steps:
# Cache the Keras model so we don't have to remake it every time
Expand Down Expand Up @@ -95,6 +97,8 @@ jobs:
needs: [linting, manifest]
name: Run brainmapper tests to check for breakages
runs-on: ubuntu-latest
env:
KERAS_BACKEND: jax
steps:
- name: Cache Keras model
uses: actions/cache@v3
Expand All @@ -115,10 +119,8 @@ jobs:
- name: Install test dependencies
run: |
python -m pip install --upgrade pip wheel
# Install cellfinder from the latest SHA on this branch
python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA
# Install tensorflow as keras' default backend
python -m pip install "tf-nightly==2.16.0.dev20240101"
# Install cellfinder from the latest SHA on this branch (Keras with JAX backend)
python -m pip install "cellfinder[jax] @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA"
# Install checked out copy of brainglobe-workflows
python -m pip install .[dev]
Expand Down
5 changes: 4 additions & 1 deletion cellfinder/core/train/train_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ def run(

ensure_directory_exists(output_dir)
model_weights = prep_model_weights(
model_weights, install_path, model, n_free_cpus
model_weights=model_weights,
install_path=install_path,
model_name=model,
n_free_cpus=n_free_cpus,
)

yaml_contents = parse_yaml(yaml_file)
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ python =
3.9: py39-{tf,jax,torch} # On GA python=3.9 job, run tox with the tf and jax environments
3.10: py310-{tf,jax,torch} # On GA python=3.10 job, run tox with the tf and jax environments
[testenv]
commands = python -m pytest -v --color=yes
deps =
Expand Down Expand Up @@ -156,4 +157,6 @@ passenv =
XAUTHORITY
NUMPY_EXPERIMENTAL_ARRAY_FUNCTION
PYVISTA_OFF_SCREEN
platform =
tf: linux|darwin # skip TF backend on windows
"""

0 comments on commit 585e7da

Please sign in to comment.