mirror of
https://github.com/n8n-io/n8n.git
synced 2026-05-31 00:37:10 +02:00
548 lines
16 KiB
Python
548 lines
16 KiB
Python
import multiprocessing
|
|
import traceback
|
|
import textwrap
|
|
import json
|
|
import io
|
|
import os
|
|
import sys
|
|
import logging
|
|
from typing import cast
|
|
|
|
from src.errors import (
|
|
TaskCancelledError,
|
|
TaskKilledError,
|
|
TaskResultMissingError,
|
|
TaskResultReadError,
|
|
TaskRuntimeError,
|
|
TaskTimeoutError,
|
|
TaskSubprocessFailedError,
|
|
SecurityViolationError,
|
|
)
|
|
from src.import_validation import validate_module_import
|
|
from src.config.security_config import SecurityConfig
|
|
|
|
from src.message_types.broker import NodeMode, Items, Query
|
|
from src.message_types.pipe import (
|
|
PipeResultMessage,
|
|
PipeErrorMessage,
|
|
TaskErrorInfo,
|
|
PrintArgs,
|
|
)
|
|
from src.pipe_reader import PipeReader
|
|
from src.constants import (
|
|
EXECUTOR_CIRCULAR_REFERENCE_KEY,
|
|
EXECUTOR_USER_OUTPUT_KEY,
|
|
EXECUTOR_ALL_ITEMS_FILENAME,
|
|
EXECUTOR_PER_ITEM_FILENAME,
|
|
SIGTERM_EXIT_CODE,
|
|
SIGKILL_EXIT_CODE,
|
|
PIPE_MSG_PREFIX_LENGTH,
|
|
)
|
|
|
|
from multiprocessing.context import ForkServerProcess
|
|
from multiprocessing.connection import Connection
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MULTIPROCESSING_CONTEXT = multiprocessing.get_context("forkserver")
|
|
MAX_PRINT_ARGS_ALLOWED = 100
|
|
|
|
type PipeConnection = Connection
|
|
|
|
|
|
class TaskExecutor:
|
|
"""Responsible for executing Python code tasks in isolated subprocesses."""
|
|
|
|
@staticmethod
|
|
def create_process(
|
|
code: str,
|
|
node_mode: NodeMode,
|
|
items: Items,
|
|
security_config: SecurityConfig,
|
|
query: Query = None,
|
|
) -> tuple[ForkServerProcess, PipeConnection, PipeConnection]:
|
|
"""Create a subprocess for executing a Python code task and a pipe for communication."""
|
|
|
|
fn = (
|
|
TaskExecutor._all_items
|
|
if node_mode == "all_items"
|
|
else TaskExecutor._per_item
|
|
)
|
|
|
|
# thread in runner process reads, subprocess writes
|
|
read_conn, write_conn = MULTIPROCESSING_CONTEXT.Pipe(duplex=False)
|
|
|
|
process = MULTIPROCESSING_CONTEXT.Process(
|
|
target=fn,
|
|
args=(
|
|
code,
|
|
items,
|
|
write_conn,
|
|
security_config,
|
|
query,
|
|
),
|
|
)
|
|
|
|
return process, read_conn, write_conn
|
|
|
|
@staticmethod
|
|
def execute_process(
|
|
process: ForkServerProcess,
|
|
read_conn: PipeConnection,
|
|
write_conn: PipeConnection,
|
|
task_timeout: int,
|
|
continue_on_fail: bool,
|
|
) -> tuple[Items, PrintArgs, int]:
|
|
"""Execute a subprocess for a Python code task."""
|
|
|
|
print_args: PrintArgs = []
|
|
|
|
pipe_reader = PipeReader(read_conn.fileno(), read_conn)
|
|
pipe_reader.start()
|
|
|
|
try:
|
|
try:
|
|
process.start()
|
|
except Exception as e:
|
|
raise TaskSubprocessFailedError(-1, e)
|
|
finally:
|
|
write_conn.close()
|
|
|
|
process.join(timeout=task_timeout)
|
|
|
|
if process.is_alive():
|
|
TaskExecutor.stop_process(process)
|
|
raise TaskTimeoutError(task_timeout)
|
|
|
|
if process.exitcode == SIGTERM_EXIT_CODE:
|
|
raise TaskCancelledError()
|
|
|
|
if process.exitcode == SIGKILL_EXIT_CODE:
|
|
raise TaskKilledError()
|
|
|
|
if process.exitcode != 0:
|
|
assert process.exitcode is not None
|
|
raise TaskSubprocessFailedError(process.exitcode)
|
|
|
|
pipe_reader.join(timeout=task_timeout)
|
|
|
|
if pipe_reader.is_alive():
|
|
try:
|
|
read_conn.close()
|
|
except Exception:
|
|
pass
|
|
raise TaskResultReadError(
|
|
TimeoutError(f"Pipe reader timed out after {task_timeout}s")
|
|
)
|
|
|
|
if pipe_reader.error:
|
|
raise TaskResultReadError(pipe_reader.error)
|
|
|
|
if pipe_reader.pipe_message is None:
|
|
raise TaskResultMissingError()
|
|
|
|
returned = pipe_reader.pipe_message
|
|
|
|
if "error" in returned:
|
|
error_msg = cast(PipeErrorMessage, returned)
|
|
raise TaskRuntimeError(error_msg["error"])
|
|
|
|
if "result" not in returned:
|
|
raise TaskResultMissingError()
|
|
|
|
result_msg = cast(PipeResultMessage, returned)
|
|
result = result_msg["result"]
|
|
print_args = result_msg.get("print_args", [])
|
|
assert pipe_reader.message_size is not None
|
|
result_size_bytes = pipe_reader.message_size
|
|
|
|
return result, print_args, result_size_bytes
|
|
|
|
except Exception as e:
|
|
if continue_on_fail:
|
|
return [{"json": {"error": str(e)}}], print_args, 0
|
|
raise
|
|
|
|
@staticmethod
|
|
def stop_process(process: ForkServerProcess | None):
|
|
"""Stop a running subprocess, gracefully else force-killing."""
|
|
|
|
if process is None or not process.is_alive():
|
|
return
|
|
|
|
try:
|
|
process.terminate()
|
|
process.join(timeout=1) # 1s grace period
|
|
|
|
if process.is_alive():
|
|
process.kill()
|
|
process.join()
|
|
except (ProcessLookupError, ConnectionError, BrokenPipeError):
|
|
# subprocess is dead or unreachable
|
|
pass
|
|
|
|
@staticmethod
|
|
def _all_items(
|
|
raw_code: str,
|
|
items: Items,
|
|
write_conn,
|
|
security_config: SecurityConfig,
|
|
query: Query = None,
|
|
):
|
|
"""Execute a Python code task in all-items mode."""
|
|
|
|
if security_config.runner_env_deny:
|
|
os.environ.clear()
|
|
|
|
TaskExecutor._sanitize_sys_modules(security_config)
|
|
|
|
print_args: PrintArgs = []
|
|
sys.stderr = stderr_capture = io.StringIO()
|
|
|
|
try:
|
|
wrapped_code = TaskExecutor._wrap_code(raw_code)
|
|
compiled_code = compile(wrapped_code, EXECUTOR_ALL_ITEMS_FILENAME, "exec")
|
|
|
|
globals = {
|
|
"__builtins__": TaskExecutor._filter_builtins(security_config),
|
|
"_items": items,
|
|
"_query": query,
|
|
"print": TaskExecutor._create_custom_print(print_args),
|
|
}
|
|
|
|
exec(compiled_code, globals)
|
|
|
|
result = cast(Items, globals[EXECUTOR_USER_OUTPUT_KEY])
|
|
TaskExecutor._put_result(write_conn.fileno(), result, print_args)
|
|
|
|
except BaseException as e:
|
|
TaskExecutor._put_error(
|
|
write_conn.fileno(), e, stderr_capture.getvalue(), print_args
|
|
)
|
|
|
|
@staticmethod
|
|
def _per_item(
|
|
raw_code: str,
|
|
items: Items,
|
|
write_conn,
|
|
security_config: SecurityConfig,
|
|
_query: Query = None, # unused, only to keep signatures consistent across modes
|
|
):
|
|
"""Execute a Python code task in per-item mode."""
|
|
|
|
if security_config.runner_env_deny:
|
|
os.environ.clear()
|
|
|
|
TaskExecutor._sanitize_sys_modules(security_config)
|
|
|
|
print_args: PrintArgs = []
|
|
sys.stderr = stderr_capture = io.StringIO()
|
|
|
|
try:
|
|
wrapped_code = TaskExecutor._wrap_code(raw_code)
|
|
compiled_code = compile(wrapped_code, EXECUTOR_PER_ITEM_FILENAME, "exec")
|
|
|
|
filtered_builtins = TaskExecutor._filter_builtins(security_config)
|
|
custom_print = TaskExecutor._create_custom_print(print_args)
|
|
|
|
result: Items = []
|
|
for index, item in enumerate(items):
|
|
globals = {
|
|
"__builtins__": filtered_builtins,
|
|
"_item": item,
|
|
"print": custom_print,
|
|
}
|
|
|
|
exec(compiled_code, globals)
|
|
|
|
user_output = globals[EXECUTOR_USER_OUTPUT_KEY]
|
|
|
|
if user_output is None:
|
|
continue
|
|
|
|
json_data = TaskExecutor._extract_json_data_per_item(user_output)
|
|
|
|
output_item = {"json": json_data, "pairedItem": {"item": index}}
|
|
|
|
if isinstance(user_output, dict) and "binary" in user_output:
|
|
output_item["binary"] = user_output["binary"]
|
|
|
|
result.append(output_item)
|
|
|
|
TaskExecutor._put_result(write_conn.fileno(), result, print_args)
|
|
|
|
except BaseException as e:
|
|
TaskExecutor._put_error(
|
|
write_conn.fileno(), e, stderr_capture.getvalue(), print_args
|
|
)
|
|
|
|
@staticmethod
|
|
def _wrap_code(raw_code: str) -> str:
|
|
indented_code = textwrap.indent(raw_code, " ")
|
|
return f"def _user_function():\n{indented_code}\n\n{EXECUTOR_USER_OUTPUT_KEY} = _user_function()"
|
|
|
|
@staticmethod
|
|
def _extract_json_data_per_item(user_output):
|
|
if not isinstance(user_output, dict):
|
|
return user_output
|
|
|
|
if "json" in user_output:
|
|
return user_output["json"]
|
|
|
|
if "binary" in user_output:
|
|
return {k: v for k, v in user_output.items() if k != "binary"}
|
|
|
|
return user_output
|
|
|
|
@staticmethod
|
|
def _put_result(write_fd: int, result: Items, print_args: PrintArgs):
|
|
message: PipeResultMessage = {
|
|
"result": result,
|
|
"print_args": TaskExecutor._truncate_print_args(print_args),
|
|
}
|
|
|
|
data = json.dumps(message, default=str, ensure_ascii=False).encode("utf-8")
|
|
length_bytes = len(data).to_bytes(PIPE_MSG_PREFIX_LENGTH, "big")
|
|
|
|
try:
|
|
TaskExecutor._write_bytes(write_fd, length_bytes)
|
|
TaskExecutor._write_bytes(write_fd, data)
|
|
finally:
|
|
try:
|
|
os.close(write_fd)
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _put_error(
|
|
write_fd: int,
|
|
e: BaseException,
|
|
stderr: str = "",
|
|
print_args: PrintArgs | None = None,
|
|
):
|
|
if print_args is None:
|
|
print_args = []
|
|
|
|
task_error_info: TaskErrorInfo = {
|
|
"message": f"Process exited with code {e.code}"
|
|
if isinstance(e, SystemExit)
|
|
else str(e),
|
|
"description": getattr(e, "description", ""),
|
|
"stack": traceback.format_exc(),
|
|
"stderr": stderr,
|
|
}
|
|
|
|
message: PipeErrorMessage = {
|
|
"error": task_error_info,
|
|
"print_args": TaskExecutor._truncate_print_args(print_args),
|
|
}
|
|
|
|
data = json.dumps(message, default=str, ensure_ascii=False).encode("utf-8")
|
|
length_bytes = len(data).to_bytes(PIPE_MSG_PREFIX_LENGTH, "big")
|
|
|
|
try:
|
|
TaskExecutor._write_bytes(write_fd, length_bytes)
|
|
TaskExecutor._write_bytes(write_fd, data)
|
|
finally:
|
|
try:
|
|
os.close(write_fd)
|
|
except Exception:
|
|
pass
|
|
|
|
# ========== print() ==========
|
|
|
|
@staticmethod
|
|
def _create_custom_print(print_args: PrintArgs):
|
|
def custom_print(*args):
|
|
serializable_args = []
|
|
|
|
for arg in args:
|
|
try:
|
|
json.dumps(arg, default=str, ensure_ascii=False)
|
|
serializable_args.append(arg)
|
|
except Exception as _:
|
|
# Ensure args are serializable so they are transmissible
|
|
# through the multiprocessing queue and via websockets.
|
|
serializable_args.append(
|
|
{
|
|
EXECUTOR_CIRCULAR_REFERENCE_KEY: repr(arg),
|
|
"__type__": type(arg).__name__,
|
|
}
|
|
)
|
|
|
|
formatted = TaskExecutor._format_print_args(*serializable_args)
|
|
print_args.append(formatted)
|
|
print("[user code]", *args)
|
|
|
|
return custom_print
|
|
|
|
@staticmethod
|
|
def _format_print_args(*args) -> list[str]:
|
|
"""
|
|
Takes the args passed to a `print()` call in user code and converts them
|
|
to string representations suitable for display in a browser console.
|
|
|
|
Expects all args to be serializable.
|
|
"""
|
|
|
|
formatted = []
|
|
|
|
for arg in args:
|
|
if isinstance(arg, str):
|
|
formatted.append(f"'{arg}'")
|
|
|
|
elif arg is None or isinstance(arg, (int, float, bool)):
|
|
formatted.append(str(arg))
|
|
|
|
elif isinstance(arg, dict) and EXECUTOR_CIRCULAR_REFERENCE_KEY in arg:
|
|
formatted.append(f"[Circular {arg.get('__type__', 'Object')}]")
|
|
|
|
else:
|
|
formatted.append(json.dumps(arg, default=str, ensure_ascii=False))
|
|
|
|
return formatted
|
|
|
|
@staticmethod
|
|
def _truncate_print_args(print_args: PrintArgs) -> PrintArgs:
|
|
"""Truncate print_args to prevent pipe buffer overflow."""
|
|
|
|
if not print_args or len(print_args) <= MAX_PRINT_ARGS_ALLOWED:
|
|
return print_args
|
|
|
|
truncated = print_args[:MAX_PRINT_ARGS_ALLOWED]
|
|
truncated.append(
|
|
[
|
|
f"[Output truncated - {len(print_args) - MAX_PRINT_ARGS_ALLOWED} more print statements]"
|
|
]
|
|
)
|
|
|
|
return truncated
|
|
|
|
# ========== security ==========
|
|
|
|
@staticmethod
|
|
def _filter_builtins(security_config: SecurityConfig):
|
|
"""Get __builtins__ with denied ones removed."""
|
|
|
|
if len(security_config.builtins_deny) == 0:
|
|
filtered = dict(__builtins__)
|
|
else:
|
|
filtered = {
|
|
k: v
|
|
for k, v in __builtins__.items()
|
|
if k not in security_config.builtins_deny
|
|
}
|
|
|
|
filtered["__import__"] = TaskExecutor._create_safe_import(security_config)
|
|
|
|
class _ImmutableBuiltins:
|
|
__slots__ = ()
|
|
|
|
def __getitem__(self, key):
|
|
return filtered[key]
|
|
|
|
def __contains__(self, key):
|
|
return key in filtered
|
|
|
|
def __iter__(self):
|
|
return iter(filtered)
|
|
|
|
def __len__(self):
|
|
return len(filtered)
|
|
|
|
def keys(self):
|
|
return filtered.keys()
|
|
|
|
def values(self):
|
|
return filtered.values()
|
|
|
|
def items(self):
|
|
return filtered.items()
|
|
|
|
def get(self, key, default=None):
|
|
return filtered.get(key, default)
|
|
|
|
def __getattr__(self, name):
|
|
try:
|
|
return filtered[name]
|
|
except KeyError:
|
|
raise AttributeError(name) from None
|
|
|
|
def __setattr__(self, name, value):
|
|
raise AttributeError("read-only")
|
|
|
|
def __delattr__(self, name):
|
|
raise AttributeError("read-only")
|
|
|
|
def __repr__(self):
|
|
return f"ImmutableBuiltins({len(filtered)} keys)"
|
|
|
|
return _ImmutableBuiltins()
|
|
|
|
@staticmethod
|
|
def _sanitize_sys_modules(security_config: SecurityConfig):
|
|
safe_modules = {
|
|
"builtins",
|
|
"__main__",
|
|
"sys",
|
|
"traceback",
|
|
"linecache",
|
|
"importlib",
|
|
"importlib.machinery",
|
|
}
|
|
|
|
if "*" in security_config.stdlib_allow:
|
|
safe_modules.update(sys.stdlib_module_names)
|
|
else:
|
|
safe_modules.update(security_config.stdlib_allow)
|
|
|
|
if "*" in security_config.external_allow:
|
|
safe_modules.update(
|
|
name
|
|
for name in sys.modules.keys()
|
|
if name not in sys.stdlib_module_names
|
|
)
|
|
else:
|
|
safe_modules.update(security_config.external_allow)
|
|
|
|
# keep modules marked as safe and submodules of those
|
|
safe_prefixes = [safe + "." for safe in safe_modules]
|
|
modules_to_remove = [
|
|
name
|
|
for name in sys.modules.keys()
|
|
if name not in safe_modules
|
|
and not any(name.startswith(prefix) for prefix in safe_prefixes)
|
|
]
|
|
|
|
for module_name in modules_to_remove:
|
|
del sys.modules[module_name]
|
|
|
|
@staticmethod
|
|
def _create_safe_import(security_config: SecurityConfig):
|
|
original_import = __builtins__["__import__"]
|
|
|
|
def safe_import(name, *args, **kwargs):
|
|
is_allowed, error_msg = validate_module_import(name, security_config)
|
|
|
|
if not is_allowed:
|
|
assert error_msg is not None
|
|
raise SecurityViolationError(
|
|
message="Security violation detected",
|
|
description=error_msg,
|
|
)
|
|
|
|
return original_import(name, *args, **kwargs)
|
|
|
|
return safe_import
|
|
|
|
# ========== pipe I/O ==========
|
|
|
|
@staticmethod
|
|
def _write_bytes(fd: int, data: bytes):
|
|
total_written = 0
|
|
while total_written < len(data):
|
|
written = os.write(fd, data[total_written:])
|
|
if written == 0:
|
|
raise OSError("Write failed")
|
|
total_written += written
|