Skip to content

Commit

Permalink
Move plots into the main module
Browse files Browse the repository at this point in the history
  • Loading branch information
jonperdomo committed Sep 24, 2023
1 parent 5627c53 commit 1c5414b
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 78 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This implementation utilizes motion scoring only (no appearance scoring)

```$ pip install openmht```

For plotting tracks with TrackVis, also install matplotlib:
To also plot tracks after completion, install matplotlib:

```$ pip install matplotlib```

Expand Down Expand Up @@ -52,9 +52,9 @@ Track tree pruning parameters:

```$ python -m openmht InputDetections.csv OutputDetections.csv ParameterFile.txt```

**TrackVis** takes a CSV in the MHT output format and shows a figure of the resulting color-coded tracks. Optionally, it saves the figure to the provided image filepath.
For generating track plots, add the **--plot** parameter (requires **matplotlib**):

```$ python -m trackvis SampleOutput.csv -o OutputTracks.png```
```$ python -m openmht ... --plot```

## Example Results

Expand Down
66 changes: 39 additions & 27 deletions openmht/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def read_uv_csv(file_path, frame_max=100):
detections[detection_index].append([u, v])
line_count += 1

logging.info(f"Reading inputs complete. Processed {line_count} lines.")
logging.info("Reading inputs complete. Processed %d lines.", line_count)

return detections

Expand All @@ -69,25 +69,14 @@ def write_uv_csv(file_path, solution_coordinates):

csv_rows.append([frame_index, track_index, u, v])

# for i in range(len(solution_coordinates)):
# track_coordinates = solution_coordinates[i]
# for j in range(len(track_coordinates)):
# coordinate = track_coordinates[j]
# if coordinate is None:
# u = v = 'None'
# else:
# u, v = [str(x) for x in coordinate]

# csv_rows.append([j, i, u, v])

# Sort the results by frame number
csv_rows.sort(key=lambda x: x[0])
with open(file_path, 'w', encoding='utf-8-sig') as csv_file:
writer = csv.writer(csv_file, lineterminator='\n')
writer.writerow(['frame', 'track', 'u', 'v'])
writer.writerows(csv_rows)

logging.info(f"CSV saved to {file_path}")
logging.info("CSV saved to %s", file_path)


def read_parameters(params_file_path):
Expand All @@ -96,16 +85,16 @@ def read_parameters(params_file_path):
params = {}

# Open the parameter file and read in the parameters
with open(params_file_path, encoding='utf-8-sig') as f:
for line in f:
with open(params_file_path, encoding='utf-8-sig') as file:
for line in file:
line_data = line.split("#")[0].split('=')
if len(line_data) == 2:
key, val = [s.strip() for s in line_data]
if key in param_keys:
try:
val = float(val)
except ValueError:
raise AssertionError(f"Incorrect value type in params.txt: {line}")
except ValueError as exc:
raise AssertionError(f"Incorrect value type in params.txt: {line}") from exc

param_keys.remove(key)
params[key] = val
Expand All @@ -120,11 +109,24 @@ def read_parameters(params_file_path):

def run(cli_args=None):
"""Read in the command line parameters and run MHT."""

# Get the version from the package
import pkg_resources
__version__ = pkg_resources.require("openmht")[0].version
logging.info("OpenMHT version %s", __version__)

# MHT parameters
parser = argparse.ArgumentParser()
parser.add_argument('ifile', help="Input CSV file path")
parser.add_argument('ofile', help="Output CSV file path")
parser.add_argument('pfile', help='Path to the parameter text file')

# Version parameter
parser.add_argument('-V', '--version', action='version', version=f"OpenMHT version {__version__}")

# Track visualization parameters
parser.add_argument('-p', '--plot', action='store_true', help="Plot the tracks")

# Parse arguments
args = parser.parse_args(cli_args)
input_file = args.ifile
Expand All @@ -139,29 +141,39 @@ def run(cli_args=None):
assert Path(output_file).suffix == '.csv', f"Output file is not CSV: {output_file}"
assert Path(param_file).suffix == '.txt', f"Parameter file is not TXT: {param_file}"

except AssertionError as e:
print(e)
except AssertionError as param_error:
print(param_error)
sys.exit(2)

logging.info(f"Input file is: {input_file}")
logging.info(f"Output file is: {output_file}")
logging.info(f"Parameter file is: {param_file}")
logging.info("Input file is: %s", input_file)
logging.info("Output file is: %s", output_file)
logging.info("Parameter file is: %s", param_file)

# Read MHT parameters
try:
params = read_parameters(param_file)
logging.info(f"MHT parameters: {params}")
logging.info("MHT parameters: %s", params)

except AssertionError as e:
print(e)
except AssertionError as param_error:
print(param_error)
sys.exit(2)

