This repository contains a machine learning project designed for multispectral image analysis using PyTorch Lightning. The project is structured to provide an easy-to-use command-line interface for training the model, along with a set of utilities for data processing and model configuration.
- Python 3.8 or higher
- Poetry for dependency management and packaging.
-
Clone the repository to your local machine.
-
Install dependencies using Poetry. Navigate to the project directory and run:
poetry install
This command reads the
pyproject.toml
andpoetry.lock
files to install the necessary dependencies in a virtual environment.
Before running the training, you may need to adjust the configurations according to your dataset and training requirements. Configuration options are available in config.py
. Review and modify them as necessary to fit your project's needs.
Note, you will also need to create a data directory with eurosat data i.e. data/EuroSAT_MS_Samples
To train the model, use the CLI provided in trainer.py
. The CLI supports various options for training customization, such as setting the number of epochs, batch size, and more.
poetry run python trainer.py [OPTIONS]
Replace [OPTIONS]
with your desired command-line options to customize the training session. The available options include:
--dataset_path
: Path to the dataset. Default is specified inconfig.py
.--num_epochs
: Number of epochs for training. Default is specified inconfig.py
.--batch_size
: Batch size for training. Default is specified inconfig.py
.--learning_rate
: Learning rate for the optimizer. Default is specified inconfig.py
.--num_input_channels
: Number of input channels for the model. Default is specified inconfig.py
.
To train your model with custom configurations, you can run the following command:
poetry run python trainer.py --dataset_path "/path/to/dataset" --num_epochs 100 --batch_size 32 --learning_rate 0.001 --num_input_channels 3
The training session logs are saved in the tb_logs
directory. You can visualize the logs using TensorBoard by running the following command:
tensorboard --logdir=tb_logs --port=8080
dataset.py
: Defines the data module for handling the dataset.model.py
: Contains the definition of the machine learning model.utils.py
: Provides additional utilities for data processing and model training.