Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Day 22: Speed up with Cython #204

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: aoc2024-py
path: py/dist/*.whl
path: py/dist/*

run:
needs: [ get-inputs, build ]
Expand Down
4 changes: 4 additions & 0 deletions py/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
__pycache__/
build/
dist/
*.c
*.pyd
*.so
*~
41 changes: 5 additions & 36 deletions py/aoc2024/day22.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Day 22: Monkey Market
"""

from collections import deque
from ctypes import c_bool, c_int, memset
from functools import reduce
from array import array

from aoc2024.day22c import part1 as _part1, part2 as _part2

SAMPLE_INPUT_1 = """\
1
Expand All @@ -20,51 +20,20 @@
"""


def _step(num: int) -> int:
num = num ^ num << 6 & 16777215
num = num ^ num >> 5 & 16777215
num = num ^ num << 11 & 16777215
return num


def part1(data: str) -> int:
"""
>>> part1(SAMPLE_INPUT_1)
37327623
"""
return sum(
reduce(lambda num, _: _step(num), range(2000), int(line))
for line in data.splitlines()
)
return _part1(array("I", (int(line) for line in data.splitlines())))


def part2(data: str) -> int:
"""
>>> part2(SAMPLE_INPUT_2)
23
"""
output = (c_int * (19 * 19 * 19 * 19))()
seen = (c_bool * (19 * 19 * 19 * 19))()
best = 0
for line in data.splitlines():
num = int(line)
memset(seen, False, len(seen))
window = deque((num % 10,), 5)
for _ in range(2001):
if len(window) == window.maxlen:
window.popleft()
num = _step(num)
window.append(num % 10)
if len(window) == window.maxlen:
a, b, c, d, e = window
key = (((a + 9 - b) * 19 + b + 9 - c) * 19 + c + 9 - d) * 19 + d + 9 - e
if not seen[key]:
value = output[key] + e
output[key] = value
if best < value:
best = value
seen[key] = True
return best
return _part2(array("I", (int(line) for line in data.splitlines())))


parts = (part1, part2)
5 changes: 5 additions & 0 deletions py/aoc2024/day22c.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cdef extern from "<stdatomic.h>":
ctypedef unsigned int atomic_uint
cdef unsigned int atomic_load_uint "atomic_load" (atomic_uint *obj) nogil
cdef unsigned int atomic_fetch_add_uint "atomic_fetch_add" (atomic_uint *obj, unsigned int arg) nogil
cdef unsigned int atomic_compare_exchange_weak_uint "atomic_compare_exchange_weak" (atomic_uint *obj, unsigned int *expected, unsigned int desired) nogil
89 changes: 89 additions & 0 deletions py/aoc2024/day22c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# cython: boundscheck=False, wraparound=False, initializedcheck=False, embedsignature=True
# ruff: noqa: F821
"""
Day 22: Monkey Market
"""

import cython
from cython.parallel import prange
from cython.cimports.libc.string import memset


@cython.cfunc
@cython.nogil
def _step(num: cython.uint) -> cython.uint:
num = num ^ num << 6 & 16777215
num = num ^ num >> 5 & 16777215
num = num ^ num << 11 & 16777215
return num


@cython.ccall
@cython.nogil
def part1(data: cython.uint[:]) -> cython.ulong:
i: cython.int
result: cython.ulong = 0
for i in prange(data.shape[0], nogil=True):
j: cython.int
secret: cython.uint = data[i]
for j in range(2000):
secret = _step(secret)
result += secret
return result


@cython.ccall
@cython.nogil
def part2(data: cython.uint[:]) -> cython.uint:
i: cython.int
acc: cython.uint[19 * 19 * 19 * 19]
memset(cython.address(acc[0]), 0, cython.sizeof(acc))
result: cython.uint = 0
for i in prange(data.shape[0]):
j: cython.int
secret: cython.uint = data[i]
seen: cython.bint[19 * 19 * 19 * 19]
memset(cython.address(seen[0]), 0, cython.sizeof(seen))
window: cython.uint[4]
window[0] = secret % 10
best: cython.uint = 0
cur: cython.uint
for j in range(1, 2001):
secret = _step(secret)
price: cython.uint = secret % 10
if j >= 4:
p0: cython.uint = window[j % 4]
p1: cython.uint = window[(j + 1) % 4]
p2: cython.uint = window[(j + 2) % 4]
p3: cython.uint = window[(j + 3) % 4]
d1: cython.int = p0 - p1
d2: cython.int = p1 - p2
d3: cython.int = p2 - p3
d4: cython.int = p3 - price
key: cython.uint = (
19 * (19 * (19 * (d1 + 9) + d2 + 9) + d3 + 9) + d4 + 9
)
if not seen[key]:
seen[key] = True
cur = (
atomic_fetch_add_uint(
cython.cast(
cython.pointer(atomic_uint), cython.address(acc[key])
),
price,
)
+ price
)
if best < cur:
best = cur
window[j % 4] = price
cur = atomic_load_uint(
cython.cast(cython.pointer(atomic_uint), cython.address(result))
)
while cur < best:
atomic_compare_exchange_weak_uint(
cython.cast(cython.pointer(atomic_uint), cython.address(result)),
cython.address(cur),
best,
)
return result
77 changes: 76 additions & 1 deletion py/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions py/poetry_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import shutil
from distutils.command.build_ext import build_ext
from distutils.core import Distribution, Extension

from Cython.Build import cythonize


def build():
ext_modules = cythonize(
module_list=[
Extension(
name="*",
sources=["aoc2024/*c.py"],
),
],
compiler_directives={"language_level": 3},
)

distribution = Distribution({"name": "extended", "ext_modules": ext_modules})
distribution.package_dir = "extended"

cmd = build_ext(distribution)
cmd.ensure_finalized()
cmd.run()

# Copy built extensions back to the project
for output in cmd.get_outputs():
relative_extension = os.path.relpath(output, cmd.build_lib)
shutil.copyfile(output, relative_extension)
mode = os.stat(relative_extension).st_mode
mode |= (mode & 0o444) >> 2
os.chmod(relative_extension, mode)


if __name__ == "__main__":
build()
20 changes: 18 additions & 2 deletions py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,33 @@ license = "BSD-3-Clause"
readme = "README.md"
repository = "https://github.com/ephemient/aoc2024/tree/main/py"

[[tool.poetry.include]]
path = "test_benchmark.py"
format = "sdist"

[[tool.poetry.include]]
path = "aoc2024/**/*.pyd"
format = "wheel"

[[tool.poetry.include]]
path = "aoc2024/**/*.so"
format = "wheel"

[tool.poetry.dependencies]
python = "^3.13"
natsort = "^8.4.0"
cython = "^3.0.11"

[tool.poetry.group.dev.dependencies]
ruff = "^0.8.5"
pytest = "^8.3.4"
pytest-benchmark = { version = "^5.1.0", extras = ["histogram"] }

[tool.poetry.build]
script = "poetry_build.py"

[tool.pytest.ini_options]
addopts = '--doctest-modules --benchmark-disable --benchmark-sort=fullname'
addopts = '--doctest-modules --benchmark-disable --benchmark-sort=fullname --import-mode=importlib'
required_plugins = ['pytest-benchmark']

[tool.poetry.scripts]
Expand Down Expand Up @@ -51,5 +67,5 @@ day24 = "aoc2024.day24:parts"
day25 = "aoc2024.day25:parts"

[build-system]
requires = ["poetry-core"]
requires = ["poetry-core>=1.0", "cython", "setuptools"]
build-backend = "poetry.core.masonry.api"
Loading