Skip to content

Commit

Permalink
Use pycrdt instead of y-py
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Oct 26, 2023
1 parent b8f7390 commit d611f2b
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 62 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
dependencies = [
"anyio >=3.6.2,<5",
"aiosqlite >=0.18.0,<1",
"y-py >=0.6.0,<0.7.0",
"pycrdt >=0.3.4,<0.4.0",
"typing_extensions; python_version < '3.8'",
]

Expand Down
14 changes: 7 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import subprocess

import pytest
import y_py as Y
from pycrdt import Array, Doc
from websockets import serve # type: ignore

from ypy_websocket import WebsocketServer


class TestYDoc:
def __init__(self):
self.ydoc = Y.YDoc()
self.array = self.ydoc.get_array("array")
self.ydoc = Doc()
self.array = Array()
self.ydoc["array"] = self.array
self.state = None
self.value = 0

def update(self):
with self.ydoc.begin_transaction() as txn:
self.array.append(txn, self.value)
self.array.append(self.value)
self.value += 1
update = Y.encode_state_as_update(self.ydoc, self.state)
self.state = Y.encode_state_vector(self.ydoc)
update = self.ydoc.get_update(self.state)
self.state = self.ydoc.get_state()
return update


Expand Down
17 changes: 9 additions & 8 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import uvicorn
import y_py as Y
from anyio import create_task_group, sleep
from pycrdt import Doc, Map
from websockets import connect # type: ignore

from ypy_websocket import ASGIServer, WebsocketProvider, WebsocketServer
Expand All @@ -22,23 +22,24 @@ async def test_asgi(unused_tcp_port):

# clients
# client 1
ydoc1 = Y.YDoc()
ymap1 = ydoc1.get_map("map")
with ydoc1.begin_transaction() as t:
ymap1.set(t, "key", "value")
ydoc1 = Doc()
ymap1 = Map()
ydoc1["map"] = ymap1
ymap1["key"] = "value"
async with connect(
f"ws://localhost:{unused_tcp_port}/my-roomname"
) as websocket1, WebsocketProvider(ydoc1, websocket1):
await sleep(0.1)

# client 2
ydoc2 = Y.YDoc()
ydoc2 = Doc()
async with connect(
f"ws://localhost:{unused_tcp_port}/my-roomname"
) as websocket2, WebsocketProvider(ydoc2, websocket2):
await sleep(0.1)

ymap2 = ydoc2.get_map("map")
assert ymap2.to_json() == '{"key":"value"}'
ymap2 = Map()
ydoc2["map"] = ymap2
assert str(ymap2) == '{"key":"value"}'

tg.cancel_scope.cancel()
30 changes: 16 additions & 14 deletions tests/test_ypy_yjs.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
import json

import pytest
import y_py as Y
from anyio import Event, create_task_group, move_on_after, sleep
from pycrdt import Array, Doc, Map
from websockets import connect # type: ignore

from ypy_websocket import WebsocketProvider


class YTest:
def __init__(self, ydoc: Y.YDoc, timeout: float = 1.0):
def __init__(self, ydoc: Doc, timeout: float = 1.0):
self.ydoc = ydoc
self.timeout = timeout
self.ytest = ydoc.get_map("_test")
self.ytest = Map()
self.ydoc["_test"] = self.ytest
self.clock = -1.0

def run_clock(self):
self.clock = max(self.clock, 0.0)
with self.ydoc.begin_transaction() as t:
self.ytest.set(t, "clock", self.clock)
self.ytest["clock"] = self.clock

async def clock_run(self):
change = Event()

def callback(event):
if "clock" in event.keys:
clk = self.ytest["clock"]
clk = event.keys["clock"]["newValue"]
if clk > self.clock:
self.clock = clk + 1.0
change.set()
Expand All @@ -41,16 +41,16 @@ def callback(event):
@pytest.mark.anyio
@pytest.mark.parametrize("yjs_client", "0", indirect=True)
async def test_ypy_yjs_0(yws_server, yjs_client):
ydoc = Y.YDoc()
ydoc = Doc()
ytest = YTest(ydoc)
async with connect("ws://127.0.0.1:1234/my-roomname") as websocket, WebsocketProvider(
ydoc, websocket
):
ymap = ydoc.get_map("map")
ymap = Map()
ydoc["map"] = ymap
# set a value in "in"
for v_in in range(10):
with ydoc.begin_transaction() as t:
ymap.set(t, "in", float(v_in))
ymap["in"] = float(v_in)
ytest.run_clock()
await ytest.clock_run()
v_out = ymap["out"]
Expand All @@ -73,7 +73,9 @@ async def test_ypy_yjs_1(yws_server, yjs_client):
ytest = YTest(ydoc)
ytest.run_clock()
await ytest.clock_run()
ycells = ydoc.get_array("cells")
ystate = ydoc.get_map("state")
assert json.loads(ycells.to_json()) == [{"metadata": {"foo": "bar"}, "source": "1 + 2"}]
assert json.loads(ystate.to_json()) == {"state": {"dirty": False}}
ycells = Array()
ystate = Map()
ydoc["cells"] = ycells
ydoc["state"] = ystate
assert json.loads(str(ycells)) == [{"metadata": {"foo": "bar"}, "source": "1 + 2"}]
assert json.loads(str(ystate)) == {"state": {"dirty": False}}
16 changes: 8 additions & 8 deletions ypy_websocket/django_channels_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from logging import getLogger
from typing import TypedDict

import y_py as Y
from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore
from pycrdt import Doc

