Skip to content

Commit

Permalink
Merge branch 'main' into 706-stress-testing-script-for-timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored Nov 13, 2024
2 parents 6a4ab0a + e89c023 commit b592bff
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 102 deletions.
41 changes: 38 additions & 3 deletions app/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,45 @@ export default function App() {
// todo give to particles and bonds
const [colorMode, handleColorMode] = useColorMode();
const [hoveredId, setHoveredId] = useState<number>(-1);
// UPDATE THESE using `vis.config` on the Python side
const [roomConfig, setRoomConfig] = useState({
arrows: {},
scene: { floor: false, particle_size: 1, bond_size: 1 },
PathTracer: { enabled: false, environment: "none" },
arrows: {
colormap: [
[-0.5, 0.9, 0.5],
[0.0, 0.9, 0.5],
],
normalize: true,
colorrange: [0, 1.0],
scale_vector_thickness: false,
opacity: 1.0,
},
scene: {
fps: 30,
material: "MeshStandardMaterial",
particle_size: 1.0,
bond_size: 1.0,
animation_loop: false,
simulation_box: true,
vectorfield: true,
controls: "OrbitControls",
vectors: "",
vector_scale: 1.0,
selection_color: "#ffa500",
camera: "PerspectiveCamera",
camera_near: 0.1,
camera_far: 300,
frame_update: true,
crosshair: false,
floor: false,
},
PathTracer: {
enabled: false,
environment: "studio",
metalness: 0.7,
roughness: 0.2,
clearcoat: 0.0,
clearcoatRoughness: 0.0,
},
});

const [isAuthenticated, setIsAuthenticated] = useState<boolean>(true);
Expand Down
4 changes: 0 additions & 4 deletions app/src/components/progressbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ const ColoredTiles = ({
setStep: (step: number) => void;
tickInterval: number;
}) => {
useEffect(() => {
console.log("component rerendered");
}, [length, disabledFrames, setStep, tickInterval]);

const [disabledPositions, setdisabledPositions] = useState<number[]>([]);
const [ticks, setTicks] = useState<number[]>([]);

Expand Down
1 change: 1 addition & 0 deletions tests/test_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_run_selection(server, s22):
"""Test the server fixture."""
vis = ZnDraw(url=server, token="test_token")
vis.extend(s22)
vis.step = 0
vis.selection = [0]

run_queue(vis, "selection", {"ConnectedParticles": {}})
Expand Down
9 changes: 9 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import ase
import numpy.testing as npt
import pytest
Expand All @@ -17,6 +19,13 @@ def test_ase_converter(s22):
structures_json = znjson.dumps(
s22, cls=znjson.ZnEncoder.from_converters([ASEConverter])
)

non_json = json.loads(structures_json)
assert "numbers" not in non_json[0]["value"]["arrays"]
assert "positions" not in non_json[0]["value"]["arrays"]
assert "pbc" not in non_json[0]["value"]["info"]
assert "cell" not in non_json[0]["value"]["info"]

structures = znjson.loads(
structures_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])
)
Expand Down
9 changes: 4 additions & 5 deletions zndraw/analyse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from pydantic import BaseModel, Field
from pydantic import Field

from zndraw.base import Extension

try:
from zndraw.analyse import mda # noqa: F401
Expand All @@ -32,14 +34,11 @@ def _get_data_from_frames(key, frames: list[ase.Atoms]):
return data


class AnaylsisMethod(BaseModel):
class AnaylsisMethod(Extension):
@classmethod
def model_json_schema_from_atoms(cls, atoms: ase.Atoms) -> dict:
return cls.model_json_schema()

def run(self, vis):
raise NotImplementedError()


class DihedralAngle(AnaylsisMethod):
def run(self, vis):
Expand Down
7 changes: 6 additions & 1 deletion zndraw/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import dataclasses
import logging
import typing as t
from abc import abstractmethod
from collections.abc import MutableSequence

import numpy as np
import splines
from pydantic import BaseModel

if t.TYPE_CHECKING:
from zndraw import ZnDraw

log = logging.getLogger(__name__)


class Extension(BaseModel):
def run(self, vis, **kwargs) -> None:
@abstractmethod
def run(self, vis: "ZnDraw", **kwargs) -> None:
raise NotImplementedError("run method must be implemented in subclass")


Expand Down
1 change: 1 addition & 0 deletions zndraw/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def encode(self, obj: ase.Atoms) -> ASEDict:

# We don't want to send positions twice
arrays.pop("positions", None)
arrays.pop("numbers", None)

return ASEDict(
numbers=numbers,
Expand Down
21 changes: 10 additions & 11 deletions zndraw/modify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import enum
import logging
import time
Expand All @@ -8,7 +7,9 @@
import ase.constraints
import numpy as np
from ase.data import chemical_symbols
from pydantic import BaseModel, Field
from pydantic import Field

from zndraw.base import Extension

try:
from zndraw.modify import extras # noqa: F401
Expand All @@ -25,12 +26,7 @@
Symbols = enum.Enum("Symbols", {symbol: symbol for symbol in chemical_symbols})


class UpdateScene(BaseModel, abc.ABC):
@abc.abstractmethod
def run(self, vis: "ZnDraw", timeout: float, **kwargs) -> None:
"""Method called when running the modifier."""
pass

class UpdateScene(Extension):
def apply_selection(
self, atom_ids: list[int], atoms: ase.Atoms
) -> t.Tuple[ase.Atoms, ase.Atoms]:
Expand Down Expand Up @@ -199,7 +195,6 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:

class AddLineParticles(UpdateScene):
symbol: Symbols
steps: int = Field(10, le=100, ge=1)

def run(self, vis: "ZnDraw", **kwargs) -> None:
if len(vis) > vis.step + 1:
Expand All @@ -209,8 +204,12 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
for point in vis.points:
atoms += ase.Atom(self.symbol.name, position=point)

for _ in range(self.steps):
vis.append(atoms)
del atoms.arrays["colors"]
del atoms.arrays["radii"]
if hasattr(atoms, "connectivity"):
del atoms.connectivity

vis.append(atoms)


class Wrap(UpdateScene):
Expand Down
96 changes: 96 additions & 0 deletions zndraw/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging
import typing as t

import socketio.exceptions
import znsocket.exceptions

from zndraw.base import Extension

if t.TYPE_CHECKING:
from zndraw import ZnDraw


log = logging.getLogger(__name__)
TASK_RUNNING = "ZNDRAW TASK IS RUNNING"


def check_queue(vis: "ZnDraw") -> None:
"""Main loop to check and process modifier tasks for both private and public queues."""
while True:
if not vis._modifiers:
vis.socket.sleep(1)
continue
try:
process_modifier_queue(vis)
process_public_queue(vis)
vis.socket.sleep(1)
except (znsocket.exceptions.ZnSocketError, socketio.exceptions.SocketIOError):
log.warning("Connection to ZnDraw server lost. Reconnecting...")
vis.socket.disconnect()
vis.socket.sleep(1)


def process_modifier_queue(vis: "ZnDraw") -> None:
"""Process private modifier tasks in the queue."""
modifier_queue = znsocket.Dict(
r=vis.r,
socket=vis._refresh_client,
key=f"queue:{vis.token}:modifier",
)

for key in modifier_queue:
if key in vis._modifiers:
try:
task = modifier_queue.pop(key)
cls = vis._modifiers[key]["cls"]
run_kwargs = vis._modifiers[key]["run_kwargs"]
run_queued_task(vis, cls, task, modifier_queue, run_kwargs)
except IndexError:
pass


def process_public_queue(vis: "ZnDraw") -> None:
"""Process public modifier tasks in the public queue."""
if not any(mod["public"] for mod in vis._modifiers.values()):
return

public_queue = znsocket.Dict(
r=vis.r,
socket=vis._refresh_client,
key="queue:default:modifier",
)

for room, room_queue in public_queue.items():
for key in room_queue:
if key in vis._modifiers and vis._modifiers[key]["public"]:
new_vis = ZnDraw(url=vis.url, token=room, r=vis.r)
try:
task = room_queue.pop(key)
# run_queued_task(new_vis, key, task, room_queue)
cls = vis._modifiers[key]["cls"]
run_kwargs = vis._modifiers[key]["run_kwargs"]
run_queued_task(new_vis, cls, task, room_queue, run_kwargs)
except IndexError:
pass
finally:
new_vis.socket.sleep(1)
new_vis.socket.disconnect()


def run_queued_task(
vis: "ZnDraw",
cls: t.Type[Extension],
task: dict,
queue: znsocket.Dict,
run_kwargs: dict | None = None,
) -> None:
"""Run a specific task and handle exceptions."""
if not run_kwargs:
run_kwargs = {}
try:
queue[TASK_RUNNING] = True
cls(**task).run(vis, **run_kwargs)
except Exception as err:
vis.log(f"Error running `{cls}`: `{err}`")
finally:
queue.pop(TASK_RUNNING)
17 changes: 4 additions & 13 deletions zndraw/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
from zndraw.config import PathTracer, Scene
from zndraw.draw import geometries
from zndraw.modify import modifier
from zndraw.queue import run_queued_task
from zndraw.selection import selections
from zndraw.utils import load_plots_to_dict

log = logging.getLogger(__name__)

TASK_RUNNING = "ZNDRAW TASK IS RUNNING"


def _get_default_generator(file_io: FileIO) -> t.Iterable[ase.Atoms]:
return [ase.Atoms()]
Expand Down Expand Up @@ -363,9 +362,7 @@ def run_room_worker(room):
if key in selections:
try:
task = selection_queue.pop(key)
selection_queue[TASK_RUNNING] = True
selections[key](**task).run(vis)
selection_queue.pop(TASK_RUNNING)
run_queued_task(vis, selections[key], task, selection_queue)
except IndexError:
pass

Expand All @@ -380,11 +377,7 @@ def run_room_worker(room):
try:
task = analysis_queue.pop(key)
try:
analysis_queue[TASK_RUNNING] = True
analyses[key](**task).run(vis)
analysis_queue.pop(
TASK_RUNNING
) # TODO: does this cause an error when trying to stop on the GUI
run_queued_task(vis, analyses[key], task, analysis_queue)
except Exception as err:
vis.log(f"Error running analysis `{key}`: {err}")
except IndexError:
Expand Down Expand Up @@ -415,9 +408,7 @@ def run_room_worker(room):
try:
task = modifier_queue.pop(key)
try:
modifier_queue[TASK_RUNNING] = True
modifier[key](**task).run(vis)
modifier_queue.pop(TASK_RUNNING)
run_queued_task(vis, modifier[key], task, modifier_queue)
except Exception as err:
vis.log(f"Error running modifier `{key}`: {err}")
except IndexError:
Expand Down
Loading

0 comments on commit b592bff

Please sign in to comment.