Skip to content

Commit

Permalink
Add infrastructure for custom action server gRPC integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 25, 2024
1 parent 2247a58 commit 41513c5
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 38 deletions.
2 changes: 1 addition & 1 deletion proto/action_webhook.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ service ActionService {
message ActionsRequest {}

message ActionsResponse {
map<string, string> actions = 1;
repeated google.protobuf.Struct actions = 1;
}

message Tracker {
Expand Down
46 changes: 21 additions & 25 deletions rasa_sdk/grpc_py/action_webhook_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 2 additions & 9 deletions rasa_sdk/grpc_py/action_webhook_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,9 @@ class ActionsRequest(_message.Message):

class ActionsResponse(_message.Message):
__slots__ = ["actions"]
class ActionsEntry(_message.Message):
__slots__ = ["key", "value"]
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: str
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
ACTIONS_FIELD_NUMBER: _ClassVar[int]
actions: _containers.ScalarMap[str, str]
def __init__(self, actions: _Optional[_Mapping[str, str]] = ...) -> None: ...
actions: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct]
def __init__(self, actions: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ...

class Tracker(_message.Message):
__slots__ = ["sender_id", "slots", "latest_message", "events", "paused", "followup_action", "active_loop", "latest_action_name", "stack"]
Expand Down
15 changes: 12 additions & 3 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def __init__(
self.auto_reload = auto_reload
self.executor = executor

async def Actions(self, request: ActionsRequest, context) -> ActionsResponse:
async def Actions(
self,
request: ActionsRequest,
context: grpc.aio.ServicerContext,
) -> ActionsResponse:
"""Handle RPC request for the actions.
Args:
Expand All @@ -107,9 +111,14 @@ async def Actions(self, request: ActionsRequest, context) -> ActionsResponse:
if self.auto_reload:
self.executor.reload()

actions = self.executor.list_actions()
actions = [action.model_dump() for action in self.executor.list_actions()]
response = ActionsResponse()
return ParseDict(actions, response)
return ParseDict(
{
"actions": actions,
},
response,
)

async def Webhook(
self,
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

import rasa_sdk
from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction
from rasa_sdk.events import SlotSet
from rasa_sdk.executor import CollectingDispatcher
Expand All @@ -13,6 +14,7 @@


def get_stack():
"""Return a dialogue stack."""
dialogue_stack = [
{
"frame_id": "CP6JP9GQ",
Expand Down Expand Up @@ -147,3 +149,15 @@ def name(self):
class SubclassTestActionB(SubclassTestActionA):
def name(self):
return "subclass_test_action_b"


@pytest.fixture
def current_rasa_version() -> str:
"""Return current Rasa version."""
return rasa_sdk.__version__


@pytest.fixture
def previous_rasa_version() -> str:
"""Return previous Rasa version."""
return "1.0.0"
245 changes: 245 additions & 0 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from typing import Union, Any, Dict, Text, List
from unittest.mock import MagicMock, AsyncMock

import grpc
import pytest
from google.protobuf.json_format import MessageToDict, ParseDict

from rasa_sdk import ActionExecutionRejection
from rasa_sdk.executor import ActionName, ActionExecutor
from rasa_sdk.grpc_errors import (
ActionExecutionFailed,
ResourceNotFound,
ResourceNotFoundType,
)
from rasa_sdk.grpc_py import action_webhook_pb2
from rasa_sdk.grpc_server import GRPCActionServerWebhook
from rasa_sdk.interfaces import ActionMissingDomainException, ActionNotFoundException


@pytest.fixture
def sender_id() -> str:
"""Return sender id."""
return "test_sender_id"


@pytest.fixture
def action_name() -> str:
"""Return action name."""
return "action_listen"


@pytest.fixture
def grpc_webhook_request(
sender_id: str,
action_name: str,
current_rasa_version: str,
) -> action_webhook_pb2.WebhookRequest:
"""Create a webhook request."""
return action_webhook_pb2.WebhookRequest(
next_action=action_name,
sender_id=sender_id,
tracker=action_webhook_pb2.Tracker(
sender_id=sender_id,
slots={},
latest_message={},
events=[],
paused=False,
followup_action="",
active_loop={},
latest_action_name="",
stack={},
),
domain=action_webhook_pb2.Domain(
config={},
session_config={},
intents=[],
entities=[],
slots={},
responses={},
actions=[],
forms={},
e2e_actions=[],
),
version=current_rasa_version,
domain_digest="",
)


@pytest.fixture
def mock_executor() -> AsyncMock:
"""Create a mock action executor."""
return AsyncMock(spec=ActionExecutor)


@pytest.fixture
def mock_grpc_service_context() -> MagicMock:
"""Create a mock gRPC service context."""
return MagicMock(spec=grpc.aio.ServicerContext)


@pytest.fixture
def grpc_action_server_webhook(mock_executor: AsyncMock) -> GRPCActionServerWebhook:
"""Create a GRPCActionServerWebhook instance with a mock executor."""
return GRPCActionServerWebhook(executor=mock_executor)


@pytest.fixture
def executor_response() -> Dict[Text, Any]:
"""Create an executor response."""
return {
"events": [{"event": "slot", "name": "test", "value": "foo"}],
"responses": [{"utter": "Hi"}],
}


@pytest.fixture
def expected_grpc_webhook_response(
executor_response: Dict[Text, Any],
) -> action_webhook_pb2.WebhookResponse:
"""Create a gRPC webhook response."""
result = action_webhook_pb2.WebhookResponse()
return ParseDict(executor_response, result)


def action_names() -> List[ActionName]:
"""Create a list of action names."""
return [
ActionName(name="action_listen"),
ActionName(name="action_restart"),
ActionName(name="action_session_start"),
]


def expected_grpc_actions_response() -> action_webhook_pb2.ActionsResponse:
"""Create a gRPC actions response."""
actions = [action.model_dump() for action in action_names()]
result = action_webhook_pb2.ActionsResponse()
return ParseDict(
{
"actions": actions,
},
result,
)


@pytest.mark.parametrize(
"auto_reload, expected_reload_call_count", [(True, 1), (False, 0)]
)
async def test_grpc_action_server_webhook_no_errors(
auto_reload: bool,
expected_reload_call_count: int,
grpc_action_server_webhook: GRPCActionServerWebhook,
grpc_webhook_request: action_webhook_pb2.WebhookRequest,
mock_executor: AsyncMock,
mock_grpc_service_context: MagicMock,
executor_response: Dict[Text, Any],
expected_grpc_webhook_response: action_webhook_pb2.WebhookResponse,
):
"""Test that the gRPC action server webhook can handle a request without errors."""
grpc_action_server_webhook.auto_reload = auto_reload
mock_executor.run.return_value = executor_response
response = await grpc_action_server_webhook.Webhook(
grpc_webhook_request,
mock_grpc_service_context,
)

assert response == expected_grpc_webhook_response

mock_grpc_service_context.set_code.assert_not_called()
mock_grpc_service_context.set_details.assert_not_called()

assert mock_executor.reload.call_count == expected_reload_call_count

expected_action_call = MessageToDict(
grpc_webhook_request,
preserving_proto_field_name=True,
)
mock_executor.run.assert_called_once_with(expected_action_call)


@pytest.mark.parametrize(
"exception, expected_status_code, expected_body",
[
(
ActionExecutionRejection("action_name", "message"),
grpc.StatusCode.INTERNAL,
ActionExecutionFailed(
action_name="action_name", message="message"
).model_dump_json(),
),
(
ActionNotFoundException("action_name", "message"),
grpc.StatusCode.NOT_FOUND,
ResourceNotFound(
action_name="action_name",
message="message",
resource_type=ResourceNotFoundType.ACTION,
).model_dump_json(),
),
(
ActionMissingDomainException("action_name", "message"),
grpc.StatusCode.NOT_FOUND,
ResourceNotFound(
action_name="action_name",
message="message",
resource_type=ResourceNotFoundType.DOMAIN,
).model_dump_json(),
),
],
)
async def test_grpc_action_server_webhook_action_execution_rejected(
exception: Union[
ActionExecutionRejection, ActionNotFoundException, ActionMissingDomainException
],
expected_status_code: grpc.StatusCode,
expected_body: str,
grpc_action_server_webhook: GRPCActionServerWebhook,
grpc_webhook_request: action_webhook_pb2.WebhookRequest,
mock_executor: AsyncMock,
mock_grpc_service_context: MagicMock,
):
"""Test that the gRPC action server webhook can handle a request with an action execution rejection.""" # noqa: E501
mock_executor.run.side_effect = exception
response = await grpc_action_server_webhook.Webhook(
grpc_webhook_request,
mock_grpc_service_context,
)

assert response == action_webhook_pb2.WebhookResponse()

mock_grpc_service_context.set_code.assert_called_once_with(expected_status_code)
mock_grpc_service_context.set_details.assert_called_once_with(expected_body)


@pytest.mark.parametrize(
"given_action_names, expected_grpc_actions_response",
[
(
[],
action_webhook_pb2.ActionsResponse(),
),
(
action_names(),
expected_grpc_actions_response(),
),
],
)
async def test_grpc_action_server_actions(
given_action_names: List[ActionName],
expected_grpc_actions_response: action_webhook_pb2.ActionsResponse,
grpc_action_server_webhook: GRPCActionServerWebhook,
mock_grpc_service_context: MagicMock,
mock_executor: AsyncMock,
):
"""Test that the gRPC action server webhook can handle a request for actions."""
mock_executor.list_actions.return_value = given_action_names

response = await grpc_action_server_webhook.Actions(
action_webhook_pb2.ActionsRequest(), mock_grpc_service_context
)

assert response == expected_grpc_actions_response

mock_grpc_service_context.set_code.assert_not_called()
mock_grpc_service_context.set_details.assert_not_called()

0 comments on commit 41513c5

Please sign in to comment.