Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-1652] Add gRPC support #1109

Merged
merged 19 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ install:
poetry run python -m pip install -U pip
poetry install


clean:
find . -name '*.pyc' -exec rm -f {} +
find . -name '*.pyo' -exec rm -f {} +
Expand Down Expand Up @@ -62,3 +61,13 @@ cleanup-generated-changelog:
release:
poetry run python scripts/release.py

download-protoc-compiler:
radovanZRasa marked this conversation as resolved.
Show resolved Hide resolved
curl -0L https://github.com/protocolbuffers/protobuf/releases/download/v25.0/protoc-25.0-osx-aarch_64.zip --output protoc-25.0-osx-aarch_64.zip

generate-grpc:
python -m grpc_tools.protoc \
-Irasa_sdk/grpc_py=./proto \
--python_out=. \
--grpc_python_out=. \
--pyi_out=. \
proto/action_webhook.proto
348 changes: 270 additions & 78 deletions poetry.lock

Large diffs are not rendered by default.

61 changes: 61 additions & 0 deletions proto/action_webhook.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
syntax = "proto3";

package action_server_webhook;
import "google/protobuf/struct.proto";

service ActionServerWebhook {
rpc webhook (WebhookRequest) returns (WebhookResponse);
}

message Tracker {
string sender_id = 1;
google.protobuf.Struct slots = 2;
google.protobuf.Struct latest_message = 3;
repeated google.protobuf.Struct events = 4;
bool paused = 5;
optional string followup_action = 6;
map<string, string> active_loop = 7;
optional string latest_action_name = 8;
repeated google.protobuf.Struct stack = 9;
}

message Intent {
string string_value = 1;
google.protobuf.Struct dict_value = 2;
}

message Entity {
string string_value = 1;
google.protobuf.Struct dict_value = 2;
}

message Action {
string string_value = 1;
google.protobuf.Struct dict_value = 2;
}

message Domain {
google.protobuf.Struct config = 1;
google.protobuf.Struct session_config = 2;
repeated Intent intents = 3;
repeated Entity entities = 4;
google.protobuf.Struct slots = 5;
google.protobuf.Struct responses = 6;
repeated Action actions = 7;
google.protobuf.Struct forms = 8;
repeated google.protobuf.Struct e2e_actions = 9;
}

message WebhookRequest {
string next_action = 1;
string sender_id = 2;
Tracker tracker = 3;
Domain domain = 4;
string version = 5;
optional string domain_digest = 6;
}

message WebhookResponse {
repeated google.protobuf.Struct events = 1;
repeated google.protobuf.Struct responses = 2;
}
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ opentelemetry-api = "~1.15.0"
opentelemetry-sdk = "~1.15.0"
opentelemetry-exporter-jaeger = "~1.15.0"
opentelemetry-exporter-otlp = "~1.15.0"
grpcio = "1.59.3"
protobuf = "4.25.3"
grpcio-tools = "1.56.2"
pydantic = "2.6.4"

[tool.poetry.dev-dependencies]
pytest-cov = "^4.1.0"
Expand All @@ -111,3 +115,4 @@ asyncio_mode = "auto"
[tool.poetry.group.dev.dependencies]
ruff = ">=0.0.256,<0.0.286"
pytest-asyncio = "^0.21.0"
types-protobuf = "4.25.0.20240417"
36 changes: 26 additions & 10 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
import asyncio

from rasa_sdk import utils
from rasa_sdk.endpoint import create_argument_parser, run
from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME
from rasa_sdk.grpc_server import run_grpc

logger = logging.getLogger(__name__)


def main_from_args(args):
Expand All @@ -18,16 +22,28 @@ def main_from_args(args):
)
utils.update_sanic_log_level()

run(
args.actions,
args.port,
args.cors,
args.ssl_certificate,
args.ssl_keyfile,
args.ssl_password,
args.auto_reload,
args.endpoints,
)
if args.grpc:
asyncio.run(
run_grpc(
args.actions,
args.port,
args.ssl_certificate,
args.ssl_keyfile,
args.ssl_password,
args.endpoints,
)
)
else:
run(
args.actions,
args.port,
args.cors,
args.ssl_certificate,
args.ssl_keyfile,
args.ssl_password,
args.auto_reload,
radovanZRasa marked this conversation as resolved.
Show resolved Hide resolved
args.endpoints,
)


