n8n/packages/@n8n/task-runner-python/src/task_executor.py

506 lines
15 KiB
Python

import multiprocessing
import traceback
import textwrap
import json
import io
import os
import sys
import logging
from typing import cast
from src.errors import (
TaskCancelledError,
TaskKilledError,
TaskResultMissingError,
TaskResultReadError,
TaskRuntimeError,
TaskTimeoutError,
TaskSubprocessFailedError,
SecurityViolationError,
)
from src.import_validation import validate_module_import
from src.config.security_config import SecurityConfig
from src.message_types.broker import NodeMode, Items, Query
from src.message_types.pipe import (
PipeResultMessage,
PipeErrorMessage,
TaskErrorInfo,
PrintArgs,
)
from src.pipe_reader import PipeReader
from src.constants import (
EXECUTOR_CIRCULAR_REFERENCE_KEY,
EXECUTOR_USER_OUTPUT_KEY,
EXECUTOR_ALL_ITEMS_FILENAME,
EXECUTOR_PER_ITEM_FILENAME,
SIGTERM_EXIT_CODE,
SIGKILL_EXIT_CODE,
PIPE_MSG_PREFIX_LENGTH,
)
from multiprocessing.context import ForkServerProcess
from multiprocessing.connection import Connection
logger = logging.getLogger(__name__)
MULTIPROCESSING_CONTEXT = multiprocessing.get_context("forkserver")
MAX_PRINT_ARGS_ALLOWED = 100
type PipeConnection = Connection
class TaskExecutor:
"""Responsible for executing Python code tasks in isolated subprocesses."""
@staticmethod
def create_process(
code: str,
node_mode: NodeMode,
items: Items,
security_config: SecurityConfig,
query: Query = None,
) -> tuple[ForkServerProcess, PipeConnection, PipeConnection]:
"""Create a subprocess for executing a Python code task and a pipe for communication."""
fn = (
TaskExecutor._all_items
if node_mode == "all_items"
else TaskExecutor._per_item
)
# thread in runner process reads, subprocess writes
read_conn, write_conn = MULTIPROCESSING_CONTEXT.Pipe(duplex=False)
process = MULTIPROCESSING_CONTEXT.Process(
target=fn,
args=(
code,
items,
write_conn,
security_config,
query,
),
)
return process, read_conn, write_conn
@staticmethod
def execute_process(
process: ForkServerProcess,
read_conn: PipeConnection,
write_conn: PipeConnection,
task_timeout: int,
continue_on_fail: bool,
) -> tuple[Items, PrintArgs, int]:
"""Execute a subprocess for a Python code task."""
print_args: PrintArgs = []
pipe_reader = PipeReader(read_conn.fileno(), read_conn)
pipe_reader.start()
try:
try:
process.start()
except Exception as e:
raise TaskSubprocessFailedError(-1, e)
finally:
write_conn.close()
process.join(timeout=task_timeout)
if process.is_alive():
TaskExecutor.stop_process(process)
raise TaskTimeoutError(task_timeout)
if process.exitcode == SIGTERM_EXIT_CODE:
raise TaskCancelledError()
if process.exitcode == SIGKILL_EXIT_CODE:
raise TaskKilledError()
if process.exitcode != 0:
assert process.exitcode is not None
raise TaskSubprocessFailedError(process.exitcode)
pipe_reader.join(timeout=task_timeout)
if pipe_reader.is_alive():
try:
read_conn.close()
except Exception:
pass
raise TaskResultReadError(
TimeoutError(f"Pipe reader timed out after {task_timeout}s")
)
if pipe_reader.error:
raise TaskResultReadError(pipe_reader.error)
if pipe_reader.pipe_message is None:
raise TaskResultMissingError()
returned = pipe_reader.pipe_message
if "error" in returned:
error_msg = cast(PipeErrorMessage, returned)
raise TaskRuntimeError(error_msg["error"])
if "result" not in returned:
raise TaskResultMissingError()
result_msg = cast(PipeResultMessage, returned)
result = result_msg["result"]
print_args = result_msg.get("print_args", [])
assert pipe_reader.message_size is not None
result_size_bytes = pipe_reader.message_size
return result, print_args, result_size_bytes
except Exception as e:
if continue_on_fail:
return [{"json": {"error": str(e)}}], print_args, 0
raise
@staticmethod
def stop_process(process: ForkServerProcess | None):
"""Stop a running subprocess, gracefully else force-killing."""
if process is None or not process.is_alive():
return
try:
process.terminate()
process.join(timeout=1) # 1s grace period
if process.is_alive():
process.kill()
process.join()
except (ProcessLookupError, ConnectionError, BrokenPipeError):
# subprocess is dead or unreachable
pass
@staticmethod
def _all_items(
raw_code: str,
items: Items,
write_conn,
security_config: SecurityConfig,
query: Query = None,
):
"""Execute a Python code task in all-items mode."""
if security_config.runner_env_deny:
os.environ.clear()
TaskExecutor._sanitize_sys_modules(security_config)
print_args: PrintArgs = []
sys.stderr = stderr_capture = io.StringIO()
try:
wrapped_code = TaskExecutor._wrap_code(raw_code)
compiled_code = compile(wrapped_code, EXECUTOR_ALL_ITEMS_FILENAME, "exec")
globals = {
"__builtins__": TaskExecutor._filter_builtins(security_config),
"_items": items,
"_query": query,
"print": TaskExecutor._create_custom_print(print_args),
}
exec(compiled_code, globals)
result = cast(Items, globals[EXECUTOR_USER_OUTPUT_KEY])
TaskExecutor._put_result(write_conn.fileno(), result, print_args)
except BaseException as e:
TaskExecutor._put_error(
write_conn.fileno(), e, stderr_capture.getvalue(), print_args
)
@staticmethod
def _per_item(
raw_code: str,
items: Items,
write_conn,
security_config: SecurityConfig,
_query: Query = None, # unused, only to keep signatures consistent across modes
):
"""Execute a Python code task in per-item mode."""
if security_config.runner_env_deny:
os.environ.clear()
TaskExecutor._sanitize_sys_modules(security_config)
print_args: PrintArgs = []
sys.stderr = stderr_capture = io.StringIO()
try:
wrapped_code = TaskExecutor._wrap_code(raw_code)
compiled_code = compile(wrapped_code, EXECUTOR_PER_ITEM_FILENAME, "exec")
filtered_builtins = TaskExecutor._filter_builtins(security_config)
custom_print = TaskExecutor._create_custom_print(print_args)
result: Items = []
for index, item in enumerate(items):
globals = {
"__builtins__": filtered_builtins,
"_item": item,
"print": custom_print,
}
exec(compiled_code, globals)
user_output = globals[EXECUTOR_USER_OUTPUT_KEY]
if user_output is None:
continue
json_data = TaskExecutor._extract_json_data_per_item(user_output)
output_item = {"json": json_data, "pairedItem": {"item": index}}
if isinstance(user_output, dict) and "binary" in user_output:
output_item["binary"] = user_output["binary"]
result.append(output_item)
TaskExecutor._put_result(write_conn.fileno(), result, print_args)
except BaseException as e:
TaskExecutor._put_error(
write_conn.fileno(), e, stderr_capture.getvalue(), print_args
)
@staticmethod
def _wrap_code(raw_code: str) -> str:
indented_code = textwrap.indent(raw_code, " ")
return f"def _user_function():\n{indented_code}\n\n{EXECUTOR_USER_OUTPUT_KEY} = _user_function()"
@staticmethod
def _extract_json_data_per_item(user_output):
if not isinstance(user_output, dict):
return user_output
if "json" in user_output:
return user_output["json"]
if "binary" in user_output:
return {k: v for k, v in user_output.items() if k != "binary"}
return user_output
@staticmethod
def _put_result(write_fd: int, result: Items, print_args: PrintArgs):
message: PipeResultMessage = {
"result": result,
"print_args": TaskExecutor._truncate_print_args(print_args),
}
data = json.dumps(message, default=str, ensure_ascii=False).encode("utf-8")
length_bytes = len(data).to_bytes(PIPE_MSG_PREFIX_LENGTH, "big")
try:
TaskExecutor._write_bytes(write_fd, length_bytes)
TaskExecutor._write_bytes(write_fd, data)
finally:
try:
os.close(write_fd)
except Exception:
pass
@staticmethod
def _put_error(
write_fd: int,
e: BaseException,
stderr: str = "",
print_args: PrintArgs | None = None,
):
if print_args is None:
print_args = []
task_error_info: TaskErrorInfo = {
"message": f"Process exited with code {e.code}"
if isinstance(e, SystemExit)
else str(e),
"description": getattr(e, "description", ""),
"stack": traceback.format_exc(),
"stderr": stderr,
}
message: PipeErrorMessage = {
"error": task_error_info,
"print_args": TaskExecutor._truncate_print_args(print_args),
}
data = json.dumps(message, default=str, ensure_ascii=False).encode("utf-8")
length_bytes = len(data).to_bytes(PIPE_MSG_PREFIX_LENGTH, "big")
try:
TaskExecutor._write_bytes(write_fd, length_bytes)
TaskExecutor._write_bytes(write_fd, data)
finally:
try:
os.close(write_fd)
except Exception:
pass
# ========== print() ==========
@staticmethod
def _create_custom_print(print_args: PrintArgs):
def custom_print(*args):
serializable_args = []
for arg in args:
try:
json.dumps(arg, default=str, ensure_ascii=False)
serializable_args.append(arg)
except Exception as _:
# Ensure args are serializable so they are transmissible
# through the multiprocessing queue and via websockets.
serializable_args.append(
{
EXECUTOR_CIRCULAR_REFERENCE_KEY: repr(arg),
"__type__": type(arg).__name__,
}
)
formatted = TaskExecutor._format_print_args(*serializable_args)
print_args.append(formatted)
print("[user code]", *args)
return custom_print
@staticmethod
def _format_print_args(*args) -> list[str]:
"""
Takes the args passed to a `print()` call in user code and converts them
to string representations suitable for display in a browser console.
Expects all args to be serializable.
"""
formatted = []
for arg in args:
if isinstance(arg, str):
formatted.append(f"'{arg}'")
elif arg is None or isinstance(arg, (int, float, bool)):
formatted.append(str(arg))
elif isinstance(arg, dict) and EXECUTOR_CIRCULAR_REFERENCE_KEY in arg:
formatted.append(f"[Circular {arg.get('__type__', 'Object')}]")
else:
formatted.append(json.dumps(arg, default=str, ensure_ascii=False))
return formatted
@staticmethod
def _truncate_print_args(print_args: PrintArgs) -> PrintArgs:
"""Truncate print_args to prevent pipe buffer overflow."""
if not print_args or len(print_args) <= MAX_PRINT_ARGS_ALLOWED:
return print_args
truncated = print_args[:MAX_PRINT_ARGS_ALLOWED]
truncated.append(
[
f"[Output truncated - {len(print_args) - MAX_PRINT_ARGS_ALLOWED} more print statements]"
]
)
return truncated
# ========== security ==========
@staticmethod
def _filter_builtins(security_config: SecurityConfig):
"""Get __builtins__ with denied ones removed."""
if len(security_config.builtins_deny) == 0:
filtered = dict(__builtins__)
else:
filtered = {
k: v
for k, v in __builtins__.items()
if k not in security_config.builtins_deny
}
filtered["__import__"] = TaskExecutor._create_safe_import(security_config)
return filtered
@staticmethod
def _sanitize_sys_modules(security_config: SecurityConfig):
safe_modules = {
"builtins",
"__main__",
"sys",
"traceback",
"linecache",
"importlib",
"importlib.machinery",
}
if "*" in security_config.stdlib_allow:
safe_modules.update(sys.stdlib_module_names)
else:
safe_modules.update(security_config.stdlib_allow)
if "*" in security_config.external_allow:
safe_modules.update(
name
for name in sys.modules.keys()
if name not in sys.stdlib_module_names
)
else:
safe_modules.update(security_config.external_allow)
# keep modules marked as safe and submodules of those
safe_prefixes = [safe + "." for safe in safe_modules]
modules_to_remove = [
name
for name in sys.modules.keys()
if name not in safe_modules
and not any(name.startswith(prefix) for prefix in safe_prefixes)
]
for module_name in modules_to_remove:
del sys.modules[module_name]
@staticmethod
def _create_safe_import(security_config: SecurityConfig):
original_import = __builtins__["__import__"]
def safe_import(name, *args, **kwargs):
is_allowed, error_msg = validate_module_import(name, security_config)
if not is_allowed:
assert error_msg is not None
raise SecurityViolationError(
message="Security violation detected",
description=error_msg,
)
return original_import(name, *args, **kwargs)
return safe_import
# ========== pipe I/O ==========
@staticmethod
def _write_bytes(fd: int, data: bytes):
total_written = 0
while total_written < len(data):
written = os.write(fd, data[total_written:])
if written == 0:
raise OSError("Write failed")
total_written += written