Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] AsyncSubtensor (Part 1) #2398

Merged
merged 12 commits into from
Nov 8, 2024
Merged
6 changes: 3 additions & 3 deletions bittensor/core/async_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

async def encode_params(
self,
call_definition: list["ParamWithTypes"],
call_definition: dict[str, list["ParamWithTypes"]],
params: Union[list[Any], dict[str, Any]],
) -> str:
"""Returns a hex encoded string of the params using their types."""
param_data = scalecodec.ScaleBytes(b"")

for i, param in enumerate(call_definition["params"]): # type: ignore
for i, param in enumerate(call_definition["params"]):
scale_obj = await self.substrate.create_scale_object(param["type"])
if isinstance(params, list):
param_data += scale_obj.encode(params[i])
Expand Down Expand Up @@ -440,7 +440,7 @@ async def query_runtime_api(

return_type = call_definition["type"]

as_scale_bytes = scalecodec.ScaleBytes(json_result["result"]) # type: ignore
as_scale_bytes = scalecodec.ScaleBytes(json_result["result"])

rpc_runtime_config = RuntimeConfiguration()
rpc_runtime_config.update_type_registry(load_type_registry_preset("legacy"))
Expand Down
8 changes: 4 additions & 4 deletions bittensor/core/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,15 @@ def add_args(cls, parser: "argparse.ArgumentParser", prefix: Optional[str] = Non
@networking.ensure_connected
def _encode_params(
self,
call_definition: list["ParamWithTypes"],
call_definition: dict[str, list["ParamWithTypes"]],
params: Union[list[Any], dict[str, Any]],
) -> str:
"""Returns a hex encoded string of the params using their types."""
param_data = scalecodec.ScaleBytes(b"")

for i, param in enumerate(call_definition["params"]): # type: ignore
for i, param in enumerate(call_definition["params"]):
scale_obj = self.substrate.create_scale_object(param["type"])
if type(params) is list:
if isinstance(params, list):
param_data += scale_obj.encode(params[i])
else:
if param["name"] not in params:
Expand Down Expand Up @@ -1232,7 +1232,7 @@ def get_subnet_hyperparameters(
else:
bytes_result = bytes.fromhex(hex_bytes_result)

return SubnetHyperparameters.from_vec_u8(bytes_result) # type: ignore
return SubnetHyperparameters.from_vec_u8(bytes_result)

# Community uses this method
# Returns network ImmunityPeriod hyper parameter.
Expand Down
293 changes: 293 additions & 0 deletions tests/unit_tests/test_async_subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
from pickle import FALSE

import pytest

from bittensor.core import async_subtensor


@pytest.fixture
def subtensor(mocker):
fake_async_substrate = mocker.AsyncMock(
autospec=async_subtensor.AsyncSubstrateInterface
)
mocker.patch.object(
async_subtensor, "AsyncSubstrateInterface", return_value=fake_async_substrate
)
return async_subtensor.AsyncSubtensor()


def test_decode_ss58_tuples_in_proposal_vote_data(mocker):
"""Tests that ProposalVoteData instance instantiation works properly,"""
# Preps
mocked_decode_account_id = mocker.patch.object(async_subtensor, "decode_account_id")
fake_proposal_dict = {
"index": "0",
"threshold": 1,
"ayes": ("0 line", "1 line"),
"nays": ("2 line", "3 line"),
"end": 123,
}

# Call
async_subtensor.ProposalVoteData(fake_proposal_dict)

# Asserts
assert mocked_decode_account_id.call_count == len(fake_proposal_dict["ayes"]) + len(
fake_proposal_dict["nays"]
)
assert mocked_decode_account_id.mock_calls == [
mocker.call("0"),
mocker.call("1"),
mocker.call("2"),
mocker.call("3"),
]


@pytest.mark.asyncio
async def test_encode_params(subtensor, mocker):
roman-opentensor marked this conversation as resolved.
Show resolved Hide resolved
"""Tests encode_params happy path."""
# Preps
subtensor.substrate.create_scale_object = mocker.AsyncMock(
autospec=async_subtensor.AsyncSubstrateInterface.create_scale_object
)
subtensor.substrate.create_scale_object.return_value.encode = mocker.Mock(
return_value=b""
)

call_definition = {
"params": [
{"name": "coldkey", "type": "Vec<u8>"},
{"name": "uid", "type": "u16"},
]
}
params = ["coldkey", "uid"]

# Call
decoded_params = await subtensor.encode_params(
call_definition=call_definition, params=params
)

# Asserts
subtensor.substrate.create_scale_object.call_args(
mocker.call("coldkey"),
mocker.call("Vec<u8>"),
mocker.call("uid"),
mocker.call("u16"),
)
assert decoded_params == "0x"


@pytest.mark.asyncio
async def test_encode_params_raises_error(subtensor, mocker):
"""Tests encode_params with raised error."""
# Preps
subtensor.substrate.create_scale_object = mocker.AsyncMock(
autospec=async_subtensor.AsyncSubstrateInterface.create_scale_object
)
subtensor.substrate.create_scale_object.return_value.encode = mocker.Mock(
return_value=b""
)

call_definition = {
"params": [
{"name": "coldkey", "type": "Vec<u8>"},
]
}
params = {"undefined param": "some value"}

# Call and assert
with pytest.raises(ValueError):
await subtensor.encode_params(call_definition=call_definition, params=params)

subtensor.substrate.create_scale_object.return_value.encode.assert_not_called()


@pytest.mark.asyncio
async def test_get_current_block(subtensor):
"""Tests get_current_block method."""
# Call
result = await subtensor.get_current_block()

# Asserts
subtensor.substrate.get_block_number.assert_called_once()
assert result == subtensor.substrate.get_block_number.return_value


@pytest.mark.asyncio
async def test_get_block_hash_without_block_id_aka_none(subtensor):
"""Tests get_block_hash method without passed block_id."""
# Call
result = await subtensor.get_block_hash()

# Asserts
assert result == subtensor.substrate.get_chain_head.return_value


@pytest.mark.asyncio
async def test_get_block_hash_with_block_id(subtensor):
"""Tests get_block_hash method with passed block_id."""
# Call
result = await subtensor.get_block_hash(block_id=1)

# Asserts
assert result == subtensor.substrate.get_block_hash.return_value


@pytest.mark.asyncio
async def test_is_hotkey_registered_any(subtensor, mocker):
"""Tests is_hotkey_registered_any method."""
# Preps
mocked_get_netuids_for_hotkey = mocker.AsyncMock(
return_value=[1, 2], autospec=subtensor.get_netuids_for_hotkey
)
subtensor.get_netuids_for_hotkey = mocked_get_netuids_for_hotkey

# Call
result = await subtensor.is_hotkey_registered_any(
hotkey_ss58="hotkey", block_hash="FAKE_HASH"
)

# Asserts
assert result is (len(mocked_get_netuids_for_hotkey.return_value) > 0)


@pytest.mark.asyncio
async def test_get_subnet_burn_cost(subtensor, mocker):
"""Tests get_subnet_burn_cost method."""
# Preps
mocked_query_runtime_api = mocker.AsyncMock(autospec=subtensor.query_runtime_api)
subtensor.query_runtime_api = mocked_query_runtime_api
fake_block_hash = None

# Call
result = await subtensor.get_subnet_burn_cost(block_hash=fake_block_hash)

# Assert
assert result == mocked_query_runtime_api.return_value
mocked_query_runtime_api.assert_called_once_with(
runtime_api="SubnetRegistrationRuntimeApi",
method="get_network_registration_cost",
params=[],
block_hash=fake_block_hash,
)


@pytest.mark.asyncio
async def test_get_total_subnets(subtensor, mocker):
"""Tests get_total_subnets method."""
# Preps
mocked_substrate_query = mocker.AsyncMock(
autospec=async_subtensor.AsyncSubstrateInterface.query
)
subtensor.substrate.query = mocked_substrate_query
fake_block_hash = None

# Call
result = await subtensor.get_total_subnets(block_hash=fake_block_hash)

# Assert
assert result == mocked_substrate_query.return_value
mocked_substrate_query.assert_called_once_with(
module="SubtensorModule",
storage_function="TotalNetworks",
params=[],
block_hash=fake_block_hash,
)


@pytest.mark.parametrize(
"records, response",
[([(0, True), (1, False), (3, False), (3, True)], [0, 3]), ([], [])],
ids=["with records", "empty-records"],
)
@pytest.mark.asyncio
async def test_get_subnets(subtensor, mocker, records, response):
"""Tests get_subnets method with any return."""
# Preps
fake_result = mocker.AsyncMock(autospec=list)
fake_result.records = records
fake_result.__aiter__.return_value = iter(records)

mocked_substrate_query_map = mocker.AsyncMock(
autospec=async_subtensor.AsyncSubstrateInterface.query_map,
return_value=fake_result,
)

subtensor.substrate.query_map = mocked_substrate_query_map
fake_block_hash = None

# Call
result = await subtensor.get_subnets(block_hash=fake_block_hash)

# Asserts
mocked_substrate_query_map.assert_called_once_with(
module="SubtensorModule",
storage_function="NetworksAdded",
block_hash=fake_block_hash,
reuse_block_hash=True,
)
assert result == response


@pytest.mark.parametrize(
"hotkey_ss58_in_result",
[True, False],
ids=["hotkey-exists", "hotkey-doesnt-exist"],
)
@pytest.mark.asyncio
async def test_is_hotkey_delegate(subtensor, mocker, hotkey_ss58_in_result):
"""Tests is_hotkey_delegate method with any return."""
# Preps
fake_hotkey_ss58 = "hotkey_58"
mocked_get_delegates = mocker.AsyncMock(
return_value=[
mocker.Mock(hotkey_ss58=fake_hotkey_ss58 if hotkey_ss58_in_result else "")
]
)
subtensor.get_delegates = mocked_get_delegates

# Call
result = await subtensor.is_hotkey_delegate(
hotkey_ss58=fake_hotkey_ss58, block_hash=None, reuse_block=True
)

# Asserts
assert result == hotkey_ss58_in_result
mocked_get_delegates.assert_called_once_with(block_hash=None, reuse_block=True)


@pytest.mark.parametrize(
"fake_hex_bytes_result, response", [(None, []), ("0xaabbccdd", b"\xaa\xbb\xcc\xdd")]
)
@pytest.mark.asyncio
async def test_get_delegates(subtensor, mocker, fake_hex_bytes_result, response):
"""Tests get_delegates method."""
# Preps
mocked_query_runtime_api = mocker.AsyncMock(
autospec=subtensor.query_runtime_api, return_value=fake_hex_bytes_result
)
subtensor.query_runtime_api = mocked_query_runtime_api
mocked_delegate_info_list_from_vec_u8 = mocker.Mock()
async_subtensor.DelegateInfo.list_from_vec_u8 = (
mocked_delegate_info_list_from_vec_u8
)

# Call
result = await subtensor.get_delegates(block_hash=None, reuse_block=True)

# Asserts
if fake_hex_bytes_result:
assert result == mocked_delegate_info_list_from_vec_u8.return_value
mocked_delegate_info_list_from_vec_u8.assert_called_once_with(
bytes.fromhex(fake_hex_bytes_result[2:])
)
else:
assert result == response

mocked_query_runtime_api.assert_called_once_with(
runtime_api="DelegateInfoRuntimeApi",
method="get_delegates",
params=[],
block_hash=None,
reuse_block=True,
)
Loading