Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow with-style streaming APIs for communicating with a running program #71

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/Citadel/Errors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ public enum CitadelError: Error {
case channelFailure
}

public struct AuthenticationFailed: Error, Equatable {}
public struct AuthenticationFailed: Error, Equatable {}
154 changes: 141 additions & 13 deletions Sources/Citadel/TTY/Client/TTY.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,55 @@ public enum ExecCommandOutput {
case stderr(ByteBuffer)
}

struct EmptySequence<Element>: Sendable, AsyncSequence {
struct AsyncIterator: AsyncIteratorProtocol {
func next() async throws -> Element? {
nil
}
}

func makeAsyncIterator() -> AsyncIterator {
AsyncIterator()
}
}

@available(macOS 15.0, *)
public struct TTYOutput: AsyncSequence {
internal let sequence: AsyncThrowingStream<ExecCommandOutput, Error>
public typealias Element = ExecCommandOutput

public struct AsyncIterator: AsyncIteratorProtocol {
public typealias Element = ExecCommandOutput
var iterator: AsyncThrowingStream<ExecCommandOutput, Error>.AsyncIterator

public mutating func next() async throws -> ExecCommandOutput? {
try await iterator.next()
}
}

public func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(iterator: sequence.makeAsyncIterator())
}
}

public struct TTYStdinWriter {
internal let channel: Channel

public func write(_ buffer: ByteBuffer) async throws {
try await channel.writeAndFlush(SSHChannelData(type: .channel, data: .byteBuffer(buffer)))
}

public func changeSize(cols: Int, rows: Int) async throws {
try await channel.triggerUserOutboundEvent(
SSHChannelRequestEvent.WindowChangeRequest(
terminalCharacterWidth: 0,
terminalRowHeight: 0,
terminalPixelWidth: 0,
terminalPixelHeight: 0
)
)
}
}

