Skip to content

Commit

Permalink
Allow mutiple cache entries
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Dec 30, 2024
1 parent f09d290 commit b77ddc5
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 20 deletions.
11 changes: 10 additions & 1 deletion auto_editor/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,16 @@ def get_domain(url: str) -> str:


def main() -> None:
subcommands = ("test", "info", "levels", "subdump", "desc", "repl", "palet")
subcommands = (
"test",
"info",
"levels",
"subdump",
"desc",
"repl",
"palet",
"cache",
)

if len(sys.argv) > 1 and sys.argv[1] in subcommands:
obj = __import__(
Expand Down
42 changes: 23 additions & 19 deletions auto_editor/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from dataclasses import dataclass
from fractions import Fraction
from hashlib import sha1
from math import ceil
from tempfile import gettempdir
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -154,8 +155,10 @@ def iter_motion(

def obj_tag(path: Path, kind: str, tb: Fraction, obj: Sequence[object]) -> str:
mod_time = int(path.stat().st_mtime)
key = f"{path.name}:{mod_time:x}:{kind}:{tb}:"
return key + ",".join(f"{v}" for v in obj)
key = f"{path}:{mod_time:x}:{tb.numerator}:{tb.denominator}:"
part1 = sha1(key.encode()).hexdigest()[:16]

return f"{part1}{kind}," + ",".join(f"{v}" for v in obj)


@dataclass(slots=True)
Expand Down Expand Up @@ -206,31 +209,31 @@ def read_cache(self, kind: str, obj: Sequence[object]) -> None | np.ndarray:
if self.no_cache:
return None

workfile = os.path.join(gettempdir(), f"ae-{__version__}", "cache.npz")
key = obj_tag(self.src.path, kind, self.tb, obj)
cache_file = os.path.join(gettempdir(), f"ae-{__version__}", f"{key}.npz")

try:
npzfile = np.load(workfile, allow_pickle=False)
with np.load(cache_file, allow_pickle=False) as npzfile:
return npzfile["data"]
except Exception as e:
self.log.debug(e)
return None

key = obj_tag(self.src.path, kind, self.tb, obj)
if key not in npzfile.files:
return None

self.log.debug("Using cache")
return npzfile[key]

def cache(self, arr: np.ndarray, kind: str, obj: Sequence[object]) -> np.ndarray:
if self.no_cache:
return arr

workdur = os.path.join(gettempdir(), f"ae-{__version__}")
if not os.path.exists(workdur):
os.mkdir(workdur)
workdir = os.path.join(gettempdir(), f"ae-{__version__}")
if not os.path.exists(workdir):
os.mkdir(workdir)

key = obj_tag(self.src.path, kind, self.tb, obj)
np.savez(os.path.join(workdur, "cache.npz"), **{key: arr})
cache_file = os.path.join(workdir, f"{key}.npz")

try:
np.savez(cache_file, data=arr)
except Exception as e:
self.log.warning(f"Cache write failed: {e}")

return arr

Expand All @@ -257,14 +260,15 @@ def audio(self, stream: int) -> NDArray[np.float32]:
bar = self.bar
bar.start(inaccurate_dur, "Analyzing audio volume")

result = np.zeros((inaccurate_dur), dtype=np.float32)
result: NDArray[np.float32] = np.zeros(inaccurate_dur, dtype=np.float32)
index = 0

for value in iter_audio(audio, self.tb):
if index > len(result) - 1:
result = np.concatenate(
(result, np.zeros((len(result)), dtype=np.float32))
(result, np.zeros(len(result), dtype=np.float32))
)

result[index] = value
bar.tick(index)
index += 1
Expand Down Expand Up @@ -296,13 +300,13 @@ def motion(self, stream: int, blur: int, width: int) -> NDArray[np.float32]:
bar = self.bar
bar.start(inaccurate_dur, "Analyzing motion")

result = np.zeros((inaccurate_dur), dtype=np.float32)
result: NDArray[np.float32] = np.zeros(inaccurate_dur, dtype=np.float32)
index = 0

for value in iter_motion(video, self.tb, blur, width):
if index > len(result) - 1:
result = np.concatenate(
(result, np.zeros((len(result)), dtype=np.float32))
(result, np.zeros(len(result), dtype=np.float32))
)
result[index] = value
bar.tick(index)
Expand Down
69 changes: 69 additions & 0 deletions auto_editor/subcommands/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import glob
import os
import sys
from shutil import rmtree
from tempfile import gettempdir

import numpy as np

from auto_editor import __version__


def main(sys_args: list[str] = sys.argv[1:]) -> None:
cache_dir = os.path.join(gettempdir(), f"ae-{__version__}")

if sys_args and sys_args[0] in ("clean", "clear"):
rmtree(cache_dir, ignore_errors=True)
return

if not os.path.exists(cache_dir):
print("Empty cache")
return

cache_files = glob.glob(os.path.join(cache_dir, "*.npz"))
if not cache_files:
print("Empty cache")
return

def format_bytes(size: float) -> str:
for unit in ("B", "KiB", "MiB", "GiB", "TiB"):
if size < 1024:
return f"{size:.2f} {unit}"
size /= 1024
return f"{size:.2f} PiB"

GRAY = "\033[90m"
GREEN = "\033[32m"
BLUE = "\033[34m"
YELLOW = "\033[33m"
RESET = "\033[0m"

total_size = 0
for cache_file in cache_files:
try:
with np.load(cache_file, allow_pickle=False) as npzfile:
array = npzfile["data"]
key = os.path.basename(cache_file)[:-4] # Remove .npz extension

hash_part = key[:16]
rest_part = key[16:]

size = array.nbytes
total_size += size
size_str = format_bytes(size)
size_num, size_unit = size_str.rsplit(" ", 1)

print(
f"{YELLOW}entry: {GRAY}{hash_part}{RESET}{rest_part} "
f"{YELLOW}size: {GREEN}{size_num} {BLUE}{size_unit}{RESET}"
)
except Exception as e:
print(f"Error reading {cache_file}: {e}")

total_str = format_bytes(total_size)
total_num, total_unit = total_str.rsplit(" ", 1)
print(f"\n{YELLOW}total cache size: {GREEN}{total_num} {BLUE}{total_unit}{RESET}")


if __name__ == "__main__":
main()

0 comments on commit b77ddc5

Please sign in to comment.