Skip to content

Commit

Permalink
Make model zoo dependencies more granular, and re-include the Beans b…
Browse files Browse the repository at this point in the history
…aseline model.

PiperOrigin-RevId: 698776329
  • Loading branch information
sdenton4 authored and copybara-github committed Dec 10, 2024
1 parent 025f7f4 commit d9fdd78
Show file tree
Hide file tree
Showing 12 changed files with 720 additions and 343 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ poetry install

## Notes on Dependencies

Machine learning framework libraries are pretty heavy! It can also be difficult to coordinate CUDA versions across multiple frameworks to ensure good GPU behavior. Thus, we provide some ability to select dependencies according to your needs.

Tensorflow is used in the `agile` library for training linear classifiers. If you do not need the `agile` library or any of the tensorflow models in the `zoo`, you can use poetry to install without tensorflow like so:

```bash
poetry install --without tf
```

The primary place where multiple frameworks may be needed is in the `zoo` library, which provides wrappers for various bioacoustic models. To install with JAX (allowing use of some models in the `zoo`):

```bash
poetry install --with jax
```

# Disclaimer

This is not an officially supported Google product. This project is not
Expand Down
12 changes: 5 additions & 7 deletions hoplite/agile/colab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,19 @@ def load_configs(
# Put the DB in the same directory as the audio.
db_path = epath.Path(next(iter(audio_sources.audio_globs)).base_path)

model_key, embedding_dim, model_config = (
model_configs.get_preset_model_config(model_config_key)
)
preset_info = model_configs.get_preset_model_config(model_config_key)
db_model_config = embed.ModelConfig(
model_key=model_key,
embedding_dim=embedding_dim,
model_config=model_config,
model_key=preset_info.model_key,
embedding_dim=preset_info.embedding_dim,
model_config=preset_info.model_config,
)
db_config = config_dict.ConfigDict({
'db_path': db_path,
})
if db_key == 'sqlite_usearch':
# A sane default.
db_config.usearch_cfg = sqlite_usearch_impl.get_default_usearch_config(
embedding_dim
preset_info.embedding_dim
)

return AgileConfigs(
Expand Down
2 changes: 1 addition & 1 deletion hoplite/agile/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.model_config = model_config
self.audio_sources = audio_sources
if embedding_model is None:
model_class = model_configs.MODEL_CLASS_MAP[model_config.model_key]
model_class = model_configs.get_model_class(model_config.model_key)
self.embedding_model = model_class.from_config(model_config.model_config)
else:
self.embedding_model = embedding_model
Expand Down
14 changes: 14 additions & 0 deletions hoplite/zoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,17 @@ The primary function in the `EmbeddingModel` interface is
`EmbeddingModel.embed(audio_array)` which runs model inference on the provided
audio array. The outputs are an `zoo_interface.InferenceOutputs` instance, which
contains optional embeddings, logits, and separated audio.

## Dependency Management

Models will typically depend on one of tensorflow, jax, and pytorch. We allow
specifying particular ML frameworks during package installation: as a result,
imports need to be carefully managed to avoid errors when using the `zoo`
library with only a subset of the frameworks installed.

* The `zoo_interface.py` and `zoo_test.py` have no framework dependencies.
* `model_configs.py` contains convenience loaders for all supported models.
The relevant framework for each model is imported lazily, instead of at the
top of the file.
* Model tests with specific framework dependencies should be placed in a
subdirectory. (Such as `tests_tf` for models with tensorflow dependencies.)
129 changes: 129 additions & 0 deletions hoplite/zoo/aves_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Zoo models using Jax."""

from collections.abc import Callable, Sequence
import dataclasses
import functools
from typing import Any

from hoplite.zoo import zoo_interface
import jax
from jax import numpy as jnp
from jax import random
import jaxonnxruntime
from jaxonnxruntime.core import call_onnx
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from ml_collections.config_dict import config_dict
import numpy as np

import onnx


@handler.register_op('InstanceNormalization')
class InstanceNormalization(handler.Handler):
"""Implementation of the ONNX InstanceNormalization operator."""

@classmethod
def _prepare(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
):
node.attrs_dict['epsilon'] = node.attrs.get('epsilon', 1e-5)

@classmethod
def version_6(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_6 InstanceNormalization op."""
cls._prepare(node, inputs, onnx_instancenormalization)
return onnx_instancenormalization


@functools.partial(jax.jit, static_argnames=('epsilon',))
def onnx_instancenormalization(*input_args, epsilon: float):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#InstanceNormalization for more details."""
input_, scale, b = input_args

dims_input = len(input_.shape)
input_mean = jnp.mean(
input_, axis=tuple(range(dims_input))[2:], keepdims=True
)
input_var = jnp.var(input_, axis=tuple(range(dims_input))[2:], keepdims=True)

dim_ones = (1,) * (dims_input - 2)
scale = scale.reshape(-1, *dim_ones)
b = b.reshape(-1, *dim_ones)

return (input_ - input_mean) / jnp.sqrt(input_var + epsilon) * scale + b


@dataclasses.dataclass
class AVES(zoo_interface.EmbeddingModel):
"""Wrapper around AVES ONNX model.
This model was originally trained to take audio with a 16 kHz sample rate.
Each time the model gets called with a new input shape, a new JAX function
gets created and compiled, and all parameters get copied. This could be
slow.
"""

model_path: str = ''
model: onnx.onnx_ml_pb2.ModelProto = dataclasses.field(init=False)
input_shape: tuple[int, ...] | None = dataclasses.field(
default=None, init=False
)
model_func: Callable[[list[np.ndarray[Any, Any]]], list[jax.Array]] | None = (
dataclasses.field(default=None, init=False)
)

@classmethod
def from_config(
cls, model_config: config_dict.ConfigDict
) -> zoo_interface.EmbeddingModel:
return cls(**model_config)

def __post_init__(self):
jaxonnxruntime.config.update(
'jaxort_only_allow_initializers_as_static_args', False
)
self.model = onnx.load(self.model_path)

def embed(
self, audio_array: np.ndarray[Any, Any]
) -> zoo_interface.InferenceOutputs:
return zoo_interface.embed_from_batch_embed_fn(
self.batch_embed, audio_array
)

def batch_embed(
self, audio_batch: np.ndarray[Any, Any]
) -> zoo_interface.InferenceOutputs:
# Compile new function if necessary
if audio_batch.shape != self.input_shape:
key = random.PRNGKey(0)
random_input = random.normal(key, audio_batch.shape, dtype=jnp.float32)
model_func, model_params = call_onnx.call_onnx_model(
self.model, [random_input]
)
self.input_shape = audio_batch.shape
self.model_func = functools.partial(model_func, model_params)

# Embed and add a single channel dimension
embeddings = np.asarray(self.model_func([audio_batch])[0])
embeddings = embeddings[:, :, np.newaxis, :]
return zoo_interface.InferenceOutputs(embeddings=embeddings, batched=True)
110 changes: 110 additions & 0 deletions hoplite/zoo/handcrafted_features_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Handcrafted features for linear models."""

import dataclasses
from typing import Self

from hoplite.zoo import zoo_interface
import librosa
from ml_collections import config_dict
import numpy as np


@dataclasses.dataclass
class HandcraftedFeaturesModel(zoo_interface.EmbeddingModel):
"""Wrapper for simple feature extraction."""

window_size_s: float
hop_size_s: float
melspec_config: config_dict.ConfigDict
aggregation: str = 'beans'

@classmethod
def from_config(
cls, config: config_dict.ConfigDict
) -> 'HandcraftedFeaturesModel':
return cls(**config)

@classmethod
def beans_baseline(
cls, sample_rate=32000, frame_rate=100
) -> 'HandcraftedFeaturesModel':
stride = sample_rate // frame_rate
mel_config = config_dict.ConfigDict({
'sample_rate': sample_rate,
'features': 160,
'stride': stride,
'kernel_size': 2 * stride,
'freq_range': (60.0, sample_rate / 2.0),
'power': 2.0,
})
features_config = config_dict.ConfigDict({
'compute_mfccs': True,
'aggregation': 'beans',
})
config = config_dict.ConfigDict({
'sample_rate': sample_rate,
'melspec_config': mel_config,
'features_config': features_config,
'window_size_s': 1.0,
'hop_size_s': 1.0,
})
# pylint: disable=unexpected-keyword-arg
return HandcraftedFeaturesModel.from_config(config)

def melspec(self, audio_array: np.ndarray) -> np.ndarray:
framed_audio = self.frame_audio(
audio_array, self.window_size_s, self.hop_size_s
)
specs = []
for frame in framed_audio:
specs.append(
librosa.feature.melspectrogram(
y=frame,
sr=self.sample_rate,
hop_length=self.melspec_config.stride,
win_length=self.melspec_config.kernel_size,
center=True,
n_mels=self.melspec_config.features,
power=self.melspec_config.power,
)
)
return np.stack(specs, axis=0)

def embed(self, audio_array: np.ndarray) -> zoo_interface.InferenceOutputs:
# Melspecs will have shape [melspec_channels, frames]
melspecs = self.melspec(audio_array)
if self.aggregation == 'beans':
features = np.concatenate(
[
melspecs.mean(axis=-1),
melspecs.std(axis=-1),
melspecs.min(axis=-1),
melspecs.max(axis=-1),
],
axis=-2,
)
else:
raise ValueError(f'unrecognized aggregation: {self.aggregation}')
# Add a trivial channels dimension.
features = features[:, np.newaxis, :]
return zoo_interface.InferenceOutputs(features, None, None)

def batch_embed(
self, audio_batch: np.ndarray
) -> zoo_interface.InferenceOutputs:
return zoo_interface.batch_embed_from_embed_fn(self.embed, audio_batch)
Loading

0 comments on commit d9fdd78

Please sign in to comment.