From b47eefe615922b730b1fde0fcc5546f859634a06 Mon Sep 17 00:00:00 2001 From: David Rodenkirchen Date: Mon, 23 Feb 2026 15:15:39 +0100 Subject: [PATCH] Overhaul Sessioning --- src/EzggLanManager.py | 12 +++--- src/ezgg_lan_manager/__init__.py | 7 ++-- .../components/DesktopNavigation.py | 36 ++++------------ src/ezgg_lan_manager/components/LoginBox.py | 17 +++++--- .../components/SeatingPlanInfoBox.py | 23 +++++++---- .../components/SeatingPlanPixels.py | 8 +++- .../components/ShoppingCartAndOrders.py | 41 +++++++++++++------ .../components/TicketBuyCard.py | 5 +-- .../components/UserEditForm.py | 11 +++-- .../components/UserInfoAndLoginBox.py | 13 +++--- .../components/UserInfoBox.py | 29 ++++++------- src/ezgg_lan_manager/helpers/LoggedInGuard.py | 22 +++++++--- src/ezgg_lan_manager/pages/Account.py | 14 ++++--- src/ezgg_lan_manager/pages/BuyTicketPage.py | 14 +++++-- src/ezgg_lan_manager/pages/CateringPage.py | 12 +++--- src/ezgg_lan_manager/pages/ContactPage.py | 8 ++-- src/ezgg_lan_manager/pages/EditProfile.py | 11 +---- src/ezgg_lan_manager/pages/ManageUsersPage.py | 8 +++- src/ezgg_lan_manager/pages/SeatingPlanPage.py | 7 +++- src/ezgg_lan_manager/pages/TeamsPage.py | 12 ++++-- .../pages/TournamentDetailsPage.py | 13 +++--- .../services/LocalDataService.py | 10 ++--- .../services/RefreshService.py | 17 ++++++++ src/ezgg_lan_manager/types/SessionStorage.py | 36 ---------------- src/ezgg_lan_manager/types/UserSession.py | 9 ++++ 25 files changed, 216 insertions(+), 179 deletions(-) create mode 100644 src/ezgg_lan_manager/services/RefreshService.py delete mode 100644 src/ezgg_lan_manager/types/SessionStorage.py create mode 100644 src/ezgg_lan_manager/types/UserSession.py diff --git a/src/EzggLanManager.py b/src/EzggLanManager.py index c6cc71f..3305adc 100644 --- a/src/EzggLanManager.py +++ b/src/EzggLanManager.py @@ -1,5 +1,4 @@ import logging -from asyncio import get_event_loop import sys @@ -8,11 +7,9 @@ from pathlib import Path from rio import App, Theme, Color, Font, ComponentPage, Session from from_root import from_root -from src.ezgg_lan_manager import pages, init_services +from src.ezgg_lan_manager import pages, init_services, LocalDataService from src.ezgg_lan_manager.helpers.LoggedInGuard import logged_in_guard, not_logged_in_guard, team_guard -from src.ezgg_lan_manager.services.DatabaseService import NoDatabaseConnectionError from src.ezgg_lan_manager.services.LocalDataService import LocalData -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage logger = logging.getLogger("EzggLanManager") @@ -30,14 +27,17 @@ if __name__ == "__main__": corner_radius_large=0, font=Font(from_root("src/ezgg_lan_manager/assets/fonts/joystix.otf")) ) - default_attachments: list = [LocalData()] + default_attachments: list = [LocalData(stored_session_token=None)] default_attachments.extend(init_services()) lan_info = default_attachments[3].get_lan_info() async def on_session_start(session: Session) -> None: await session.set_title(lan_info.name) - session.attach(SessionStorage()) + if session[LocalData].stored_session_token: + user_session = session[LocalDataService].verify_token(session[LocalData].stored_session_token) + if user_session is not None: + session.attach(user_session) async def on_app_start(a: App) -> None: init_result = await a.default_attachments[4].init_db_pool() diff --git a/src/ezgg_lan_manager/__init__.py b/src/ezgg_lan_manager/__init__.py index b52831d..b81b07d 100644 --- a/src/ezgg_lan_manager/__init__.py +++ b/src/ezgg_lan_manager/__init__.py @@ -10,6 +10,7 @@ from src.ezgg_lan_manager.services.DatabaseService import DatabaseService from src.ezgg_lan_manager.services.LocalDataService import LocalDataService from src.ezgg_lan_manager.services.MailingService import MailingService from src.ezgg_lan_manager.services.NewsService import NewsService +from src.ezgg_lan_manager.services.RefreshService import RefreshService from src.ezgg_lan_manager.services.ReceiptPrintingService import ReceiptPrintingService from src.ezgg_lan_manager.services.SeatingService import SeatingService from src.ezgg_lan_manager.services.TeamService import TeamService @@ -19,7 +20,7 @@ from src.ezgg_lan_manager.services.UserService import UserService from src.ezgg_lan_manager.types import * # Inits services in the correct order -def init_services() -> tuple[AccountingService, CateringService, ConfigurationService, DatabaseService, MailingService, NewsService, SeatingService, TicketingService, UserService, LocalDataService, ReceiptPrintingService, TournamentService, TeamService]: +def init_services() -> tuple[AccountingService, CateringService, ConfigurationService, DatabaseService, MailingService, NewsService, SeatingService, TicketingService, UserService, LocalDataService, ReceiptPrintingService, TournamentService, TeamService, RefreshService]: logging.basicConfig(level=logging.DEBUG) configuration_service = ConfigurationService(from_root("config.toml")) db_service = DatabaseService(configuration_service.get_database_configuration()) @@ -34,6 +35,6 @@ def init_services() -> tuple[AccountingService, CateringService, ConfigurationSe local_data_service = LocalDataService() tournament_service = TournamentService(db_service, user_service) team_service = TeamService(db_service) + refresh_service = RefreshService() - - return accounting_service, catering_service, configuration_service, db_service, mailing_service, news_service, seating_service, ticketing_service, user_service, local_data_service, receipt_printing_service, tournament_service, team_service + return accounting_service, catering_service, configuration_service, db_service, mailing_service, news_service, seating_service, ticketing_service, user_service, local_data_service, receipt_printing_service, tournament_service, team_service, refresh_service diff --git a/src/ezgg_lan_manager/components/DesktopNavigation.py b/src/ezgg_lan_manager/components/DesktopNavigation.py index 51ae7ce..aaf2e19 100644 --- a/src/ezgg_lan_manager/components/DesktopNavigation.py +++ b/src/ezgg_lan_manager/components/DesktopNavigation.py @@ -1,47 +1,29 @@ from typing import Optional, Callable -from rio import * +from rio import Component, event, Spacer, Card, Column, Text, TextStyle -from src.ezgg_lan_manager import ConfigurationService, UserService, LocalDataService +from src.ezgg_lan_manager.services.ConfigurationService import ConfigurationService +from src.ezgg_lan_manager.services.UserService import UserService from src.ezgg_lan_manager.components.DesktopNavigationButton import DesktopNavigationButton from src.ezgg_lan_manager.components.NavigationSponsorBox import NavigationSponsorBox from src.ezgg_lan_manager.components.UserInfoAndLoginBox import UserInfoAndLoginBox -from src.ezgg_lan_manager.services.LocalDataService import LocalData -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class DesktopNavigation(Component): user: Optional[User] = None - force_login_box_refresh: list[Callable] = [] @event.on_populate - async def async_init(self) -> None: - self.session[SessionStorage].subscribe_to_logged_in_or_out_event(str(self.__class__), self.async_init) - local_data = self.session[LocalData] - if local_data.stored_session_token: - session_ = self.session[LocalDataService].verify_token(local_data.stored_session_token) - if session_: - self.session.detach(SessionStorage) - self.session.attach(session_) - self.user = await self.session[UserService].get_user(session_.user_id) - try: - # Hack-around, maybe fix in the future - self.force_login_box_refresh[-1]() - except IndexError: - pass - - return - - if self.session[SessionStorage].user_id: - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - else: + async def on_populate(self) -> None: + try: + self.user = await self.session[UserService].get_user(self.session[UserSession].user_id) + except KeyError: self.user = None def build(self) -> Component: lan_info = self.session[ConfigurationService].get_lan_info() - user_info_and_login_box = UserInfoAndLoginBox() - self.force_login_box_refresh.append(user_info_and_login_box.force_refresh) + user_info_and_login_box = UserInfoAndLoginBox(state_changed_cb=self.on_populate) navigation = [ DesktopNavigationButton("News", "./news"), Spacer(min_height=0.7), diff --git a/src/ezgg_lan_manager/components/LoginBox.py b/src/ezgg_lan_manager/components/LoginBox.py index 6a83129..bef1744 100644 --- a/src/ezgg_lan_manager/components/LoginBox.py +++ b/src/ezgg_lan_manager/components/LoginBox.py @@ -1,10 +1,13 @@ -from rio import Component, TextStyle, Color, TextInput, Button, Text, Rectangle, Column, Row, Spacer, \ - EventHandler +import uuid +from rio import Component, TextStyle, Color, TextInput, Button, Text, Rectangle, Column, Row, Spacer, \ + EventHandler, Webview + +from src.ezgg_lan_manager import RefreshService from src.ezgg_lan_manager.services.LocalDataService import LocalDataService, LocalData from src.ezgg_lan_manager.services.UserService import UserService -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class LoginBox(Component): @@ -26,11 +29,13 @@ class LoginBox(Component): self.password_input_is_valid = True self.login_button_is_loading = False self.is_account_locked = False - await self.session[SessionStorage].set_user_id_and_team_member_flag(user.user_id, user.is_team_member) - token = self.session[LocalDataService].set_session(self.session[SessionStorage]) + user_session = UserSession(id=uuid.uuid4(), user_id=user.user_id, is_team_member=user.is_team_member) + self.session.attach(user_session) + token = self.session[LocalDataService].set_session(user_session) self.session[LocalData].stored_session_token = token self.session.attach(self.session[LocalData]) - self.status_change_cb() + await self.status_change_cb() + await self.session[RefreshService].trigger_refresh() else: self.user_name_input_is_valid = False self.password_input_is_valid = False diff --git a/src/ezgg_lan_manager/components/SeatingPlanInfoBox.py b/src/ezgg_lan_manager/components/SeatingPlanInfoBox.py index 70bbd91..7feff63 100644 --- a/src/ezgg_lan_manager/components/SeatingPlanInfoBox.py +++ b/src/ezgg_lan_manager/components/SeatingPlanInfoBox.py @@ -1,11 +1,10 @@ from decimal import Decimal -from functools import partial from typing import Optional, Callable from rio import Component, Column, Text, TextStyle, Button, Spacer, event from src.ezgg_lan_manager import TicketingService -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession class SeatingPlanInfoBox(Component): @@ -22,11 +21,14 @@ class SeatingPlanInfoBox(Component): @event.on_populate async def check_ticket(self) -> None: - if self.session[SessionStorage].user_id: - user_ticket = await self.session[TicketingService].get_user_ticket(self.session[SessionStorage].user_id) + try: + user_id = self.session[UserSession].user_id + user_ticket = await self.session[TicketingService].get_user_ticket(user_id) self.has_user_ticket = not (user_ticket is None) self.booking_button_text = "Buchen" if self.has_user_ticket else "Ticket kaufen" self.force_refresh() + except KeyError: + return async def purchase_clicked(self): if self.has_user_ticket: @@ -35,6 +37,11 @@ class SeatingPlanInfoBox(Component): self.session.navigate_to("./buy_ticket") def build(self) -> Component: + try: + user_id = self.session[UserSession].user_id + except KeyError: + user_id = None + if self.override_text: return Column(Text(self.override_text, margin=1, style=TextStyle(fill=self.session.theme.neutral_color, font_size=1.4), overflow="wrap", @@ -75,9 +82,9 @@ class SeatingPlanInfoBox(Component): grow_y=False, is_sensitive=not self.is_booking_blocked, on_press=self.purchase_clicked - ) if self.session[SessionStorage].user_id else Text(f"Du musst eingeloggt sein um einen Sitzplatz zu buchen", - margin=1, - style=TextStyle(fill=self.session.theme.neutral_color), - overflow="wrap", justify="center"), + ) if user_id is not None else Text(f"Du musst eingeloggt sein um einen Sitzplatz zu buchen", + margin=1, + style=TextStyle(fill=self.session.theme.neutral_color), + overflow="wrap", justify="center"), min_height=10 ) diff --git a/src/ezgg_lan_manager/components/SeatingPlanPixels.py b/src/ezgg_lan_manager/components/SeatingPlanPixels.py index dd4f53f..3c937f5 100644 --- a/src/ezgg_lan_manager/components/SeatingPlanPixels.py +++ b/src/ezgg_lan_manager/components/SeatingPlanPixels.py @@ -4,7 +4,7 @@ from rio import Component, Text, Icon, TextStyle, Rectangle, Spacer, Color, Poin from typing import Optional, Callable, Literal from src.ezgg_lan_manager.types.Seat import Seat -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession class SeatPixel(Component): @@ -14,7 +14,11 @@ class SeatPixel(Component): seat_orientation: Literal["top", "bottom"] def determine_color(self) -> Color: - if self.seat.user is not None and self.seat.user.user_id == self.session[SessionStorage].user_id: + try: + user_id = self.session[UserSession].user_id + except KeyError: + user_id = None + if self.seat.user is not None and self.seat.user.user_id == user_id: return Color.from_hex("800080") elif self.seat.is_blocked or self.seat.user is not None: return self.session.theme.danger_color diff --git a/src/ezgg_lan_manager/components/ShoppingCartAndOrders.py b/src/ezgg_lan_manager/components/ShoppingCartAndOrders.py index dc51237..01c54db 100644 --- a/src/ezgg_lan_manager/components/ShoppingCartAndOrders.py +++ b/src/ezgg_lan_manager/components/ShoppingCartAndOrders.py @@ -1,5 +1,6 @@ from asyncio import sleep, create_task from decimal import Decimal +from typing import Optional from rio import Component, Column, Text, TextStyle, Button, Row, ScrollContainer, Spacer, Popup, Table, event, Card @@ -8,7 +9,7 @@ from src.ezgg_lan_manager.components.CateringOrderItem import CateringOrderItem from src.ezgg_lan_manager.services.AccountingService import AccountingService from src.ezgg_lan_manager.services.CateringService import CateringService, CateringError, CateringErrorType from src.ezgg_lan_manager.types.CateringOrder import CateringOrder, CateringMenuItemsWithAmount -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession POPUP_CLOSE_TIMEOUT_SECONDS = 3 @@ -22,16 +23,21 @@ class ShoppingCartAndOrders(Component): @event.periodic(5) async def periodic_refresh_of_orders(self) -> None: - if not self.show_cart and not self.popup_is_shown: - self.orders = await self.session[CateringService].get_orders_for_user(self.session[SessionStorage].user_id) + user_id = self._get_user_id() + if not self.show_cart and not self.popup_is_shown and user_id is not None: + self.orders = await self.session[CateringService].get_orders_for_user(user_id) async def switch(self) -> None: self.show_cart = not self.show_cart - self.orders = await self.session[CateringService].get_orders_for_user(self.session[SessionStorage].user_id) + user_id = self._get_user_id() + if user_id is not None: + self.orders = await self.session[CateringService].get_orders_for_user(user_id) async def on_remove_item(self, list_id: int) -> None: catering_service = self.session[CateringService] - user_id = self.session[SessionStorage].user_id + user_id = self._get_user_id() + if user_id is None: + return cart = catering_service.get_cart(user_id) try: cart.pop(list_id) @@ -41,13 +47,16 @@ class ShoppingCartAndOrders(Component): self.force_refresh() async def on_empty_cart_pressed(self) -> None: - self.session[CateringService].save_cart(self.session[SessionStorage].user_id, []) + user_id = self._get_user_id() + if user_id is None: + return + self.session[CateringService].save_cart(user_id, []) self.force_refresh() async def on_add_item(self, article_id: int) -> None: catering_service = self.session[CateringService] - user_id = self.session[SessionStorage].user_id - if not user_id: + user_id = self._get_user_id() + if user_id is None: return cart = catering_service.get_cart(user_id) item_to_add = await catering_service.get_menu_item_by_id(article_id) @@ -68,7 +77,9 @@ class ShoppingCartAndOrders(Component): self.order_button_loading = True self.force_refresh() - user_id = self.session[SessionStorage].user_id + user_id = self._get_user_id() + if user_id is None: + return cart = self.session[CateringService].get_cart(user_id) show_popup_task = None if len(cart) < 1: @@ -90,7 +101,7 @@ class ShoppingCartAndOrders(Component): else: show_popup_task = create_task(self.show_popup("Unbekannter Fehler", True)) else: - self.session[CateringService].save_cart(self.session[SessionStorage].user_id, []) + self.session[CateringService].save_cart(user_id, []) self.order_button_loading = False if not show_popup_task: show_popup_task = create_task(self.show_popup("Bestellung erfolgreich aufgegeben!", False)) @@ -133,10 +144,16 @@ class ShoppingCartAndOrders(Component): ) await dialog.wait_for_close() + def _get_user_id(self) -> Optional[int]: + try: + return self.session[UserSession].user_id + except KeyError: + return None + def build(self) -> Component: - user_id = self.session[SessionStorage].user_id + user_id = self._get_user_id() catering_service = self.session[CateringService] - cart = catering_service.get_cart(user_id) + cart = catering_service.get_cart(user_id) if user_id is not None else [] if self.show_cart: cart_container = ScrollContainer( content=Column( diff --git a/src/ezgg_lan_manager/components/TicketBuyCard.py b/src/ezgg_lan_manager/components/TicketBuyCard.py index 3937e3e..86584b3 100644 --- a/src/ezgg_lan_manager/components/TicketBuyCard.py +++ b/src/ezgg_lan_manager/components/TicketBuyCard.py @@ -2,7 +2,6 @@ from functools import partial from typing import Callable, Optional from decimal import Decimal -import rio from rio import Component, Card, Column, Text, Row, Button, TextStyle, ProgressBar, event, Spacer from src.ezgg_lan_manager import TicketingService @@ -22,10 +21,10 @@ class TicketBuyCard(Component): available_tickets: int = 0 @event.on_populate - async def async_init(self) -> None: + async def on_populate(self) -> None: self.available_tickets = await self.session[TicketingService].get_available_tickets_for_category(self.category) - def build(self) -> rio.Component: + def build(self) -> Component: ticket_description_style = TextStyle( fill=self.session.theme.neutral_color, font_size=1.2, diff --git a/src/ezgg_lan_manager/components/UserEditForm.py b/src/ezgg_lan_manager/components/UserEditForm.py index 50f2780..b3aa0de 100644 --- a/src/ezgg_lan_manager/components/UserEditForm.py +++ b/src/ezgg_lan_manager/components/UserEditForm.py @@ -9,8 +9,8 @@ from rio import Component, Column, Button, Color, TextStyle, Text, TextInput, Ro from src.ezgg_lan_manager.services.UserService import UserService, NameNotAllowedError from src.ezgg_lan_manager.services.ConfigurationService import ConfigurationService -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class UserEditForm(Component): @@ -35,8 +35,13 @@ class UserEditForm(Component): async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Profil bearbeiten") if self.is_own_profile: - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - self.profile_picture = await self.session[UserService].get_profile_picture(self.user.user_id) + try: + user_id = self.session[UserSession].user_id + except KeyError: + self.session.navigate_to("/") + else: + self.user = await self.session[UserService].get_user(user_id) + self.profile_picture = await self.session[UserService].get_profile_picture(self.user.user_id) else: self.profile_picture = await self.session[UserService].get_profile_picture(self.user.user_id) diff --git a/src/ezgg_lan_manager/components/UserInfoAndLoginBox.py b/src/ezgg_lan_manager/components/UserInfoAndLoginBox.py index dad5015..aa41a01 100644 --- a/src/ezgg_lan_manager/components/UserInfoAndLoginBox.py +++ b/src/ezgg_lan_manager/components/UserInfoAndLoginBox.py @@ -1,15 +1,18 @@ import logging +from typing import Callable from rio import Component from src.ezgg_lan_manager.components.LoginBox import LoginBox from src.ezgg_lan_manager.components.UserInfoBox import UserInfoBox -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession logger = logging.getLogger(__name__.split(".")[-1]) class UserInfoAndLoginBox(Component): + state_changed_cb: Callable def build(self) -> Component: - if self.session[SessionStorage].user_id is None: - return LoginBox(status_change_cb=self.force_refresh) - else: - return UserInfoBox(status_change_cb=self.force_refresh) + try: + user_id = self.session[UserSession].user_id + return UserInfoBox(status_change_cb=self.state_changed_cb, user_id=user_id) + except KeyError: + return LoginBox(status_change_cb=self.state_changed_cb) diff --git a/src/ezgg_lan_manager/components/UserInfoBox.py b/src/ezgg_lan_manager/components/UserInfoBox.py index 4db6d55..9038b68 100644 --- a/src/ezgg_lan_manager/components/UserInfoBox.py +++ b/src/ezgg_lan_manager/components/UserInfoBox.py @@ -6,6 +6,7 @@ from rio import Component, TextStyle, Color, Button, Text, Rectangle, Column, Ro from src.ezgg_lan_manager.components.UserInfoBoxButton import UserInfoBoxButton from src.ezgg_lan_manager.services.LocalDataService import LocalData, LocalDataService +from src.ezgg_lan_manager.services.RefreshService import RefreshService from src.ezgg_lan_manager.services.UserService import UserService from src.ezgg_lan_manager.services.AccountingService import AccountingService from src.ezgg_lan_manager.services.TicketingService import TicketingService @@ -13,7 +14,7 @@ from src.ezgg_lan_manager.services.SeatingService import SeatingService from src.ezgg_lan_manager.types.Seat import Seat from src.ezgg_lan_manager.types.Ticket import Ticket from src.ezgg_lan_manager.types.User import User -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession class StatusButton(Component): @@ -41,6 +42,7 @@ class StatusButton(Component): class UserInfoBox(Component): + user_id: int status_change_cb: EventHandler = None TEXT_STYLE = TextStyle(fill=Color.from_hex("02dac5"), font_size=0.9) user: Optional[User] = None @@ -53,31 +55,26 @@ class UserInfoBox(Component): return choice(["Guten Tacho", "Tuten Gag", "Servus", "Moinjour", "Hallöchen", "Heyho", "Moinsen"]) async def logout(self) -> None: - await self.session[SessionStorage].clear() + self.session.detach(UserSession) self.user = None self.session[LocalDataService].del_session(self.session[LocalData].stored_session_token) self.session[LocalData].stored_session_token = None self.session.attach(self.session[LocalData]) - self.status_change_cb() - self.session.navigate_to("/") + await self.status_change_cb() + await self.session[RefreshService].trigger_refresh() @event.on_populate async def async_init(self) -> None: - if self.session[SessionStorage].user_id: - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - self.user_balance = await self.session[AccountingService].get_balance(self.user.user_id) - self.user_ticket = await self.session[TicketingService].get_user_ticket(self.user.user_id) - self.user_seat = await self.session[SeatingService].get_user_seat(self.user.user_id) - self.session[AccountingService].add_update_hook(self.update) - - async def update(self) -> None: - if not self.user: - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - if not self.user: - return + self.user = await self.session[UserService].get_user(self.user_id) self.user_balance = await self.session[AccountingService].get_balance(self.user.user_id) self.user_ticket = await self.session[TicketingService].get_user_ticket(self.user.user_id) self.user_seat = await self.session[SeatingService].get_user_seat(self.user.user_id) + self.session[AccountingService].add_update_hook(self.update) + + async def update(self) -> None: + self.user_balance = await self.session[AccountingService].get_balance(self.user_id) + self.user_ticket = await self.session[TicketingService].get_user_ticket(self.user_id) + self.user_seat = await self.session[SeatingService].get_user_seat(self.user_id) def build(self) -> Component: if not self.user: diff --git a/src/ezgg_lan_manager/helpers/LoggedInGuard.py b/src/ezgg_lan_manager/helpers/LoggedInGuard.py index 1d51ca8..4d9c714 100644 --- a/src/ezgg_lan_manager/helpers/LoggedInGuard.py +++ b/src/ezgg_lan_manager/helpers/LoggedInGuard.py @@ -3,22 +3,32 @@ from typing import Optional from rio import URL, GuardEvent from src.ezgg_lan_manager.services.UserService import UserService -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession # Guards pages against access from users that are NOT logged in def logged_in_guard(event: GuardEvent) -> Optional[URL]: - if event.session[SessionStorage].user_id is None: + try: + _ = event.session[UserSession].user_id + return None + except KeyError: return URL("./") # Guards pages against access from users that ARE logged in def not_logged_in_guard(event: GuardEvent) -> Optional[URL]: - if event.session[SessionStorage].user_id is not None: + try: + _ = event.session[UserSession].user_id return URL("./") + except KeyError: + return None # Guards pages against access from users that are NOT logged in and NOT team members def team_guard(event: GuardEvent) -> Optional[URL]: - user_id = event.session[SessionStorage].user_id - is_team_member = event.session[SessionStorage].is_team_member - if user_id is None or not is_team_member: + try: + user_id = event.session[UserSession].user_id + is_team_member = event.session[UserSession].is_team_member + if user_id and is_team_member: + return None + return URL("./") + except KeyError: return URL("./") diff --git a/src/ezgg_lan_manager/pages/Account.py b/src/ezgg_lan_manager/pages/Account.py index 8374037..7222943 100644 --- a/src/ezgg_lan_manager/pages/Account.py +++ b/src/ezgg_lan_manager/pages/Account.py @@ -1,14 +1,13 @@ from decimal import Decimal -from functools import partial from typing import Optional from rio import Column, Component, event, Text, TextStyle, Button, Color, Revealer, Row, ProgressCircle, Link from src.ezgg_lan_manager import ConfigurationService, UserService, AccountingService from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.Transaction import Transaction from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class AccountPage(Component): @@ -21,9 +20,14 @@ class AccountPage(Component): @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Guthabenkonto") - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - self.balance = await self.session[AccountingService].get_balance(self.user.user_id) - self.transaction_history = await self.session[AccountingService].get_transaction_history(self.user.user_id) + try: + user_id = self.session[UserSession].user_id + except KeyError: + pass + else: + self.user = await self.session[UserService].get_user(user_id) + self.balance = await self.session[AccountingService].get_balance(user_id) + self.transaction_history = await self.session[AccountingService].get_transaction_history(user_id) async def _on_banking_info_press(self) -> None: self.banking_info_revealer_open = not self.banking_info_revealer_open diff --git a/src/ezgg_lan_manager/pages/BuyTicketPage.py b/src/ezgg_lan_manager/pages/BuyTicketPage.py index 4ae6a6f..88339c4 100644 --- a/src/ezgg_lan_manager/pages/BuyTicketPage.py +++ b/src/ezgg_lan_manager/pages/BuyTicketPage.py @@ -2,14 +2,14 @@ from typing import Optional from rio import Text, Column, TextStyle, Component, event, Button, Popup -from src.ezgg_lan_manager import ConfigurationService, UserService, TicketingService +from src.ezgg_lan_manager import ConfigurationService, UserService, TicketingService, RefreshService from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ezgg_lan_manager.components.TicketBuyCard import TicketBuyCard from src.ezgg_lan_manager.services.AccountingService import InsufficientFundsError from src.ezgg_lan_manager.services.TicketingService import TicketNotAvailableError, UserAlreadyHasTicketError -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.Ticket import Ticket from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class BuyTicketPage(Component): @@ -23,12 +23,18 @@ class BuyTicketPage(Component): @event.on_populate async def on_populate(self) -> None: - self.session[SessionStorage].subscribe_to_logged_in_or_out_event(str(self.__class__), self.on_populate) + self.session[RefreshService].subscribe(self.on_populate) await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Ticket kaufen") - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) + try: + user_id = self.session[UserSession].user_id + except KeyError: + self.user = None + else: + self.user = await self.session[UserService].get_user(user_id) if self.user is None: # No user logged in self.is_buying_enabled = False self.is_user_logged_in = False + self.user_ticket = None else: # User is logged in self.is_user_logged_in = True possible_ticket = await self.session[TicketingService].get_user_ticket(self.user.user_id) diff --git a/src/ezgg_lan_manager/pages/CateringPage.py b/src/ezgg_lan_manager/pages/CateringPage.py index 420c989..74d40fc 100644 --- a/src/ezgg_lan_manager/pages/CateringPage.py +++ b/src/ezgg_lan_manager/pages/CateringPage.py @@ -6,8 +6,9 @@ from src.ezgg_lan_manager import ConfigurationService, CateringService from src.ezgg_lan_manager.components.CateringSelectionItem import CateringSelectionItem from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ezgg_lan_manager.components.ShoppingCartAndOrders import ShoppingCartAndOrders +from src.ezgg_lan_manager.services.RefreshService import RefreshService from src.ezgg_lan_manager.types.CateringMenuItem import CateringMenuItemCategory, CateringMenuItem -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession class CateringPage(Component): @@ -15,11 +16,9 @@ class CateringPage(Component): all_menu_items: Optional[list[CateringMenuItem]] = None shopping_cart_and_orders: list[ShoppingCartAndOrders] = [] - def __post_init__(self) -> None: - self.session[SessionStorage].subscribe_to_logged_in_or_out_event(self.__class__.__name__, self.on_user_logged_in_status_changed) - @event.on_populate async def on_populate(self) -> None: + self.session[RefreshService].subscribe(self.on_populate) await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Catering") self.all_menu_items = await self.session[CateringService].get_menu() @@ -34,7 +33,10 @@ class CateringPage(Component): return list(filter(lambda item: item.category == category, all_menu_items)) def build(self) -> Component: - user_id = self.session[SessionStorage].user_id + try: + user_id = self.session[UserSession].user_id + except KeyError: + user_id = None if len(self.shopping_cart_and_orders) == 0: self.shopping_cart_and_orders.append(ShoppingCartAndOrders()) if len(self.shopping_cart_and_orders) > 1: diff --git a/src/ezgg_lan_manager/pages/ContactPage.py b/src/ezgg_lan_manager/pages/ContactPage.py index de25100..1d573fe 100644 --- a/src/ezgg_lan_manager/pages/ContactPage.py +++ b/src/ezgg_lan_manager/pages/ContactPage.py @@ -5,8 +5,8 @@ from rio import Text, Column, TextStyle, Component, event, TextInput, MultiLineT from src.ezgg_lan_manager import ConfigurationService, UserService, MailingService from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class ContactPage(Component): @@ -25,9 +25,9 @@ class ContactPage(Component): @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Kontakt") - if self.session[SessionStorage].user_id is not None: - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - else: + try: + self.user = await self.session[UserService].get_user(self.session[UserSession].user_id) + except KeyError: self.user = None self.e_mail = "" if not self.user else self.user.user_mail diff --git a/src/ezgg_lan_manager/pages/EditProfile.py b/src/ezgg_lan_manager/pages/EditProfile.py index 2d86c9a..516bcf9 100644 --- a/src/ezgg_lan_manager/pages/EditProfile.py +++ b/src/ezgg_lan_manager/pages/EditProfile.py @@ -1,23 +1,14 @@ -from typing import Optional - from rio import Column, Component, event, Spacer -from src.ezgg_lan_manager import ConfigurationService, UserService +from src.ezgg_lan_manager import ConfigurationService from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ezgg_lan_manager.components.UserEditForm import UserEditForm -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage -from src.ezgg_lan_manager.types.User import User class EditProfilePage(Component): - user: Optional[User] = None - pfp: Optional[bytes] = None - @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Profil bearbeiten") - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - self.pfp = await self.session[UserService].get_profile_picture(self.user.user_id) def build(self) -> Component: return Column( diff --git a/src/ezgg_lan_manager/pages/ManageUsersPage.py b/src/ezgg_lan_manager/pages/ManageUsersPage.py index c62896f..6632739 100644 --- a/src/ezgg_lan_manager/pages/ManageUsersPage.py +++ b/src/ezgg_lan_manager/pages/ManageUsersPage.py @@ -11,9 +11,9 @@ from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBo from src.ezgg_lan_manager.components.NewTransactionForm import NewTransactionForm from src.ezgg_lan_manager.components.UserEditForm import UserEditForm from src.ezgg_lan_manager.services.AccountingService import InsufficientFundsError -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.Transaction import Transaction from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession logger = logging.getLogger(__name__.split(".")[-1]) @@ -84,7 +84,11 @@ class ManageUsersPage(Component): await self.session[UserService].update_user(self.selected_user) async def on_new_transaction(self, transaction: Transaction) -> None: - if not self.session[SessionStorage].is_team_member: # Better safe than sorry + try: + user = await self.session[UserService].get_user(self.session[UserSession].user_id) + if not user.is_team_member: # Better safe than sorry + return + except KeyError: return logger.info(f"Got new transaction for user with ID '{transaction.user_id}' over " diff --git a/src/ezgg_lan_manager/pages/SeatingPlanPage.py b/src/ezgg_lan_manager/pages/SeatingPlanPage.py index 08a7ccf..12b44ee 100644 --- a/src/ezgg_lan_manager/pages/SeatingPlanPage.py +++ b/src/ezgg_lan_manager/pages/SeatingPlanPage.py @@ -12,8 +12,8 @@ from src.ezgg_lan_manager.components.SeatingPlanInfoBox import SeatingPlanInfoBo from src.ezgg_lan_manager.components.SeatingPurchaseBox import SeatingPurchaseBox from src.ezgg_lan_manager.services.SeatingService import NoTicketError, SeatNotFoundError, WrongCategoryError, SeatAlreadyTakenError from src.ezgg_lan_manager.types.Seat import Seat -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession logger = logging.getLogger(__name__.split(".")[-1]) @@ -37,7 +37,10 @@ class SeatingPlanPage(Component): async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Sitzplan") self.seating_info = await self.session[SeatingService].get_seating() - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) + try: + self.user = await self.session[UserService].get_user(self.session[UserSession].user_id) + except KeyError: + self.user = None if not self.user: self.is_booking_blocked = True else: diff --git a/src/ezgg_lan_manager/pages/TeamsPage.py b/src/ezgg_lan_manager/pages/TeamsPage.py index b6c94a6..9eb90bc 100644 --- a/src/ezgg_lan_manager/pages/TeamsPage.py +++ b/src/ezgg_lan_manager/pages/TeamsPage.py @@ -6,11 +6,12 @@ from src.ezgg_lan_manager import ConfigurationService from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ezgg_lan_manager.components.TeamRevealer import TeamRevealer from src.ezgg_lan_manager.components.TeamsDialogHandler import * +from src.ezgg_lan_manager.services.RefreshService import RefreshService from src.ezgg_lan_manager.services.TeamService import TeamService from src.ezgg_lan_manager.services.UserService import UserService -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.Team import Team from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession class TeamsPage(Component): @@ -26,10 +27,13 @@ class TeamsPage(Component): @event.on_populate async def on_populate(self) -> None: - self.all_teams = await self.session[TeamService].get_all_teams() - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Teams") - self.session[SessionStorage].subscribe_to_logged_in_or_out_event(str(self.__class__), self.on_populate) + self.session[RefreshService].subscribe(self.on_populate) + self.all_teams = await self.session[TeamService].get_all_teams() + try: + self.user = await self.session[UserService].get_user(self.session[UserSession].user_id) + except KeyError: + self.user = None async def on_join_button_pressed(self, team: Team) -> None: if self.user is None: diff --git a/src/ezgg_lan_manager/pages/TournamentDetailsPage.py b/src/ezgg_lan_manager/pages/TournamentDetailsPage.py index d852a30..b2c3f9f 100644 --- a/src/ezgg_lan_manager/pages/TournamentDetailsPage.py +++ b/src/ezgg_lan_manager/pages/TournamentDetailsPage.py @@ -11,11 +11,11 @@ from src.ezgg_lan_manager import ConfigurationService, TournamentService, UserSe from src.ezgg_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ezgg_lan_manager.components.TournamentDetailsInfoRow import TournamentDetailsInfoRow from src.ezgg_lan_manager.types.DateUtil import weekday_to_display_text -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage from src.ezgg_lan_manager.types.Team import Team, TeamStatus from src.ezgg_lan_manager.types.Tournament import Tournament from src.ezgg_lan_manager.types.TournamentBase import TournamentStatus, tournament_status_to_display_text, tournament_format_to_display_texts, ParticipantType from src.ezgg_lan_manager.types.User import User +from src.ezgg_lan_manager.types.UserSession import UserSession logger = logging.getLogger(__name__.split(".")[-1]) @@ -53,9 +53,13 @@ class TournamentDetailsPage(Component): else: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Turniere") - self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) - if self.user is not None: - self.user_teams = await self.session[TeamService].get_teams_for_user_by_id(self.user.user_id) + try: + user_id = self.session[UserSession].user_id + self.user = await self.session[UserService].get_user(user_id) + self.user_teams = await self.session[TeamService].get_teams_for_user_by_id(user_id) + except KeyError: + self.user = None + self.user_teams = [] self.loading_done() @@ -167,7 +171,6 @@ class TournamentDetailsPage(Component): def loading_done(self) -> None: if self.tournament is None: self.tournament = "Turnier konnte nicht gefunden werden" - self.session[SessionStorage].subscribe_to_logged_in_or_out_event(str(self.__class__), self.on_populate) def build(self) -> Component: if self.tournament is None: diff --git a/src/ezgg_lan_manager/services/LocalDataService.py b/src/ezgg_lan_manager/services/LocalDataService.py index 993bfcb..a8fb882 100644 --- a/src/ezgg_lan_manager/services/LocalDataService.py +++ b/src/ezgg_lan_manager/services/LocalDataService.py @@ -3,20 +3,20 @@ from typing import Optional from rio import UserSettings -from src.ezgg_lan_manager.types.SessionStorage import SessionStorage +from src.ezgg_lan_manager.types.UserSession import UserSession class LocalData(UserSettings): - stored_session_token: Optional[str] = None + stored_session_token: Optional[str] class LocalDataService: def __init__(self) -> None: - self._session: dict[str, SessionStorage] = {} + self._session: dict[str, UserSession] = {} - def verify_token(self, token: str) -> Optional[SessionStorage]: + def verify_token(self, token: str) -> Optional[UserSession]: return self._session.get(token) - def set_session(self, session: SessionStorage) -> str: + def set_session(self, session: UserSession) -> str: key = secrets.token_hex(32) self._session[key] = session return key diff --git a/src/ezgg_lan_manager/services/RefreshService.py b/src/ezgg_lan_manager/services/RefreshService.py new file mode 100644 index 0000000..baaef84 --- /dev/null +++ b/src/ezgg_lan_manager/services/RefreshService.py @@ -0,0 +1,17 @@ +from typing import Callable + + +class RefreshService: + """ + rio.Components can subscribe to this service with their on_populate method. + Those methods get called whenever a overall refresh is needed. Usually when the user logs in or out. + """ + def __init__(self) -> None: + self.subscribers: set[Callable] = set() + + def subscribe(self, refresh_cb: Callable) -> None: + self.subscribers.add(refresh_cb) + + async def trigger_refresh(self) -> None: + for refresh_cb in self.subscribers: + await refresh_cb() diff --git a/src/ezgg_lan_manager/types/SessionStorage.py b/src/ezgg_lan_manager/types/SessionStorage.py deleted file mode 100644 index 206cdd2..0000000 --- a/src/ezgg_lan_manager/types/SessionStorage.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Optional - -logger = logging.getLogger(__name__.split(".")[-1]) - - -# ToDo: Persist between reloads: https://rio.dev/docs/howto/persistent-settings -# Note for ToDo: rio.UserSettings are saved LOCALLY, do not just read a user_id here! -@dataclass(frozen=False) -class SessionStorage: - _user_id: Optional[int] = None # DEBUG: Put user ID here to skip login - _is_team_member: bool = False - _notification_callbacks: dict[str, Callable] = field(default_factory=dict) - - async def clear(self) -> None: - await self.set_user_id_and_team_member_flag(None, False) - - def subscribe_to_logged_in_or_out_event(self, component_id: str, callback: Callable) -> None: - self._notification_callbacks[component_id] = callback - - @property - def user_id(self) -> Optional[int]: - return self._user_id - - @property - def is_team_member(self) -> bool: - return self._is_team_member - - async def set_user_id_and_team_member_flag(self, user_id: Optional[int], is_team_member: bool) -> None: - self._user_id = user_id - self._is_team_member = is_team_member - for component_id, callback in self._notification_callbacks.items(): - logger.debug(f"Calling logged in callback from {component_id}") - await callback() diff --git a/src/ezgg_lan_manager/types/UserSession.py b/src/ezgg_lan_manager/types/UserSession.py new file mode 100644 index 0000000..52e4e4f --- /dev/null +++ b/src/ezgg_lan_manager/types/UserSession.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from uuid import UUID + + +@dataclass +class UserSession: + id: UUID + user_id: int + is_team_member: bool