Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model manager for machine learning models #918

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/frequenz/sdk/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# License: MIT
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH

"""Model interface."""

from ._model_manager import ModelManager

__all__ = [
"ModelManager",
]
143 changes: 143 additions & 0 deletions src/frequenz/sdk/ml/_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# License: MIT
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH

"""Load, update, monitor and retrieve machine learning models."""

import asyncio
import logging
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, TypeVar, cast

from frequenz.channels.file_watcher import EventType, FileWatcher
from typing_extensions import override

from frequenz.sdk.actor import BackgroundService

_logger = logging.getLogger(__name__)

T = TypeVar("T")


@dataclass
class _Model(Generic[T]):
"""Represent a machine learning model."""

data: T
path: Path
idlir-shkurti-frequenz marked this conversation as resolved.
Show resolved Hide resolved


class ModelNotFoundError(Exception):
"""Exception raised when a model is not found."""

def __init__(self, key: str) -> None:
"""Initialize the exception with the specified model key.

Args:
key: The key of the model that was not found.
"""
super().__init__(f"Model with key '{key}' is not found.")


class ModelManager(BackgroundService, Generic[T]):
"""Load, update, monitor and retrieve machine learning models."""

def __init__(self, model_paths: dict[str, Path], *, name: str | None = None):
"""Initialize the model manager with the specified model paths.

Args:
model_paths: A dictionary of model keys and their corresponding file paths.
name: The name of the model manager service.
"""
super().__init__(name=name)
self._models: dict[str, _Model[T]] = {}
self.model_paths = model_paths
self.load_models()

def load_models(self) -> None:
"""Load the models from the specified paths."""
for key, path in self.model_paths.items():
self._models[key] = _Model(data=self._load(path), path=path)

@staticmethod
def _load(path: Path) -> T:
"""Load the model from the specified path.

Args:
path: The path to the model file.

Returns:
T: The loaded model data.

Raises:
ModelNotFoundError: If the model file does not exist.
"""
try:
with path.open("rb") as file:
return cast(T, pickle.load(file))
except FileNotFoundError as exc:
raise ModelNotFoundError(str(path)) from exc

@override
def start(self) -> None:
idlir-shkurti-frequenz marked this conversation as resolved.
Show resolved Hide resolved
"""Start the model monitoring service by creating a background task."""
if not self.is_running:
task = asyncio.create_task(self._monitor_paths())
self._tasks.add(task)
_logger.info(
"%s: Started ModelManager service with task %s",
self.name,
task,
)

async def _monitor_paths(self) -> None:
"""Monitor model file paths and reload models as necessary."""
model_paths = [model.path for model in self._models.values()]
file_watcher = FileWatcher(
paths=list(model_paths), event_types=[EventType.CREATE, EventType.MODIFY]
)
_logger.info("%s: Monitoring model paths for changes.", self.name)
async for event in file_watcher:
_logger.info(
"%s: Reloading model from file %s due to a %s event...",
self.name,
event.path,
event.type.name,
)
self.reload_model(Path(event.path))

def reload_model(self, path: Path) -> None:
"""Reload the model from the specified path.

Args:
path: The path to the model file.
"""
for key, model in self._models.items():
if model.path == path:
try:
model.data = self._load(path)
_logger.info(
"%s: Successfully reloaded model from %s",
self.name,
path,
)
except Exception: # pylint: disable=broad-except
_logger.exception("Failed to reload model from %s", path)

def get_model(self, key: str) -> T:
"""Retrieve a loaded model by key.

Args:
key: The key of the model to retrieve.

Returns:
The loaded model data.

Raises:
KeyError: If the model with the specified key is not found.
"""
try:
return self._models[key].data
except KeyError as exc:
raise KeyError(f"Model with key '{key}' is not found.") from exc
4 changes: 4 additions & 0 deletions tests/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# License: MIT
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH

"""Tests for the model package."""
118 changes: 118 additions & 0 deletions tests/model/test_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# License: MIT
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH

"""Tests for machine learning model manager."""

import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, mock_open, patch

import pytest

from frequenz.sdk.ml import ModelManager


@dataclass
class MockModel:
"""Mock model for unit testing purposes."""

data: int | str

def predict(self) -> int | str:
"""Make a prediction based on the model data."""
return self.data


async def test_model_manager_loading() -> None:
"""Test loading models using ModelManager with direct configuration."""
model1 = MockModel("Model 1 Data")
model2 = MockModel("Model 2 Data")
pickled_model1 = pickle.dumps(model1)
pickled_model2 = pickle.dumps(model2)

model_paths = {
"model1": Path("path/to/model1.pkl"),
"model2": Path("path/to/model2.pkl"),
}

mock_files = {
"path/to/model1.pkl": mock_open(read_data=pickled_model1)(),
"path/to/model2.pkl": mock_open(read_data=pickled_model2)(),
}

def mock_open_func(file_path: Path, *__args: Any, **__kwargs: Any) -> Any:
"""Mock open function to return the correct mock file object.

Args:
file_path: The path to the file to open.
*__args: Variable length argument list. This can be used to pass additional
positional parameters typically used in file opening operations,
such as `mode` or `buffering`.
**__kwargs: Arbitrary keyword arguments. This can include parameters like
`encoding` and `errors`, common in file opening operations.

Returns:
Any: The mock file object.

Raises:
FileNotFoundError: If the file path is not in the mock files dictionary.
"""
file_path_str = str(file_path)
if file_path_str in mock_files:
file_handle = MagicMock()
file_handle.__enter__.return_value = mock_files[file_path_str]
return file_handle
raise FileNotFoundError(f"No mock setup for {file_path_str}")

with patch("pathlib.Path.open", new=mock_open_func):
with patch.object(Path, "exists", return_value=True):
model_manager: ModelManager[MockModel] = ModelManager(
model_paths=model_paths
)

with patch(
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock
):
model_manager.start() # Start the service

assert isinstance(model_manager.get_model("model1"), MockModel)
assert model_manager.get_model("model1").data == "Model 1 Data"
assert model_manager.get_model("model2").data == "Model 2 Data"

with pytest.raises(KeyError):
model_manager.get_model("key3")

await model_manager.stop() # Stop the service to clean up


async def test_model_manager_update() -> None:
"""Test updating a model in ModelManager."""
original_model = MockModel("Original Data")
updated_model = MockModel("Updated Data")
pickled_original_model = pickle.dumps(original_model)
pickled_updated_model = pickle.dumps(updated_model)

model_paths = {"model1": Path("path/to/model1.pkl")}

mock_file = mock_open(read_data=pickled_original_model)
with (
patch("pathlib.Path.open", mock_file),
patch.object(Path, "exists", return_value=True),
):
model_manager = ModelManager[MockModel](model_paths=model_paths)
with patch(
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock
):
model_manager.start() # Start the service

assert model_manager.get_model("model1").data == "Original Data"

# Simulate updating the model file
mock_file.return_value.read.return_value = pickled_updated_model
with patch("pathlib.Path.open", mock_file):
model_manager.reload_model(Path("path/to/model1.pkl"))
assert model_manager.get_model("model1").data == "Updated Data"

await model_manager.stop() # Stop the service to clean up
Loading