diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 0d074aa78..4c2a038f0 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -76,6 +76,7 @@ def create_app( from. cors_origins: CORS origins to allow. auto_reload: When `True`, auto-reloading of actions is enabled. + tracer_provider: Tracer provider to use for tracing. Returns: A new Sanic application ready to be run. diff --git a/tests/tracing/conftest.py b/tests/tracing/conftest.py new file mode 100644 index 000000000..ff6547da7 --- /dev/null +++ b/tests/tracing/conftest.py @@ -0,0 +1,74 @@ +import pathlib +import socket +from concurrent import futures +from typing import Generator, Optional + +import grpc +import opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc as trace_service +import pytest +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer +from opentelemetry.exporter.jaeger.thrift.gen.agent.Agent import emitBatch_args +from opentelemetry.exporter.jaeger.thrift.gen.jaeger.ttypes import Batch +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceRequest, + ExportTraceServiceResponse, +) +from opentelemetry.proto.trace.v1.trace_pb2 import ResourceSpans +from thrift.protocol.TCompactProtocol import TCompactProtocol +from thrift.transport.TTransport import TMemoryBuffer + + +TRACING_TESTS_FIXTURES_DIRECTORY = pathlib.Path(__file__).parent / "fixtures" + + +@pytest.fixture +def udp_server() -> Generator[socket.socket, None, None]: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(("localhost", 6832)) + yield sock + sock.close() + + +class CapturingTestSpanExporter(trace_service.TraceServiceServicer): + def __init__(self) -> None: + self.spans: Optional[RepeatedCompositeFieldContainer[ResourceSpans]] = None + + def Export( + self, request: ExportTraceServiceRequest, context: grpc.ServicerContext + ) -> ExportTraceServiceResponse: + self.spans = request.resource_spans + + return ExportTraceServiceResponse() + + +@pytest.fixture +def span_exporter() -> CapturingTestSpanExporter: + return CapturingTestSpanExporter() + + +@pytest.fixture +def grpc_server( + span_exporter: CapturingTestSpanExporter, +) -> Generator[grpc.Server, None, None]: + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + trace_service.add_TraceServiceServicer_to_server( # type: ignore + span_exporter, server + ) + + server.add_insecure_port("[::]:4317") + + server.start() + yield server + server.stop(None) + + +def deserialize_jaeger_batch(data: bytearray) -> Batch: + trans = TMemoryBuffer(data) + prot = TCompactProtocol(trans) + prot.readMessageBegin() + emitBatch = emitBatch_args() # type: ignore + emitBatch.read(prot) # type: ignore + prot.readMessageEnd() + + return emitBatch.batch diff --git a/tests/tracing/fixtures/jaeger_endpoints.yml b/tests/tracing/fixtures/jaeger_endpoints.yml new file mode 100644 index 000000000..93811949e --- /dev/null +++ b/tests/tracing/fixtures/jaeger_endpoints.yml @@ -0,0 +1,4 @@ +tracing: + type: jaeger + host: localhost + port: 6832 diff --git a/tests/tracing/fixtures/otlp_endpoints.yml b/tests/tracing/fixtures/otlp_endpoints.yml new file mode 100644 index 000000000..91df61402 --- /dev/null +++ b/tests/tracing/fixtures/otlp_endpoints.yml @@ -0,0 +1,4 @@ +tracing: + type: otlp + endpoint: localhost:4317 + insecure: true diff --git a/tests/tracing/test_config.py b/tests/tracing/test_config.py new file mode 100644 index 000000000..eea7bbffd --- /dev/null +++ b/tests/tracing/test_config.py @@ -0,0 +1,123 @@ +import socket +import threading + +import pytest +from rasa_sdk.tracing.endpoints import EndpointConfig + +from rasa_sdk.tracing import config +from rasa_sdk.tracing.config import JaegerTracerConfigurer +from tests.conftest import wait +from tests.tracing import conftest +from tests.tracing.conftest import ( + TRACING_TESTS_FIXTURES_DIRECTORY, + CapturingTestSpanExporter, +) + +UDP_BUFFER_SIZE = 2048 + + +def test_jaeger_config_correctly_extracted() -> None: + cfg = EndpointConfig( + host="hostname", + port=1234, + username="user", + password="password", + ) + + extracted = JaegerTracerConfigurer._extract_config(cfg) + + assert extracted["agent_host_name"] == cfg.kwargs["host"] + assert extracted["agent_port"] == cfg.kwargs["port"] + assert extracted["username"] == cfg.kwargs["username"] + assert extracted["password"] == cfg.kwargs["password"] + + +def test_jaeger_config_sets_defaults() -> None: + extracted = JaegerTracerConfigurer._extract_config(EndpointConfig()) + + assert extracted["agent_host_name"] == "localhost" + assert extracted["agent_port"] == 6831 + assert extracted["username"] is None + assert extracted["password"] is None + + +def test_get_tracer_provider_otlp_collector( + span_exporter: CapturingTestSpanExporter, + result_available_event: threading.Event, +) -> None: + endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "otlp_endpoints.yml") + + tracer_provider = config.get_tracer_provider(endpoints_file) + assert tracer_provider is not None + + tracer = tracer_provider.get_tracer("foo") + + with tracer.start_as_current_span("otlp_test_span"): + pass + + tracer_provider.force_flush() + + wait( + lambda: span_exporter.spans is not None, + result_available_event=result_available_event, + timeout_seconds=15, + ) + + spans = span_exporter.spans + + assert spans is not None + assert len(spans[0].scope_spans[0].spans) == 1 + assert spans[0].scope_spans[0].spans[0].name == "otlp_test_span" + + +@pytest.mark.skip(reason="Temporary disabled due to TLS timeout error") +def test_get_tracer_provider_tls_otlp_collector( + span_exporter: CapturingTestSpanExporter, + result_available_event: threading.Event, +) -> None: + endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "otlp_endpoints_tls.yml") + + tracer_provider = config.get_tracer_provider(endpoints_file) + assert tracer_provider is not None + + tracer = tracer_provider.get_tracer("foo") + + with tracer.start_as_current_span("otlp_test_span"): + pass + + tracer_provider.force_flush() + + wait( + lambda: span_exporter.spans is not None, + result_available_event=result_available_event, + timeout_seconds=15, + ) + + spans = span_exporter.spans + + assert spans is not None + assert len(spans[0].scope_spans[0].spans) == 1 + assert spans[0].scope_spans[0].spans[0].name == "otlp_test_span" + + +def test_get_tracer_provider_jaeger(udp_server: socket.socket) -> None: + endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "jaeger_endpoints.yml") + + tracer_provider = config.get_tracer_provider(endpoints_file) + assert tracer_provider is not None + + tracer = tracer_provider.get_tracer(__name__) + + with tracer.start_as_current_span("jaeger_test_span"): + pass + + tracer_provider.force_flush() + + message, _ = udp_server.recvfrom(UDP_BUFFER_SIZE) + + batch = conftest.deserialize_jaeger_batch(bytearray(message)) + + assert batch.process.serviceName == "rasa" + + assert len(batch.spans) == 1 + assert batch.spans[0].operationName == "jaeger_test_span"