diff --git a/Project.toml b/Project.toml index de89f71..8a061db 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LibSSH" uuid = "00483490-30f8-4353-8aba-35b82f51f4d0" authors = ["James Wrigley and contributors"] -version = "0.1.0" +version = "0.2.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 2d66b4f..885b8ee 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -3,6 +3,19 @@ This documents notable changes in LibSSH.jl. The format is based on [Keep a Changelog](https://keepachangelog.com). +## [v0.2.0] - 2024-02-01 + +### Changed + +- The [Command execution](@ref) API was completely rewritten to match Julia's + API ([#2]). This is a breaking change, any code using the old `ssh.execute()` + will need to be rewritten. + +### Fixed + +- A cause of segfaults was fixed by storing callbacks properly, so they don't get + garbage collected accidentally ([#2]). + ## [v0.1.0] - 2024-01-29 The initial release 🎉 ✨ diff --git a/docs/src/examples.jl b/docs/src/examples.jl index a2b71e9..99219ff 100644 --- a/docs/src/examples.jl +++ b/docs/src/examples.jl @@ -73,7 +73,7 @@ ssh.userauth_list(session) # Now we're authenticated to the server and we can actually do something, like # running a command: -ssh.execute(session, "echo 'Hello world!'") +@assert read(`echo 'Hello world!'`, session, String) == "Hello world!\n" # What we get back is a tuple of the return code and the output from the # command. diff --git a/docs/src/sessions_and_channels.md b/docs/src/sessions_and_channels.md index b1d7049..4b7b3ac 100644 --- a/docs/src/sessions_and_channels.md +++ b/docs/src/sessions_and_channels.md @@ -98,8 +98,19 @@ You should prefer using these instead of more low-level methods, if you can. #### Command execution +LibSSH.jl attempts to mimic Julia's API for running local commands with `run()` +etc. But some features are not supported and we attempt to document all of the +differences. + ```@docs -execute +SshProcessFailedException +SshProcess +Base.wait(::SshProcess) +Base.success(::SshProcess) +Base.run(::Cmd, ::Session) +Base.read(::Cmd, ::Session) +Base.read(::Cmd, ::Session, ::Type{String}) +Base.success(::Cmd, ::Session) ``` #### Direct port forwarding diff --git a/src/LibSSH.jl b/src/LibSSH.jl index 68c16d6..bb35149 100644 --- a/src/LibSSH.jl +++ b/src/LibSSH.jl @@ -158,9 +158,27 @@ function lib_version() VersionNumber(lib.LIBSSH_VERSION_MAJOR, lib.LIBSSH_VERSION_MINOR, lib.LIBSSH_VERSION_MICRO) end +# Safe wrapper around poll_fd(). There's a race condition in older Julia +# versions between the loop condition evaluation and this line, so we wrap +# poll_fd() in a try-catch in case the bind (and thus the file descriptor) has +# been closed in the meantime, which would cause poll_fd() to throw an IOError: +# https://github.com/JuliaLang/julia/pull/52377 +function _safe_poll_fd(args...; kwargs...) + result = nothing + try + result = FileWatching.poll_fd(args...; kwargs...) + catch ex + if !(ex isa Base.IOError) + rethrow() + end + end + + return result +end + include("pki.jl") -include("session.jl") include("callbacks.jl") +include("session.jl") include("channel.jl") include("message.jl") include("server.jl") diff --git a/src/channel.jl b/src/channel.jl index e280b0e..9d3615a 100644 --- a/src/channel.jl +++ b/src/channel.jl @@ -20,6 +20,7 @@ mutable struct SshChannel session::Union{Session, Nothing} close_lock::ReentrantLock local_eof::Bool + callbacks::Union{Callbacks.ChannelCallbacks, Nothing} @doc """ $(TYPEDSIGNATURES) @@ -36,7 +37,7 @@ mutable struct SshChannel elseif own && !isnothing(session) && !session.owning throw(ArgumentError("Cannot create a SshChannel from a non-owning Session")) end - self = new(ptr, own, session, ReentrantLock(), false) + self = new(ptr, own, session, ReentrantLock(), false, nothing) if own push!(session.channels, self) @@ -267,6 +268,7 @@ function set_channel_callbacks(sshchan::SshChannel, callbacks::Callbacks.Channel if ret != SSH_OK throw(LibSSHException("Error when setting channel callbacks: $(ret)")) end + sshchan.callbacks = callbacks end """ @@ -343,101 +345,252 @@ end ## execute() -function _log(msg, userdata) - if userdata[:verbose] +function _log(msg, process) + if process._verbose @info "execute(): $(msg)" end end -function _on_channel_data(session, sshchan, data, is_stderr, userdata) +function _on_channel_data(session, sshchan, data, is_stderr, process) is_stderr = Bool(is_stderr) fd_msg = is_stderr ? "stderr" : "stdout" - _log("channel_data $(length(data)) bytes from $fd_msg", userdata) + _log("channel_data $(length(data)) bytes from $fd_msg", process) - put!(userdata[:channel], copy(data)) + append!(is_stderr ? process.err : process.out, data) return length(data) end -function _on_channel_eof(session, sshchan, userdata) - _log("channel_eof", userdata) +function _on_channel_eof(session, sshchan, process) + _log("channel_eof", process) end -function _on_channel_close(session, sshchan, userdata) - _log("channel_close", userdata) +function _on_channel_close(session, sshchan, process) + _log("channel_close", process) end -function _on_channel_exit_status(session, sshchan, ret, userdata) - _log("exit_status $ret", userdata) - userdata[:exit_code] = Int(ret) +function _on_channel_exit_status(session, sshchan, ret, process) + _log("exit_status $ret", process) + process.exitcode = Int(ret) end +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +This is analogous to `Base.Process`, it represents a command running over an +SSH session. The stdout and stderr output are stored as byte arrays in +`SshProcess.out` and `SshProcess.err` respectively. They can be converted to +strings using e.g. `String(process.out)`. +""" +@kwdef mutable struct SshProcess + out::Vector{UInt8} = Vector{UInt8}() + err::Vector{UInt8} = Vector{UInt8}() + + cmd::Union{Cmd, Nothing} = nothing + exitcode::Int = typemin(Int) + + _sshchan::Union{SshChannel, Nothing} = nothing + _task::Union{Task, Nothing} = nothing + _verbose::Bool = false +end + +function Base.show(io::IO, process::SshProcess) + status = process_running(process) ? "ProcessRunning" : "ProcessExited($(process.exitcode))" + print(io, SshProcess, "(cmd=$(process.cmd), $status)") +end + +Base.process_running(process::SshProcess) = !istaskdone(process._task) +Base.process_exited(process::SshProcess) = istaskdone(process._task) + +""" +$(TYPEDSIGNATURES) + +Check if the process succeeded. +""" +Base.success(process::SshProcess) = process_exited(process) && process.exitcode == 0 + """ $(TYPEDSIGNATURES) -Execute `command` remotely. This will return a tuple of -`(return_code::Union{Int, Nothing}, output::String)`. The `return_code` may be -`nothing` if it wasn't sent by the server (which would point to an incorrect -server implementation). +# Throws +- [`SshProcessFailedException`](@ref): if `ignorestatus()` wasn't used. """ -function execute(session::Session, command::AbstractString; verbose=false) - userdata = Dict{Symbol, Any}(:channel => Channel(), - :exit_code => nothing, - :verbose => verbose) - callbacks = Callbacks.ChannelCallbacks(userdata; +function Base.wait(process::SshProcess) + try + wait(process._task) + catch ex + if !process.cmd.ignorestatus + rethrow() + end + end +end + +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +This is analogous to `ProcessFailedException`. +""" +struct SshProcessFailedException <: Exception + process::SshProcess +end + +function _exec_command(process::SshProcess) + sshchan = process._sshchan + session = sshchan.session + cmd_str = join(process.cmd.exec, " ") + + # Open the session channel + ret = _session_trywait(session) do + lib.ssh_channel_open_session(sshchan.ptr) + end + if ret != SSH_OK + throw(LibSSHException("Failed to open a session channel: $(ret)")) + end + + # Make the request + ret = _session_trywait(session) do + GC.@preserve cmd_str begin + lib.ssh_channel_request_exec(sshchan.ptr, Base.unsafe_convert(Ptr{Cchar}, cmd_str)) + end + end + if ret != SSH_OK + err = get_error(session) + throw(LibSSHException("Error from lib.ssh_channel_request_exec, could not execute command: $(err)")) + end + + # Wait for data to be read + ret = poll_loop(sshchan) + + # Close the channel + lib.ssh_channel_send_eof(sshchan.ptr) + close(sshchan) + + # Check the result of the read for an error + if ret == SSH_ERROR + throw(LibSSHException("Error while reading data from channel: $(ret)")) + end + + if !process.cmd.ignorestatus && process.exitcode != 0 + throw(SshProcessFailedException(process)) + end +end + +""" +$(TYPEDSIGNATURES) + +Run a command on the remote host over an SSH session. Things that aren't +supported compared to `run()`: +- Pipelined commands (use a regular pipe like `foo | bar` instead). +- Setting the directory to execute the command in. +- Setting environment variables (support is possible, it just hasn't been + implemented yet). + +# Throws +- [`SshProcessFailedException`](@ref): if the command fails and `ignorestatus()` + wasn't used. + +# Arguments +- `cmd`: The command to run. This will be converted to a string for running + remotely. +- `session`: The session to run the command over. +- `wait=true`: Wait for the command to finish before returning. +- `verbose=false`: Print debug logging messages. Note that this is not the same + as setting the `log_verbosity` on a [`Session`](@ref). +- `combine_outputs=true`: Write the `stderr` command output to the `IOBuffer` + for the commands `stdout`. If this is `true` then `SshProcess.out` and + `SshProcess.err` will refer to the same object. +- `print_out=true`: Print the output (stdout + stderr by default) of the + command. + +# Examples +```julia-repl +julia> import LibSSH as ssh + +julia> ssh.Demo.DemoServer(2222; password="foo") do + session = ssh.Session("127.0.0.1", 2222) + @assert ssh.userauth_password(session, "foo") == ssh.AuthStatus_Success + + @info "1" + run(`echo foo`, session) + + println() + @info "2" + run(ignorestatus(`foo`), session) + end +[ Info: 1 +foo + +[ Info: 2 +sh: line 1: foo: command not found +``` +""" +function Base.run(cmd::Cmd, session::Session; + wait::Bool=true, verbose::Bool=false, + combine_outputs::Bool=true, print_out::Bool=true) + process = SshProcess(; cmd, _verbose=verbose) + if combine_outputs + process.err = process.out + end + + callbacks = Callbacks.ChannelCallbacks(process; on_eof=_on_channel_eof, on_close=_on_channel_close, on_data=_on_channel_data, on_exit_status=_on_channel_exit_status) + process._sshchan = SshChannel(session) + set_channel_callbacks(process._sshchan, callbacks) - SshChannel(session) do sshchan - set_channel_callbacks(sshchan, callbacks) + process._task = Threads.@spawn _exec_command(process) + if wait + # Note the use of Base.wait() to avoid aliasing with the `wait` argument + Base.wait(process._task) - # Open the session - ret = _session_trywait(session) do - lib.ssh_channel_open_session(sshchan.ptr) - end - if ret != SSH_OK - throw(LibSSHException("Failed to open a session channel: $(ret)")) + if print_out + print(String(process.out)) end + end - # Make the request - ret = _session_trywait(session) do - GC.@preserve command begin - lib.ssh_channel_request_exec(sshchan.ptr, Base.unsafe_convert(Ptr{Cchar}, command)) - end - end - if ret != SSH_OK - err = get_error(session) - throw(LibSSHException("Error from channel_request_exec, could not execute command: $(err)")) - end + return process +end - # Start a task to read incoming data and append it to a vector - cmd_output = String[] - reader_task = Threads.@spawn for data in userdata[:channel] - try - push!(cmd_output, String(data)) - catch ex - @error "Error handling command output" exception=(ex, catch_backtrace()) - end - end +""" +$(TYPEDSIGNATURES) - # Wait for data to be read - ret = poll_loop(sshchan) +Read the output from the command in bytes. +""" +function Base.read(cmd::Cmd, session::Session) + process = run(cmd, session; print_out=false) + return process.out +end - # Close the reader task and send an EOF - close(userdata[:channel]) - wait(reader_task) - lib.ssh_channel_send_eof(sshchan.ptr) +""" +$(TYPEDSIGNATURES) - # Check the result of the read for an error - if ret == SSH_ERROR - throw(LibSSHException("Error while reading data from channel: $(ret)")) - end +Read the output from the command as a String. - return (userdata[:exit_code]::Union{Int, Nothing}, string(cmd_output...)) - end -end +# Examples +```julia-repl +julia> import LibSSH as ssh + +julia> ssh.Demo.DemoServer(2222; password="foo") do + session = ssh.Session("127.0.0.1", 2222) + @assert ssh.userauth_password(session, "foo") == ssh.AuthStatus_Success + + @show read(`echo foo`, session, String) + end +read(`echo foo`, session, String) = "foo\\n" +``` +""" +Base.read(cmd::Cmd, session::Session, ::Type{String}) = String(read(cmd, session)) + +""" +$(TYPEDSIGNATURES) + +Check the command succeeded. +""" +Base.success(cmd::Cmd, session::Session) = success(run(cmd, session; print_out=false)) ## Direct port forwarding diff --git a/src/server.jl b/src/server.jl index 3017493..d040f12 100644 --- a/src/server.jl +++ b/src/server.jl @@ -193,6 +193,10 @@ mutable struct Bind end end +function Base.show(io::IO, bind::Bind) + print(io, Bind, "(addr=$(bind.addr), port=$(bind.port))") +end + """ $(TYPEDSIGNATURES) @@ -334,8 +338,8 @@ doesn't matter much. It'll only control how frequently the listen loop wakes up to check if the bind has been closed yet. """ function listen(handler::Function, bind::Bind; poll_timeout=0.1) - if poll_timeout < 0 - throw(ArgumentError("poll_timeout cannot be negative!")) + if poll_timeout <= 0 + throw(ArgumentError("poll_timeout=$(poll_timeout), it must be greater than 0")) end ret = lib.ssh_bind_listen(bind.ptr) @@ -357,20 +361,11 @@ function listen(handler::Function, bind::Bind; poll_timeout=0.1) notify(bind._listener_event) end - # Wait for new connection attempts. Note that there's a race condition - # between the loop condition evaluation and this line, so we wrap - # poll_fd() in a try-catch in case the bind (and thus the file - # descriptor) has been closed in the meantime, which would cause - # poll_fd() to throw an IOError. - local poll_result - try - poll_result = FileWatching.poll_fd(fd, poll_timeout; readable=true) - catch ex - if ex isa Base.IOError - continue - else - rethrow() - end + poll_result = _safe_poll_fd(fd, poll_timeout; readable=true) + if isnothing(poll_result) + # This means the session's file descriptor has been closed (see the + # comments for _safe_poll_fd()). + continue end # The first thing we do is check if the Bind has been closed, because @@ -473,6 +468,7 @@ function set_server_callbacks(session::Session, callbacks::ServerCallbacks) if ret != SSH_OK throw(LibSSHException("Error setting server callbacks: $(ret)")) end + session.server_callbacks = callbacks end """ @@ -679,7 +675,6 @@ end authenticated::Bool = false session_event::Union{ssh.SessionEvent, Nothing} = nothing - server_callbacks::ServerCallbacks = ServerCallbacks() channel_callbacks::ChannelCallbacks = ChannelCallbacks() unclaimed_channels::Vector{ssh.SshChannel} = ssh.SshChannel[] channel_operations::Vector{Any} = [] @@ -697,6 +692,7 @@ function Base.close(client::Client) end close(client.session_event) + close(client.session) wait(client.task) end @@ -714,6 +710,10 @@ $(TYPEDFIELDS) clients::Vector{Client} = Client[] end +function Base.show(io::IO, ds::DemoServer) + print(io, DemoServer, "(bind.port=$(ds.bind.port))") +end + """ $(TYPEDSIGNATURES) @@ -824,11 +824,11 @@ function _handle_client(session::ssh.Session, ds::DemoServer) session, password=ds.password, verbose=ds.verbose) - client.server_callbacks = ServerCallbacks(client; - on_auth_password=on_auth_password, - on_auth_none=on_auth_none, - on_service_request=on_service_request, - on_channel_open_request_session=on_channel_open) + server_callbacks = ServerCallbacks(client; + on_auth_password=on_auth_password, + on_auth_none=on_auth_none, + on_service_request=on_service_request, + on_channel_open_request_session=on_channel_open) client.channel_callbacks = ChannelCallbacks(client; on_eof=on_channel_eof, on_close=on_channel_close, @@ -837,7 +837,7 @@ function _handle_client(session::ssh.Session, ds::DemoServer) on_env_request=on_channel_env_request) client.task = current_task() - ssh.set_server_callbacks(session, client.server_callbacks) + ssh.set_server_callbacks(session, server_callbacks) if !ssh.handle_key_exchange(session) @error "Key exchange failed" return @@ -877,12 +877,13 @@ $(TYPEDSIGNATURES) Stop a [`DemoServer`](@ref). """ function stop(demo_server::DemoServer) - for client in demo_server.clients - close(client) - end - if !isnothing(demo_server.listener_task) close(demo_server.bind) + + for client in demo_server.clients + close(client) + end + wait(demo_server.listener_task) demo_server.listener_task = nothing end @@ -1011,19 +1012,18 @@ end @kwdef mutable struct Forwarder client::Client sshchan::ssh.SshChannel - channel_callbacks::ChannelCallbacks = ChannelCallbacks() socket::Sockets.TCPSocket = Sockets.TCPSocket() task::Union{Task, Nothing} = nothing end function Forwarder(client::Client, sshchan::ssh.SshChannel, hostname::String, port::Integer) self = Forwarder(; client, sshchan) - self.channel_callbacks = ChannelCallbacks(self; - on_eof=on_fwd_channel_eof, - on_close=on_fwd_channel_close, - on_data=on_fwd_channel_data, - on_exit_status=on_fwd_channel_exit_status) - ssh.set_channel_callbacks(sshchan, self.channel_callbacks) + channel_callbacks = ChannelCallbacks(self; + on_eof=on_fwd_channel_eof, + on_close=on_fwd_channel_close, + on_data=on_fwd_channel_data, + on_exit_status=on_fwd_channel_exit_status) + ssh.set_channel_callbacks(sshchan, channel_callbacks) # Set up the listener socket. Restrict ourselves to IPv4 for simplicity # since the test HTTP servers bind to the IPv4 loopback interface (and diff --git a/src/session.jl b/src/session.jl index 88b3119..6f09d2e 100644 --- a/src/session.jl +++ b/src/session.jl @@ -12,6 +12,7 @@ mutable struct Session owning::Bool log_verbosity::Int channels::Vector{Any} + server_callbacks::Union{Callbacks.ServerCallbacks, Nothing} @doc """ $(TYPEDSIGNATURES) @@ -33,7 +34,7 @@ mutable struct Session # Set to non-blocking mode lib.ssh_set_blocking(ptr, 0) - session = new(ptr, own, -1, []) + session = new(ptr, own, -1, [], nothing) if !isnothing(log_verbosity) session.log_verbosity = log_verbosity end @@ -46,6 +47,14 @@ mutable struct Session end end +function Base.show(io::IO, session::Session) + if isopen(session) + print(io, Session, "(host=$(session.host), port=$(session.port), user=$(session.user), connected=$(isconnected(session)))") + else + print(io, Session, "()") + end +end + # Non-throwing finalizer for Session objects function _finalizer(session::Session) try @@ -179,7 +188,7 @@ const SESSION_PROPERTY_OPTIONS = Dict(:host => (SSH_OPTIONS_HOST, Cstring), const SAVED_PROPERTIES = (:log_verbosity,) function Base.propertynames(::Session, private::Bool=false) - (:host, :port, :user, :log_verbosity, :owning, (private ? (:ptr, :channels) : ())...) + (:host, :port, :user, :log_verbosity, :owning, (private ? (:ptr, :channels, :server_callbacks) : ())...) end function Base.getproperty(session::Session, name::Symbol) @@ -229,7 +238,7 @@ function Base.setproperty!(session::Session, name::Symbol, value) error("type Session has no field $(name)") end - if name == :ptr + if name == :ptr || name == :server_callbacks return setfield!(session, name, value) end @@ -272,8 +281,15 @@ $(TYPEDSIGNATURES) Waits for a session in non-blocking mode. If the session is in blocking mode the function will return immediately. + +The `poll_timeout` argument has the same meaning as [`listen(::Function, +::Bind)`](@ref). """ -function Base.wait(session::Session) +function Base.wait(session::Session; poll_timeout=0.1) + if poll_timeout <= 0 + throw(ArgumentError("poll_timeout=$(poll_timeout), it must be greater than 0")) + end + if lib.ssh_is_blocking(session.ptr) == 1 return end @@ -283,7 +299,16 @@ function Base.wait(session::Session) writable = (poll_flags & lib.SSH_WRITE_PENDING) > 0 fd = RawFD(lib.ssh_get_fd(session.ptr)) - FileWatching.poll_fd(fd; readable, writable) + while isopen(session) + result = _safe_poll_fd(fd, poll_timeout; readable, writable) + if isnothing(result) + # This means the session's file descriptor has been closed (see the + # comments for _safe_poll_fd()). + continue + elseif !result.timedout + break + end + end return nothing end diff --git a/test/LibSSHTests.jl b/test/LibSSHTests.jl index 48743da..d3eeb7f 100644 --- a/test/LibSSHTests.jl +++ b/test/LibSSHTests.jl @@ -182,6 +182,12 @@ end end @test length(demo_server.clients) == 2 end + + # Test that the DemoServer cleans up lingering sessions + server_task = Threads.@spawn DemoServer(2222; timeout=10) do + session = ssh.Session("127.0.0.1", 2222) + end + @test timedwait(() -> istaskdone(server_task), 5) == :ok end @testset "Session" begin @@ -319,18 +325,20 @@ end end @testset "Executing commands" begin - # Test executing commands - demo_server_with_session(2222) do session - ret, output = ssh.execute(session, "whoami") - @test ret == 0 - @test strip(output) == username() - end - - # Check that we read stderr as well as stdout demo_server_with_session(2222) do session - ret, output = ssh.execute(session, "thisdoesntexist") - @test ret == 127 - @test !isempty(output) + # Smoke test + process = run(`whoami`, session; print_out=false) + @test success(process) + @test chomp(String(process.out)) == username() + + # Check that we read stderr as well as stdout + process = run(ignorestatus(`thisdoesntexist`), session; print_out=false) + @test process.exitcode == 127 + @test !isempty(String(process.out)) + + # Test Base methods + @test read(`echo foo`, session, String) == "foo\n" + @test success(`whoami`, session) end end