import asyncio import logging import time from typing import Callable, Awaitable from dataclasses import dataclass from urllib.parse import urlparse import websockets from websockets.exceptions import InvalidStatus from websockets.asyncio.client import ClientConnection import random from src.errors import TaskCancelledError from src.config.task_runner_config import TaskRunnerConfig from src.errors import ( NoIdleTimeoutHandlerError, TaskMissingError, WebsocketConnectionError, ) from src.message_types.broker import TaskSettings from src.nanoid import nanoid from src.constants import ( RUNNER_NAME, TASK_REJECTED_REASON_AT_CAPACITY, TASK_REJECTED_REASON_OFFER_EXPIRED, TASK_TYPE_PYTHON, OFFER_INTERVAL, OFFER_VALIDITY, OFFER_VALIDITY_MAX_JITTER, OFFER_VALIDITY_LATENCY_BUFFER, TASK_BROKER_WS_PATH, RPC_BROWSER_CONSOLE_LOG_METHOD, LOG_TASK_COMPLETE, LOG_TASK_CANCEL, LOG_TASK_CANCEL_UNKNOWN, LOG_TASK_CANCEL_WAITING, ) from src.message_types import ( BrokerMessage, RunnerMessage, BrokerInfoRequest, BrokerRunnerRegistered, BrokerTaskOfferAccept, BrokerTaskSettings, BrokerTaskCancel, BrokerRpcResponse, RunnerInfo, RunnerTaskOffer, RunnerTaskAccepted, RunnerTaskRejected, RunnerTaskDone, RunnerTaskError, RunnerRpcCall, ) from src.message_serde import MessageSerde from src.task_state import TaskState, TaskStatus from src.task_executor import TaskExecutor from src.task_analyzer import TaskAnalyzer from src.config.security_config import SecurityConfig @dataclass class TaskOffer: offer_id: str valid_until: float @property def has_expired(self) -> bool: return time.time() > self.valid_until class TaskRunner: def __init__( self, config: TaskRunnerConfig, ): self.runner_id = nanoid() self.name = RUNNER_NAME self.config = config self.websocket_connection: ClientConnection | None = None self.can_send_offers = False self.open_offers: dict[str, TaskOffer] = {} self.running_tasks: dict[str, TaskState] = {} self.offers_coroutine: asyncio.Task | None = None self.serde = MessageSerde() self.executor = TaskExecutor() self.security_config = SecurityConfig( stdlib_allow=config.stdlib_allow, external_allow=config.external_allow, builtins_deny=config.builtins_deny, runner_env_deny=config.env_deny, ) self.analyzer = TaskAnalyzer(self.security_config) self.logger = logging.getLogger(__name__) self.idle_coroutine: asyncio.Task | None = None self.on_idle_timeout: Callable[[], Awaitable[None]] | None = None self.last_activity_time = time.time() self.is_shutting_down = False self.task_broker_uri = config.task_broker_uri websocket_host = urlparse(config.task_broker_uri).netloc self.websocket_url = ( f"ws://{websocket_host}{TASK_BROKER_WS_PATH}?id={self.runner_id}" ) @property def running_tasks_count(self) -> int: return len(self.running_tasks) async def start(self) -> None: if self.config.is_auto_shutdown_enabled and not self.on_idle_timeout: raise NoIdleTimeoutHandlerError(self.config.auto_shutdown_timeout) headers = {"Authorization": f"Bearer {self.config.grant_token}"} while not self.is_shutting_down: try: self.websocket_connection = await websockets.connect( self.websocket_url, additional_headers=headers, max_size=self.config.max_payload_size, ) self.logger.info("Connected to broker") await self._listen_for_messages() except InvalidStatus as e: if e.response.status_code == 403: self.logger.error( f"Authentication failed with status {e.response.status_code}: {e}" ) raise self.logger.warning(f"Failed to connect to broker: {e} - retrying...") except Exception as e: self.logger.warning(f"Failed to connect to broker: {e} - retrying...") if not self.is_shutting_down: self.websocket_connection = None self.can_send_offers = False await self._cancel_coroutine(self.offers_coroutine) await self._cancel_coroutine(self.idle_coroutine) await asyncio.sleep(5) async def _cancel_coroutine(self, coroutine: asyncio.Task | None) -> None: if coroutine and not coroutine.done(): coroutine.cancel() try: await coroutine except asyncio.CancelledError: pass # ========== Shutdown ========== async def stop(self) -> None: self.is_shutting_down = True self.can_send_offers = False await self._cancel_coroutine(self.offers_coroutine) await self._cancel_coroutine(self.idle_coroutine) await self._wait_for_tasks() await self._terminate_tasks() if self.websocket_connection: await self.websocket_connection.close() self.logger.info("Disconnected from broker") self.logger.info("Runner stopped") async def _wait_for_tasks(self): if not self.running_tasks: return timeout = self.config.graceful_shutdown_timeout self.logger.debug( f"Waiting for {self.running_tasks_count} tasks to complete (timeout: {timeout}s)..." ) start_time = time.time() while self.running_tasks and (time.time() - start_time) < timeout: await asyncio.sleep(0.5) if self.running_tasks: self.logger.warning( f"Timed out waiting for {self.running_tasks_count} tasks to complete" ) async def _terminate_tasks(self): if not self.running_tasks: return self.logger.warning(f"Terminating {self.running_tasks_count} tasks...") tasks_to_terminate = [ asyncio.to_thread(self.executor.stop_process, task_state.process) for task_state in self.running_tasks.values() if task_state.process ] if tasks_to_terminate: await asyncio.gather(*tasks_to_terminate, return_exceptions=True) self.running_tasks.clear() self.logger.warning("Terminated tasks") # ========== Messages ========== async def _listen_for_messages(self) -> None: if self.websocket_connection is None: raise WebsocketConnectionError(self.task_broker_uri) async for raw_message in self.websocket_connection: try: message = self.serde.deserialize_broker_message(raw_message) await self._handle_message(message) except websockets.ConnectionClosedOK: break except Exception as e: self.logger.error(f"Error handling message: {e}") async def _handle_message(self, message: BrokerMessage) -> None: match message: case BrokerInfoRequest(): await self._handle_info_request() case BrokerRunnerRegistered(): await self._handle_runner_registered() case BrokerTaskOfferAccept(): await self._handle_task_offer_accept(message) case BrokerTaskSettings(): await self._handle_task_settings(message) case BrokerTaskCancel(): await self._handle_task_cancel(message) case BrokerRpcResponse(): pass # currently only logging, already handled by browser case _: self.logger.warning(f"Unhandled message type: {type(message)}") async def _handle_info_request(self) -> None: response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON]) await self._send_message(response) async def _handle_runner_registered(self) -> None: self.can_send_offers = True self.offers_coroutine = asyncio.create_task(self._send_offers_loop()) self.logger.info("Registered with broker") self._reset_idle_timer() async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None: offer = self.open_offers.get(message.offer_id) if offer is None or offer.has_expired: response = RunnerTaskRejected( task_id=message.task_id, reason=TASK_REJECTED_REASON_OFFER_EXPIRED, ) await self._send_message(response) return if self.running_tasks_count >= self.config.max_concurrency: response = RunnerTaskRejected( task_id=message.task_id, reason=TASK_REJECTED_REASON_AT_CAPACITY, ) await self._send_message(response) return del self.open_offers[message.offer_id] task_state = TaskState(message.task_id) self.running_tasks[message.task_id] = task_state response = RunnerTaskAccepted(task_id=message.task_id) await self._send_message(response) self.logger.info(f"Accepted task {message.task_id}") self._reset_idle_timer() async def _handle_task_settings(self, message: BrokerTaskSettings) -> None: task_state = self.running_tasks.get(message.task_id) if task_state is None: raise TaskMissingError(message.task_id) if task_state.status != TaskStatus.WAITING_FOR_SETTINGS: self.logger.warning( f"Received settings for task but it is already {task_state.status}. Discarding message." ) return task_state.workflow_name = message.settings.workflow_name task_state.workflow_id = message.settings.workflow_id task_state.node_name = message.settings.node_name task_state.node_id = message.settings.node_id task_state.status = TaskStatus.RUNNING asyncio.create_task(self._execute_task(message.task_id, message.settings)) self.logger.info(f"Received task {message.task_id}") async def _execute_task(self, task_id: str, task_settings: TaskSettings) -> None: start_time = time.time() try: task_state = self.running_tasks.get(task_id) if task_state is None: raise TaskMissingError(task_id) self.analyzer.validate(task_settings.code) process, read_conn, write_conn = self.executor.create_process( code=task_settings.code, node_mode=task_settings.node_mode, items=task_settings.items, security_config=self.security_config, query=task_settings.query, ) task_state.process = process result, print_args, result_size_bytes = await asyncio.to_thread( self.executor.execute_process, process=process, read_conn=read_conn, write_conn=write_conn, task_timeout=self.config.task_timeout, pipe_reader_timeout=self.config.pipe_reader_timeout, continue_on_fail=task_settings.continue_on_fail, ) for print_args_per_call in print_args: await self._send_rpc_message( task_id, RPC_BROWSER_CONSOLE_LOG_METHOD, print_args_per_call ) response = RunnerTaskDone(task_id=task_id, data={"result": result}) await self._send_message(response) self.logger.info( LOG_TASK_COMPLETE.format( task_id=task_id, duration=self._get_duration(start_time), result_size=self._get_result_size(result_size_bytes), **task_state.context(), ) ) except TaskCancelledError as e: response = RunnerTaskError(task_id=task_id, error={"message": str(e)}) await self._send_message(response) except SyntaxError as e: self.logger.warning(f"Task {task_id} failed syntax validation") error = {"message": str(e)} response = RunnerTaskError(task_id=task_id, error=error) await self._send_message(response) except Exception as e: self.logger.error(f"Task {task_id} failed", exc_info=True) error = { "message": getattr(e, "message", str(e)), "description": getattr(e, "description", ""), } response = RunnerTaskError(task_id=task_id, error=error) await self._send_message(response) finally: self.running_tasks.pop(task_id, None) self._reset_idle_timer() async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None: task_id = message.task_id task_state = self.running_tasks.get(task_id) if task_state is None: self.logger.warning(LOG_TASK_CANCEL_UNKNOWN.format(task_id=task_id)) return if task_state.status == TaskStatus.WAITING_FOR_SETTINGS: self.running_tasks.pop(task_id, None) self.logger.info(LOG_TASK_CANCEL_WAITING.format(task_id=task_id)) await self._send_offers() return if task_state.status == TaskStatus.RUNNING: task_state.status = TaskStatus.ABORTING await asyncio.to_thread(self.executor.stop_process, task_state.process) self.logger.info( LOG_TASK_CANCEL.format(task_id=task_id, **task_state.context()) ) async def _send_rpc_message(self, task_id: str, method_name: str, params: list): message = RunnerRpcCall( call_id=nanoid(), task_id=task_id, name=method_name, params=params ) await self._send_message(message) async def _send_message(self, message: RunnerMessage) -> None: if self.websocket_connection is None: raise WebsocketConnectionError(self.task_broker_uri) serialized = self.serde.serialize_runner_message(message) await self.websocket_connection.send(serialized) # ========== Formatting ========== def _get_duration(self, start_time: float) -> str: elapsed = time.time() - start_time if elapsed < 1: return f"{int(elapsed * 1000)}ms" if elapsed < 60: return f"{int(elapsed)}s" return f"{int(elapsed) // 60}m" def _get_result_size(self, size_bytes: int) -> str: if size_bytes < 1024: return f"{size_bytes} bytes" elif size_bytes < 1024 * 1024: return f"{size_bytes / 1024:.1f} KB" else: return f"{size_bytes / (1024 * 1024):.1f} MB" # ========== Offers ========== async def _send_offers_loop(self) -> None: while self.can_send_offers: try: await self._send_offers() await asyncio.sleep(OFFER_INTERVAL) except asyncio.CancelledError: break except Exception as e: self.logger.error(f"Error sending offers: {e}") async def _send_offers(self) -> None: if not self.can_send_offers: return expired_offer_ids = [ offer_id for offer_id, offer in self.open_offers.items() if offer.has_expired ] for offer_id in expired_offer_ids: self.open_offers.pop(offer_id, None) offers_to_send = self.config.max_concurrency - ( len(self.open_offers) + self.running_tasks_count ) for _ in range(offers_to_send): offer_id = nanoid() valid_for_ms = OFFER_VALIDITY + random.randint(0, OFFER_VALIDITY_MAX_JITTER) valid_until = ( time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER ) self.open_offers[offer_id] = TaskOffer(offer_id, valid_until) message = RunnerTaskOffer( offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms ) await self._send_message(message) # ========== Inactivity ========== def _reset_idle_timer(self): """Reset idle timer when key event occurs, namely runner registration, task acceptance, and task completion or failure.""" if not self.config.is_auto_shutdown_enabled: return self.last_activity_time = time.time() if self.idle_coroutine and not self.idle_coroutine.done(): self.idle_coroutine.cancel() self.idle_coroutine = asyncio.create_task(self._idle_timer_coroutine()) async def _idle_timer_coroutine(self): try: await asyncio.sleep(self.config.auto_shutdown_timeout) if self.running_tasks_count > 0: return assert self.on_idle_timeout is not None # validated at start() await self.on_idle_timeout() except asyncio.CancelledError: pass