Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoiCluster #7

Merged
merged 13 commits into from
Jun 3, 2024
Merged
16 changes: 14 additions & 2 deletions sdcat/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import click
from sdcat.logger import err, info, create_logger_file
from sdcat import __version__
from sdcat.cluster.commands import run_cluster
from sdcat.cluster.commands import run_cluster_det, run_cluster_roi
from sdcat.detect.commands import run_detect


Expand All @@ -28,7 +28,19 @@ def cli():
pass

cli.add_command(run_detect)
cli.add_command(run_cluster)


@cli.group(name="cluster")
def cli_cluster():
"""
Commands related to converting data
"""
pass


cli.add_command(cli_cluster)
cli_cluster.add_command(run_cluster_det)
cli_cluster.add_command(run_cluster_roi)


if __name__ == '__main__':
Expand Down
34 changes: 24 additions & 10 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import MinMaxScaler
from sdcat.logger import info, warn, debug, err
from sdcat.cluster.utils import cluster_grid, crop_square_image
from sdcat.cluster.utils import cluster_grid, crop_square_image, square_image
from sdcat.cluster.embedding import fetch_embedding, has_cached_embedding, compute_norm_embedding

if find_spec("multicore_tsne"):
Expand Down Expand Up @@ -227,12 +227,14 @@ def cluster_vits(
cluster_selection_epsilon: float,
min_similarity: float,
min_cluster_size: int,
min_samples: int):
min_samples: int,
roi: bool = False) -> pd.DataFrame:
""" Cluster the crops using the VITS embeddings.
:param prefix: A unique prefix to save artifacts from clustering
:param model: The model to use for clustering
:param df_dets: The dataframe with the detections
:param output_path: The output path to save the clustering artifacts to
:param roi: Whether the detections are already cropped to the ROI
:param cluster_selection_epsilon: The epsilon parameter for HDBSCAN
:param alpha: The alpha parameter for HDBSCAN
:param min_similarity: The minimum similarity score to use for -1 cluster reassignment
Expand All @@ -250,12 +252,18 @@ def cluster_vits(

# Skip cropping if all the crops are already done
if num_crop != len(df_dets):
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
num_processes = min(multiprocessing.cpu_count(), len(df_dets))
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)
if roi == True:
info('ROI crops already exist. Creating square crops in parallel using {multiprocessing.cpu_count()} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(square_image, args)
else:
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)

# Drop any rows with crop_path that have files that don't exist - sometimes the crops fail
df_dets = df_dets[df_dets['crop_path'].apply(lambda x: os.path.exists(x))]
Expand All @@ -279,9 +287,15 @@ def cluster_vits(
(output_path / prefix).mkdir(parents=True)

# Remove everything except ancillary data to include in clustering
ancillary_df = df_dets.drop(
columns=['x', 'y', 'xx', 'xy', 'w', 'h', 'image_width', 'image_height', 'cluster_id', 'cluster', 'score',
'class', 'image_path', 'crop_path'])
columns = ['x', 'y', 'xx', 'xy', 'w', 'h', 'image_width', 'image_height', 'cluster_id', 'cluster', 'score',
'class', 'image_path', 'crop_path']
# Check if the columns exist in the dataframe
if all(col in df_dets.columns for col in columns):
ancillary_df = df_dets.drop(
columns=['x', 'y', 'xx', 'xy', 'w', 'h', 'image_width', 'image_height', 'cluster_id', 'cluster', 'score',
'class', 'image_path', 'crop_path'])
else:
ancillary_df = df_dets

# Cluster the images
cluster_sim, unique_clusters, cluster_means, coverage = _run_hdbscan_assign(prefix,
Expand Down
7 changes: 4 additions & 3 deletions sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import pandas as pd
import pytz
import torch
from PIL import Image

from sdcat import common_args
from sdcat.config import config as cfg
from sdcat.logger import info, err, warn
from sdcat.cluster.cluster import cluster_vits


@click.command('cluster', help='Cluster detections. See cluster --config-ini to override cluster defaults.')
@click.command('detections', help='Cluster detections. See cluster --config-ini to override cluster defaults.')
@common_args.config_ini
@common_args.start_image
@common_args.end_image
Expand All @@ -31,7 +32,7 @@
@click.option('--alpha', help='Alpha is a parameter that controls the linkage. See https://hdbscan.readthedocs.io/en/latest/parameter_selection.html. Default is 0.92. Increase for less conservative clustering, e.g. 1.0', type=float)
@click.option('--cluster-selection-epsilon', help='Epsilon is a parameter that controls the linkage. Default is 0. Increase for less conservative clustering', type=float)
@click.option('--min-cluster-size', help='The minimum number of samples in a group for that group to be considered a cluster. Default is 2. Increase for less conservative clustering, e.g. 5, 15', type=int)
def run_cluster(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, start_image, end_image):
def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, min_cluster_size, start_image, end_image):
config = cfg.Config(config_ini)
max_area = int(config('cluster', 'max_area'))
min_area = int(config('cluster', 'min_area'))
Expand Down Expand Up @@ -259,7 +260,7 @@ def is_day(utc_dt):
shutil.copy(Path(config_ini), save_dir / f'{prefix}_config.ini')
else:
warn(f'No detections found to cluster')

@click.command('roi', help='Cluster roi. See cluster --config-ini to override cluster defaults.')
@common_args.config_ini
@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True)
Expand Down
38 changes: 37 additions & 1 deletion sdcat/cluster/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,45 @@ def gen_grid(with_attention: bool):
# gen_grid(with_attention=True)


def square_image(row, square_dim: int):
"""
Squares an image to the model dimension, filling it with black bars if necessary
:param row:
:param square_dim: dimension of the square image
:return:
"""
try:
if not Path(row.image_path).exists():
warn(f'Skipping {row.crop_path} because the image {row.image_path} does not exist')
return

if Path(row.crop_path).exists(): # If the crop already exists, skip it
return

# Determine the size of the new square
max_side = max(row.image_width, row.image_height)

# Create a new square image with a black background
new_image = Image.new('RGB', (max_side, max_side), (0, 0, 0))

img = Image.open(row.image_path)

# Paste the original image onto the center of the new image
new_image.paste(img, ((max_side - row.image_width) // 2, (max_side - row.image_height) // 2))

# Resize the image to square_dim x square_dim
img = img.resize((square_dim, square_dim), Image.LANCZOS)

# Save the image
img.save(row.crop_path)
img.close()
except Exception as e:
exception(f'Error cropping {row.image_path} {e}')
raise e

def crop_square_image(row, square_dim: int):
"""
Crop the image to a square padding the shorted dimension, then resize it to square_dim x square_dim
Crop the image to a square padding the shortest dimension, then resize it to square_dim x square_dim
This also adjusts the crop to make sure the crop is fully in the frame, otherwise the crop that
exceeds the frame is filled with black bars - these produce clusters of "edge" objects instead
of the detection
Expand Down
Loading