from .websocket import Websocket
from .yutils import YMessageType, process_sync_message, sync
Expand Down Expand Up @@ -79,7 +79,7 @@ class YjsConsumer(AsyncWebsocketConsumer):
A full example of a custom consumer showcasing all of these options is:
```py
import y_py as Y
from pycrdt import Doc
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from ypy_websocket.django_channels_consumer import YjsConsumer
Expand All @@ -91,10 +91,10 @@ def make_room_name(self) -> str:
# modify the room name here
return self.scope["url_route"]["kwargs"]["room"]
async def make_ydoc(self) -> Y.YDoc:
doc = Y.YDoc()
async def make_ydoc(self) -> Doc:
doc = Doc()
# fill doc with data from DB here
doc.observe_after_transaction(self.on_update_event)
doc.observe(self.on_update_event)
return doc
async def connect(self):
Expand All @@ -110,7 +110,7 @@ def on_update_event(self, event):
async def doc_update(self, update_wrapper):
update = update_wrapper["update"]
Y.apply_update(self.ydoc, update)
self.ydoc.apply_update(update)
await self.group_send_message(create_update_message(update))
Expand All @@ -137,7 +137,7 @@ def make_room_name(self) -> str:
"""
return self.scope["url_route"]["kwargs"]["room"]

async def make_ydoc(self) -> Y.YDoc:
async def make_ydoc(self) -> Doc:
"""Make the YDoc for a new channel.
Override to customize the YDoc when a channel is created
Expand All @@ -146,7 +146,7 @@ async def make_ydoc(self) -> Y.YDoc:
Returns:
The YDoc for a new channel. Defaults to a new empty YDoc.
"""
return Y.YDoc()
return Doc()

def _make_websocket_shim(self, path: str) -> _WebsocketShim:
return _WebsocketShim(path, self.group_send_message)
Expand Down
8 changes: 4 additions & 4 deletions ypy_websocket/websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial
from logging import Logger, getLogger

import y_py as Y
from anyio import (
TASK_STATUS_IGNORED,
Event,
Expand All @@ -13,6 +12,7 @@
)
from anyio.abc import TaskGroup, TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pycrdt import Doc

from .websocket import Websocket
from .yutils import (
Expand All @@ -27,14 +27,14 @@
class WebsocketProvider:
"""WebSocket provider."""

_ydoc: Y.YDoc
_ydoc: Doc
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_started: Event | None
_starting: bool
_task_group: TaskGroup | None

def __init__(self, ydoc: Y.YDoc, websocket: Websocket, log: Logger | None = None) -> None:
def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -> None:
"""Initialize the object.
The WebsocketProvider instance should preferably be used as an async context manager:
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, ydoc: Y.YDoc, websocket: Websocket, log: Logger | None = None
self._started = None
self._starting = False
self._task_group = None
ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream))
ydoc.observe(partial(put_updates, self._update_send_stream))

@property
def started(self) -> Event:
Expand Down
8 changes: 4 additions & 4 deletions ypy_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from logging import Logger, getLogger
from typing import Awaitable, Callable

import y_py as Y
from anyio import (
TASK_STATUS_IGNORED,
Event,
Expand All @@ -15,6 +14,7 @@
)
from anyio.abc import TaskGroup, TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pycrdt import Doc

from .awareness import Awareness
from .websocket import Websocket
Expand All @@ -31,7 +31,7 @@
class YRoom:

clients: list
ydoc: Y.YDoc
ydoc: Doc
ystore: BaseYStore | None
_on_message: Callable[[bytes], Awaitable[bool] | bool] | None
_update_send_stream: MemoryObjectSendStream
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
ystore: An optional store in which to persist document updates.
log: An optional logger.
"""
self.ydoc = Y.YDoc()
self.ydoc = Doc()
self.awareness = Awareness(self.ydoc)
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
Expand Down Expand Up @@ -101,7 +101,7 @@ def ready(self, value: bool) -> None:
value: True if the internal YDoc is ready to be synchronized, False otherwise."""
self._ready = value
if value:
self.ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream))
self.ydoc.observe(partial(put_updates, self._update_send_stream))

@property
def on_message(self) -> Callable[[bytes], Awaitable[bool] | bool] | None:
Expand Down
16 changes: 8 additions & 8 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

import aiosqlite
import anyio
import y_py as Y
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc

from .yutils import Decoder, get_new_path, write_var_uint

Expand Down Expand Up @@ -111,23 +111,23 @@ async def get_metadata(self) -> bytes:
metadata = cast(bytes, metadata)
return metadata

async def encode_state_as_update(self, ydoc: Y.YDoc) -> None:
async def encode_state_as_update(self, ydoc: Doc) -> None:
"""Store a YDoc state.
Arguments:
ydoc: The YDoc from which to store the state.
"""
update = Y.encode_state_as_update(ydoc) # type: ignore
update = ydoc.get_update()
await self.write(update)

async def apply_updates(self, ydoc: Y.YDoc) -> None:
async def apply_updates(self, ydoc: Doc) -> None:
"""Apply all stored updates to the YDoc.
Arguments:
ydoc: The YDoc on which to apply the updates.
"""
async for update, *rest in self.read(): # type: ignore
Y.apply_update(ydoc, update) # type: ignore
ydoc.apply_update(update)


class FileYStore(BaseYStore):
Expand Down Expand Up @@ -421,16 +421,16 @@ async def write(self, data: bytes) -> None:

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Y.YDoc()
ydoc = Doc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for update, in cursor:
Y.apply_update(ydoc, update)
ydoc.apply_update(update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = Y.encode_state_as_update(ydoc)
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
Expand Down
Loading

0 comments on commit d611f2b

Please sign in to comment.