In order to increase the performance of the tracking server and the various stores, we propose to rewrite the server and store implementation in Go.
This package is not yet available on PyPI and currently requires the Go SDK to be installed.
You can then install the package via pip:
pip install git+https://github.com/jgiannuzzi/mlflow-go.git
This repository uses mage to streamline some utility functions.
# Install mage (already done in the dev container)
go install github.com/magefile/[email protected]
# See all targets
mage
# Execute single target
mage generate
The beauty of Mage is that we can use regular Go code for our scripting. That being said, we are not married to this tool.
To integrate with MLflow, you need to include the source code. The mlflow/mlflow repository contains proto files that define the tracking API. It also includes Python tests that we use to verify our Go implementation produces identical behaviour.
We use a .mlflow.ref
file to specify the exact location from which to pull our sources. The format should be remote#reference
, where remote
is a git remote and reference
is a branch, tag, or commit SHA.
If the .mlflow.ref
file is modified and becomes out of sync with the current source files, the mage target will automatically detect this. To manually force a sync, you can run mage repo:update
.
To ensure we stay compatible with the Python implementation, we aim to generate as much as possible based on the .proto
files.
By running
mage generate
Go code will be generated. Use the protos files from .mlflow.repo
repository.
This includes the generation of:
- Structs for each endpoint. (pkg/protos)
- Go interfaces for each service. (pkg/contract/service/*.g.go)
- fiber routes for each endpoint. (pkg/server/routes/*.g.go)
If there is any change in the proto files, this should ripple into the Go code.
We use Go validator to validate all incoming request structs. As the proto files don't specify any validation rules, we map them manually in pkg/cmd/generate/validations.go.
Once the mapping has been done, validation will be invoked automatically in the generated fiber code.
When the need arises, we can write custom validation function in pkg/validation/validation.go.
Initially, we want to focus on supporting Postgres SQL. We chose Gorm as ORM to interact with the database.
We do not generate any Go code based on the database schema. Gorm has generation capabilities but they didn't fit our needs. The plan would be to eventually assert the current code still matches the database schema via an integration test.
All the models use pointers for their fields. We do this for performance reasons and to distinguish between zero values and null values.
We have enabled various linters from golangci-lint, you can run these via:
pre-commit run golangci-lint --all-files
Sometimes golangci-lint
can complain about unrelated files, run golangci-lint cache clean
to clear the cache.
To enable use of the Go server, users can run the mlflow-go server
command.
# Start the Go server with a database URI
# Other databases are supported as well: sqlite, mysql and mssql
mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres
This will launch the python process as usual. Within Python, a random port is chosen to start the existing server and a Go child process is spawned. The Go server will use the user specified port (5000 by default) and spawn the actual Python server as its own child process (gunicorn
or waitress
).
Any incoming requests the Go server cannot process will be proxied to the existing Python server.
Any Go-specific options can be passed with --go-opts
, which takes a comma-separated list of key-value pairs.
mlflow-go server --backend-store-uri postgresql://postgres:postgres@localhost:5432/postgres --go-opts log_level=debug,shutdown_timeout=5s
MLflow client could be pointed the Go server:
import mlflow
# Use the Go server
mlflow.set_tracking_uri("http://localhost:5000")
# Use MLflow as usual
mlflow.set_experiment("my-experiment")
with mlflow.start_run():
mlflow.log_param("param", 1)
mlflow.log_metric("metric", 2)
To ensure everything still compiles:
go build -o /dev/null ./pkg/cmd/server
or
python -m mlflow_go.lib . /tmp
import mlflow
import mlflow_go
# Enable the Go client implementation (disabled by default)
mlflow_go.enable_go()
# Set the tracking URI (you can also set it via the environment variable MLFLOW_TRACKING_URI)
# Currently only database URIs are supported
mlflow.set_tracking_uri("sqlite:///mlflow.db")
# Use MLflow as usual
mlflow.set_experiment("my-experiment")
with mlflow.start_run():
mlflow.log_param("param", 1)
mlflow.log_metric("metric", 2)
import logging
import mlflow
import mlflow_go
# Enable debug logging
logging.basicConfig()
logging.getLogger('mlflow_go').setLevel(logging.DEBUG)
# Enable the Go client implementation (disabled by default)
mlflow_go.enable_go()
# Instantiate the tracking store with a database URI
tracking_store = mlflow.tracking._tracking_service.utils._get_store('sqlite:///mlflow.db')
# Call any tracking store method
tracking_store.get_experiment(0)
# Instantiate the model registry store with a database URI
model_registry_store = mlflow.tracking._model_registry.utils._get_store('sqlite:///mlflow.db')
# Call any model registry store method
model_registry_store.get_latest_versions("model")
Sometimes, it can be very useful to modify failing tests and use print
statements to display the current state or differences between objects from Python or Go services.
Adding "-vv"
to the pytest
command in magefiles/tests.go
can also provide more information when assertions are not met.
At times, you might want to apply store calls to your local database to investigate certain read operations via the local tracking server.
You can achieve this by changing:
def test_search_runs_datasets(store: SqlAlchemyStore):
to:
def test_search_runs_datasets():
db_uri = "postgresql://postgres:postgres@localhost:5432/postgres"
artifact_uri = Path("/tmp/artifacts")
artifact_uri.mkdir(exist_ok=True)
store = SqlAlchemyStore(db_uri, artifact_uri.as_uri())
in the test file located in .mlflow.repo
.
The currently supported endpoints can be found by running mage command:
mage endpoints