Skip to content

Commit

Permalink
Support shared sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
filchristou committed Jun 21, 2023
1 parent bd58191 commit 107178b
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 37 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
ReplMaker = "b873ce64-0db9-51f5-a568-4457d8e49576"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
OpenSSH_jll = "8.1"
Expand Down
7 changes: 6 additions & 1 deletion src/RemoteREPL.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module RemoteREPL

export connect_repl, serve_repl, @remote, connect_remote
using REPL, ReplMaker
using Sockets, Serialization
using UUIDs, Logging
using OpenSSH_jll

export connect_repl, serve_repl, @remote, @remoterepl, connect_remote

const DEFAULT_PORT = 27754
const PROTOCOL_MAGIC = "RemoteREPL"
Expand Down
57 changes: 42 additions & 15 deletions src/client.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
using ReplMaker
using REPL
using Serialization
using Sockets

struct RemoteException <: Exception
msg::String
end
Expand Down Expand Up @@ -114,7 +109,8 @@ mutable struct Connection
region::Union{AbstractString,Nothing}
namespace::Union{AbstractString,Nothing}
socket::Union{IO,Nothing}
in_module::Symbol
in_module::Union{Symbol, Expr}
session_id::UUID
end

function Connection(; host::Union{AbstractString,Sockets.IPAddr}=Sockets.localhost,
Expand All @@ -123,8 +119,11 @@ function Connection(; host::Union{AbstractString,Sockets.IPAddr}=Sockets.localho
ssh_opts::Cmd=``,
region=nothing,
namespace=nothing,
in_module::Symbol=:Main)
conn = Connection(host, port, tunnel, ssh_opts, region, namespace, nothing, in_module)
in_module::Symbol=:Main,
session_id=nothing)
sesid = isnothing(session_id) ? UUIDs.uuid4() : session_id
@info "Using session id $(sesid)"
conn = Connection(host, port, tunnel, ssh_opts, region, namespace, nothing, in_module, sesid)
setup_connection!(conn)
finalizer(close, conn)
end
Expand All @@ -139,6 +138,8 @@ function setup_connection!(conn::Connection)
tunnel=conn.tunnel, ssh_opts=conn.ssh_opts, region=conn.region,
namespace=conn.namespace)
end
# transmit session id
serialize(socket, conn.session_id)
try
verify_header(socket)
catch exc
Expand Down Expand Up @@ -440,9 +441,10 @@ function connect_repl(host=Sockets.localhost, port::Integer=DEFAULT_PORT;
region::Union{AbstractString,Nothing}=nothing,
namespace::Union{AbstractString,Nothing}=nothing,
startup_text::Bool=true,
repl=Base.active_repl)
repl=Base.active_repl,
session_id=nothing)

conn = connect_remote(host, port; tunnel, ssh_opts, region,namespace)
conn = connect_remote(host, port; tunnel, ssh_opts, region, namespace, session_id)
out_stream = stdout
prompt = ReplMaker.initrepl(c->run_remote_repl_command(conn, out_stream, c),
repl = Base.active_repl,
Expand Down Expand Up @@ -473,7 +475,8 @@ function connect_remote(host=Sockets.localhost, port::Integer=DEFAULT_PORT;
tunnel::Symbol = host!=Sockets.localhost ? :ssh : :none,
ssh_opts::Cmd=``,
region::Union{AbstractString,Nothing}=nothing,
namespace::Union{AbstractString,Nothing}=nothing)
namespace::Union{AbstractString,Nothing}=nothing,
session_id=nothing)

global _repl_client_connection

Expand All @@ -485,7 +488,7 @@ function connect_remote(host=Sockets.localhost, port::Integer=DEFAULT_PORT;
end
end
conn = RemoteREPL.Connection(host=host, port=port, tunnel=tunnel,
ssh_opts=ssh_opts, region=region, namespace=namespace)
ssh_opts=ssh_opts, region=region, namespace=namespace, session_id = session_id)

# Record the connection in a global variable so it's accessible to REPL and `@remote`
_repl_client_connection = conn
Expand Down Expand Up @@ -536,12 +539,36 @@ end

_remote_expr(conn, ex) = :(remote_eval_and_fetch($conn, $(QuoteNode(ex))))

"""
@remoterepl cmd
Parse `cmd` and appropriately handle/execute it in the remote REPL server session.
It's a superset of `@remote`, which can only evaluate expressions.
# Examples
Change the session's module evaluation settings to evaluate commands in `TempMod`
```
julia> @remoterepl "%module TempMod"
```
"""
macro remoterepl(cmd)
_remoterepl_expr(:_repl_client_connection, cmd)
end

