Skip to content

Commit

Permalink
More strict typings + all code fixed up (#77)
Browse files Browse the repository at this point in the history
* Various typing fixes for tests

* More typing fixes

* Typing for examples

* Ignore torch interop example in build

* Correctly ignore torch interop!

* Switch unnecary type ignore to warning due to cross-platform issues
  • Loading branch information
ccummingsNV authored Aug 19, 2024
1 parent 5ccf497 commit cd3c7b8
Show file tree
Hide file tree
Showing 35 changed files with 275 additions and 221 deletions.
3 changes: 1 addition & 2 deletions .vscode-default/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@
"python.testing.pytestEnabled": true,
"python.testing.pytestArgs": [
"./src/sgl"
],
"python.analysis.typeCheckingMode": "standard"
]
}
43 changes: 29 additions & 14 deletions examples/pathtracer/pathtracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional
import sgl
import numpy as np
import numpy.typing as npt
from pathlib import Path
from dataclasses import dataclass
import struct
Expand All @@ -11,6 +13,7 @@

class Camera:
def __init__(self):
super().__init__()
self.width = 100
self.height = 100
self.aspect_ratio = 1.0
Expand Down Expand Up @@ -52,6 +55,7 @@ class CameraController:
MOVE_SHIFT_FACTOR = 10.0

def __init__(self, camera: Camera):
super().__init__()
self.camera = camera
self.mouse_down = False
self.mouse_pos = sgl.float2()
Expand Down Expand Up @@ -129,12 +133,16 @@ def on_mouse_event(self, event: sgl.MouseEvent):


class Material:
def __init__(self, base_color=sgl.float3(0.5)):
def __init__(self, base_color: "sgl.float3param" = sgl.float3(0.5)):
super().__init__()
self.base_color = base_color


class Mesh:
def __init__(self, vertices, indices):
def __init__(
self, vertices: npt.NDArray[np.float32], indices: npt.NDArray[np.uint32]
):
super().__init__()
assert vertices.ndim == 2 and vertices.dtype == np.float32
assert indices.ndim == 2 and indices.dtype == np.uint32
self.vertices = vertices
Expand All @@ -153,7 +161,7 @@ def index_count(self):
return self.triangle_count * 3

