diff --git a/Project.toml b/Project.toml index 4bc0ce3..9de338b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" ReplMaker = "b873ce64-0db9-51f5-a568-4457d8e49576" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] OpenSSH_jll = "8.1" diff --git a/src/RemoteREPL.jl b/src/RemoteREPL.jl index b22427c..ebbac84 100644 --- a/src/RemoteREPL.jl +++ b/src/RemoteREPL.jl @@ -1,6 +1,11 @@ module RemoteREPL -export connect_repl, serve_repl, @remote, connect_remote +using REPL, ReplMaker +using Sockets, Serialization +using UUIDs, Logging +using OpenSSH_jll + +export connect_repl, serve_repl, @remote, @remoterepl, connect_remote const DEFAULT_PORT = 27754 const PROTOCOL_MAGIC = "RemoteREPL" diff --git a/src/client.jl b/src/client.jl index 4c28b2e..d743b98 100644 --- a/src/client.jl +++ b/src/client.jl @@ -1,8 +1,3 @@ -using ReplMaker -using REPL -using Serialization -using Sockets - struct RemoteException <: Exception msg::String end @@ -114,7 +109,8 @@ mutable struct Connection region::Union{AbstractString,Nothing} namespace::Union{AbstractString,Nothing} socket::Union{IO,Nothing} - in_module::Symbol + in_module::Union{Symbol, Expr} + session_id::UUID end function Connection(; host::Union{AbstractString,Sockets.IPAddr}=Sockets.localhost, @@ -123,8 +119,11 @@ function Connection(; host::Union{AbstractString,Sockets.IPAddr}=Sockets.localho ssh_opts::Cmd=``, region=nothing, namespace=nothing, - in_module::Symbol=:Main) - conn = Connection(host, port, tunnel, ssh_opts, region, namespace, nothing, in_module) + in_module::Symbol=:Main, + session_id=nothing) + sesid = isnothing(session_id) ? UUIDs.uuid4() : session_id + @info "Using session id $(sesid)" + conn = Connection(host, port, tunnel, ssh_opts, region, namespace, nothing, in_module, sesid) setup_connection!(conn) finalizer(close, conn) end @@ -139,6 +138,8 @@ function setup_connection!(conn::Connection) tunnel=conn.tunnel, ssh_opts=conn.ssh_opts, region=conn.region, namespace=conn.namespace) end + # transmit session id + serialize(socket, conn.session_id) try verify_header(socket) catch exc @@ -440,9 +441,10 @@ function connect_repl(host=Sockets.localhost, port::Integer=DEFAULT_PORT; region::Union{AbstractString,Nothing}=nothing, namespace::Union{AbstractString,Nothing}=nothing, startup_text::Bool=true, - repl=Base.active_repl) + repl=Base.active_repl, + session_id=nothing) - conn = connect_remote(host, port; tunnel, ssh_opts, region,namespace) + conn = connect_remote(host, port; tunnel, ssh_opts, region, namespace, session_id) out_stream = stdout prompt = ReplMaker.initrepl(c->run_remote_repl_command(conn, out_stream, c), repl = Base.active_repl, @@ -473,7 +475,8 @@ function connect_remote(host=Sockets.localhost, port::Integer=DEFAULT_PORT; tunnel::Symbol = host!=Sockets.localhost ? :ssh : :none, ssh_opts::Cmd=``, region::Union{AbstractString,Nothing}=nothing, - namespace::Union{AbstractString,Nothing}=nothing) + namespace::Union{AbstractString,Nothing}=nothing, + session_id=nothing) global _repl_client_connection @@ -485,7 +488,7 @@ function connect_remote(host=Sockets.localhost, port::Integer=DEFAULT_PORT; end end conn = RemoteREPL.Connection(host=host, port=port, tunnel=tunnel, - ssh_opts=ssh_opts, region=region, namespace=namespace) + ssh_opts=ssh_opts, region=region, namespace=namespace, session_id = session_id) # Record the connection in a global variable so it's accessible to REPL and `@remote` _repl_client_connection = conn @@ -536,12 +539,36 @@ end _remote_expr(conn, ex) = :(remote_eval_and_fetch($conn, $(QuoteNode(ex)))) +""" + @remoterepl cmd + +Parse `cmd` and appropriately handle/execute it in the remote REPL server session. +It's a superset of `@remote`, which can only evaluate expressions. + +# Examples + +Change the session's module evaluation settings to evaluate commands in `TempMod` + +``` +julia> @remoterepl "%module TempMod" +``` +""" +macro remoterepl(cmd) + _remoterepl_expr(:_repl_client_connection, cmd) +end + +macro remoterepl(conn, cmd) + _remoterepl_expr(esc(conn), cmd) +end + +_remoterepl_expr(conn, cmd) = :(run_remote_repl_command($conn, Base.stdout, $(string(cmd)))) + #-------------------------------------------------- """ remote_eval(cmdstr) remote_eval(host, port, cmdstr) -Parse a string `cmdstr`, evaluate it in the remote REPL server's `Main` module, +Parse a string `cmdstr`, evaluate it in the remote REPL server's `Main` module or the session with `session_id`, then close the connection. Returns the result which the REPL would normally pass to `show()` (likely a `Text` object). @@ -553,8 +580,8 @@ RemoteREPL.remote_eval("exit()") ``` """ function remote_eval(host, port::Integer, cmdstr::AbstractString; - tunnel::Symbol = host!=Sockets.localhost ? :ssh : :none) - conn = Connection(; host=host, port=port, tunnel=tunnel) + tunnel::Symbol = host!=Sockets.localhost ? :ssh : :none, session_id=nothing) + conn = Connection(; host=host, port=port, tunnel=tunnel, session_id=session_id) local result try setup_connection!(conn) diff --git a/src/server.jl b/src/server.jl index f05c4c4..65a9019 100644 --- a/src/server.jl +++ b/src/server.jl @@ -1,16 +1,14 @@ -using Sockets -using Serialization -using REPL -using Logging - mutable struct ServerSideSession - socket + sockets::Vector display_properties::Dict in_module::Module end -Base.isopen(session::ServerSideSession) = isopen(session.socket) -Base.close(session::ServerSideSession) = close(session.socket) +Base.isopen(session::ServerSideSession) = any(isopen.(session.sockets)) +function closeNdelete!(session::ServerSideSession, socket) + close(socket) + filter!(!=(socket), session.sockets) +end function send_header(io, ser_version=Serialization.ser_version) write(io, PROTOCOL_MAGIC, PROTOCOL_VERSION) @@ -177,8 +175,7 @@ function serialize_responses(socket, response_chan) end # Serve a remote REPL session to a single client -function serve_repl_session(session) - socket = session.socket +function serve_repl_session(session, socket) send_header(socket) @sync begin request_chan = Channel(1) @@ -248,27 +245,35 @@ end serve_repl(port::Integer; kws...) = serve_repl(Sockets.localhost, port; kws...) function serve_repl(server::Base.IOServer; on_client_connect=nothing) - open_sessions = Set{ServerSideSession}() + open_sessions = Dict{UUID, ServerSideSession}() @sync try while isopen(server) socket = accept(server) - session = ServerSideSession(socket, Dict(), Main) - push!(open_sessions, session) + # expect session id + session_id = deserialize(socket) + @show session_id + session = if haskey(open_sessions, session_id) + push!(open_sessions[session_id].sockets, socket) + open_sessions[session_id] + else + open_sessions[session_id] = ServerSideSession([socket], Dict(), Main) + end + socketidx = length(session.sockets) peer = getpeername(socket) @async try if !isnothing(on_client_connect) on_client_connect(session) end - serve_repl_session(session) + serve_repl_session(session, socket) catch exc - if !(exc isa EOFError && !isopen(session)) + if !(exc isa EOFError && !isopen(socket)) @warn "Something went wrong evaluating client command" #= =# exception=exc,catch_backtrace() end finally @info "REPL client exited" peer - close(session) - pop!(open_sessions, session) + closeNdelete!(session, socket) + length(session.sockets) == 0 && delete!(open_sessions, session_id) end @info "REPL client opened a connection" peer end @@ -280,8 +285,8 @@ function serve_repl(server::Base.IOServer; on_client_connect=nothing) @error "Unexpected server failure" isopen(server) exception=exc,catch_backtrace() rethrow() finally - for session in open_sessions - close(session) + for session in values(open_sessions) + close.(session.sockets) end end end diff --git a/src/tunnels.jl b/src/tunnels.jl index c9b3d85..263d8b1 100644 --- a/src/tunnels.jl +++ b/src/tunnels.jl @@ -1,7 +1,5 @@ # Utilities for securely tunnelling traffic from client to a remote server -using OpenSSH_jll - # Find a free port on `network_interface` function find_free_port(network_interface) # listen on port 0 => kernel chooses a free port. See, for example,