mirror of
https://github.com/n8n-io/n8n.git
synced 2026-05-31 16:57:08 +02:00
Co-authored-by: Matsu <matias.huhta@n8n.io> Co-authored-by: Dawid Myslak <dawid.myslak@gmail.com> Co-authored-by: Bernhard Wittmann <bernhard.wittmann@n8n.io> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Dimitri Lavrenük <20122620+dlavrenuek@users.noreply.github.com> Co-authored-by: Benjamin Schroth <68321970+schrothbn@users.noreply.github.com> Co-authored-by: Danny Martini <danny@n8n.io> Co-authored-by: RomanDavydchuk <roman.davydchuk@n8n.io> Co-authored-by: Sandra Zollner <sandra.zollner@n8n.io> Co-authored-by: Milorad FIlipović <milorad@n8n.io> Co-authored-by: Iván Ovejero <ivov.src@gmail.com>
306 lines
11 KiB
Python
306 lines
11 KiB
Python
import ast
|
|
import hashlib
|
|
import re
|
|
from collections import OrderedDict
|
|
|
|
from src.errors import SecurityViolationError
|
|
from src.import_validation import validate_module_import
|
|
from src.config.security_config import SecurityConfig
|
|
from src.constants import (
|
|
MAX_VALIDATION_CACHE_SIZE,
|
|
ERROR_RELATIVE_IMPORT,
|
|
ERROR_DANGEROUS_NAME,
|
|
ERROR_DANGEROUS_ATTRIBUTE,
|
|
ERROR_NAME_MANGLED_ATTRIBUTE,
|
|
ERROR_DYNAMIC_IMPORT,
|
|
ERROR_DANGEROUS_STRING_PATTERN,
|
|
ERROR_MATCH_PATTERN_ATTRIBUTE,
|
|
ERROR_MATCH_POSITIONAL_PATTERN,
|
|
ERROR_GLOBAL_BLOCKED_NAME,
|
|
ERROR_FUNCDEF_BLOCKED_NAME,
|
|
ERROR_CLASSDEF_BLOCKED_NAME,
|
|
BLOCKED_ATTRIBUTES,
|
|
BLOCKED_NAMES,
|
|
)
|
|
|
|
CacheKey = tuple[str, tuple] # (code_hash, allowlists_tuple)
|
|
CachedViolations = list[str]
|
|
ValidationCache = OrderedDict[CacheKey, CachedViolations]
|
|
|
|
FORMAT_FIELD_PATTERN = re.compile(r"\{([^}]*)\}")
|
|
|
|
|
|
class SecurityValidator(ast.NodeVisitor):
|
|
"""AST visitor that enforces import allowlists and blocks dangerous attribute access."""
|
|
|
|
def __init__(self, security_config: SecurityConfig):
|
|
self.checked_modules: set[str] = set()
|
|
self.violations: list[str] = []
|
|
self.security_config = security_config
|
|
|
|
# ========== Detection ==========
|
|
|
|
def visit_Import(self, node: ast.Import) -> None:
|
|
"""Detect bare import statements (e.g., import os), including aliased (e.g., import numpy as np)."""
|
|
|
|
for alias in node.names:
|
|
module_name = alias.name
|
|
self._validate_import(module_name, node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
"""Detect from import statements (e.g., from os import path)."""
|
|
|
|
if node.level > 0:
|
|
self._add_violation(node.lineno, ERROR_RELATIVE_IMPORT)
|
|
elif node.module:
|
|
self._validate_import(node.module, node.lineno)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_Name(self, node: ast.Name) -> None:
|
|
if node.id in BLOCKED_NAMES:
|
|
self._add_violation(node.lineno, ERROR_DANGEROUS_NAME.format(name=node.id))
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
"""Detect access to unsafe attributes that could bypass security restrictions."""
|
|
|
|
if node.attr in BLOCKED_ATTRIBUTES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
|
|
)
|
|
|
|
if node.attr.startswith("_") and "__" in node.attr:
|
|
parts = node.attr.split("__", 1)
|
|
if len(parts) == 2 and parts[0].startswith("_"):
|
|
self._add_violation(node.lineno, ERROR_NAME_MANGLED_ATTRIBUTE)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_Call(self, node: ast.Call) -> None:
|
|
"""Detect calls to __import__() that could bypass security restrictions."""
|
|
|
|
is_import_call = (
|
|
# __import__()
|
|
(isinstance(node.func, ast.Name) and node.func.id == "__import__")
|
|
or
|
|
# builtins.__import__() or __builtins__.__import__()
|
|
(
|
|
isinstance(node.func, ast.Attribute)
|
|
and node.func.attr == "__import__"
|
|
and isinstance(node.func.value, ast.Name)
|
|
and node.func.value.id in {"builtins", "__builtins__"}
|
|
)
|
|
)
|
|
|
|
if is_import_call:
|
|
if (
|
|
node.args
|
|
and isinstance(node.args[0], ast.Constant)
|
|
and isinstance(node.args[0].value, str)
|
|
):
|
|
module_name = node.args[0].value
|
|
self._validate_import(module_name, node.lineno)
|
|
else:
|
|
self._add_violation(node.lineno, ERROR_DYNAMIC_IMPORT)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_Subscript(self, node: ast.Subscript) -> None:
|
|
"""Detect dict access to blocked attributes, e.g. __builtins__['__spec__']"""
|
|
|
|
is_builtins_access = (
|
|
# __builtins__['__spec__']
|
|
(
|
|
isinstance(node.value, ast.Name)
|
|
and node.value.id in {"__builtins__", "builtins"}
|
|
)
|
|
# obj.__builtins__['__spec__']
|
|
or (
|
|
isinstance(node.value, ast.Attribute)
|
|
and node.value.attr in {"__builtins__", "builtins"}
|
|
)
|
|
)
|
|
|
|
if (
|
|
is_builtins_access
|
|
and isinstance(node.slice, ast.Constant)
|
|
and isinstance(node.slice.value, str)
|
|
):
|
|
key = node.slice.value
|
|
if key in BLOCKED_ATTRIBUTES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=key)
|
|
)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_Constant(self, node: ast.Constant) -> None:
|
|
"""Detect string constants containing dangerous format patterns."""
|
|
|
|
if isinstance(node.value, str):
|
|
self._check_format_string(node.value, node.lineno)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def visit_Global(self, node: ast.Global) -> None:
|
|
"""Detect global declarations of blocked names, e.g. `global __builtins__`"""
|
|
|
|
for name in node.names:
|
|
if name in BLOCKED_NAMES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_GLOBAL_BLOCKED_NAME.format(name=name)
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
"""Detect function definitions with blocked names, e.g. `def __builtins__(): pass`"""
|
|
|
|
if node.name in BLOCKED_NAMES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_FUNCDEF_BLOCKED_NAME.format(name=node.name)
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
"""Detect async function definitions with blocked names."""
|
|
|
|
if node.name in BLOCKED_NAMES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_FUNCDEF_BLOCKED_NAME.format(name=node.name)
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
"""Detect class definitions with blocked names, e.g. `class __builtins__: pass`"""
|
|
|
|
if node.name in BLOCKED_NAMES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_CLASSDEF_BLOCKED_NAME.format(name=node.name)
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_MatchClass(self, node: ast.MatchClass) -> None:
|
|
"""Detect match patterns that extract blocked attributes, e.g. `case AttributeError(obj=x)`"""
|
|
|
|
if node.patterns:
|
|
self._add_violation(node.lineno, ERROR_MATCH_POSITIONAL_PATTERN)
|
|
|
|
for attr in node.kwd_attrs:
|
|
if attr in BLOCKED_ATTRIBUTES:
|
|
self._add_violation(
|
|
node.lineno, ERROR_MATCH_PATTERN_ATTRIBUTE.format(attr=attr)
|
|
)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def _check_format_string(self, s: str, lineno: int) -> None:
|
|
"""Check if a string contains format patterns that access blocked attributes."""
|
|
|
|
# escaped braces produce literal braces, not format fields
|
|
s = s.replace("{{", "").replace("}}", "")
|
|
|
|
for match in FORMAT_FIELD_PATTERN.finditer(s):
|
|
field = match.group(1)
|
|
|
|
# attribute access
|
|
for attr_match in re.finditer(r"\.(\w+)", field):
|
|
attr = attr_match.group(1)
|
|
if attr in BLOCKED_ATTRIBUTES or attr in BLOCKED_NAMES:
|
|
self._add_violation(
|
|
lineno, ERROR_DANGEROUS_STRING_PATTERN.format(attr=attr)
|
|
)
|
|
|
|
# subscript access
|
|
for subscript_match in re.finditer(r"\[(['\"]?)(\w+)\1\]", field):
|
|
key = subscript_match.group(2)
|
|
if key in BLOCKED_ATTRIBUTES or key in BLOCKED_NAMES:
|
|
self._add_violation(
|
|
lineno, ERROR_DANGEROUS_STRING_PATTERN.format(attr=key)
|
|
)
|
|
|
|
# ========== Validation ==========
|
|
|
|
def _validate_import(self, module_path: str, lineno: int) -> None:
|
|
"""Validate that a module import is allowed based on allowlists. Also disallow relative imports."""
|
|
|
|
if module_path.startswith("."):
|
|
self._add_violation(lineno, ERROR_RELATIVE_IMPORT)
|
|
return
|
|
|
|
module_name = module_path.split(".")[0] # e.g., os.path -> os
|
|
|
|
if module_name in self.checked_modules:
|
|
return
|
|
|
|
self.checked_modules.add(module_name)
|
|
|
|
is_allowed, error_msg = validate_module_import(
|
|
module_path, self.security_config
|
|
)
|
|
|
|
if not is_allowed:
|
|
assert error_msg is not None
|
|
self._add_violation(lineno, error_msg)
|
|
|
|
def _add_violation(self, lineno: int, message: str) -> None:
|
|
self.violations.append(f"Line {lineno}: {message}")
|
|
|
|
|
|
class TaskAnalyzer:
|
|
_cache: ValidationCache = OrderedDict()
|
|
|
|
def __init__(self, security_config: SecurityConfig):
|
|
self._security_config = security_config
|
|
self._allowlists = (
|
|
tuple(sorted(security_config.stdlib_allow)),
|
|
tuple(sorted(security_config.external_allow)),
|
|
)
|
|
self._allow_all = (
|
|
"*" in security_config.stdlib_allow
|
|
and "*" in security_config.external_allow
|
|
)
|
|
|
|
def validate(self, code: str) -> None:
|
|
if self._allow_all:
|
|
return
|
|
|
|
cache_key = self._to_cache_key(code)
|
|
cached_violations = self._cache.get(cache_key)
|
|
|
|
if cached_violations is not None:
|
|
self._cache.move_to_end(cache_key)
|
|
|
|
if len(cached_violations) == 0:
|
|
return
|
|
|
|
self._raise_security_error(cached_violations)
|
|
|
|
tree = ast.parse(code)
|
|
|
|
security_validator = SecurityValidator(self._security_config)
|
|
security_validator.visit(tree)
|
|
|
|
self._set_in_cache(cache_key, security_validator.violations)
|
|
|
|
if security_validator.violations:
|
|
self._raise_security_error(security_validator.violations)
|
|
|
|
def _raise_security_error(self, violations: CachedViolations) -> None:
|
|
raise SecurityViolationError(
|
|
message="Security violations detected", description="\n".join(violations)
|
|
)
|
|
|
|
def _to_cache_key(self, code: str) -> CacheKey:
|
|
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
|
return (code_hash, self._allowlists)
|
|
|
|
def _set_in_cache(self, cache_key: CacheKey, violations: CachedViolations) -> None:
|
|
if len(self._cache) >= MAX_VALIDATION_CACHE_SIZE:
|
|
self._cache.popitem(last=False) # FIFO
|
|
|
|
self._cache[cache_key] = violations.copy()
|
|
self._cache.move_to_end(cache_key)
|