HomeDashboard/.venv/lib/python3.12/site-packages/nicegui/testing/general.py

84 lines
2.5 KiB
Python
Raw Permalink Normal View History

2026-01-03 14:54:18 +01:00
import contextlib
import gc
import sys
from copy import copy
from starlette.routing import Route
from .. import app, binding, core, event, run, ui
from ..client import Client
def prepare_simulation() -> None:
"""Prepare the simulation by adding the run config and setting the storage secret."""
core.app.config.add_run_config(
reload=False,
title='Test App',
viewport='',
favicon=None,
dark=False,
language='en-US',
binding_refresh_interval=0.1,
reconnect_timeout=3.0,
message_history_length=1000,
tailwind=True,
prod_js=True,
show_welcome_message=False,
)
@contextlib.contextmanager
def nicegui_reset_globals():
"""Reset the global state of the NiceGUI package."""
for route in list(app.routes):
if isinstance(route, Route) and (
not route.path.startswith('/_nicegui/')
or route.path.startswith('/_nicegui/auto/static')
or route.path.startswith('/_nicegui/client/')
):
app.remove_route(route.path)
app.openapi_schema = None
app.middleware_stack = None
app.user_middleware.clear()
app.urls.clear()
core.reset()
element_types: list[type[ui.element]] = [ui.element, *_find_all_subclasses(ui.element)]
default_classes = {t: copy(t._default_classes) for t in element_types} # pylint: disable=protected-access
default_styles = {t: copy(t._default_style) for t in element_types} # pylint: disable=protected-access
default_props = {t: copy(t._default_props) for t in element_types} # pylint: disable=protected-access
Client.instances.clear()
Client.page_routes.clear()
app.reset()
binding.reset()
gc.collect()
try:
yield
finally:
gc.collect()
app.reset()
event.reset()
run.reset()
for t in element_types:
t._default_classes = default_classes[t] # pylint: disable=protected-access
t._default_style = default_styles[t] # pylint: disable=protected-access
t._default_props = default_props[t] # pylint: disable=protected-access
for func in Client.page_routes:
if not func.__module__.startswith('tests.'):
sys.modules.pop(func.__module__, None)
def _find_all_subclasses(cls: type) -> list[type]:
subclasses = []
for subclass in cls.__subclasses__():
subclasses.append(subclass)
subclasses.extend(_find_all_subclasses(subclass))
return subclasses