def main():
Expand Down
5 changes: 5 additions & 0 deletions rasa_sdk/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,8 @@ def add_endpoint_arguments(parser):
default=DEFAULT_ENDPOINTS_PATH,
help="Configuration file for the assistant as a yml file.",
)
parser.add_argument(
"--grpc",
help="Starts grpc server instead of http",
action="store_true"
)
1 change: 1 addition & 0 deletions rasa_sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
"https://docs.python.org/3/library/logging.config.html#dictionary-schema-details"
)
DEFAULT_ENDPOINTS_PATH = "endpoints.yml"
NO_GRACE_PERIOD = 0
56 changes: 50 additions & 6 deletions rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def utter_message(

# deprecated
def utter_custom_message(self, *elements: Dict[Text, Any], **kwargs: Any) -> None:
"""Sends a message with custom elements to the output channel.

Deprecated:
Use `utter_message(elements=<list of elements>)` instead.

Args:
elements: List of elements to be sent to the output channel.
kwargs: Additional parameters to be sent to the output channel.
"""
warnings.warn(
"Use of `utter_custom_message` is deprecated. "
"Use `utter_message(elements=<list of elements>)` instead.",
Expand Down Expand Up @@ -166,6 +175,13 @@ def __init__(self) -> None:
self.domain_digest: Optional[Text] = None

def register_action(self, action: Union[Type[Action], Action]) -> None:
"""Register an action with the executor.

Args:
action: Action to be registered. It can either be an instance of
`Action` subclass class or an actual `Action` subclass.
"""

if inspect.isclass(action):
action = cast(Type[Action], action)
if action.__module__.startswith("rasa."):
Expand All @@ -189,7 +205,13 @@ def register_action(self, action: Union[Type[Action], Action]) -> None:
"a function, use `register_function` instead."
)

def register_function(self, name: Text, f: Callable) -> None:
def register_function(self, action_name: Text, f: Callable) -> None:
"""Register an executor function for an action.

Args:
action_name: Name of the action.
f: Function to be registered.
"""
valid_keys = utils.arguments_of(f)
if len(valid_keys) < 3:
raise Exception(
Expand All @@ -200,12 +222,12 @@ def register_function(self, name: Text, f: Callable) -> None:
"parameters."
)

if name in self.actions:
logger.info(f"Re-registered function for '{name}'.")
if action_name in self.actions:
logger.info(f"Re-registered function for '{action_name}'.")
else:
logger.info(f"Registered function for '{name}'.")
logger.info(f"Registered function for '{action_name}'.")

self.actions[name] = f
self.actions[action_name] = f

def _import_submodules(
self, package: Union[Text, types.ModuleType], recursive: bool = True
Expand Down Expand Up @@ -349,7 +371,20 @@ def _create_api_response(
return {"events": events, "responses": messages}

@staticmethod
def validate_events(events: List[Dict[Text, Any]], action_name: Text):
def validate_events(
events: List[Dict[Text, Any]],
action_name: Text,
) -> List[Dict[Text, Any]]:
"""Validate the events returned by the action.

Args:
events: List of events returned by the action.

action_name: Name of the action that should be executed.

Returns:
List of validated events.
"""
validated = []
for event in events:
if isinstance(event, dict):
Expand Down Expand Up @@ -425,6 +460,15 @@ def update_and_return_domain(
return self.domain

async def run(self, action_call: Dict[Text, Any]) -> Optional[Dict[Text, Any]]:
"""Run the action and return the response.

Args:
action_call: Request payload containing the action data.

Returns:
Response containing the events and messages or None if
the action does not exist.
"""
from rasa_sdk.interfaces import Tracker

action_name = action_call.get("next_action")
Expand Down
37 changes: 37 additions & 0 deletions rasa_sdk/grpc_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from pydantic import BaseModel, Field

from enum import Enum


class ActionExecutionFailed(BaseModel):
radovanZRasa marked this conversation as resolved.
Show resolved Hide resolved
"""Error which indicates that an action execution failed.

Attributes:
action_name: Name of the action that failed.
message: Message which describes the error.
"""

action_name: str = Field(alias="action_name")
message: str = Field(alias="message")


class ResourceNotFoundType(str, Enum):
"""Type of resource that was not found."""

ACTION = "ACTION"
DOMAIN = "DOMAIN"


class ResourceNotFound(BaseModel):
"""Error which indicates that a resource was not found.

Attributes:
action_name: Name of the action that was not found.
message: Message which describes the error.
"""

action_name: str = Field(alias="action_name")
message: str = Field(alias="message")
resource_type: ResourceNotFoundType = Field(alias="resource_type")
Empty file added rasa_sdk/grpc_py/__init__.py
Empty file.
45 changes: 45 additions & 0 deletions rasa_sdk/grpc_py/action_webhook_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading