Skip to content

Commit

Permalink
move to zndraw.ZnDraw in seperate process
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Oct 3, 2023
1 parent 77b6614 commit 0e799d4
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 125 deletions.
141 changes: 35 additions & 106 deletions zndraw/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import ase
import numpy as np
import tqdm
from flask import Flask, render_template, session
from flask import Flask, render_template, session, request
from flask_socketio import SocketIO, emit

from zndraw.data import atoms_from_json, atoms_to_json
from zndraw.draw import Geometry
from zndraw.select import get_selection_class
from zndraw.settings import GlobalConfig
from zndraw.zndraw import ZnDraw
from zndraw.zndraw import ZnDraw, FileIO

app = Flask(__name__)
app.config["SECRET_KEY"] = str(uuid.uuid4())
Expand All @@ -27,14 +27,27 @@
@app.route("/")
def index():
"""Render the main ZnDraw page."""
session["uuid"] = str(uuid.uuid4())
if "upgrade_insecure_requests" in app.config:
return render_template(
"index.html",
upgrade_insecure_requests=app.config["upgrade_insecure_requests"],
uuid=session["uuid"],
if "uuid" not in session:
session["uuid"] = str(uuid.uuid4())
proc = mp.Process(
target=ZnDraw,
kwargs={
"url": request.url_root,
"file": FileIO(
name=app.config["filename"],
start=app.config["start"],
stop=app.config["stop"],
step=app.config["step"],
),
},
)
return render_template("index.html")
proc.start()

return render_template(
"index.html",
upgrade_insecure_requests=app.config["upgrade_insecure_requests"],
uuid=session["uuid"],
)


@app.route("/exit")
Expand All @@ -45,50 +58,10 @@ def exit_route():
return "Server shutting down..."


def _read_file(filename, start, stop, step, compute_bonds, url=None):
if url is None:
if compute_bonds:
instance = ZnDraw(socket=io, display_new=False)
else:
instance = ZnDraw(socket=io, display_new=False, bonds_calculator=None)
else:
if compute_bonds:
instance = ZnDraw(url=url, display_new=False)
else:
instance = ZnDraw(url=url, display_new=False, bonds_calculator=None)

instance.read(filename, start, stop, step)


@io.on("atoms:request")
def atoms_request(url):
"""Return the atoms."""

if "filename" in app.config:
if app.config["multiprocessing"]:
proc = mp.Process(
target=_read_file,
args=(
app.config["filename"],
app.config["start"],
app.config["stop"],
app.config["step"],
app.config["compute_bonds"],
url,
),
)
proc.start()
else:
io.start_background_task(
target=_read_file,
filename=app.config["filename"],
start=app.config["start"],
stop=app.config["stop"],
step=app.config["step"],
compute_bonds=app.config["compute_bonds"],
)
else:
emit("atoms:upload", {0: atoms_to_json(ase.Atoms())})
emit("atoms:request", url, broadcast=True, include_self=False)


@io.on("modifier:schema")
Expand All @@ -114,39 +87,7 @@ def modifier_schema():

@io.on("modifier:run")
def modifier_run(data):
import ase

points = np.array([[val["x"], val["y"], val["z"]] for val in data["points"]])
segments = np.array(data["segments"])

if "atoms" in data:
atoms = atoms_from_json(data["atoms"])
else:
atoms = ase.Atoms()

module_name, function_name = data["name"].rsplit(".", 1)
module = importlib.import_module(module_name)
modifier_cls = getattr(module, function_name)
modifier = modifier_cls(**data["params"])
# available_methods = {x.__name__: x for x in [Explode, Duplicate]}

# modifier = available_methods[data["name"]](**data["params"])
print(f"modifier:run {modifier = }")
atoms_list = modifier.run(
atom_ids=data["selection"],
atoms=atoms,
points=points,
segments=segments,
json_data=data["atoms"] if "atoms" in data else None,
url=data["url"],
)
io.emit("atoms:clear", int(data["step"]) + 1)
for idx, atoms in tqdm.tqdm(enumerate(atoms_list)):
atoms_dict = atoms_to_json(atoms)
io.emit("atoms:upload", {idx + 1 + int(data["step"]): atoms_dict})

io.emit("view:set", int(data["step"]) + 1)
io.emit("view:play")
emit("modifier:run", data, broadcast=True, include_self=False)


@io.on("analysis:schema")
Expand Down Expand Up @@ -175,34 +116,13 @@ def selection_schema():

@io.on("selection:run")
def selection_run(data):
import ase

if "atoms" in data:
atoms = atoms_from_json(data["atoms"])
else:
atoms = ase.Atoms()

try:
selection = get_selection_class()(**data["params"])
selected_ids = selection.get_ids(atoms, data["selection"])
io.emit("selection:run", selected_ids)
except ValueError as err:
print(err)
emit("selection:run", data, broadcast=True, include_self=False)


@io.on("analysis:run")
def analysis_run(data):
atoms_list = [atoms_from_json(x) for x in data["atoms_list"].values()]
emit("analysis:run", data, broadcast=True, include_self=False)

print(f"Analysing {len(atoms_list)} frames")

module_name, function_name = data["name"].rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, function_name)
instance = cls(**data["params"])

fig = instance.run(atoms_list, data["selection"])
return fig.to_json()


@io.on("config")
Expand Down Expand Up @@ -363,6 +283,15 @@ def selection_get(data):
emit("selection:get", data, broadcast=True, include_self=False)


@io.on("selection:set")
def selection_get(data):
emit("selection:set", data, broadcast=True, include_self=False)


@io.on("draw:get_line")
def draw_points(data):
emit("draw:get_line", data, broadcast=True, include_self=False)

