diff --git a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandler.kt b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandler.kt index de566cfb9c..b4d3e55452 100644 --- a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandler.kt +++ b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandler.kt @@ -21,7 +21,7 @@ import com.expediagroup.graphql.server.execution.subscription.GRAPHQL_WS_PROTOCO import com.expediagroup.graphql.server.execution.subscription.GraphQLWebSocketServer import com.expediagroup.graphql.server.types.GraphQLSubscriptionStatus import com.fasterxml.jackson.databind.ObjectMapper -import kotlinx.coroutines.reactive.awaitFirst +import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactor.flux import org.springframework.web.reactive.socket.CloseStatus import org.springframework.web.reactive.socket.WebSocketHandler @@ -53,7 +53,7 @@ class SubscriptionWebSocketHandler( ) override suspend fun closeSession(session: WebSocketSession, reason: GraphQLSubscriptionStatus) { - session.close(CloseStatus(reason.code, reason.reason)).awaitFirst() + session.close(CloseStatus(reason.code, reason.reason)).awaitFirstOrNull() } override suspend fun sendSubscriptionMessage(session: WebSocketSession, message: String): WebSocketMessage = diff --git a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandlerTest.kt b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandlerTest.kt index e8662f52b3..8ad5168ddf 100644 --- a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandlerTest.kt +++ b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SubscriptionWebSocketHandlerTest.kt @@ -17,10 +17,16 @@ package com.expediagroup.graphql.server.spring.subscriptions import com.expediagroup.graphql.server.execution.subscription.GRAPHQL_WS_PROTOCOL +import com.expediagroup.graphql.server.types.GraphQLSubscriptionStatus import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import io.mockk.every import io.mockk.mockk import org.junit.jupiter.api.Test import kotlin.test.assertEquals +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.assertDoesNotThrow +import org.springframework.web.reactive.socket.WebSocketSession +import reactor.core.publisher.Mono class SubscriptionWebSocketHandlerTest { @@ -35,4 +41,16 @@ class SubscriptionWebSocketHandlerTest { val handler = SubscriptionWebSocketHandler(mockk(), mockk(), mockk(), mockk(), 1_000, jacksonObjectMapper()) assertEquals(expected = listOf(GRAPHQL_WS_PROTOCOL), actual = handler.subProtocols) } + + @Test + fun `verify default subscription handler handles init timeout gracefully`() = runTest { + val handler = SubscriptionWebSocketHandler(mockk(), mockk(), mockk(), mockk(), 1_000, jacksonObjectMapper()) + val session = mockk() + + every { session.close(any()) } returns Mono.empty() + + assertDoesNotThrow { + handler.closeSession(session, GraphQLSubscriptionStatus.CONNECTION_INIT_TIMEOUT) + } + } }