@classmethod
def create_quad(cls, size=sgl.float2(1)):
def create_quad(cls, size: "sgl.float2param" = sgl.float2(1)):
vertices = np.array(
[
# position, normal, uv
Expand All @@ -175,7 +183,7 @@ def create_quad(cls, size=sgl.float2(1)):
return Mesh(vertices, indices)

@classmethod
def create_cube(cls, size=sgl.float3(1)):
def create_cube(cls, size: "sgl.float3param" = sgl.float3(1)):
vertices = np.array(
[
# position, normal, uv
Expand Down Expand Up @@ -237,6 +245,7 @@ def create_cube(cls, size=sgl.float3(1)):

class Transform:
def __init__(self):
super().__init__()
self.translation = sgl.float3(0)
self.scaling = sgl.float3(1)
self.rotation = sgl.float3(0)
Expand All @@ -251,6 +260,7 @@ def update_matrix(self):

class Stage:
def __init__(self):
super().__init__()
self.camera = Camera()
self.materials = []
self.meshes = []
Expand Down Expand Up @@ -292,17 +302,17 @@ def demo(cls):
for _ in range(10):
cube_materials.append(
stage.add_material(
Material(base_color=np.random.rand(3).astype(np.float32))
Material(base_color=sgl.float3(np.random.rand(3).astype(np.float32))) # type: ignore (TYPINGTODO: need explicit np->float conversion)
)
)
cube_mesh = stage.add_mesh(Mesh.create_cube([0.1, 0.1, 0.1]))

for i in range(1000):
transform = Transform()
transform.translation = (np.random.rand(3) * 2 - 1).astype(np.float32)
transform.translation = sgl.float3((np.random.rand(3) * 2 - 1).astype(np.float32)) # type: ignore (TYPINGTODO: need explicit np->float conversion)
transform.translation[1] += 1
transform.scaling = (np.random.rand(3) + 0.5).astype(np.float32)
transform.rotation = (np.random.rand(3) * 10).astype(np.float32)
transform.scaling = sgl.float3((np.random.rand(3) + 0.5).astype(np.float32)) # type: ignore (TYPINGTODO: need explicit np->float conversion)
transform.rotation = sgl.float3((np.random.rand(3) * 10).astype(np.float32)) # type: ignore (TYPINGTODO: need explicit np->float conversion)
transform.update_matrix()
cube_transform = stage.add_transform(transform)
stage.add_instance(
Expand Down Expand Up @@ -348,6 +358,7 @@ def pack(self):
return struct.pack("III", self.mesh_id, self.material_id, self.transform_id)

def __init__(self, device: sgl.Device, stage: Stage):
super().__init__()
self.device = device

self.camera = stage.camera
Expand Down Expand Up @@ -584,6 +595,7 @@ def bind(self, cursor: sgl.ShaderCursor):

class PathTracer:
def __init__(self, device: sgl.Device, scene: Scene):
super().__init__()
self.device = device
self.scene = scene

Expand Down Expand Up @@ -611,17 +623,18 @@ def execute(

class Accumulator:
def __init__(self, device: sgl.Device):
super().__init__()
self.device = device
self.program = self.device.load_program("accumulator.slang", ["main"])
self.kernel = self.device.create_compute_kernel(self.program)
self.accumulator: sgl.Texture = None
self.accumulator: Optional[sgl.Texture] = None

def execute(
self,
command_buffer: sgl.CommandBuffer,
input: sgl.Texture,
output: sgl.Texture,
reset=False,
reset: bool = False,
):
if (
self.accumulator == None
Expand Down Expand Up @@ -653,6 +666,7 @@ def execute(

class ToneMapper:
def __init__(self, device: sgl.Device):
super().__init__()
self.device = device
self.program = self.device.load_program("tone_mapper.slang", ["main"])
self.kernel = self.device.create_compute_kernel(self.program)
Expand All @@ -674,6 +688,7 @@ def execute(

class App:
def __init__(self):
super().__init__()
self.window = sgl.Window(
width=1920, height=1080, title="PathTracer", resizable=True
)
Expand All @@ -688,9 +703,9 @@ def __init__(self):
enable_vsync=False,
)

self.render_texture: sgl.Texture = None
self.accum_texture: sgl.Texture = None
self.output_texture: sgl.Texture = None
self.render_texture: sgl.Texture = None # type: ignore (will be set immediately)
self.accum_texture: sgl.Texture = None # type: ignore (will be set immediately)
self.output_texture: sgl.Texture = None # type: ignore (will be set immediately)

self.window.on_keyboard_event = self.on_keyboard_event
self.window.on_mouse_event = self.on_mouse_event
Expand Down Expand Up @@ -726,7 +741,7 @@ def on_keyboard_event(self, event: sgl.KeyboardEvent):
def on_mouse_event(self, event: sgl.MouseEvent):
self.camera_controller.on_mouse_event(event)

def on_resize(self, width, height):
def on_resize(self, width: int, height: int):
self.device.wait()
self.swapchain.resize(width, height)

Expand Down
4 changes: 2 additions & 2 deletions examples/texture_array/texture_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def __init__(self, app: sgl.App):
program = self.device.load_program(str(EXAMPLE_DIR / "draw.slang"), ["main"])
self.kernel = self.device.create_compute_kernel(program)

self.render_texture: sgl.Texture = None
self.render_texture: sgl.Texture = None # type: ignore (will be immediately initialized)

self.setup_ui()

def setup_ui(self):
window = sgl.ui.Window(self.screen, "Settings", size=(500, 300))
window = sgl.ui.Window(self.screen, "Settings", size=sgl.float2(500, 300))

self.layer = sgl.ui.SliderInt(
window, "Layer", value=0, min=0, max=self.texture.array_size - 1
Expand Down
6 changes: 3 additions & 3 deletions examples/window/window.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

import sgl
import numpy as np
from pathlib import Path

EXAMPLE_DIR = Path(__file__).parent


class App:
def __init__(self):
super().__init__()
self.window = sgl.Window(
width=1920, height=1280, title="Example", resizable=True
)
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self):

def setup_ui(self):
screen = self.ui.screen
window = sgl.ui.Window(screen, "Settings", size=(500, 300))
window = sgl.ui.Window(screen, "Settings", size=sgl.float2(500, 300))

self.fps_text = sgl.ui.Text(window, "FPS: 0")

Expand Down Expand Up @@ -104,7 +104,7 @@ def on_mouse_event(self, event: sgl.MouseEvent):
if event.button == sgl.MouseButton.left:
self.mouse_down = False

def on_resize(self, width, height):
def on_resize(self, width: int, height: int):
self.framebuffers.clear()
self.device.wait()
self.swapchain.resize(width, height)
Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ requires = ["setuptools", "wheel", "typing-extensions"]
build-backend = "setuptools.build_meta"

[tool.pyright]
include = ["./src","./tools"]
include = ["./src","./tools","./examples"]
extraPaths = ["./src/sgl/device/tests"]
ignore = ["./tools/host","./tools/download","./src/sgl/device/tests/test_torch_interop.py"]
ignore = ["./tools/host","./tools/download","./src/sgl/device/tests/test_torch_interop.py","./examples/torch_interop/torch_interop.py"]
pythonVersion = "3.10"
typeCheckingMode = "basic"
reportUnusedImport = "error"
reportMissingSuperCall = "error"
reportInvalidStringEscapeSequence = "error"
reportMissingParameterType = "error"
reportMissingTypeArgument = "warning"
45 changes: 27 additions & 18 deletions src/sgl/core/tests/test_bitmap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Optional, Sequence
import pytest
from sgl import Bitmap, Struct
import numpy as np
import numpy.typing as npt

PIXEL_FORMAT_TO_CHANNELS = {
Bitmap.PixelFormat.y: 1,
Expand All @@ -27,7 +30,13 @@
}


def create_test_array(width, height, channels, dtype, type_range):
def create_test_array(
width: int,
height: int,
channels: int,
dtype: npt.DTypeLike,
type_range: tuple[float, float],
):
img = np.zeros((height, width, channels), dtype)
for i in range(height):
for j in range(width):
Expand All @@ -41,8 +50,8 @@ def create_test_array(width, height, channels, dtype, type_range):


def create_test_image(
width,
height,
width: int,
height: int,
pixel_format: Bitmap.PixelFormat,
component_type: Bitmap.ComponentType,
):
Expand All @@ -56,15 +65,15 @@ def create_test_image(


def write_read_test(
directory,
ext,
width,
height,
pixel_format,
component_type,
quality=None,
rtol=None,
atol=None,
directory: Path,
ext: str,
width: int,
height: int,
pixel_format: Bitmap.PixelFormat,
component_type: Bitmap.ComponentType,
quality: Optional[int] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
):
path = directory / f"test_{width}x{height}_{pixel_format}_{component_type}.{ext}"

Expand Down Expand Up @@ -148,7 +157,7 @@ def test_bitmap_vflip():


@pytest.mark.parametrize("layout", EXR_LAYOUTS)
def test_exr_io(tmp_path, layout):
def test_exr_io(tmp_path: Path, layout: Sequence[Any]):
extra = layout[4] if len(layout) > 4 else {}
write_read_test(
tmp_path, "exr", layout[0], layout[1], layout[2], layout[3], **extra
Expand All @@ -162,7 +171,7 @@ def test_exr_io(tmp_path, layout):


@pytest.mark.parametrize("layout", BMP_LAYOUTS)
def test_bmp_io(tmp_path, layout):
def test_bmp_io(tmp_path: Path, layout: Sequence[Any]):
extra = layout[4] if len(layout) > 4 else {}
write_read_test(
tmp_path, "bmp", layout[0], layout[1], layout[2], layout[3], **extra
Expand All @@ -177,7 +186,7 @@ def test_bmp_io(tmp_path, layout):


@pytest.mark.parametrize("layout", TGA_LAYOUTS)
def test_tga_io(tmp_path, layout):
def test_tga_io(tmp_path: Path, layout: Sequence[Any]):
extra = layout[4] if len(layout) > 4 else {}
write_read_test(
tmp_path, "tga", layout[0], layout[1], layout[2], layout[3], **extra
Expand All @@ -199,7 +208,7 @@ def test_tga_io(tmp_path, layout):


@pytest.mark.parametrize("layout", PNG_LAYOUTS)
def test_png_io(tmp_path, layout):
def test_png_io(tmp_path: Path, layout: Sequence[Any]):
extra = layout[4] if len(layout) > 4 else {}
write_read_test(
tmp_path, "png", layout[0], layout[1], layout[2], layout[3], **extra
Expand All @@ -220,7 +229,7 @@ def test_png_io(tmp_path, layout):


@pytest.mark.parametrize("layout", JPG_LAYOUTS)
def test_jpg_io(tmp_path, layout):
def test_jpg_io(tmp_path: Path, layout: Sequence[Any]):
extra = layout[4] if len(layout) > 4 else {}
write_read_test(
tmp_path, "jpg", layout[0], layout[1], layout[2], layout[3], **extra
Expand All @@ -233,7 +242,7 @@ def test_jpg_io(tmp_path, layout):


@pytest.mark.parametrize("layout", HDR_LAYOUTS)
def test_hdr_io(tmp_path, layout):
def test_hdr_io(tmp_path: Path, layout: Sequence[Any]):
extra = layout[4] if len(layout) > 4 else {}
write_read_test(
tmp_path, "hdr", layout[0], layout[1], layout[2], layout[3], **extra
Expand Down
Loading

0 comments on commit cd3c7b8

Please sign in to comment.