Skip to content

Commit

Permalink
Merge pull request #8 from JuliaWeb/fixes
Browse files Browse the repository at this point in the history
Forwarder and authentication fixes
  • Loading branch information
JamesWrigley authored Mar 10, 2024
2 parents 85c8eb5 + abc3cbf commit 83e948a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 29 deletions.
4 changes: 4 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Changelog](https://keepachangelog.com).
[`Gssapi.principal_name()`](@ref) was added to get the name of the default
principal if one is available ([#6]).
- An experimental [`authenticate()`](@ref) function to simplify authenticating ([#7]).
- A do-constructor for [`Session(::Function)`](@ref) ([#8]).

### Changed

Expand All @@ -36,6 +37,9 @@ Changelog](https://keepachangelog.com).
escaped properly ([#6]).
- Fixed a bug in [`Base.run(::Cmd, ::Session)`](@ref) that would clear the
output buffer when printing ([#6]).
- Changed [`poll_loop()`](@ref) to poll the stdout and stderr streams, which
fixes a bug where callbacks would sometimes not get executed even when data
was available ([#8]).

## [v0.2.1] - 2024-02-27

Expand Down
1 change: 1 addition & 0 deletions docs/src/sessions_and_channels.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ session must be authenticated before being able to do anything with it.
Session
Session(::Union{AbstractString, Sockets.IPAddr})
Session(::lib.ssh_session)
Session(::Function)
connect
disconnect
isconnected
Expand Down
64 changes: 48 additions & 16 deletions src/channel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,25 @@ function poll_loop(sshchan::SshChannel)

ret = SSH_ERROR
while true
# We always check if the channel and session are open within the loop
# because ssh_channel_poll() will execute callbacks, which could close
# them before returning.
if !isopen(sshchan)
return nothing
end
# Poll stdout and stderr
for io_stream in (0, 1)
# We always check if the channel and session are open within the loop
# because ssh_channel_poll() will execute callbacks, which could close
# them before returning.
if !isopen(sshchan)
return nothing
end

# Note that we don't actually read any data in this loop, that's
# handled by the callbacks, which are called by ssh_channel_poll().
ret = lib.ssh_channel_poll(sshchan.ptr, 0)
# Note that we don't actually read any data in this loop, that's
# handled by the callbacks, which are called by ssh_channel_poll().
ret = lib.ssh_channel_poll(sshchan.ptr, io_stream)

# Break if there was an error, or if an EOF has been sent
if ret == SSH_ERROR || ret == SSH_EOF
break
# Break if there was an error, or if an EOF has been sent. We use a
# @goto here (Knuth forgive me) to break out of the outer loop as
# well as the inner one.
if ret == SSH_ERROR || ret == SSH_EOF
@goto loop_end
end
end

if !isopen(sshchan.session)
Expand All @@ -355,6 +360,8 @@ function poll_loop(sshchan::SshChannel)
wait(sshchan.session)
end

@label loop_end

return Int(ret)
end

Expand Down Expand Up @@ -615,7 +622,11 @@ Base.success(cmd::Cmd, session::Session) = success(run(cmd, session; print_out=f
function _on_client_channel_data(session, sshchan, data, is_stderr, client)
_logcb(client, "Received $(length(data)) bytes from server")

write(client.sock, data)
if isopen(client.sock)
write(client.sock, data)
else
@warn "Client socket has been closed, dropping $(length(data)) bytes from the remote forwarded port"
end

return length(data)
end
Expand All @@ -638,12 +649,22 @@ end
# the channel and forwarding data to the server and client.
function _handle_forwarding_client(client)
# Start polling the client channel
poller = Threads.@spawn poll_loop(client.sshchan)
poller = errormonitor(Threads.@spawn poll_loop(client.sshchan))

# Read data from the socket while it's open
sock = client.sock
while isopen(sock)
data = readavailable(sock)
local data
try
# This will throw an IOError if the socket is closed during the read
data = readavailable(sock)
catch ex
if ex isa Base.IOError
continue
else
rethrow()
end
end

if !isempty(data) && isopen(client.sshchan)
write(client.sshchan, data)
Expand Down Expand Up @@ -715,6 +736,7 @@ This object manages a direct forwarding channel between `localport` and `remoteh
mutable struct Forwarder
remotehost::String
remoteport::Int
localinterface::Sockets.IPAddr
localport::Int

_listen_server::TCPServer
Expand Down Expand Up @@ -744,7 +766,7 @@ mutable struct Forwarder
verbose=false, localinterface::Sockets.IPAddr=IPv4(0))
listen_server = Sockets.listen(localinterface, localport)

self = new(remotehost, remoteport, localport,
self = new(remotehost, remoteport, localinterface, localport,
listen_server, nothing, _ForwardingClient[],
session, verbose)

Expand All @@ -759,6 +781,14 @@ mutable struct Forwarder
end
end

function Base.show(io::IO, f::Forwarder)
if !isopen(f)
print(io, Forwarder, "()")
else
print(io, Forwarder, "($(f.localinterface):$(f.localport)$(f.remotehost):$(f.remoteport))")
end
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -792,6 +822,8 @@ function Base.close(forwarder::Forwarder)
end
end

Base.isopen(forwarder::Forwarder) = isopen(forwarder._listen_server)

# This function accepts connections on the local port and sets up
# _ForwardingClient's for them.
function _fwd_listen(forwarder::Forwarder)
Expand Down
51 changes: 38 additions & 13 deletions src/session.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mutable struct Session
$(TYPEDSIGNATURES)
This is only useful if you already have a `ssh_session` (i.e. in a
server). Do not use it if you want a client, use the other constructor.
server). Do not use it if you want a client, use the host/port constructor.
# Arguments
- `ptr`: A pointer to the `lib.ssh_session` to wrap.
Expand Down Expand Up @@ -135,6 +135,21 @@ end
"""
$(TYPEDSIGNATURES)
Do-constructor for [`Session`](@ref). All arguments are forwarded to the other
constructors.
"""
function Session(f::Function, args...; kwargs...)
session = Session(args...; kwargs...)
try
return f(session)
finally
close(session)
end
end

"""
$(TYPEDSIGNATURES)
Check if the `Session` holds a valid pointer to a `lib.ssh_session`. This will
be `false` if the session has been closed.
"""
Expand Down Expand Up @@ -469,25 +484,25 @@ end

# Helper function to call userauth_kbdint() until we get a non-AuthStatus_Info
# response.
function _try_userauth_kbdint(session::Session; answers=nothing)
function _try_userauth_kbdint(session::Session, answers, throw_on_error)
# We keep track of when we need to start an keyboard-interactive auth
# session with the server through the _require_init_kbdint field.
if session._require_init_kbdint
userauth_kbdint(session)
userauth_kbdint(session; throw_on_error)
end

if !isnothing(answers)
userauth_kbdint_setanswers(session, answers)
end

status = userauth_kbdint(session)
status = userauth_kbdint(session; throw_on_error)
if status == AuthStatus_Info
prompts = userauth_kbdint_getprompts(session)

# If the server responds with Info but doesn't send any prompts, then we
# just keep trying until we get something different. Servers can do that.
if isempty(prompts)
return _try_userauth_kbdint(session)
return _try_userauth_kbdint(session, nothing, throw_on_error)
end
end

Expand Down Expand Up @@ -541,14 +556,18 @@ It can return either:
server. Use [`userauth_kbdint_getprompts()`](@ref) to get the prompts if
`authenticate()` returns `AuthMethod_Interactive` and then pass the answers in
the next call.
- `throw_on_error=true`: Whether to throw if there's an internal error while
authenticating (`AuthStatus_Error`).
# Throws
- `ArgumentError`: If the session isn't connected, or if both `password` and
`kbdint_answers` are passed.
- `ErrorException`: If there are no more supported authentication methods
available.
- `LibSSHException`: If there's an internal error and `throw_on_error=true`.
"""
function authenticate(session::Session; password=nothing, kbdint_answers=nothing)
function authenticate(session::Session; password=nothing, kbdint_answers=nothing,
throw_on_error=true)
if !isconnected(session)
throw(ArgumentError("Session is disconnected, cannot authenticate"))
elseif !isnothing(password) && !isnothing(kbdint_answers)
Expand All @@ -563,12 +582,18 @@ function authenticate(session::Session; password=nothing, kbdint_answers=nothing
# attempt authentication if so.
if !isnothing(password) || !isnothing(kbdint_answers)
status = if !isnothing(password)
userauth_password(session, password)
userauth_password(session, password; throw_on_error)
else
_try_userauth_kbdint(session; answers=kbdint_answers)
_try_userauth_kbdint(session, kbdint_answers, throw_on_error)
end

# For the sake of consistency we never return AuthStatus_Info to the
# caller.
if !isnothing(kbdint_answers) && status == AuthStatus_Info
status = AuthMethod_Interactive
end

return status == AuthStatus_Partial ? authenticate(session) : status
return status == AuthStatus_Partial ? authenticate(session; throw_on_error) : status
end

if isempty(session._auth_methods)
Expand All @@ -582,22 +607,22 @@ function authenticate(session::Session; password=nothing, kbdint_answers=nothing
if (_can_attempt_auth(session, AuthMethod_GSSAPI_MIC)
&& Gssapi.isavailable()
&& !isnothing(Gssapi.principal_name()))
status = userauth_gssapi(session)
status = userauth_gssapi(session; throw_on_error)

if status == AuthStatus_Denied
push!(session._attempted_auth_methods, AuthMethod_GSSAPI_MIC)

# If the ticket isn't valid but there are still other methods
# available, continue trying. Otherwise just return Denied.
if length(session._auth_methods) > 1
return authenticate(session)
return authenticate(session; throw_on_error)
else
return status
end
elseif status == AuthStatus_Partial
# If we're now partially authenticated, then we continue with some
# other method.
return authenticate(session)
return authenticate(session; throw_on_error)
else
return status
end
Expand All @@ -613,7 +638,7 @@ function authenticate(session::Session; password=nothing, kbdint_answers=nothing
# Start a keyboard-interactive session if necessary. We call this now so
# that the caller can call userauth_kbdint_getprompts() immediately.
if session._require_init_kbdint
userauth_kbdint(session)
userauth_kbdint(session; throw_on_error)
end

return AuthMethod_Interactive
Expand Down

0 comments on commit 83e948a

Please sign in to comment.