Skip to content

Commit

Permalink
Implement sub-shells
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Dec 21, 2022
1 parent 40c5567 commit 84596c5
Show file tree
Hide file tree
Showing 9 changed files with 437 additions and 402 deletions.
10 changes: 5 additions & 5 deletions ipykernel/control.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import asyncio
from threading import Thread

from tornado.ioloop import IOLoop


class ControlThread(Thread):
def __init__(self, **kwargs):
Thread.__init__(self, name="Control", **kwargs)
self.io_loop = IOLoop(make_current=False)
self.io_loop = asyncio.new_event_loop()
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True

def run(self):
self.name = "Control"
asyncio.set_event_loop(self.io_loop)
try:
self.io_loop.start()
self.io_loop.run_forever()
finally:
self.io_loop.close()

Expand All @@ -22,4 +22,4 @@ def stop(self):
This method is threadsafe.
"""
self.io_loop.add_callback(self.io_loop.stop)
self.io_loop.call_soon_threadsafe(self.io_loop.stop)
32 changes: 16 additions & 16 deletions ipykernel/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ async def get_message(self):


class DebugpyClient:
def __init__(self, log, debugpy_stream, event_callback):
def __init__(self, log, debugpy_socket, event_callback):
self.log = log
self.debugpy_stream = debugpy_stream
self.debugpy_socket = debugpy_socket
self.event_callback = event_callback
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
self.debugpy_host = "127.0.0.1"
Expand All @@ -192,9 +192,9 @@ def _forward_event(self, msg):
self.init_event_seq = msg["seq"]
self.event_callback(msg)

def _send_request(self, msg):
async def _send_request(self, msg):
if self.routing_id is None:
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
content = jsonapi.dumps(
msg,
default=json_default,
Expand All @@ -209,7 +209,7 @@ def _send_request(self, msg):
self.log.debug("DEBUGPYCLIENT:")
self.log.debug(self.routing_id)
self.log.debug(buf)
self.debugpy_stream.send_multipart((self.routing_id, buf))
await self.debugpy_socket.send_multipart((self.routing_id, buf))

async def _wait_for_response(self):
# Since events are never pushed to the message_queue
Expand All @@ -227,7 +227,7 @@ async def _handle_init_sequence(self):
"seq": int(self.init_event_seq) + 1,
"command": "configurationDone",
}
self._send_request(configurationDone)
await self._send_request(configurationDone)

# 3] Waits for configurationDone response
await self._wait_for_response()
Expand All @@ -238,7 +238,7 @@ async def _handle_init_sequence(self):

def get_host_port(self):
if self.debugpy_port == -1:
socket = self.debugpy_stream.socket
socket = self.debugpy_socket
socket.bind_to_random_port("tcp://" + self.debugpy_host)
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
socket.unbind(self.endpoint)
Expand All @@ -247,11 +247,11 @@ def get_host_port(self):
return self.debugpy_host, self.debugpy_port

def connect_tcp_socket(self):
self.debugpy_stream.socket.connect(self._get_endpoint())
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
self.debugpy_socket.connect(self._get_endpoint())
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)

def disconnect_tcp_socket(self):
self.debugpy_stream.socket.disconnect(self._get_endpoint())
self.debugpy_socket.disconnect(self._get_endpoint())
self.routing_id = None
self.init_event = Event()
self.init_event_seq = -1
Expand All @@ -261,7 +261,7 @@ def receive_dap_frame(self, frame):
self.message_queue.put_tcp_frame(frame)

async def send_dap_request(self, msg):
self._send_request(msg)
await self._send_request(msg)
if self.wait_for_attach and msg["command"] == "attach":
rep = await self._handle_init_sequence()
self.wait_for_attach = False
Expand Down Expand Up @@ -290,10 +290,10 @@ class Debugger:
static_debug_msg_types = ["debugInfo", "inspectVariables", "richInspectVariables", "modules"]

def __init__(
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
):
self.log = log
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
self.shell_socket = shell_socket
self.session = session
self.is_started = False
Expand Down Expand Up @@ -375,7 +375,7 @@ async def handle_stopped_event(self):
def tcp_client(self):
return self.debugpy_client

def start(self):
async def start(self):
if not self.debugpy_initialized:
tmp_dir = get_tmp_directory()
if not os.path.exists(tmp_dir):
Expand All @@ -384,7 +384,7 @@ def start(self):
code = "import debugpy;"
code += 'debugpy.listen(("' + host + '",' + port + "))"
content = {"code": code, "silent": True}
self.session.send(
await self.session.async_send(
self.shell_socket,
"execute_request",
content,
Expand Down Expand Up @@ -639,7 +639,7 @@ async def process_request(self, message):
if self.is_started:
self.log.info("The debugger has already started")
else:
self.is_started = self.start()
self.is_started = await self.start()
if self.is_started:
self.log.info("The debugger has started")
else:
Expand Down
68 changes: 38 additions & 30 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import atexit
import io
import os
Expand All @@ -18,8 +19,6 @@

import zmq
from jupyter_client.session import extract_header
from tornado.ioloop import IOLoop
from zmq.eventloop.zmqstream import ZMQStream

# -----------------------------------------------------------------------------
# Globals
Expand Down Expand Up @@ -57,7 +56,7 @@ def __init__(self, socket, pipe=False):
self.background_socket = BackgroundSocket(self)
self._master_pid = os.getpid()
self._pipe_flag = pipe
self.io_loop = IOLoop(make_current=False)
self.io_loop = asyncio.new_event_loop()
if pipe:
self._setup_pipe_in()
self._local = threading.local()
Expand All @@ -72,20 +71,27 @@ def __init__(self, socket, pipe=False):

def _thread_main(self):
"""The inner loop that's actually run in a thread"""
self.io_loop.start()
self.io_loop.close(all_fds=True)
asyncio.set_event_loop(self.io_loop)
try:
self.io_loop.run_until_complete(self._process_messages())
finally:
self.io_loop.close()