@io.on("analysis:figure")
def analysis_figure(data):
emit("analysis:figure", data, broadcast=True, include_self=False)
34 changes: 17 additions & 17 deletions zndraw/static/UI/json_editor.js
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ function analysis_editor(socket, cache, world) {
console.log(new Date().toISOString(), "running analysis");
const value = editor.getValue();

socket.on("analysis:figure", (data) => {
Plotly.newPlot("analysisPlot", JSON.parse(data));

function buildPlot() {
Plotly.newPlot("analysisPlot", JSON.parse(data));
const myplot = document.getElementById("analysisPlot");
myplot.on("plotly_click", (data) => {
const point = data.points[0];
const step = point.x;
world.setStep(step);
});
}

buildPlot();
});

socket.emit(
"analysis:run",
{
Expand All @@ -123,23 +139,7 @@ function analysis_editor(socket, cache, world) {
atoms: cache.get(world.getStep()),
selection: world.getSelection(),
step: world.getStep(),
atoms_list: cache.getAllAtoms(),
},
(data) => {
Plotly.newPlot("analysisPlot", JSON.parse(data));

function buildPlot() {
Plotly.newPlot("analysisPlot", JSON.parse(data));
const myplot = document.getElementById("analysisPlot");
myplot.on("plotly_click", (data) => {
const point = data.points[0];
const step = point.x;
world.setStep(step);
});
}

buildPlot();
},
}
);

document.getElementById("analysis-json-editor-submit").disabled = true;
Expand Down
2 changes: 1 addition & 1 deletion zndraw/static/World/systems/select.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Selection {
return particlesGroup.get_center();
};

this.socket.on("selection:run", (data) => {
this.socket.on("selection:set", (data) => {
const particlesGroup = this.scene.getObjectByName("particlesGroup");
particlesGroup.selection = data;
particlesGroup.step();
Expand Down
85 changes: 84 additions & 1 deletion zndraw/zndraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from zndraw.bonds import ASEComputeBonds
from zndraw.data import atoms_from_json, atoms_to_json
from zndraw.utils import ZnDrawLoggingHandler, get_port
from zndraw.select import get_selection_class


def _await_answer(socket, channel, data=None, timeout=5):
Expand All @@ -41,6 +42,14 @@ def on_answer(data):
return answer


@dataclasses.dataclass
class FileIO:
name: str
start: int = 0
stop: int = None
step: int = 1


@dataclasses.dataclass
class ZnDraw(collections.abc.MutableSequence):
url: str = None
Expand All @@ -49,6 +58,7 @@ class ZnDraw(collections.abc.MutableSequence):
bonds_calculator: ASEComputeBonds = dataclasses.field(
default_factory=ASEComputeBonds
)
file: FileIO = None

display_new: bool = True
_retries: int = 5
Expand Down Expand Up @@ -85,6 +95,16 @@ def __post_init__(self):
"connect", lambda: print(f"Connected to ZnDraw server at {self.url}")
)

self.socket.on(
"atoms:request",
lambda url: self.read(
self.file.name, self.file.start, self.file.stop, self.file.step
),
)
self.socket.on("modifier:run", self._run_modifier)
self.socket.on("selection:run", self._run_selection)
self.socket.on("analysis:run", self._run_analysis)

self.socket.on("disconnect", lambda: self.disconnect())

for _ in range(self._retries):
Expand Down Expand Up @@ -206,7 +226,7 @@ def log(self, message: str) -> None:
def get_logging_handler(self) -> ZnDrawLoggingHandler:
return ZnDrawLoggingHandler(self.socket)

def read(self, filename: str, start: int, stop: int, step: int):
def read(self, filename: str, start: int = 0, stop: int = None, step: int = 1):
"""Read atoms from file and return a list of atoms dicts.
Parameters
Expand Down Expand Up @@ -245,3 +265,66 @@ def get_line(self) -> tuple[np.ndarray, np.ndarray]:
segments = np.array(data["segments"])

return points, segments

def set_selection(self, selection: list[int]) -> None:
"""Set the selected atoms"""
self.socket.emit("selection:set", selection)

def _run_modifier(self, data):
import ase, importlib

points = np.array([[val["x"], val["y"], val["z"]] for val in data["points"]])
segments = np.array(data["segments"])

if "atoms" in data:
atoms = atoms_from_json(data["atoms"])
else:
atoms = ase.Atoms()

module_name, function_name = data["name"].rsplit(".", 1)
module = importlib.import_module(module_name)
modifier_cls = getattr(module, function_name)
modifier = modifier_cls(**data["params"])
# available_methods = {x.__name__: x for x in [Explode, Duplicate]}

# modifier = available_methods[data["name"]](**data["params"])
print(f"modifier:run {modifier = }")
atoms_list = modifier.run(
atom_ids=data["selection"],
atoms=atoms,
points=points,
segments=segments,
json_data=data["atoms"] if "atoms" in data else None,
url=data["url"],
)
del self[data["step"] :]
self.extend(atoms_list)

def _run_selection(self, data):
import ase

if "atoms" in data:
atoms = atoms_from_json(data["atoms"])
else:
atoms = ase.Atoms()

try:
selection = get_selection_class()(**data["params"])
selected_ids = selection.get_ids(atoms, data["selection"])
self.set_selection(selected_ids)
except ValueError as err:
print(err)

def _run_analysis(self, data):
import importlib
atoms_list = list(self)

print(f"Analysing {len(atoms_list)} frames")

module_name, function_name = data["name"].rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, function_name)
instance = cls(**data["params"])

fig = instance.run(atoms_list, data["selection"])
self.socket.emit("analysis:figure", fig.to_json())

0 comments on commit 0e799d4

Please sign in to comment.