Skip to content

Commit

Permalink
task referencer (#18914)
Browse files Browse the repository at this point in the history
* put otherwise unheld tasks into the pit

* and the rest too

* cull less often

* task referencer

* private

* report unexpectedly unreferenced tasks

* also warn for automated catching in tests

* reporting task

* undo some

* fixup for 3.9

* ban asyncio.create_task

* oof

* simplify including using `Task.add_done_callback()` for culling
  • Loading branch information
altendky authored Jan 2, 2025
1 parent 48fa660 commit 1ffa73e
Show file tree
Hide file tree
Showing 41 changed files with 229 additions and 158 deletions.
7 changes: 4 additions & 3 deletions benchmarks/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from chia.types.spend_bundle import SpendBundle
from chia.util.batches import to_batches
from chia.util.ints import uint32, uint64
from chia.util.task_referencer import create_referenced_task

NUM_ITERS = 200
NUM_PEERS = 5
Expand Down Expand Up @@ -189,7 +190,7 @@ async def add_spend_bundles(spend_bundles: list[SpendBundle]) -> None:
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(large_spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(large_spend_bundles[peer])))
tasks.append(create_referenced_task(add_spend_bundles(large_spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
Expand All @@ -208,7 +209,7 @@ async def add_spend_bundles(spend_bundles: list[SpendBundle]) -> None:
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(spend_bundles[peer])))
tasks.append(create_referenced_task(add_spend_bundles(spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
Expand All @@ -221,7 +222,7 @@ async def add_spend_bundles(spend_bundles: list[SpendBundle]) -> None:
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(replacement_spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(replacement_spend_bundles[peer])))
tasks.append(create_referenced_task(add_spend_bundles(replacement_spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/data_layer/test_data_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64
from chia.util.keychain import bytes_to_mnemonic
from chia.util.task_referencer import create_referenced_task
from chia.util.timing import adjusted_timeout, backoff_times
from chia.wallet.trading.offer import Offer as TradingOffer
from chia.wallet.transaction_record import TransactionRecord
Expand Down Expand Up @@ -2191,7 +2192,7 @@ async def test_issue_15955_deadlock(
while time.monotonic() < end:
with anyio.fail_after(adjusted_timeout(timeout)):
await asyncio.gather(
*(asyncio.create_task(data_layer.get_value(store_id=store_id, key=key)) for _ in range(10))
*(create_referenced_task(data_layer.get_value(store_id=store_id, key=key)) for _ in range(10))
)


Expand Down
5 changes: 3 additions & 2 deletions chia/_tests/core/farmer/test_farmer_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from asyncio import Task, create_task, gather, sleep
from asyncio import Task, gather, sleep
from collections.abc import Coroutine
from typing import Any, Optional, TypeVar

Expand All @@ -20,13 +20,14 @@
from chia.server.outbound_message import Message, NodeType
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint32, uint64
from chia.util.task_referencer import create_referenced_task

T = TypeVar("T")


async def begin_task(coro: Coroutine[Any, Any, T]) -> Task[T]:
"""Awaitable function that adds a coroutine to the event loop and sets it running."""
task = create_task(coro)
task = create_referenced_task(coro)
await sleep(0)

return task
Expand Down
5 changes: 3 additions & 2 deletions chia/_tests/core/full_node/stores/test_block_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from chia.util.db_wrapper import get_host_parameter_limit
from chia.util.full_block_utils import GeneratorBlockInfo
from chia.util.ints import uint8, uint32, uint64
from chia.util.task_referencer import create_referenced_task

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -242,12 +243,12 @@ async def test_deadlock(tmp_dir: Path, db_version: int, bt: BlockTools, use_cach
rand_i = random.randint(0, 9)
if random.random() < 0.5:
tasks.append(
asyncio.create_task(
create_referenced_task(
store.add_full_block(blocks[rand_i].header_hash, blocks[rand_i], block_records[rand_i])
)
)
if random.random() < 0.5:
tasks.append(asyncio.create_task(store.get_full_block(blocks[rand_i].header_hash)))
tasks.append(create_referenced_task(store.get_full_block(blocks[rand_i].header_hash)))
await asyncio.gather(*tasks)


Expand Down
13 changes: 6 additions & 7 deletions chia/_tests/core/full_node/test_full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from chia.util.ints import uint8, uint16, uint32, uint64, uint128
from chia.util.limited_semaphore import LimitedSemaphore
from chia.util.recursive_replace import recursive_replace
from chia.util.task_referencer import create_referenced_task
from chia.util.vdf_prover import get_vdf_info_and_proof
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.wallet_spend_bundle import WalletSpendBundle
Expand Down Expand Up @@ -807,13 +808,13 @@ async def test_new_peak(self, wallet_nodes, self_hostname):
uint32(0),
block.reward_chain_block.get_unfinished().get_hash(),
)
task_1 = asyncio.create_task(full_node_1.new_peak(new_peak, dummy_peer))
task_1 = create_referenced_task(full_node_1.new_peak(new_peak, dummy_peer))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 1))
task_1.cancel()

await full_node_1.full_node.add_block(block, peer)
# Ignores, already have
task_2 = asyncio.create_task(full_node_1.new_peak(new_peak, dummy_peer))
task_2 = create_referenced_task(full_node_1.new_peak(new_peak, dummy_peer))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 0))
task_2.cancel()

Expand All @@ -829,8 +830,7 @@ async def suppress_value_error(coro: Coroutine) -> None:
uint32(0),
blocks_reorg[-2].reward_chain_block.get_unfinished().get_hash(),
)
# TODO: stop dropping tasks on the floor
asyncio.create_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer))) # noqa: RUF006
create_referenced_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer)))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 0))