async def _process_messages(self):
asyncio.create_task(self._handle_event())
if self._pipe_flag:
asyncio.create_task(self._handle_pipe_msg())
await asyncio.Future() # wait forever

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
ctx = self.socket.context
pipe_in = ctx.socket(zmq.PULL)
pipe_in.linger = 0
self._pipe_in0 = ctx.socket(zmq.PULL)
self._pipe_in0.linger = 0

_uuid = b2a_hex(os.urandom(16)).decode("ascii")
iface = self._event_interface = "inproc://%s" % _uuid
pipe_in.bind(iface)
self._event_puller = ZMQStream(pipe_in, self.io_loop)
self._event_puller.on_recv(self._handle_event)
self._pipe_in0.bind(iface)

@property
def _event_pipe(self):
Expand All @@ -104,20 +110,22 @@ def _event_pipe(self):
self._event_pipes.add(event_pipe)
return event_pipe

def _handle_event(self, msg):
async def _handle_event(self):
"""Handle an event on the event pipe
Content of the message is ignored.
Whenever *an* event arrives on the event stream,
*all* waiting events are processed in order.
"""
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
for _ in range(n_events):
event_f = self._events.popleft()
event_f()
while True:
await self._pipe_in0.recv_multipart()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
for _ in range(n_events):
event_f = self._events.popleft()
event_f()

def _setup_pipe_in(self):
"""setup listening pipe for IOPub from forked subprocesses"""
Expand All @@ -126,30 +134,30 @@ def _setup_pipe_in(self):
# use UUID to authenticate pipe messages
self._pipe_uuid = os.urandom(16)

pipe_in = ctx.socket(zmq.PULL)
pipe_in.linger = 0
self._pipe_in1 = ctx.socket(zmq.PULL)
self._pipe_in1.linger = 0

try:
self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1")
self._pipe_port = self._pipe_in1.bind_to_random_port("tcp://127.0.0.1")
except zmq.ZMQError as e:
warnings.warn(
"Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e
+ "\nsubprocess output will be unavailable."
)
self._pipe_flag = False
pipe_in.close()
self._pipe_in1.close()
return
self._pipe_in = ZMQStream(pipe_in, self.io_loop)
self._pipe_in.on_recv(self._handle_pipe_msg)

