Skip to content

Commit

Permalink
Add animation
Browse files Browse the repository at this point in the history
  • Loading branch information
mstoelzle committed Mar 7, 2024
1 parent ce3cf3d commit f8ae9df
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 14 deletions.
Binary file added examples/outputs/two_body.mp4
Binary file not shown.
25 changes: 18 additions & 7 deletions examples/simulate_two_body_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from nbodyx.constants import G, M_sun, M_earth, AU
from nbodyx.ode import make_n_body_ode
from nbodyx.rendering.opencv import render_n_body_in_opencv
from nbodyx.rendering.opencv import animate_n_body, render_n_body

if __name__ == "__main__":
ode_fn = make_n_body_ode(jnp.array([M_sun, M_earth]))
Expand All @@ -18,26 +18,24 @@
theta_earth = jnp.arctan2(x_earth[1], x_earth[0])
v0_earth = jnp.sqrt(G * M_sun / AU)
v0_earth = jnp.array([-v0_earth * jnp.sin(theta_earth), v0_earth * jnp.cos(theta_earth)])

# initial conditions for sun
x_sun = jnp.array([0.0, 0.0])
v_sun = jnp.array([0.0, 0.0])

# initial state
y0 = jnp.concatenate([x_sun, x_earth, v_sun, v0_earth])
print("y0", y0)

# state bounds
x_min, x_max = -2 * AU * jnp.ones((1,)), 2 * AU * jnp.ones((1,))

# evaluate the ODE at the initial state
y_d0 = jit(ode_fn)(0.0, y0)
print("y_d0", y_d0)

# render the image at the initial state
img = render_n_body_in_opencv(
jnp.array([x_sun, x_earth]), 500, 500, -2 * AU * jnp.ones((1,)), 2 * AU * jnp.ones((1,))
)
img = render_n_body(jnp.array([x_sun, x_earth]), 500, 500, x_min, x_max)
plt.figure(num="Sample rendering")
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# plt.title(f"y0 = {y0}")
plt.show()

# simulation settings
Expand All @@ -63,3 +61,16 @@
plt.grid(True)
plt.box(True)
plt.show()

# animate the solution
animate_n_body(
ts,
x_bds_ts,
500,
500,
video_path="examples/outputs/two_body.mp4",
speed_up=ts[-1] / 10,
x_min=x_min,
x_max=x_max,
timestamp_unit="M",
)
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@ classifiers = [
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
requires-python = ">=3.8.1"
requires-python = ">=3.10.0"
dynamic = ["version"]
dependencies = ["diffrax", "jax", "numpy", "opencv-python"]
dependencies = ["diffrax", "jax", "numpy", "opencv-python", "tqdm"]

[project.optional-dependencies]
examples = ["matplotlib", "seaborn"]
examples = ["matplotlib"]
spark = ["pyspark>=3.0.0"]
test = [
"bandit[toml]==1.7.5",
Expand Down
79 changes: 77 additions & 2 deletions src/nbodyx/rendering/opencv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from typing import List
from os import PathLike
from pathlib import Path
from tqdm import tqdm


def render_n_body_in_opencv(
def render_n_body(
x_bds: Array,
width: int,
height: int,
x_min: Array,
x_max: Array,
body_radii: Array = None,
body_colors: Array = None,
label: str = None,
) -> onp.ndarray:
"""Render the n-body problem using OpenCV.
Expand All @@ -23,6 +26,7 @@ def render_n_body_in_opencv(
height: The height of the image.
body_radii: The radii of the bodies. Array of shape (num_bodies, ).
body_colors: The RGB colors of the bodies. Array of shape (num_bodies, 3).
label: The label of the image.
Returns:
img: The rendered image.
"""
Expand Down Expand Up @@ -60,4 +64,75 @@ def x_to_uv(x: Array) -> Array:
color = tuple(body_colors[i].tolist())
cv2.circle(img, center, int(body_radii[i]), color, -1)

# draw the label
if label is not None:
font = cv2.FONT_HERSHEY_SIMPLEX
bottom_left_corner_of_text = (10, 50)
font_scale = 0.5
font_color = (255, 255, 255)
line_type = 2
cv2.putText(
img,
label,
bottom_left_corner_of_text,
font,
font_scale,
font_color,
line_type,
)

return img


def animate_n_body(
ts: Array,
x_bds_ts: Array,
width: int,
height: int,
video_path: PathLike,
speed_up: int = 1,
skip_step: int = 1,
add_timestamp: bool = True,
timestamp_unit: str = "s",
**kwargs,
):
dt = jnp.mean(jnp.diff(ts)).item()
fps = float(speed_up / (skip_step * dt))
print(f"fps: {fps}")

# create video
fourcc = cv2.VideoWriter_fourcc(*"MP4V")
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
video = cv2.VideoWriter(
str(video_path),
fourcc,
fps, # fps,
(width, height),
)

# skip frames
ts = ts[::skip_step]
x_bds_ts = x_bds_ts[::skip_step]

for time_idx, t in (pbar := tqdm(enumerate(ts))):
pbar.set_description(f"Rendering frame {time_idx + 1}/{len(ts)}")

label = None
if add_timestamp:
if timestamp_unit == "s":
label = f"t = {t:.1f} seconds"
elif timestamp_unit == "d":
label = f"t = {t / (24 * 3600):.1f} days"
elif timestamp_unit == "M":
label = f"t = {t / (30 * 24 * 3600):.1f} months"
elif timestamp_unit == "y":
label = f"t = {t / (365 * 24 * 3600):.1f} years"
else:
raise ValueError(f"Invalid timestamp unit: {timestamp_unit}")

# render the image
img = render_n_body(x_bds_ts[time_idx], width, height, label=label, **kwargs)

video.write(img)

video.release()

0 comments on commit f8ae9df

Please sign in to comment.