# run MHT on detections
# Run MHT on detections
detections = read_uv_csv(input_file)
start = time.time()
mht = MHT(detections, params)
solution_coordinates = mht.run()
write_uv_csv(output_file, solution_coordinates)
end = time.time()
elapsed_seconds = end - start
logging.info(f"Elapsed time (seconds): {elapsed_seconds:.3f}")
logging.info("Elapsed time (seconds): %.3f", elapsed_seconds)

# Plot the tracks
if args.plot:

# Import here to allow running without matplotlib
from .plot_tracks import plot_2d_tracks

logging.info("Plotting tracks...")
plot_2d_tracks(output_file)
logging.info("Done.")
11 changes: 4 additions & 7 deletions openmht/mht.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __generate_track_trees(self):

# Log the N-miss pruning
if nmiss_prune_count > 0:
logging.info("Pruned %d branch(es) at frame %d [nmiss]", nmiss_prune_count, frame_index)
logging.info("[nmiss] Pruned %d branch(es) at frame %d", nmiss_prune_count, frame_index)

# Prune subtrees that diverge from the solution_trees at frame k-N
prune_index = max(0, frame_index-n_scan)
Expand Down Expand Up @@ -146,7 +146,7 @@ def __generate_track_trees(self):

# Log the N-scan pruning
if n_scan_prune_count > 0:
logging.info("Pruned %d branch(es) at frame N-%d [nscan]", n_scan_prune_count, n_scan)
logging.info("[nscan] Pruned %d branch(es) at frame N-%d", n_scan_prune_count, n_scan)

# Prune branches that exceed the maximum number of branches and keep only the top b_th branches
branch_count = branches_added - len(prune_ids)
Expand All @@ -165,16 +165,13 @@ def __generate_track_trees(self):
# Log the B-threshold pruning
b_th_prune_count = branches_added - len(prune_ids)
if b_th_prune_count > 0:
logging.info("Pruned %d branch(es) using B-threshold [bth].", b_th_prune_count)
logging.info("[bth] Pruned %d branch(es) using B-threshold.", b_th_prune_count)

# Prune tracks identified by n-scan, n-miss, and b-threshold
for k in sorted(prune_ids, reverse=True):
del track_detections[k]
del kalman_filters[k]

# Log the total pruning for this frame
logging.info("Pruned %d total branch(es) at frame %d", len(prune_ids), frame_index)


frame_index += 1

logging.info("Generated %d track trees", len(solution_coordinates))
Expand Down
23 changes: 10 additions & 13 deletions trackvis/plot_tracks.py → openmht/plot_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import os
import matplotlib.pyplot as plt

def plot_2d_tracks(input_csv, output_png):
def plot_2d_tracks(input_csv):
"""Plot tracks from a file in CSV format."""
# Get the filename for the plot title
filename = os.path.basename(input_csv)
plot_title = "Tracks from {}".format(filename)
plot_title = f"Tracks from {filename}"

# Read the CSV file
with open(input_csv, 'r') as f:
with open(input_csv, 'r', encoding='utf-8-sig') as f:
lines = f.readlines()

# Remove the header
Expand Down Expand Up @@ -48,25 +48,22 @@ def plot_2d_tracks(input_csv, output_png):
track.append((float(x.strip()), float(y.strip())))
# Plot the tracks
fig = plt.figure()
ax = fig.add_subplot(111)
plot_axis = fig.add_subplot(111)
for track in tracks:
# Set a unique color for each track
color = next(ax._get_lines.prop_cycler)['color']
color = next(plot_axis._get_lines.prop_cycler)['color']

# Convert the track to a list of X and Y coordinates
x, y = zip(*track)

# Plot the track as a line
ax.plot(x, y, color=color)
plot_axis.plot(x, y, color=color)

# Plot each point in the track as a black dot
ax.scatter(x, y, color='black')
plot_axis.scatter(x, y, color='black')

# Set the axis labels and title
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title(plot_title)
if output_png is not None:
plt.savefig(output_png)

plot_axis.set_xlabel('X')
plot_axis.set_ylabel('Y')
plot_axis.set_title(plot_title)
plt.show()
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Setup script for OpenMHT."""
import setuptools

with open("README.md", "r") as fh:
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

setuptools.setup(
name="openmht",
version="2.0.0",
version="2.0.1",
author="Jonathan Elliot Perdomo",
author_email="[email protected]",
description="OpenMHT",
Expand All @@ -22,7 +23,6 @@
entry_points={
'console_scripts': [
'openmht = openmht.__main__:main',
'trackvis = trackvis.__main__:main'
]
},
)
25 changes: 0 additions & 25 deletions trackvis/__main__.py

This file was deleted.

0 comments on commit 1c5414b

Please sign in to comment.