# Does not ignore equal weight
Expand All @@ -841,8 +841,7 @@ async def suppress_value_error(coro: Coroutine) -> None:
uint32(0),
blocks_reorg[-1].reward_chain_block.get_unfinished().get_hash(),
)
# TODO: stop dropping tasks on the floor
asyncio.create_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer))) # noqa: RUF006
create_referenced_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer)))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 1))

@pytest.mark.anyio
Expand Down Expand Up @@ -1568,7 +1567,7 @@ async def test_double_blocks_same_pospace(self, wallet_nodes, self_hostname):
block_2 = recursive_replace(block_2, "foliage.foliage_transaction_block_signature", new_fbh_sig)
block_2 = recursive_replace(block_2, "transactions_generator", None)

rb_task = asyncio.create_task(full_node_2.full_node.add_block(block_2, dummy_peer))
rb_task = create_referenced_task(full_node_2.full_node.add_block(block_2, dummy_peer))

await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 1))
rb_task.cancel()
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/full_node/test_tx_processing_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from chia.full_node.tx_processing_queue import TransactionQueue, TransactionQueueFull
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.transaction_queue_entry import TransactionQueueEntry
from chia.util.task_referencer import create_referenced_task

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,7 +77,7 @@ async def test_one_peer_and_await(seeded_random: random.Random) -> None:
assert list_txs[i - 20] == resulting_txs[i]

# now we validate that the pop command is blocking
task = asyncio.create_task(transaction_queue.pop())
task = create_referenced_task(transaction_queue.pop())
with pytest.raises(asyncio.InvalidStateError): # task is not done, so we expect an error when getting result
task.result()
# add a tx to test task completion
Expand Down
5 changes: 3 additions & 2 deletions chia/_tests/core/server/flood.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time

from chia._tests.util.misc import create_logger
from chia.util.task_referencer import create_referenced_task

# TODO: CAMPid 0945094189459712842390t591
IP = "127.0.0.1"
Expand Down Expand Up @@ -62,15 +63,15 @@ async def dun() -> None:

task.cancel()

file_task = asyncio.create_task(dun())
file_task = create_referenced_task(dun())

with out_path.open(mode="w") as file:
logger = create_logger(file=file)

async def f() -> None:
await asyncio.gather(*[tcp_echo_client(task_counter=f"{i}", logger=logger) for i in range(0, NUM_CLIENTS)])

task = asyncio.create_task(f())
task = create_referenced_task(f())
try:
await task
except asyncio.CancelledError:
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/server/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chia._tests.util.misc import create_logger
from chia.server.chia_policy import ChiaPolicy
from chia.server.start_service import async_run
from chia.util.task_referencer import create_referenced_task

if sys.platform == "win32":
import _winapi
Expand Down Expand Up @@ -86,7 +87,7 @@ async def dun() -> None:

thread_end_event.set()

file_task = asyncio.create_task(dun())
file_task = create_referenced_task(dun())

loop = asyncio.get_event_loop()
server = await loop.create_server(functools.partial(EchoServer, logger=logger), ip, port)
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/server/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from chia._tests.core.server import serve
from chia._tests.util.misc import create_logger
from chia.server import chia_policy
from chia.util.task_referencer import create_referenced_task
from chia.util.timing import adjusted_timeout

here = pathlib.Path(__file__).parent
Expand Down Expand Up @@ -123,7 +124,7 @@ def _run(self) -> None:
asyncio.set_event_loop_policy(original_event_loop_policy)

async def main(self) -> None:
self.server_task = asyncio.create_task(
self.server_task = create_referenced_task(
serve.async_main(
out_path=self.out_path,
ip=self.ip,
Expand Down
13 changes: 7 additions & 6 deletions chia/_tests/db/test_db_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chia._tests.util.db_connection import DBConnection, PathDBConnection
from chia._tests.util.misc import Marks, boolean_datacases, datacases
from chia.util.db_wrapper import DBWrapper2, ForeignKeyError, InternalError, NestedForeignKeyDelayedRequestError
from chia.util.task_referencer import create_referenced_task

if TYPE_CHECKING:
ConnectionContextManager = contextlib.AbstractAsyncContextManager[aiosqlite.core.Connection]
Expand Down Expand Up @@ -119,7 +120,7 @@ async def test_concurrent_writers(acquire_outside: bool, get_reader_method: GetR

tasks = []
for index in range(concurrent_task_count):
task = asyncio.create_task(increment_counter(db_wrapper))
task = create_referenced_task(increment_counter(db_wrapper))
tasks.append(task)

await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
Expand Down Expand Up @@ -263,7 +264,7 @@ async def write() -> None:
async with get_reader() as reader:
assert await query_value(connection=reader) == 0

task = asyncio.create_task(write())
task = create_referenced_task(write())
await writer_committed.wait()

assert await query_value(connection=reader) == 0 if transactioned else 1
Expand Down Expand Up @@ -342,7 +343,7 @@ async def test_concurrent_readers(acquire_outside: bool, get_reader_method: GetR
tasks = []
values: list[int] = []
for index in range(concurrent_task_count):
task = asyncio.create_task(sum_counter(db_wrapper, values))
task = create_referenced_task(sum_counter(db_wrapper, values))
tasks.append(task)

await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
Expand Down Expand Up @@ -371,11 +372,11 @@ async def test_mixed_readers_writers(acquire_outside: bool, get_reader_method: G
tasks = []
values: list[int] = []
for index in range(concurrent_task_count):
task = asyncio.create_task(increment_counter(db_wrapper))
task = create_referenced_task(increment_counter(db_wrapper))
tasks.append(task)
task = asyncio.create_task(decrement_counter(db_wrapper))
task = create_referenced_task(decrement_counter(db_wrapper))
tasks.append(task)
task = asyncio.create_task(sum_counter(db_wrapper, values))
task = create_referenced_task(sum_counter(db_wrapper, values))
tasks.append(task)

await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
Expand Down
7 changes: 4 additions & 3 deletions chia/_tests/util/test_limited_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from chia.util.limited_semaphore import LimitedSemaphore, LimitedSemaphoreFullError
from chia.util.task_referencer import create_referenced_task


@pytest.mark.anyio
Expand All @@ -27,16 +28,16 @@ async def acquire(entered_event: Optional[asyncio.Event] = None) -> None:
waiting_events = [asyncio.Event() for _ in range(waiting_limit)]
failed_events = [asyncio.Event() for _ in range(beyond_limit)]

entered_tasks = [asyncio.create_task(acquire(entered_event=event)) for event in entered_events]
waiting_tasks = [asyncio.create_task(acquire(entered_event=event)) for event in waiting_events]
entered_tasks = [create_referenced_task(acquire(entered_event=event)) for event in entered_events]
waiting_tasks = [create_referenced_task(acquire(entered_event=event)) for event in waiting_events]

await asyncio.gather(*(event.wait() for event in entered_events))
assert all(event.is_set() for event in entered_events)
assert all(not event.is_set() for event in waiting_events)

assert semaphore._available_count == 0

failure_tasks = [asyncio.create_task(acquire()) for _ in range(beyond_limit)]
failure_tasks = [create_referenced_task(acquire()) for _ in range(beyond_limit)]

failure_results = await asyncio.gather(*failure_tasks, return_exceptions=True)
assert [str(error) for error in failure_results] == [str(LimitedSemaphoreFullError())] * beyond_limit
Expand Down
17 changes: 9 additions & 8 deletions chia/_tests/util/test_priority_mutex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from chia._tests.util.misc import Marks, datacases
from chia._tests.util.time_out_assert import time_out_assert_custom_interval
from chia.util.priority_mutex import NestedLockUnsupportedError, PriorityMutex
from chia.util.task_referencer import create_referenced_task
from chia.util.timing import adjusted_timeout

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,10 +66,10 @@ async def do_low(i: int) -> None:
log.warning(f"Spend {time.time() - t1} waiting for low {i}")
await kind_of_slow_func()

h = asyncio.create_task(do_high())
h = create_referenced_task(do_high())
l_tasks = []
for i in range(50):
l_tasks.append(asyncio.create_task(do_low(i)))
l_tasks.append(create_referenced_task(do_low(i)))

winner = None

Expand Down Expand Up @@ -334,13 +335,13 @@ async def queued_after() -> None:
async with mutex.acquire(priority=MutexPriority.high):
pass

block_task = asyncio.create_task(block())
block_task = create_referenced_task(block())
await blocker_acquired_event.wait()

cancel_task = asyncio.create_task(to_be_cancelled(mutex=mutex))
cancel_task = create_referenced_task(to_be_cancelled(mutex=mutex))
await wait_queued(mutex=mutex, task=cancel_task)

queued_after_task = asyncio.create_task(queued_after())
queued_after_task = create_referenced_task(queued_after())
await wait_queued(mutex=mutex, task=queued_after_task)

cancel_task.cancel()
Expand Down Expand Up @@ -441,7 +442,7 @@ async def create_acquire_tasks_in_controlled_order(
release_event = asyncio.Event()

for request in requests:
task = asyncio.create_task(request.acquire(mutex=mutex, wait_for=release_event))
task = create_referenced_task(request.acquire(mutex=mutex, wait_for=release_event))
tasks.append(task)
await wait_queued(mutex=mutex, task=task)

Expand All @@ -461,14 +462,14 @@ async def other_task_function() -> None:
await other_task_allow_release_event.wait()

async with mutex.acquire(priority=MutexPriority.high):
other_task = asyncio.create_task(other_task_function())
other_task = create_referenced_task(other_task_function())
await wait_queued(mutex=mutex, task=other_task)

async def another_task_function() -> None:
async with mutex.acquire(priority=MutexPriority.high):
pass

another_task = asyncio.create_task(another_task_function())
another_task = create_referenced_task(another_task_function())
await wait_queued(mutex=mutex, task=another_task)
other_task_allow_release_event.set()

Expand Down
7 changes: 3 additions & 4 deletions chia/daemon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from chia.util.ints import uint32
from chia.util.json_util import dict_to_json_str
from chia.util.task_referencer import create_referenced_task
from chia.util.ws_message import WsRpcMessage, create_payload_dict


Expand Down Expand Up @@ -67,8 +68,7 @@ async def listener_task() -> None:
finally:
await self.close()

# TODO: stop dropping tasks on the floor
asyncio.create_task(listener_task()) # noqa: RUF006
create_referenced_task(listener_task(), known_unreferenced=True)
await asyncio.sleep(1)

async def listener(self) -> None:
Expand All @@ -92,8 +92,7 @@ async def _get(self, request: WsRpcMessage) -> WsRpcMessage:
string = dict_to_json_str(request)
if self.websocket is None or self.websocket.closed:
raise Exception("Websocket is not connected")
# TODO: stop dropping tasks on the floor
asyncio.create_task(self.websocket.send_str(string)) # noqa: RUF006
create_referenced_task(self.websocket.send_str(string), known_unreferenced=True)
try:
await asyncio.wait_for(self._request_dict[request_id].wait(), timeout=30)
self._request_dict.pop(request_id)
Expand Down
Loading

0 comments on commit 1ffa73e

Please sign in to comment.