macro remoterepl(conn, cmd)
_remoterepl_expr(esc(conn), cmd)
end

_remoterepl_expr(conn, cmd) = :(run_remote_repl_command($conn, Base.stdout, $(string(cmd))))

#--------------------------------------------------
"""
remote_eval(cmdstr)
remote_eval(host, port, cmdstr)
Parse a string `cmdstr`, evaluate it in the remote REPL server's `Main` module,
Parse a string `cmdstr`, evaluate it in the remote REPL server's `Main` module or the session with `session_id`,
then close the connection. Returns the result which the REPL would normally
pass to `show()` (likely a `Text` object).
Expand All @@ -553,8 +580,8 @@ RemoteREPL.remote_eval("exit()")
```
"""
function remote_eval(host, port::Integer, cmdstr::AbstractString;
tunnel::Symbol = host!=Sockets.localhost ? :ssh : :none)
conn = Connection(; host=host, port=port, tunnel=tunnel)
tunnel::Symbol = host!=Sockets.localhost ? :ssh : :none, session_id=nothing)
conn = Connection(; host=host, port=port, tunnel=tunnel, session_id=session_id)
local result
try
setup_connection!(conn)
Expand Down
43 changes: 24 additions & 19 deletions src/server.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
using Sockets
using Serialization
using REPL
using Logging

mutable struct ServerSideSession
socket
sockets::Vector
display_properties::Dict
in_module::Module
end

Base.isopen(session::ServerSideSession) = isopen(session.socket)
Base.close(session::ServerSideSession) = close(session.socket)
Base.isopen(session::ServerSideSession) = any(isopen.(session.sockets))
function closeNdelete!(session::ServerSideSession, socket)
close(socket)
filter!(!=(socket), session.sockets)
end

function send_header(io, ser_version=Serialization.ser_version)
write(io, PROTOCOL_MAGIC, PROTOCOL_VERSION)
Expand Down Expand Up @@ -177,8 +175,7 @@ function serialize_responses(socket, response_chan)
end

# Serve a remote REPL session to a single client
function serve_repl_session(session)
socket = session.socket
function serve_repl_session(session, socket)
send_header(socket)
@sync begin
request_chan = Channel(1)
Expand Down Expand Up @@ -248,27 +245,35 @@ end
serve_repl(port::Integer; kws...) = serve_repl(Sockets.localhost, port; kws...)

function serve_repl(server::Base.IOServer; on_client_connect=nothing)
open_sessions = Set{ServerSideSession}()
open_sessions = Dict{UUID, ServerSideSession}()
@sync try
while isopen(server)
socket = accept(server)
session = ServerSideSession(socket, Dict(), Main)
push!(open_sessions, session)
# expect session id
session_id = deserialize(socket)
@show session_id
session = if haskey(open_sessions, session_id)
push!(open_sessions[session_id].sockets, socket)
open_sessions[session_id]
else
open_sessions[session_id] = ServerSideSession([socket], Dict(), Main)
end
socketidx = length(session.sockets)
peer = getpeername(socket)
@async try
if !isnothing(on_client_connect)
on_client_connect(session)
end
serve_repl_session(session)
serve_repl_session(session, socket)
catch exc
if !(exc isa EOFError && !isopen(session))
if !(exc isa EOFError && !isopen(socket))
@warn "Something went wrong evaluating client command" #=
=# exception=exc,catch_backtrace()
end
finally
@info "REPL client exited" peer
close(session)
pop!(open_sessions, session)
closeNdelete!(session, socket)
length(session.sockets) == 0 && delete!(open_sessions, session_id)
end
@info "REPL client opened a connection" peer
end
Expand All @@ -280,8 +285,8 @@ function serve_repl(server::Base.IOServer; on_client_connect=nothing)
@error "Unexpected server failure" isopen(server) exception=exc,catch_backtrace()
rethrow()
finally
for session in open_sessions
close(session)
for session in values(open_sessions)
close.(session.sockets)
end
end
end
Expand Down
2 changes: 0 additions & 2 deletions src/tunnels.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Utilities for securely tunnelling traffic from client to a remote server

using OpenSSH_jll

# Find a free port on `network_interface`
function find_free_port(network_interface)
# listen on port 0 => kernel chooses a free port. See, for example,
Expand Down

0 comments on commit 107178b

Please sign in to comment.