final class ExecCommandHandler: ChannelDuplexHandler {
enum Output {
Expand Down Expand Up @@ -126,7 +175,12 @@ extension SSHClient {
/// - maxResponseSize: The maximum size of the response. If the response is larger, the command will fail.
/// - mergeStreams: If the answer should also include stderr.
/// - inShell: Whether to request the remote server to start a shell before executing the command.
public func executeCommand(_ command: String, maxResponseSize: Int = .max, mergeStreams: Bool = false, inShell: Bool = false) async throws -> ByteBuffer {
public func executeCommand(
_ command: String,
maxResponseSize: Int = .max,
mergeStreams: Bool = false,
inShell: Bool = false
) async throws -> ByteBuffer {
var result = ByteBuffer()
let stream = try await executeCommandStream(command, inShell: inShell)

Expand Down Expand Up @@ -156,12 +210,27 @@ extension SSHClient {
/// - Parameters:
/// - command: The command to execute.
/// - inShell: Whether to request the remote server to start a shell before executing the command.
public func executeCommandStream(_ command: String, inShell: Bool = false) async throws -> AsyncThrowingStream<ExecCommandOutput, Error> {
var streamContinuation: AsyncThrowingStream<ExecCommandOutput, Error>.Continuation!
let stream = AsyncThrowingStream<ExecCommandOutput, Error>(bufferingPolicy: .unbounded) { continuation in
streamContinuation = continuation
}

public func executeCommandStream(
_ command: String,
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
inShell: Bool = false
) async throws -> AsyncThrowingStream<ExecCommandOutput, Error> {
try await _executeCommandStream(
environment: environment,
mode: inShell ? .tty(command: command) : .command(command)
).output
}

enum CommandMode {
case pty(SSHChannelRequestEvent.PseudoTerminalRequest), tty(command: String?), command(String)
}

internal func _executeCommandStream(
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
mode: CommandMode
) async throws -> (channel: Channel, output: AsyncThrowingStream<ExecCommandOutput, Error>) {
let (stream, streamContinuation) = AsyncThrowingStream<ExecCommandOutput, Error>.makeStream()

var hasReceivedChannelSuccess = false
var exitCode: Int?

Expand All @@ -180,9 +249,11 @@ extension SSHClient {
streamContinuation.finish()
}
case .channelSuccess:
if inShell, !hasReceivedChannelSuccess {
let commandData = SSHChannelData(type: .channel,
data: .byteBuffer(ByteBuffer(string: command + ";exit\n")))
if case .tty(.some(let command)) = mode, !hasReceivedChannelSuccess {
let commandData = SSHChannelData(
type: .channel,
data: .byteBuffer(ByteBuffer(string: command + ";exit\n"))
)
channel.writeAndFlush(commandData, promise: nil)
hasReceivedChannelSuccess = true
}
Expand All @@ -204,18 +275,75 @@ extension SSHClient {
return createChannel.futureResult
}.get()

if inShell {
for env in environment {
try await channel.triggerUserOutboundEvent(env)
}

switch mode {
case .pty(let request):
try await channel.triggerUserOutboundEvent(request)
fallthrough
case .tty:
try await channel.triggerUserOutboundEvent(SSHChannelRequestEvent.ShellRequest(
wantReply: true
))
} else {
case .command(let command):
try await channel.triggerUserOutboundEvent(SSHChannelRequestEvent.ExecRequest(
command: command,
wantReply: true
))
}

return stream
return (channel, stream)
}

@available(macOS 15.0, *)
public func withPTY(
_ request: SSHChannelRequestEvent.PseudoTerminalRequest,
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
perform: (_ inbound: TTYOutput, _ outbound: TTYStdinWriter) async throws -> Void
) async throws {
let (channel, output) = try await _executeCommandStream(
environment: environment,
mode: .pty(request)
)

func close() async throws {
try await channel.close()
}

do {
let inbound = TTYOutput(sequence: output)
try await perform(inbound, TTYStdinWriter(channel: channel))
try await close()
} catch {
try await close()
throw error
}
}

@available(macOS 15.0, *)
public func withTTY(
environment: [SSHChannelRequestEvent.EnvironmentRequest] = [],
perform: (_ inbound: TTYOutput, _ outbound: TTYStdinWriter) async throws -> Void
) async throws {
let (channel, output) = try await _executeCommandStream(
environment: environment,
mode: .tty(command: nil)
)

func close() async throws {
try await channel.close()
}

do {
let inbound = TTYOutput(sequence: output)
try await perform(inbound, TTYStdinWriter(channel: channel))
try await close()
} catch {
try await close()
throw error
}
}

/// Executes a command on the remote server. This will return the pair of streams stdout and stderr of the command. If the command fails, the error will be thrown.
Expand Down
53 changes: 53 additions & 0 deletions Tests/CitadelTests/Citadel2Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,57 @@ final class Citadel2Tests: XCTestCase {

try await client.close()
}

@available(macOS 15.0, *)
func testStdinStream() async throws {
guard
let host = ProcessInfo.processInfo.environment["SSH_HOST"],
let _port = ProcessInfo.processInfo.environment["SSH_PORT"],
let port = Int(_port),
let username = ProcessInfo.processInfo.environment["SSH_USERNAME"],
let password = ProcessInfo.processInfo.environment["SSH_PASSWORD"]
else {
throw XCTSkip()
}

let client = try await SSHClient.connect(
host: host,
port: port,
authenticationMethod: .passwordBased(username: username, password: password),
hostKeyValidator: .acceptAnything(),
reconnect: .never
)

try await client.withTTY { inbound, outbound in
try await outbound.write(ByteBuffer(string: "cat"))
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
var i = UInt8.min
for try await value in inbound {
switch value {
case .stdout(let value):
for byte in value.readableBytesView {
XCTAssertEqual(byte, i)
i = i &+ 1
}
case .stderr:
XCTFail("Unexpected stderr")
}
}
}

group.addTask {
for i: UInt8 in .min ... .max {
let value = ByteBufferAllocator().buffer(integer: i)
try await outbound.write(value)
}
}

try await group.next()
group.cancelAll()
}
}

try await client.close()
}
}
Loading