282 lines
11 KiB
Python
282 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import gzip
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections.abc import AsyncIterator
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any
|
|
from uuid import uuid4
|
|
|
|
import socketio
|
|
import socketio.exceptions
|
|
|
|
from . import background_tasks, core, helpers
|
|
from .client import Client
|
|
from .dataclasses import KWONLY_SLOTS
|
|
from .logging import log
|
|
from .timer import Timer as timer
|
|
|
|
if TYPE_CHECKING:
|
|
import httpx
|
|
|
|
RELAY_HOST = 'https://on-air.nicegui.io/'
|
|
|
|
|
|
@dataclass(**KWONLY_SLOTS)
|
|
class Stream:
|
|
data: AsyncIterator[bytes]
|
|
response: httpx.Response
|
|
|
|
|
|
class Air:
|
|
|
|
def __init__(self, token: str) -> None:
|
|
import httpx # pylint: disable=import-outside-toplevel
|
|
|
|
self.log = logging.getLogger('nicegui.air')
|
|
self.token = token
|
|
self.relay = socketio.AsyncClient()
|
|
self.client = httpx.AsyncClient(transport=httpx.ASGITransport(app=core.app))
|
|
self.streaming_client = httpx.AsyncClient()
|
|
self.connecting = False
|
|
self.streams: dict[str, Stream] = {}
|
|
self.remote_url: str | None = None
|
|
self._host_unreachable_warning = f'On Air host "{RELAY_HOST}" is not reachable. Please check your internet connection.'
|
|
|
|
timer(5, self.connect) # ensure we stay connected
|
|
|
|
@self.relay.on('http')
|
|
async def _handle_http(data: dict[str, Any]) -> dict[str, Any]:
|
|
headers: dict[str, Any] = data['headers']
|
|
headers.update({'Accept-Encoding': 'identity', 'X-Forwarded-Prefix': data['prefix']})
|
|
url = 'http://test' + data['path']
|
|
request = self.client.build_request(
|
|
data['method'],
|
|
url,
|
|
params=data['params'],
|
|
headers=headers,
|
|
content=data['body'],
|
|
)
|
|
response = await self.client.send(request)
|
|
self.client.cookies.clear()
|
|
instance_id = data['instance-id']
|
|
content = response.content.replace(
|
|
b'const extraHeaders = {};',
|
|
(f'const extraHeaders = {{ "fly-force-instance-id" : "{instance_id}" }};').encode(),
|
|
)
|
|
match = re.search(b'const query = ({.*?})', content)
|
|
if match:
|
|
new_js_object = match.group(1).decode().rstrip('}') + ", 'fly_instance_id' : '" + instance_id + "'}"
|
|
content = content.replace(match.group(0), f'const query = {new_js_object}'.encode())
|
|
compressed = gzip.compress(content)
|
|
total_size = len(compressed)
|
|
response.headers.update({'content-encoding': 'gzip', 'content-length': str(total_size)})
|
|
result = {
|
|
'status_code': response.status_code,
|
|
'headers': response.headers.multi_items(),
|
|
'content': compressed,
|
|
}
|
|
|
|
# NOTE: chunk large responses to stay within the SocketIO limit
|
|
if total_size > 1_000_000:
|
|
async def chunk_iterator() -> AsyncIterator[bytes]:
|
|
chunk_size = 512 * 1024
|
|
for i in range(0, total_size, chunk_size):
|
|
yield compressed[i:i + chunk_size]
|
|
stream_id = str(uuid4())
|
|
self.streams[stream_id] = Stream(data=chunk_iterator(), response=response)
|
|
result['stream_id'] = stream_id
|
|
result['total_size'] = total_size
|
|
del result['content']
|
|
|
|
return result
|
|
|
|
@self.relay.on('range-request')
|
|
async def _handle_range_request(data: dict[str, Any]) -> dict[str, Any]:
|
|
headers: dict[str, Any] = data['headers']
|
|
url = next(iter(u for u in core.app.urls if self.remote_url != u)) + data['path']
|
|
data['params']['nicegui_chunk_size'] = 1024
|
|
request = self.client.build_request(
|
|
data['method'],
|
|
url,
|
|
params=data['params'],
|
|
headers=headers,
|
|
)
|
|
response = await self.streaming_client.send(request, stream=True)
|
|
stream_id = str(uuid4())
|
|
self.streams[stream_id] = Stream(data=response.aiter_bytes(), response=response)
|
|
return {
|
|
'status_code': response.status_code,
|
|
'headers': response.headers.multi_items(),
|
|
'stream_id': stream_id,
|
|
}
|
|
|
|
@self.relay.on('read-stream')
|
|
async def _handle_read_stream(stream_id: str) -> bytes | None:
|
|
try:
|
|
return await self.streams[stream_id].data.__anext__()
|
|
except StopAsyncIteration:
|
|
await _handle_close_stream(stream_id)
|
|
return None
|
|
except Exception:
|
|
await _handle_close_stream(stream_id)
|
|
raise
|
|
|
|
@self.relay.on('close-stream')
|
|
async def _handle_close_stream(stream_id: str) -> None:
|
|
await self.streams[stream_id].response.aclose()
|
|
del self.streams[stream_id]
|
|
|
|
@self.relay.on('ready')
|
|
def _handle_ready(data: dict[str, Any]) -> None:
|
|
core.app.urls.add(data['device_url'])
|
|
self.remote_url = data['device_url']
|
|
if core.app.config.show_welcome_message:
|
|
print(f'NiceGUI is on air at {data["device_url"]}', flush=True)
|
|
|
|
@self.relay.on('error')
|
|
def _handleerror(data: dict[str, Any]) -> None:
|
|
print('Error:', data['message'], flush=True)
|
|
|
|
@self.relay.on('handshake')
|
|
def _handle_handshake(data: dict[str, Any]) -> bool:
|
|
client_id = data['client_id']
|
|
if client_id not in Client.instances:
|
|
return False
|
|
client = Client.instances[client_id]
|
|
client.environ = data['environ']
|
|
if data.get('old_tab_id'):
|
|
core.app.storage.copy_tab(data['old_tab_id'], data['tab_id'])
|
|
client.tab_id = data['tab_id']
|
|
client.on_air = True
|
|
client.handle_handshake(data['sid'], data['document_id'], data.get('next_message_id'))
|
|
return True
|
|
|
|
@self.relay.on('client_disconnect')
|
|
def _handle_client_disconnect(data: dict[str, Any]) -> None:
|
|
self.log.debug('client disconnected.')
|
|
client_id = data['client_id']
|
|
if client_id not in Client.instances:
|
|
return
|
|
Client.instances[client_id].handle_disconnect(data['sid'])
|
|
|
|
@self.relay.on('connect')
|
|
async def _handle_connect() -> None:
|
|
self.log.debug('connected.')
|
|
# NOTE: reset the warning so it can be shown again if connection breaks in the future
|
|
helpers._shown_warnings.discard(self._host_unreachable_warning) # pylint: disable=protected-access
|
|
|
|
@self.relay.on('disconnect')
|
|
async def _handle_disconnect() -> None:
|
|
self.log.debug('disconnected.')
|
|
|
|
@self.relay.on('connect_error')
|
|
async def _handle_connect_error(data) -> None:
|
|
message = data.get('message', 'Unknown error') if isinstance(data, dict) else data
|
|
if message == 'Connection error':
|
|
helpers.warn_once(self._host_unreachable_warning)
|
|
else:
|
|
self.log.warning(f'Connection error: {message}')
|
|
await self.relay.disconnect()
|
|
|
|
@self.relay.on('event')
|
|
def _handle_event(data: dict[str, Any]) -> None:
|
|
client_id = data['client_id']
|
|
if client_id not in Client.instances:
|
|
return
|
|
client = Client.instances[client_id]
|
|
args = data['msg']['args']
|
|
if args and isinstance(args[0], str) and args[0].startswith('{"socket_id":'):
|
|
arg0 = json.loads(args[0])
|
|
arg0['socket_id'] = client_id # HACK: translate socket_id of ui.scene's init event
|
|
args[0] = json.dumps(arg0)
|
|
client.handle_event(data['msg'])
|
|
|
|
@self.relay.on('javascript_response')
|
|
def _handle_javascript_response(data: dict[str, Any]) -> None:
|
|
client_id = data['client_id']
|
|
if client_id not in Client.instances:
|
|
return
|
|
client = Client.instances[client_id]
|
|
client.handle_javascript_response(data['msg'])
|
|
|
|
@self.relay.on('ack')
|
|
def _handle_ack(data: dict[str, Any]) -> None:
|
|
client_id = data['client_id']
|
|
if client_id not in Client.instances:
|
|
return
|
|
client = Client.instances[client_id]
|
|
client.outbox.prune_history(data['msg']['next_message_id'])
|
|
|
|
@self.relay.on('out_of_time')
|
|
async def _handle_out_of_time() -> None:
|
|
print('Sorry, you have reached the time limit of this NiceGUI On Air preview.', flush=True)
|
|
await self.connect()
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect to the NiceGUI On Air server."""
|
|
if self.connecting:
|
|
self.log.debug('Already connecting.')
|
|
return
|
|
if self.relay.connected:
|
|
return
|
|
self.log.debug('Going to connect...')
|
|
self.connecting = True
|
|
try:
|
|
if self.relay.connected:
|
|
await asyncio.wait_for(self.disconnect(), timeout=5)
|
|
self.log.debug('Connecting...')
|
|
await self.relay.connect(
|
|
f'{RELAY_HOST}?device_token={self.token}',
|
|
socketio_path='/on_air/socket.io',
|
|
transports=['websocket', 'polling'], # favor websocket over polling
|
|
wait_timeout=5,
|
|
)
|
|
assert self.relay.connected
|
|
return
|
|
except socketio.exceptions.ConnectionError:
|
|
self.log.debug('Connection error.', stack_info=True)
|
|
except ValueError: # NOTE this sometimes happens when the internal socketio client is not yet ready
|
|
self.log.debug('ValueError while connecting.', stack_info=True)
|
|
except Exception:
|
|
log.exception('Could not connect to NiceGUI On Air server.')
|
|
finally:
|
|
self.connecting = False
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the NiceGUI On Air server."""
|
|
self.log.debug('Disconnecting...')
|
|
if self.relay.connected:
|
|
await self.relay.disconnect()
|
|
for stream in self.streams.values():
|
|
await stream.response.aclose()
|
|
self.streams.clear()
|
|
self.log.debug('Disconnected.')
|
|
|
|
async def emit(self, message_type: str, data: dict[str, Any], room: str) -> None:
|
|
"""Emit a message to the NiceGUI On Air server."""
|
|
if self.relay.connected:
|
|
await self.relay.emit('forward', {'event': message_type, 'data': data, 'room': room})
|
|
|
|
@staticmethod
|
|
def is_air_target(target_id: str) -> bool:
|
|
"""Whether the given target ID is an On Air client or a SocketIO room."""
|
|
if target_id in Client.instances:
|
|
return Client.instances[target_id].on_air
|
|
return target_id in core.sio.manager.rooms
|
|
|
|
|
|
def connect() -> None:
|
|
"""Connect to the NiceGUI On Air server if there is an air instance."""
|
|
if core.air:
|
|
background_tasks.create(core.air.connect(), name='On Air connect')
|
|
|
|
|
|
def disconnect() -> None:
|
|
"""Disconnect from the NiceGUI On Air server if there is an air instance."""
|
|
if core.air:
|
|
background_tasks.create(core.air.disconnect(), name='On Air disconnect')
|