From 8d3c64747f9a6d99b29558b6e612dee83db7cdbb Mon Sep 17 00:00:00 2001 From: Brian Quinlan Date: Wed, 6 Mar 2024 18:17:49 -0800 Subject: [PATCH] Add support for negotiating a subprotocol (#1150) --- pkgs/web_socket/CHANGELOG.md | 2 +- .../lib/src/browser_web_socket.dart | 19 ++++- pkgs/web_socket/lib/src/io_web_socket.dart | 32 ++++++++- pkgs/web_socket/lib/src/web_socket.dart | 15 ++++ .../example/client_test.dart | 3 + .../lib/src/connect_uri_tests.dart | 18 +++++ .../lib/src/protocol_server.dart | 47 +++++++++++++ .../lib/src/protocol_server_vm.dart | 12 ++++ .../lib/src/protocol_server_web.dart | 9 +++ .../lib/src/protocol_tests.dart | 70 +++++++++++++++++++ .../lib/web_socket_conformance_tests.dart | 4 ++ 11 files changed, 226 insertions(+), 5 deletions(-) create mode 100644 pkgs/web_socket_conformance_tests/lib/src/connect_uri_tests.dart create mode 100644 pkgs/web_socket_conformance_tests/lib/src/protocol_server.dart create mode 100644 pkgs/web_socket_conformance_tests/lib/src/protocol_server_vm.dart create mode 100644 pkgs/web_socket_conformance_tests/lib/src/protocol_server_web.dart create mode 100644 pkgs/web_socket_conformance_tests/lib/src/protocol_tests.dart diff --git a/pkgs/web_socket/CHANGELOG.md b/pkgs/web_socket/CHANGELOG.md index 3d138c0f6e..3bd731a51b 100644 --- a/pkgs/web_socket/CHANGELOG.md +++ b/pkgs/web_socket/CHANGELOG.md @@ -1,3 +1,3 @@ ## 0.1.0-wip -- Abstract interface definition. +- Basic functionality in place. diff --git a/pkgs/web_socket/lib/src/browser_web_socket.dart b/pkgs/web_socket/lib/src/browser_web_socket.dart index 069f2781f0..80135fdc3e 100644 --- a/pkgs/web_socket/lib/src/browser_web_socket.dart +++ b/pkgs/web_socket/lib/src/browser_web_socket.dart @@ -18,9 +18,23 @@ class BrowserWebSocket implements WebSocket { final web.WebSocket _webSocket; final _events = StreamController(); + /// Create a new WebSocket connection using the JavaScript WebSocket API. + /// + /// The URL supplied in [url] must use the scheme ws or wss. + /// + /// If provided, the [protocols] argument indicates that subprotocols that + /// the peer is able to select. See + /// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9). static Future connect(Uri url, {Iterable? protocols}) async { - final webSocket = web.WebSocket(url.toString())..binaryType = 'arraybuffer'; + if (!url.isScheme('ws') && !url.isScheme('wss')) { + throw ArgumentError.value( + url, 'url', 'only ws: and wss: schemes are supported'); + } + + final webSocket = web.WebSocket(url.toString(), + protocols?.map((e) => e.toJS).toList().toJS ?? JSArray()) + ..binaryType = 'arraybuffer'; final browserSocket = BrowserWebSocket._(webSocket); final webSocketConnected = Completer(); @@ -126,6 +140,9 @@ class BrowserWebSocket implements WebSocket { @override Stream get events => _events.stream; + + @override + String get protocol => _webSocket.protocol; } const connect = BrowserWebSocket.connect; diff --git a/pkgs/web_socket/lib/src/io_web_socket.dart b/pkgs/web_socket/lib/src/io_web_socket.dart index b0bc3c7ea5..d44a33ecfa 100644 --- a/pkgs/web_socket/lib/src/io_web_socket.dart +++ b/pkgs/web_socket/lib/src/io_web_socket.dart @@ -6,8 +6,8 @@ import 'dart:async'; import 'dart:io' as io; import 'dart:typed_data'; -import '../web_socket.dart'; import 'utils.dart'; +import 'web_socket.dart'; /// A `dart-io`-based [WebSocket] implementation. /// @@ -16,14 +16,37 @@ class IOWebSocket implements WebSocket { final io.WebSocket _webSocket; final _events = StreamController(); + /// Create a new WebSocket connection using dart:io WebSocket. + /// + /// The URL supplied in [url] must use the scheme ws or wss. + /// + /// If provided, the [protocols] argument indicates that subprotocols that + /// the peer is able to select. See + /// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9). static Future connect(Uri url, {Iterable? protocols}) async { + if (!url.isScheme('ws') && !url.isScheme('wss')) { + throw ArgumentError.value( + url, 'url', 'only ws: and wss: schemes are supported'); + } + + final io.WebSocket webSocket; try { - final webSocket = await io.WebSocket.connect(url.toString()); - return IOWebSocket._(webSocket); + webSocket = + await io.WebSocket.connect(url.toString(), protocols: protocols); } on io.WebSocketException catch (e) { throw WebSocketException(e.message); } + + if (webSocket.protocol != null && + !(protocols ?? []).contains(webSocket.protocol)) { + // dart:io WebSocket does not correctly validate the returned protocol. + // See https://github.com/dart-lang/sdk/issues/55106 + await webSocket.close(1002); // protocol error + throw WebSocketException( + 'unexpected protocol selected by peer: ${webSocket.protocol}'); + } + return IOWebSocket._(webSocket); } IOWebSocket._(this._webSocket) { @@ -90,6 +113,9 @@ class IOWebSocket implements WebSocket { @override Stream get events => _events.stream; + + @override + String get protocol => _webSocket.protocol ?? ''; } const connect = IOWebSocket.connect; diff --git a/pkgs/web_socket/lib/src/web_socket.dart b/pkgs/web_socket/lib/src/web_socket.dart index 945fc57aed..560a5eb04e 100644 --- a/pkgs/web_socket/lib/src/web_socket.dart +++ b/pkgs/web_socket/lib/src/web_socket.dart @@ -115,6 +115,13 @@ class WebSocketConnectionClosed extends WebSocketException { /// socket.sendText('Hello Dart WebSockets! 🎉'); /// } abstract interface class WebSocket { + /// Create a new WebSocket connection. + /// + /// The URL supplied in [url] must use the scheme ws or wss. + /// + /// If provided, the [protocols] argument indicates that subprotocols that + /// the peer is able to select. See + /// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9). static Future connect(Uri url, {Iterable? protocols}) => connector.connect(url, protocols: protocols); @@ -169,4 +176,12 @@ abstract interface class WebSocket { /// /// Errors will never appear in this [Stream]. Stream get events; + + /// The WebSocket subprotocol negotiated with the peer. + /// + /// Will be the empty string if no subprotocol was negotiated. + /// + /// See + /// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9). + String get protocol; } diff --git a/pkgs/web_socket_conformance_tests/example/client_test.dart b/pkgs/web_socket_conformance_tests/example/client_test.dart index ec3d01c17a..d08dad94c1 100644 --- a/pkgs/web_socket_conformance_tests/example/client_test.dart +++ b/pkgs/web_socket_conformance_tests/example/client_test.dart @@ -20,6 +20,9 @@ class MyWebSocketImplementation implements WebSocket { @override void sendText(String s) => throw UnimplementedError(); + + @override + String get protocol => throw UnimplementedError(); } void main() { diff --git a/pkgs/web_socket_conformance_tests/lib/src/connect_uri_tests.dart b/pkgs/web_socket_conformance_tests/lib/src/connect_uri_tests.dart new file mode 100644 index 0000000000..0caa9e6f8c --- /dev/null +++ b/pkgs/web_socket_conformance_tests/lib/src/connect_uri_tests.dart @@ -0,0 +1,18 @@ +// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file +// for details. All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +import 'package:test/test.dart'; +import 'package:web_socket/web_socket.dart'; + +/// Tests that the [WebSocket] rejects invalid connection URIs. +void testConnectUri( + Future Function(Uri uri, {Iterable? protocols}) + channelFactory) { + group('connect uri', () { + test('no protocol', () async { + await expectLater(() => channelFactory(Uri.https('www.example.com', '/')), + throwsA(isA())); + }); + }); +} diff --git a/pkgs/web_socket_conformance_tests/lib/src/protocol_server.dart b/pkgs/web_socket_conformance_tests/lib/src/protocol_server.dart new file mode 100644 index 0000000000..c0df5b6ea4 --- /dev/null +++ b/pkgs/web_socket_conformance_tests/lib/src/protocol_server.dart @@ -0,0 +1,47 @@ +// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file +// for details. All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +import 'dart:async'; +import 'dart:convert'; +import 'dart:io'; + +import 'package:crypto/crypto.dart'; +import 'package:stream_channel/stream_channel.dart'; + +const _webSocketGuid = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'; + +/// Starts an WebSocket server that responds with a scripted subprotocol. +void hybridMain(StreamChannel channel) async { + late final HttpServer server; + server = (await HttpServer.bind('localhost', 0)) + ..listen((request) async { + final serverProtocol = request.requestedUri.queryParameters['protocol']; + var key = request.headers.value('Sec-WebSocket-Key'); + var digest = sha1.convert('$key$_webSocketGuid'.codeUnits); + var accept = base64.encode(digest.bytes); + channel.sink.add(request.headers['Sec-WebSocket-Protocol']); + request.response + ..statusCode = HttpStatus.switchingProtocols + ..headers.add(HttpHeaders.connectionHeader, 'Upgrade') + ..headers.add(HttpHeaders.upgradeHeader, 'websocket') + ..headers.add('Sec-WebSocket-Accept', accept); + if (serverProtocol != null) { + request.response.headers.add('Sec-WebSocket-Protocol', serverProtocol); + } + request.response.contentLength = 0; + final socket = await request.response.detachSocket(); + final webSocket = WebSocket.fromUpgradedSocket(socket, + protocol: serverProtocol, serverSide: true); + webSocket.listen((e) async { + webSocket.add(e); + await webSocket.close(); + }); + }); + + channel.sink.add(server.port); + + await channel + .stream.first; // Any writes indicates that the server should exit. + unawaited(server.close()); +} diff --git a/pkgs/web_socket_conformance_tests/lib/src/protocol_server_vm.dart b/pkgs/web_socket_conformance_tests/lib/src/protocol_server_vm.dart new file mode 100644 index 0000000000..a31da9ec1e --- /dev/null +++ b/pkgs/web_socket_conformance_tests/lib/src/protocol_server_vm.dart @@ -0,0 +1,12 @@ +// Generated by generate_server_wrappers.dart. Do not edit. + +import 'package:stream_channel/stream_channel.dart'; + +import 'protocol_server.dart'; + +/// Starts the redirect test HTTP server in the same process. +Future> startServer() async { + final controller = StreamChannelController(sync: true); + hybridMain(controller.foreign); + return controller.local; +} diff --git a/pkgs/web_socket_conformance_tests/lib/src/protocol_server_web.dart b/pkgs/web_socket_conformance_tests/lib/src/protocol_server_web.dart new file mode 100644 index 0000000000..a752ed7ac2 --- /dev/null +++ b/pkgs/web_socket_conformance_tests/lib/src/protocol_server_web.dart @@ -0,0 +1,9 @@ +// Generated by generate_server_wrappers.dart. Do not edit. + +import 'package:stream_channel/stream_channel.dart'; +import 'package:test/test.dart'; + +/// Starts the redirect test HTTP server out-of-process. +Future> startServer() async => spawnHybridUri(Uri( + scheme: 'package', + path: 'web_socket_conformance_tests/src/protocol_server.dart')); diff --git a/pkgs/web_socket_conformance_tests/lib/src/protocol_tests.dart b/pkgs/web_socket_conformance_tests/lib/src/protocol_tests.dart new file mode 100644 index 0000000000..4af977fb7c --- /dev/null +++ b/pkgs/web_socket_conformance_tests/lib/src/protocol_tests.dart @@ -0,0 +1,70 @@ +// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file +// for details. All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +import 'package:async/async.dart'; +import 'package:stream_channel/stream_channel.dart'; +import 'package:test/test.dart'; +import 'package:web_socket/web_socket.dart'; + +import 'protocol_server_vm.dart' + if (dart.library.html) 'protocol_server_web.dart'; + +/// Tests that the [WebSocket] can correctly negotiate a subprotocol with the +/// peer. +/// +/// See +/// [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9). +void testProtocols( + Future Function(Uri uri, {Iterable? protocols}) + channelFactory) { + group('protocols', () { + late Uri uri; + late StreamChannel httpServerChannel; + late StreamQueue httpServerQueue; + + setUp(() async { + httpServerChannel = await startServer(); + httpServerQueue = StreamQueue(httpServerChannel.stream); + uri = Uri.parse('ws://localhost:${await httpServerQueue.next}'); + }); + tearDown(() => httpServerChannel.sink.add(null)); + + test('no protocol', () async { + final socket = await channelFactory(uri); + + expect(await httpServerQueue.next, null); + expect(socket.protocol, ''); + socket.sendText('Hello World!'); + }); + + test('single protocol', () async { + final socket = await channelFactory( + uri.replace(queryParameters: {'protocol': 'chat.example.com'}), + protocols: ['chat.example.com']); + + expect(await httpServerQueue.next, ['chat.example.com']); + expect(socket.protocol, 'chat.example.com'); + socket.sendText('Hello World!'); + }); + + test('mutiple protocols', () async { + final socket = await channelFactory( + uri.replace(queryParameters: {'protocol': 'text.example.com'}), + protocols: ['chat.example.com', 'text.example.com']); + + expect( + await httpServerQueue.next, ['chat.example.com, text.example.com']); + expect(socket.protocol, 'text.example.com'); + socket.sendText('Hello World!'); + }); + + test('protocol mismatch', () async { + await expectLater( + () => channelFactory( + uri.replace(queryParameters: {'protocol': 'example.example.com'}), + protocols: ['chat.example.com']), + throwsA(isA())); + }); + }); +} diff --git a/pkgs/web_socket_conformance_tests/lib/web_socket_conformance_tests.dart b/pkgs/web_socket_conformance_tests/lib/web_socket_conformance_tests.dart index 248fc3870a..9e6e011628 100644 --- a/pkgs/web_socket_conformance_tests/lib/web_socket_conformance_tests.dart +++ b/pkgs/web_socket_conformance_tests/lib/web_socket_conformance_tests.dart @@ -5,10 +5,12 @@ import 'package:web_socket/web_socket.dart'; import 'src/close_local_tests.dart'; import 'src/close_remote_tests.dart'; +import 'src/connect_uri_tests.dart'; import 'src/disconnect_after_upgrade_tests.dart'; import 'src/no_upgrade_tests.dart'; import 'src/payload_transfer_tests.dart'; import 'src/peer_protocol_errors_tests.dart'; +import 'src/protocol_tests.dart'; /// Runs the entire test suite against the given [WebSocket]. void testAll( @@ -16,8 +18,10 @@ void testAll( webSocketFactory) { testCloseLocal(webSocketFactory); testCloseRemote(webSocketFactory); + testConnectUri(webSocketFactory); testDisconnectAfterUpgrade(webSocketFactory); testNoUpgrade(webSocketFactory); testPayloadTransfer(webSocketFactory); testPeerProtocolErrors(webSocketFactory); + testProtocols(webSocketFactory); }