Skip to content

Commit

Permalink
Protocol: support tunneling the connection over a Stream. (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Aug 18, 2024
1 parent 10c2d26 commit db7af5a
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/Tmds.DBus.Protocol/ClientConnectionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public ClientConnectionOptions(string address)
_address = address;
}

protected ClientConnectionOptions()
protected internal ClientConnectionOptions()
{
_address = string.Empty;
}
Expand Down
9 changes: 9 additions & 0 deletions src/Tmds.DBus.Protocol/ClientSetupResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ public ClientSetupResult(string address)
ConnectionAddress = address ?? throw new ArgumentNullException(nameof(address));
}

public ClientSetupResult() :
this("")
{ }

public string ConnectionAddress { get; }

public object? TeardownToken { get; set; }
Expand All @@ -16,4 +20,9 @@ public ClientSetupResult(string address)
public string? MachineId { get; set; }

public bool SupportsFdPassing { get; set; }

// SupportsFdPassing and ConnectionAddress are ignored when this is set.
// The implementation assumes that it is safe to Dispose the Stream
// while there are on-going reads/writes, and that these on-going operations will be aborted.
public Stream? ConnectionStream { get; set; }
}
9 changes: 8 additions & 1 deletion src/Tmds.DBus.Protocol/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,14 @@ private async Task<DBusConnection> DoConnectAsync()
_setupResult = await _connectionOptions.SetupAsync(_connectCts.Token).ConfigureAwait(false);
connection = _connection = new DBusConnection(this, _setupResult.MachineId ?? DBusEnvironment.MachineId);

await connection.ConnectAsync(_setupResult.ConnectionAddress, _setupResult.UserId, _setupResult.SupportsFdPassing, _connectCts.Token).ConfigureAwait(false);
if (_setupResult.ConnectionStream is Stream stream)
{
await connection.ConnectAsync(stream, _setupResult.UserId, _connectCts.Token).ConfigureAwait(false);
}
else
{
await connection.ConnectAsync(_setupResult.ConnectionAddress, _setupResult.UserId, _setupResult.SupportsFdPassing, _connectCts.Token).ConfigureAwait(false);
}

lock (_gate)
{
Expand Down
34 changes: 34 additions & 0 deletions src/Tmds.DBus.Protocol/DBusConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,40 @@ public async ValueTask ConnectAsync(string address, string? userId, bool support
throw new ArgumentException("No addresses were found", nameof(address));
}

public async ValueTask ConnectAsync(Stream stream, string? userId, CancellationToken cancellationToken)
{
_state = ConnectionState.Connecting;

try
{
MessageStream messageStream = new MessageStream(stream);
_messageStream = messageStream;

await messageStream.DoClientAuthAsync(default(Guid), userId, supportsFdPassing: false).ConfigureAwait(false);

messageStream.ReceiveMessages(
static (Exception? exception, Message message, DBusConnection connection) =>
connection.HandleMessages(exception, message), this);

lock (_gate)
{
if (_state != ConnectionState.Connecting)
{
throw new DisconnectedException(DisconnectReason);
}
_state = ConnectionState.Connected;
}

_localName = await GetLocalNameAsync().ConfigureAwait(false);
}
catch
{
stream.Dispose();

throw;
}
}

private async Task<string?> GetLocalNameAsync()
{
MyValueTaskSource<string?> vts = new();
Expand Down
45 changes: 39 additions & 6 deletions src/Tmds.DBus.Protocol/MessageStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace Tmds.DBus.Protocol;
class MessageStream : IMessageStream
{
private static readonly ReadOnlyMemory<byte> OneByteArray = new[] { (byte)0 };
private readonly Socket _socket;
private readonly Socket? _socket;
private readonly Stream _stream;
private UnixFdCollection? _fdCollection;
private bool _supportsFdPassing;
private readonly MessagePool _messagePool;
Expand All @@ -26,8 +27,18 @@ class MessageStream : IMessageStream
private bool _isMonitor;

public MessageStream(Socket socket)
: this(socket, null)
{}

public MessageStream(Stream stream)
: this(null, stream)
{}

private MessageStream(Socket? socket, Stream? stream)
{
_socket = socket;
_stream = stream ?? new NetworkStream(_socket!, ownsSocket: true);

Channel<MessageBuffer> channel = Channel.CreateUnbounded<MessageBuffer>(new UnboundedChannelOptions
{
AllowSynchronousContinuations = true,
Expand Down Expand Up @@ -56,7 +67,16 @@ private async void ReadFromSocketIntoPipe()
while (true)
{
Memory<byte> memory = writer.GetMemory(1024);
int bytesRead = await _socket.ReceiveAsync(memory, _fdCollection).ConfigureAwait(false);
int bytesRead;
if (_socket is not null)
{
bytesRead = await _socket.ReceiveAsync(memory, _fdCollection).ConfigureAwait(false);
}
else
{
bytesRead = await _stream.ReadAsync(memory).ConfigureAwait(false);
}

if (bytesRead == 0)
{
throw new IOException("Connection closed by peer");
Expand Down Expand Up @@ -89,14 +109,14 @@ private async void ReadMessagesIntoSocket()
var buffer = message.AsReadOnlySequence();
if (buffer.IsSingleSegment)
{
await _socket.SendAsync(buffer.First, handles).ConfigureAwait(false);
await WriteAsync(buffer.First, handles).ConfigureAwait(false);
}
else
{
SequencePosition position = buffer.Start;
while (buffer.TryGet(ref position, out ReadOnlyMemory<byte> memory))
{
await _socket.SendAsync(memory, handles).ConfigureAwait(false);
await WriteAsync(memory, handles).ConfigureAwait(false);
handles = null;
}
}
Expand All @@ -113,6 +133,18 @@ private async void ReadMessagesIntoSocket()
}
}

private ValueTask WriteAsync(ReadOnlyMemory<byte> memory, UnixFdCollection? handles)
{
if (_socket is not null)
{
return _socket.SendAsync(memory, handles);
}
else
{
return _stream.WriteAsync(memory);
}
}

public async void ReceiveMessages<T>(IMessageStream.MessageReceivedHandler<T> handler, T state)
{
var reader = _pipeReader;
Expand Down Expand Up @@ -165,7 +197,7 @@ public async ValueTask DoClientAuthAsync(Guid guid, string? userId, bool support
ReadFromSocketIntoPipe();

// send 1 byte
await _socket.SendAsync(OneByteArray, SocketFlags.None).ConfigureAwait(false);
await _stream.WriteAsync(OneByteArray).ConfigureAwait(false);
// auth
var authenticationResult = await SendAuthCommandsAsync(userId, supportsFdPassing).ConfigureAwait(false);
_supportsFdPassing = authenticationResult.SupportsFdPassing;
Expand Down Expand Up @@ -334,7 +366,7 @@ private async ValueTask WriteAsync(string message, Memory<byte> lineBuffer)
{
int length = Encoding.ASCII.GetBytes(message.AsSpan(), lineBuffer.Span);
lineBuffer = lineBuffer.Slice(0, length);
await _socket.SendAsync(lineBuffer, SocketFlags.None).ConfigureAwait(false);
await _stream.WriteAsync(lineBuffer).ConfigureAwait(false);
}

private async ValueTask<int> ReadLineAsync(Memory<byte> lineBuffer)
Expand Down Expand Up @@ -396,6 +428,7 @@ private Exception CloseCore(Exception closeReason)
if (previous is null)
{
_socket?.Dispose();
_stream.Dispose();
_messageWriter.Complete();
}
return previous ?? closeReason;
Expand Down
86 changes: 86 additions & 0 deletions src/Tmds.DBus.Protocol/Polyfill/StreamExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#if NETSTANDARD2_0

namespace System.IO
{
internal static class StreamExtensions
{
public static ValueTask<int> ReadAsync(this Stream stream, Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
return new ValueTask<int>(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
return FinishReadAsync(stream.ReadAsync(sharedBuffer, 0, buffer.Length, cancellationToken), sharedBuffer, buffer);

static async ValueTask<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
{
try
{
int result = await readTask.ConfigureAwait(false);
new Span<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
return result;
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}

public static void Write(this Stream stream, ReadOnlyMemory<byte> buffer)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
stream.Write(array.Array, array.Offset, array.Count);
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
buffer.Span.CopyTo(sharedBuffer);
stream.Write(sharedBuffer, 0, buffer.Length);
}
finally
{
ArrayPool<byte>.Shared.Return(sharedBuffer);
}
}
}

public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
buffer.Span.CopyTo(sharedBuffer);
return new ValueTask(FinishWriteAsync(stream.WriteAsync(sharedBuffer, 0, buffer.Length, cancellationToken), sharedBuffer));
}
}

private static async Task FinishWriteAsync(Task writeTask, byte[] localBuffer)
{
try
{
await writeTask.ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}

#endif
10 changes: 8 additions & 2 deletions test/Tmds.DBus.Protocol.Tests/DBusDaemon.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ public void Dispose()
File.Delete(_configFile);
}
_state = State.Disposed;
_process?.Kill();
_process?.Dispose();

try
{
_process?.Kill();
_process?.Dispose();
}
catch
{ }
}

public Task StartAsync(DBusDaemonProtocol protocol = DBusDaemonProtocol.Default, string? socketPath = null)
Expand Down
53 changes: 53 additions & 0 deletions test/Tmds.DBus.Protocol.Tests/TransportTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System;
using System.IO;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Win32.SafeHandles;
using Xunit;
Expand Down Expand Up @@ -29,6 +32,44 @@ public async Task TransportAsync(DBusDaemonProtocol protocol)
}
}

[Fact]
public async Task ConnectionStream()
{
var tokenTcs = new TaskCompletionSource<object?>();
var token = new object();
using (var dbusDaemon = new DBusDaemon())
{
await dbusDaemon.StartAsync(DBusDaemonProtocol.Unix);
var address = dbusDaemon.Address!;

Socket socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
Assert.StartsWith("unix:path=", address);
string path = address.Substring(10);
path = path.Substring(0, path.IndexOf(',')); // strip ',guid=...'
await socket.ConnectAsync(new UnixDomainSocketEndPoint(path));

var connection = new Connection(new MyConnectionOptions
{
ConnectFunction = () => ValueTask.FromResult(
new ClientSetupResult()
{
TeardownToken = token,
ConnectionStream = new NetworkStream(socket, ownsSocket: true)
}),
DisposeAction = o => tokenTcs.SetResult(o)
});

await connection.ConnectAsync();
Assert.True(socket.Connected);
Assert.StartsWith(":", connection.UniqueName);

connection.Dispose();
Assert.False(socket.Connected); // The ConnectionStream was disposed
var disposeToken = await tokenTcs.Task;
Assert.Equal(token, disposeToken);
}
}

[Fact]
public async Task TryMultipleAddressesAsync()
{
Expand Down Expand Up @@ -143,5 +184,17 @@ private ValueTask ReceiveHandle(MethodContext context)
return default;
}
}

private class MyConnectionOptions : ClientConnectionOptions
{
public required Func<ValueTask<ClientSetupResult>> ConnectFunction { get; set; }
public required Action<object?> DisposeAction { get; set; }

protected internal override ValueTask<ClientSetupResult> SetupAsync(CancellationToken cancellationToken)
=> ConnectFunction();

protected internal override void Teardown(object? token)
=> DisposeAction(token);
}
}
}

0 comments on commit db7af5a

Please sign in to comment.