def _handle_pipe_msg(self, msg):
async def _handle_pipe_msg(self):
"""handle a pipe message from a subprocess"""
if not self._pipe_flag or not self._is_master_process():
return
if msg[0] != self._pipe_uuid:
print("Bad pipe message: %s", msg, file=sys.__stderr__)
return
self.send_multipart(msg[1:])
while True:
msg = await self._pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_master_process():
return
if msg[0] != self._pipe_uuid:
print("Bad pipe message: %s", msg, file=sys.__stderr__)
return
self.send_multipart(msg[1:])

def _setup_pipe_out(self):
# must be new context after fork
Expand Down
40 changes: 24 additions & 16 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import asyncio
import builtins
import getpass
import signal
import sys
import threading
import typing as t
from contextlib import contextmanager
from functools import partial

import comm
import zmq.asyncio
from IPython.core import release
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
from zmq.eventloop.zmqstream import ZMQStream

from .comm.comm import BaseComm
from .comm.manager import CommManager
Expand All @@ -39,6 +38,11 @@
_use_experimental_60_completion = False


def DEBUG(msg):
with open("debug.log", "a") as f:
f.write(f"{msg}\n")


_EXPERIMENTAL_KEY_NAME = "_jupyter_types_experimental"


Expand Down Expand Up @@ -75,7 +79,9 @@ class IPythonKernel(KernelBase):
help="Set this flag to False to deactivate the use of experimental IPython completion APIs.",
).tag(config=True)

debugpy_stream = Instance(ZMQStream, allow_none=True) if _is_debugpy_available else None
debugpy_socket = (
Instance(zmq.asyncio.Socket, allow_none=True) if _is_debugpy_available else None
)

user_module = Any()

Expand Down Expand Up @@ -106,7 +112,7 @@ def __init__(self, **kwargs):
if _is_debugpy_available:
self.debugger = Debugger(
self.log,
self.debugpy_stream,
self.debugpy_socket,
self._publish_debug_event,
self.debug_shell_socket,
self.session,
Expand Down Expand Up @@ -188,8 +194,9 @@ def __init__(self, **kwargs):
"file_extension": ".py",
}

def dispatch_debugpy(self, msg):
if _is_debugpy_available:
async def dispatch_debugpy(self):
while True:
msg = await self.debugpy_socket.recv_multipart()
# The first frame is the socket id, we can drop it
frame = msg[1].bytes.decode("utf-8")
self.log.debug("Debugpy received: %s", frame)
Expand All @@ -205,15 +212,13 @@ async def poll_stopped_queue(self):

def start(self):
self.shell.exit_now = False
if self.debugpy_stream is None:
self.log.warning("debugpy_stream undefined, debugging will not be enabled")
else:
self.debugpy_stream.on_recv(self.dispatch_debugpy, copy=False)
if self.debugpy_socket is None:
self.log.warning("debugpy_socket undefined, debugging will not be enabled")
elif _is_debugpy_available:
asyncio.create_task(self.dispatch_debugpy())
super().start()
if self.debugpy_stream:
asyncio.run_coroutine_threadsafe(
self.poll_stopped_queue(), self.control_thread.io_loop.asyncio_loop
)
if self.debugpy_socket:
asyncio.run_coroutine_threadsafe(self.poll_stopped_queue(), self.control_thread.io_loop)

def set_parent(self, ident, parent, channel="shell"):
"""Overridden from parent to tell the display hook and output streams
Expand Down Expand Up @@ -309,6 +314,8 @@ def cancel_unless_done(f, _ignored):
future.add_done_callback(partial(cancel_unless_done, sigint_future))

def handle_sigint(*args):
DEBUG("handle_sigint")

def set_sigint_result():
if sigint_future.cancelled() or sigint_future.done():
return
Expand All @@ -318,12 +325,13 @@ def set_sigint_result():
self.io_loop.add_callback(set_sigint_result)

# set the custom sigint hander during this context
save_sigint = signal.signal(signal.SIGINT, handle_sigint)
# save_sigint = signal.signal(signal.SIGINT, handle_sigint)
try:
yield
finally:
# restore the previous sigint handler
signal.signal(signal.SIGINT, save_sigint)
# signal.signal(signal.SIGINT, save_sigint)
pass

async def do_execute(
self,
Expand Down
Loading

0 comments on commit 84596c5

Please sign in to comment.