diff --git a/README.md b/README.md index 85c8899..8cc2534 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ This implementation utilizes motion scoring only (no appearance scoring) ```$ pip install openmht``` -For plotting tracks with TrackVis, also install matplotlib: +To also plot tracks after completion, install matplotlib: ```$ pip install matplotlib``` @@ -52,9 +52,9 @@ Track tree pruning parameters: ```$ python -m openmht InputDetections.csv OutputDetections.csv ParameterFile.txt``` -**TrackVis** takes a CSV in the MHT output format and shows a figure of the resulting color-coded tracks. Optionally, it saves the figure to the provided image filepath. +For generating track plots, add the **--plot** parameter (requires **matplotlib**): -```$ python -m trackvis SampleOutput.csv -o OutputTracks.png``` +```$ python -m openmht ... --plot``` ## Example Results diff --git a/openmht/cli.py b/openmht/cli.py index 0cf18d8..9dfa537 100644 --- a/openmht/cli.py +++ b/openmht/cli.py @@ -46,7 +46,7 @@ def read_uv_csv(file_path, frame_max=100): detections[detection_index].append([u, v]) line_count += 1 - logging.info(f"Reading inputs complete. Processed {line_count} lines.") + logging.info("Reading inputs complete. Processed %d lines.", line_count) return detections @@ -69,17 +69,6 @@ def write_uv_csv(file_path, solution_coordinates): csv_rows.append([frame_index, track_index, u, v]) - # for i in range(len(solution_coordinates)): - # track_coordinates = solution_coordinates[i] - # for j in range(len(track_coordinates)): - # coordinate = track_coordinates[j] - # if coordinate is None: - # u = v = 'None' - # else: - # u, v = [str(x) for x in coordinate] - - # csv_rows.append([j, i, u, v]) - # Sort the results by frame number csv_rows.sort(key=lambda x: x[0]) with open(file_path, 'w', encoding='utf-8-sig') as csv_file: @@ -87,7 +76,7 @@ def write_uv_csv(file_path, solution_coordinates): writer.writerow(['frame', 'track', 'u', 'v']) writer.writerows(csv_rows) - logging.info(f"CSV saved to {file_path}") + logging.info("CSV saved to %s", file_path) def read_parameters(params_file_path): @@ -96,16 +85,16 @@ def read_parameters(params_file_path): params = {} # Open the parameter file and read in the parameters - with open(params_file_path, encoding='utf-8-sig') as f: - for line in f: + with open(params_file_path, encoding='utf-8-sig') as file: + for line in file: line_data = line.split("#")[0].split('=') if len(line_data) == 2: key, val = [s.strip() for s in line_data] if key in param_keys: try: val = float(val) - except ValueError: - raise AssertionError(f"Incorrect value type in params.txt: {line}") + except ValueError as exc: + raise AssertionError(f"Incorrect value type in params.txt: {line}") from exc param_keys.remove(key) params[key] = val @@ -120,11 +109,24 @@ def read_parameters(params_file_path): def run(cli_args=None): """Read in the command line parameters and run MHT.""" + + # Get the version from the package + import pkg_resources + __version__ = pkg_resources.require("openmht")[0].version + logging.info("OpenMHT version %s", __version__) + + # MHT parameters parser = argparse.ArgumentParser() parser.add_argument('ifile', help="Input CSV file path") parser.add_argument('ofile', help="Output CSV file path") parser.add_argument('pfile', help='Path to the parameter text file') + # Version parameter + parser.add_argument('-V', '--version', action='version', version=f"OpenMHT version {__version__}") + + # Track visualization parameters + parser.add_argument('-p', '--plot', action='store_true', help="Plot the tracks") + # Parse arguments args = parser.parse_args(cli_args) input_file = args.ifile @@ -139,24 +141,24 @@ def run(cli_args=None): assert Path(output_file).suffix == '.csv', f"Output file is not CSV: {output_file}" assert Path(param_file).suffix == '.txt', f"Parameter file is not TXT: {param_file}" - except AssertionError as e: - print(e) + except AssertionError as param_error: + print(param_error) sys.exit(2) - logging.info(f"Input file is: {input_file}") - logging.info(f"Output file is: {output_file}") - logging.info(f"Parameter file is: {param_file}") + logging.info("Input file is: %s", input_file) + logging.info("Output file is: %s", output_file) + logging.info("Parameter file is: %s", param_file) # Read MHT parameters try: params = read_parameters(param_file) - logging.info(f"MHT parameters: {params}") + logging.info("MHT parameters: %s", params) - except AssertionError as e: - print(e) + except AssertionError as param_error: + print(param_error) sys.exit(2) - # run MHT on detections + # Run MHT on detections detections = read_uv_csv(input_file) start = time.time() mht = MHT(detections, params) @@ -164,4 +166,14 @@ def run(cli_args=None): write_uv_csv(output_file, solution_coordinates) end = time.time() elapsed_seconds = end - start - logging.info(f"Elapsed time (seconds): {elapsed_seconds:.3f}") + logging.info("Elapsed time (seconds): %.3f", elapsed_seconds) + + # Plot the tracks + if args.plot: + + # Import here to allow running without matplotlib + from .plot_tracks import plot_2d_tracks + + logging.info("Plotting tracks...") + plot_2d_tracks(output_file) + logging.info("Done.") diff --git a/openmht/mht.py b/openmht/mht.py index 3014970..e99dddd 100644 --- a/openmht/mht.py +++ b/openmht/mht.py @@ -117,7 +117,7 @@ def __generate_track_trees(self): # Log the N-miss pruning if nmiss_prune_count > 0: - logging.info("Pruned %d branch(es) at frame %d [nmiss]", nmiss_prune_count, frame_index) + logging.info("[nmiss] Pruned %d branch(es) at frame %d", nmiss_prune_count, frame_index) # Prune subtrees that diverge from the solution_trees at frame k-N prune_index = max(0, frame_index-n_scan) @@ -146,7 +146,7 @@ def __generate_track_trees(self): # Log the N-scan pruning if n_scan_prune_count > 0: - logging.info("Pruned %d branch(es) at frame N-%d [nscan]", n_scan_prune_count, n_scan) + logging.info("[nscan] Pruned %d branch(es) at frame N-%d", n_scan_prune_count, n_scan) # Prune branches that exceed the maximum number of branches and keep only the top b_th branches branch_count = branches_added - len(prune_ids) @@ -165,16 +165,13 @@ def __generate_track_trees(self): # Log the B-threshold pruning b_th_prune_count = branches_added - len(prune_ids) if b_th_prune_count > 0: - logging.info("Pruned %d branch(es) using B-threshold [bth].", b_th_prune_count) + logging.info("[bth] Pruned %d branch(es) using B-threshold.", b_th_prune_count) # Prune tracks identified by n-scan, n-miss, and b-threshold for k in sorted(prune_ids, reverse=True): del track_detections[k] del kalman_filters[k] - - # Log the total pruning for this frame - logging.info("Pruned %d total branch(es) at frame %d", len(prune_ids), frame_index) - + frame_index += 1 logging.info("Generated %d track trees", len(solution_coordinates)) diff --git a/trackvis/plot_tracks.py b/openmht/plot_tracks.py similarity index 78% rename from trackvis/plot_tracks.py rename to openmht/plot_tracks.py index 9166f79..98dd1ad 100644 --- a/trackvis/plot_tracks.py +++ b/openmht/plot_tracks.py @@ -3,14 +3,14 @@ import os import matplotlib.pyplot as plt -def plot_2d_tracks(input_csv, output_png): +def plot_2d_tracks(input_csv): """Plot tracks from a file in CSV format.""" # Get the filename for the plot title filename = os.path.basename(input_csv) - plot_title = "Tracks from {}".format(filename) + plot_title = f"Tracks from {filename}" # Read the CSV file - with open(input_csv, 'r') as f: + with open(input_csv, 'r', encoding='utf-8-sig') as f: lines = f.readlines() # Remove the header @@ -48,25 +48,22 @@ def plot_2d_tracks(input_csv, output_png): track.append((float(x.strip()), float(y.strip()))) # Plot the tracks fig = plt.figure() - ax = fig.add_subplot(111) + plot_axis = fig.add_subplot(111) for track in tracks: # Set a unique color for each track - color = next(ax._get_lines.prop_cycler)['color'] + color = next(plot_axis._get_lines.prop_cycler)['color'] # Convert the track to a list of X and Y coordinates x, y = zip(*track) # Plot the track as a line - ax.plot(x, y, color=color) + plot_axis.plot(x, y, color=color) # Plot each point in the track as a black dot - ax.scatter(x, y, color='black') + plot_axis.scatter(x, y, color='black') # Set the axis labels and title - ax.set_xlabel('X') - ax.set_ylabel('Y') - ax.set_title(plot_title) - if output_png is not None: - plt.savefig(output_png) - + plot_axis.set_xlabel('X') + plot_axis.set_ylabel('Y') + plot_axis.set_title(plot_title) plt.show() diff --git a/setup.py b/setup.py index f174621..7607d7b 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,12 @@ +"""Setup script for OpenMHT.""" import setuptools -with open("README.md", "r") as fh: +with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setuptools.setup( name="openmht", - version="2.0.0", + version="2.0.1", author="Jonathan Elliot Perdomo", author_email="jonperdomodb@gmail.com", description="OpenMHT", @@ -22,7 +23,6 @@ entry_points={ 'console_scripts': [ 'openmht = openmht.__main__:main', - 'trackvis = trackvis.__main__:main' ] }, ) diff --git a/trackvis/__main__.py b/trackvis/__main__.py deleted file mode 100644 index c3e92a0..0000000 --- a/trackvis/__main__.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python - -"""Plot MHT tracks from a file.""" - -import argparse - -from trackvis.plot_tracks import plot_2d_tracks - - -def main(): - """Plot the tracks in the given file.""" - # Create the parser - parser = argparse.ArgumentParser(description="Plot tracks from a file.") - parser.add_argument("file", help="The file to plot.") - parser.add_argument("-o", "--output", help="The output file name.") - - # Parse the arguments - args = parser.parse_args() - - # Plot the tracks - plot_2d_tracks(args.file, args.output) - - -if __name__ == '__main__': - main()