diff --git a/examples/outputs/two_body.mp4 b/examples/outputs/two_body.mp4 new file mode 100644 index 0000000..4a0fcde Binary files /dev/null and b/examples/outputs/two_body.mp4 differ diff --git a/examples/simulate_two_body_problem.py b/examples/simulate_two_body_problem.py index 7ecdc85..a96f88c 100644 --- a/examples/simulate_two_body_problem.py +++ b/examples/simulate_two_body_problem.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt from nbodyx.constants import G, M_sun, M_earth, AU from nbodyx.ode import make_n_body_ode -from nbodyx.rendering.opencv import render_n_body_in_opencv +from nbodyx.rendering.opencv import animate_n_body, render_n_body if __name__ == "__main__": ode_fn = make_n_body_ode(jnp.array([M_sun, M_earth])) @@ -18,26 +18,24 @@ theta_earth = jnp.arctan2(x_earth[1], x_earth[0]) v0_earth = jnp.sqrt(G * M_sun / AU) v0_earth = jnp.array([-v0_earth * jnp.sin(theta_earth), v0_earth * jnp.cos(theta_earth)]) - # initial conditions for sun x_sun = jnp.array([0.0, 0.0]) v_sun = jnp.array([0.0, 0.0]) - # initial state y0 = jnp.concatenate([x_sun, x_earth, v_sun, v0_earth]) print("y0", y0) + # state bounds + x_min, x_max = -2 * AU * jnp.ones((1,)), 2 * AU * jnp.ones((1,)) + # evaluate the ODE at the initial state y_d0 = jit(ode_fn)(0.0, y0) print("y_d0", y_d0) # render the image at the initial state - img = render_n_body_in_opencv( - jnp.array([x_sun, x_earth]), 500, 500, -2 * AU * jnp.ones((1,)), 2 * AU * jnp.ones((1,)) - ) + img = render_n_body(jnp.array([x_sun, x_earth]), 500, 500, x_min, x_max) plt.figure(num="Sample rendering") plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - # plt.title(f"y0 = {y0}") plt.show() # simulation settings @@ -63,3 +61,16 @@ plt.grid(True) plt.box(True) plt.show() + + # animate the solution + animate_n_body( + ts, + x_bds_ts, + 500, + 500, + video_path="examples/outputs/two_body.mp4", + speed_up=ts[-1] / 10, + x_min=x_min, + x_max=x_max, + timestamp_unit="M", + ) diff --git a/pyproject.toml b/pyproject.toml index 10296ce..9b799d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,17 +12,16 @@ classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] -requires-python = ">=3.8.1" +requires-python = ">=3.10.0" dynamic = ["version"] -dependencies = ["diffrax", "jax", "numpy", "opencv-python"] +dependencies = ["diffrax", "jax", "numpy", "opencv-python", "tqdm"] [project.optional-dependencies] -examples = ["matplotlib", "seaborn"] +examples = ["matplotlib"] spark = ["pyspark>=3.0.0"] test = [ "bandit[toml]==1.7.5", diff --git a/src/nbodyx/rendering/opencv.py b/src/nbodyx/rendering/opencv.py index 31ec8fe..31c1d99 100644 --- a/src/nbodyx/rendering/opencv.py +++ b/src/nbodyx/rendering/opencv.py @@ -3,10 +3,12 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as onp -from typing import List +from os import PathLike +from pathlib import Path +from tqdm import tqdm -def render_n_body_in_opencv( +def render_n_body( x_bds: Array, width: int, height: int, @@ -14,6 +16,7 @@ def render_n_body_in_opencv( x_max: Array, body_radii: Array = None, body_colors: Array = None, + label: str = None, ) -> onp.ndarray: """Render the n-body problem using OpenCV. @@ -23,6 +26,7 @@ def render_n_body_in_opencv( height: The height of the image. body_radii: The radii of the bodies. Array of shape (num_bodies, ). body_colors: The RGB colors of the bodies. Array of shape (num_bodies, 3). + label: The label of the image. Returns: img: The rendered image. """ @@ -60,4 +64,75 @@ def x_to_uv(x: Array) -> Array: color = tuple(body_colors[i].tolist()) cv2.circle(img, center, int(body_radii[i]), color, -1) + # draw the label + if label is not None: + font = cv2.FONT_HERSHEY_SIMPLEX + bottom_left_corner_of_text = (10, 50) + font_scale = 0.5 + font_color = (255, 255, 255) + line_type = 2 + cv2.putText( + img, + label, + bottom_left_corner_of_text, + font, + font_scale, + font_color, + line_type, + ) + return img + + +def animate_n_body( + ts: Array, + x_bds_ts: Array, + width: int, + height: int, + video_path: PathLike, + speed_up: int = 1, + skip_step: int = 1, + add_timestamp: bool = True, + timestamp_unit: str = "s", + **kwargs, +): + dt = jnp.mean(jnp.diff(ts)).item() + fps = float(speed_up / (skip_step * dt)) + print(f"fps: {fps}") + + # create video + fourcc = cv2.VideoWriter_fourcc(*"MP4V") + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + video = cv2.VideoWriter( + str(video_path), + fourcc, + fps, # fps, + (width, height), + ) + + # skip frames + ts = ts[::skip_step] + x_bds_ts = x_bds_ts[::skip_step] + + for time_idx, t in (pbar := tqdm(enumerate(ts))): + pbar.set_description(f"Rendering frame {time_idx + 1}/{len(ts)}") + + label = None + if add_timestamp: + if timestamp_unit == "s": + label = f"t = {t:.1f} seconds" + elif timestamp_unit == "d": + label = f"t = {t / (24 * 3600):.1f} days" + elif timestamp_unit == "M": + label = f"t = {t / (30 * 24 * 3600):.1f} months" + elif timestamp_unit == "y": + label = f"t = {t / (365 * 24 * 3600):.1f} years" + else: + raise ValueError(f"Invalid timestamp unit: {timestamp_unit}") + + # render the image + img = render_n_body(x_bds_ts[time_idx], width, height, label=label, **kwargs) + + video.write(img) + + video.release()