From 30b32a4c025390242ebabdaef933214a96c58c03 Mon Sep 17 00:00:00 2001 From: David Rodenkirchen Date: Tue, 3 Sep 2024 14:30:32 +0200 Subject: [PATCH] aiomysql refactor --- requirements.txt | Bin 170 -> 190 bytes src/EzLanManager.py | 17 +- .../components/CateringOrderItem.py | 4 +- .../components/DesktopNavigation.py | 8 +- .../components/ShoppingCartAndOrders.py | 9 +- .../components/UserInfoAndLoginBox.py | 51 +- src/ez_lan_manager/pages/Account.py | 39 +- src/ez_lan_manager/pages/BasePage.py | 7 +- src/ez_lan_manager/pages/CateringPage.py | 56 +- src/ez_lan_manager/pages/ContactPage.py | 14 +- src/ez_lan_manager/pages/DbErrorPage.py | 12 +- src/ez_lan_manager/pages/EditProfile.py | 43 +- src/ez_lan_manager/pages/ForgotPassword.py | 4 +- src/ez_lan_manager/pages/GuestsPage.py | 31 +- src/ez_lan_manager/pages/NewsPage.py | 2 +- src/ez_lan_manager/pages/RegisterPage.py | 4 +- .../services/AccountingService.py | 22 +- .../services/CateringService.py | 70 +- .../services/DatabaseService.py | 1149 +++++++++-------- src/ez_lan_manager/services/NewsService.py | 15 +- src/ez_lan_manager/services/SeatingService.py | 34 +- .../services/TicketingService.py | 30 +- src/ez_lan_manager/services/UserService.py | 32 +- src/ez_lan_manager/types/User.py | 3 + 24 files changed, 901 insertions(+), 755 deletions(-) diff --git a/requirements.txt b/requirements.txt index ffb0c349e554dd2e41200e45af8bc020c7172e8e..54b23df4e6b70533b8f4a90e613f680d0a59643a 100644 GIT binary patch delta 60 zcmZ3*xQ}synsg#VCPO|$E<+_lF+(9k4nrM-ErS7r9)l5s-b6=XVOuEI5-4QMP{feT JU@$Q_8~}))3(f!l delta 40 pcmdnTxQcOtnshEhB0~{FCXh^FNMfjCuw^i0(1YNKj=~c&!T`W52 None: + init_result = await a.default_attachments[3].init_db_pool() + if not init_result: + logger.fatal("Could not connect to database, exiting...") + sys.exit(1) + app = App( name="EZ LAN Manager", pages=[ @@ -138,6 +140,7 @@ if __name__ == "__main__": assets_dir=Path(__file__).parent / "assets", default_attachments=services, on_session_start=on_session_start, + on_app_start=on_app_start, icon=from_root("src/ez_lan_manager/assets/img/favicon.png"), meta_tags={ "robots": "INDEX,FOLLOW", diff --git a/src/ez_lan_manager/components/CateringOrderItem.py b/src/ez_lan_manager/components/CateringOrderItem.py index 4018c1a..f82a49e 100644 --- a/src/ez_lan_manager/components/CateringOrderItem.py +++ b/src/ez_lan_manager/components/CateringOrderItem.py @@ -1,10 +1,8 @@ from datetime import datetime -from typing import Callable import rio -from rio import Component, Row, Text, IconButton, TextStyle, Color +from rio import Component, Row, Text, TextStyle, Color -from src.ez_lan_manager import AccountingService from src.ez_lan_manager.types.CateringOrder import CateringOrderStatus MAX_LEN = 24 diff --git a/src/ez_lan_manager/components/DesktopNavigation.py b/src/ez_lan_manager/components/DesktopNavigation.py index dbcb381..ae86f70 100644 --- a/src/ez_lan_manager/components/DesktopNavigation.py +++ b/src/ez_lan_manager/components/DesktopNavigation.py @@ -6,19 +6,13 @@ from src.ez_lan_manager.components.UserInfoAndLoginBox import UserInfoAndLoginBo from src.ez_lan_manager.types.SessionStorage import SessionStorage class DesktopNavigation(Component): - def __post_init__(self) -> None: - self.session[SessionStorage].subscribe_to_logged_in_or_out_event(self.__class__.__name__, self.refresh_cb) - - async def refresh_cb(self) -> None: - await self.force_refresh() - def build(self) -> Component: lan_info = self.session[ConfigurationService].get_lan_info() return Card( Column( Text(lan_info.name, align_x=0.5, margin_top=0.3, style=TextStyle(fill=self.session.theme.hud_color, font_size=2.5)), Text(f"Edition {lan_info.iteration}", align_x=0.5, style=TextStyle(fill=self.session.theme.hud_color, font_size=1.2), margin_top=0.3, margin_bottom=2), - UserInfoAndLoginBox(refresh_cb=self.refresh_cb), + UserInfoAndLoginBox(), DesktopNavigationButton("News", "./news"), Spacer(min_height=1), DesktopNavigationButton(f"Über {lan_info.name} {lan_info.iteration}", "./overview"), diff --git a/src/ez_lan_manager/components/ShoppingCartAndOrders.py b/src/ez_lan_manager/components/ShoppingCartAndOrders.py index 1639626..b07d10a 100644 --- a/src/ez_lan_manager/components/ShoppingCartAndOrders.py +++ b/src/ez_lan_manager/components/ShoppingCartAndOrders.py @@ -1,3 +1,5 @@ +from typing import Optional + import rio from rio import Component, Column, Text, TextStyle, Button, Row, ScrollContainer, Spacer @@ -11,9 +13,11 @@ from src.ez_lan_manager.types.SessionStorage import SessionStorage class ShoppingCartAndOrders(Component): show_cart: bool = True + orders: list[CateringOrder] = [] async def switch(self) -> None: self.show_cart = not self.show_cart + self.orders = await self.session[CateringService].get_orders_for_user(self.session[SessionStorage].user_id) async def on_remove_item(self, list_id: int) -> None: catering_service = self.session[CateringService] @@ -36,7 +40,7 @@ class ShoppingCartAndOrders(Component): if not user_id: return cart = catering_service.get_cart(user_id) - cart.append(catering_service.get_menu_item_by_id(article_id)) + cart.append(await catering_service.get_menu_item_by_id(article_id)) catering_service.save_cart(user_id, cart) await self.force_refresh() @@ -99,14 +103,13 @@ class ShoppingCartAndOrders(Component): ) ) else: - orders = catering_service.get_orders_for_user(user_id) orders_container = ScrollContainer( content=Column( *[CateringOrderItem( order_id=order_item.order_id, order_datetime=order_item.order_date, order_status=order_item.status, - ) for order_item in orders], + ) for order_item in self.orders], Spacer(grow_y=True) ), min_height=8, diff --git a/src/ez_lan_manager/components/UserInfoAndLoginBox.py b/src/ez_lan_manager/components/UserInfoAndLoginBox.py index 769cd2d..9feccd1 100644 --- a/src/ez_lan_manager/components/UserInfoAndLoginBox.py +++ b/src/ez_lan_manager/components/UserInfoAndLoginBox.py @@ -1,15 +1,17 @@ import logging from random import choice -from typing import Callable +from typing import Optional -from rio import Component, Column, Text, Row, Rectangle, Button, TextStyle, Color, Spacer, TextInput, Link +from rio import Component, Column, Text, Row, Rectangle, Button, TextStyle, Color, Spacer, TextInput, Link, event from src.ez_lan_manager import UserService from src.ez_lan_manager.components.UserInfoBoxButton import UserInfoBoxButton from src.ez_lan_manager.services.AccountingService import AccountingService -from src.ez_lan_manager.services.DatabaseService import NoDatabaseConnectionError, DatabaseService from src.ez_lan_manager.services.TicketingService import TicketingService from src.ez_lan_manager.services.SeatingService import SeatingService +from src.ez_lan_manager.types.Seat import Seat +from src.ez_lan_manager.types.Ticket import Ticket +from src.ez_lan_manager.types.User import User from src.ez_lan_manager.types.SessionStorage import SessionStorage logger = logging.getLogger(__name__.split(".")[-1]) @@ -39,9 +41,20 @@ class StatusButton(Component): class UserInfoAndLoginBox(Component): - refresh_cb: Callable TEXT_STYLE = TextStyle(fill=Color.from_hex("02dac5"), font_size=0.9) show_login: bool = True + user: Optional[User] = None + user_balance: Optional[int] = 0 + user_ticket: Optional[Ticket] = None + user_seat: Optional[Seat] = None + + @event.on_populate + async def async_init(self) -> None: + if self.session[SessionStorage].user_id: + self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) + self.user_balance = await self.session[AccountingService].get_balance(self.user.user_id) + self.user_ticket = await self.session[TicketingService].get_user_ticket(self.user.user_id) + self.user_seat = await self.session[SeatingService].get_user_seat(self.user.user_id) @staticmethod def get_greeting() -> str: @@ -64,13 +77,13 @@ class UserInfoAndLoginBox(Component): async def _on_login_pressed(self) -> None: user_name = self.user_name_input.text.lower() - if self.session[UserService].is_login_valid(user_name, self.password_input.text): + if await self.session[UserService].is_login_valid(user_name, self.password_input.text): self.user_name_input.is_valid = True self.password_input.is_valid = True self.login_button.is_loading = False - await self.session[SessionStorage].set_user_id(self.session[UserService].get_user(user_name).user_id) + await self.session[SessionStorage].set_user_id((await self.session[UserService].get_user(user_name)).user_id) + await self.async_init() self.show_login = False - await self.refresh_cb() else: self.user_name_input.is_valid = False self.password_input.is_valid = False @@ -114,7 +127,7 @@ class UserInfoAndLoginBox(Component): on_press=lambda: self.session.navigate_to("./forgot-password") ) - if self.show_login and self.session[SessionStorage].user_id is None: + if self.user is None and self.session[SessionStorage].user_id is None: return Rectangle( content=Column( self.user_name_input, @@ -139,25 +152,31 @@ class UserInfoAndLoginBox(Component): margin_top=0.3, margin_bottom=2 ) + elif self.user is None and self.session[SessionStorage].user_id is not None: + return Rectangle( + content=Column(), + fill=Color.TRANSPARENT, + min_height=8, + min_width=12, + align_x=0.5, + margin_top=0.3, + margin_bottom=2 + ) else: - user = self.session[UserService].get_user(self.session[SessionStorage].user_id) - if user is None: - logger.warning("User could not be found, this should not have happend.") - a_s = self.session[AccountingService] return Rectangle( content=Column( Text(f"{self.get_greeting()},", style=TextStyle(fill=Color.from_hex("02dac5"), font_size=0.9), justify="center"), - Text(f"{user.user_name}", style=TextStyle(fill=Color.from_hex("02dac5"), font_size=1.2), justify="center"), + Text(f"{self.user.user_name}", style=TextStyle(fill=Color.from_hex("02dac5"), font_size=1.2), justify="center"), Row( StatusButton(label="TICKET", target_url="./buy_ticket", - enabled=self.session[TicketingService].get_user_ticket(user.user_id) is not None), + enabled=self.user_ticket is not None), StatusButton(label="SITZPLATZ", target_url="./seating", - enabled=self.session[SeatingService].get_user_seat(user.user_id) is not None), + enabled=self.user_seat is not None), proportions=(50, 50), grow_y=False ), UserInfoBoxButton("Profil bearbeiten", "./edit-profile"), - UserInfoBoxButton(f"Guthaben: {a_s.make_euro_string_from_int(a_s.get_balance(user.user_id))}", "./account"), + UserInfoBoxButton(f"Guthaben: {self.session[AccountingService].make_euro_string_from_int(self.user_balance)}", "./account"), Button( content=Text("Ausloggen", style=TextStyle(fill=Color.from_hex("02dac5"), font_size=0.6)), shape="rectangle", diff --git a/src/ez_lan_manager/pages/Account.py b/src/ez_lan_manager/pages/Account.py index 5f3adbf..0006531 100644 --- a/src/ez_lan_manager/pages/Account.py +++ b/src/ez_lan_manager/pages/Account.py @@ -1,22 +1,47 @@ -from rio import Column, Component, event, Text, TextStyle, Button, Color, Spacer, Revealer, Row +from asyncio import sleep +from typing import Optional + +from rio import Column, Component, event, Text, TextStyle, Button, Color, Spacer, Revealer, Row, ProgressCircle from src.ez_lan_manager import ConfigurationService, UserService, AccountingService from src.ez_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ez_lan_manager.pages import BasePage from src.ez_lan_manager.types.SessionStorage import SessionStorage +from src.ez_lan_manager.types.Transaction import Transaction +from src.ez_lan_manager.types.User import User class AccountPage(Component): + user: Optional[User] = None + balance: Optional[int] = None + transaction_history: list[Transaction] = list() + @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Guthabenkonto") + self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) + self.balance = await self.session[AccountingService].get_balance(self.user.user_id) + self.transaction_history = await self.session[AccountingService].get_transaction_history(self.user.user_id) async def _on_banking_info_press(self): self.banking_info_revealer.is_open = not self.banking_info_revealer.is_open def build(self) -> Component: - user = self.session[UserService].get_user(self.session[SessionStorage].user_id) - a_s = self.session[AccountingService] + if not self.user and not self.balance: + return BasePage( + content=Column( + MainViewContentBox( + ProgressCircle( + color="secondary", + align_x=0.5, + margin_top=2, + margin_bottom=2 + ) + ), + align_y = 0, + ) + ) + self.banking_info_revealer = Revealer( header=None, content=Column( @@ -45,7 +70,7 @@ class AccountPage(Component): align_x=0.2 ), Text( - f"AUFLADUNG - {user.user_id} - {user.user_name}", + f"AUFLADUNG - {self.user.user_id} - {self.user.user_name}", style=TextStyle( fill=self.session.theme.neutral_color ), @@ -73,7 +98,7 @@ class AccountPage(Component): ) ) - for transaction in sorted(self.session[AccountingService].get_transaction_history(user.user_id), key=lambda t: t.transaction_date, reverse=True): + for transaction in sorted(self.transaction_history, key=lambda t: t.transaction_date, reverse=True): transaction_history.add( Row( Text( @@ -89,7 +114,7 @@ class AccountPage(Component): align_x=0 ), Text( - f"{'-' if transaction.is_debit else '+'}{a_s.make_euro_string_from_int(transaction.value)}", + f"{'-' if transaction.is_debit else '+'}{AccountingService.make_euro_string_from_int(transaction.value)}", style=TextStyle( fill=self.session.theme.danger_color if transaction.is_debit else self.session.theme.success_color, font_size=0.8 @@ -106,7 +131,7 @@ class AccountPage(Component): content=Column( MainViewContentBox( content=Text( - f"Kontostand: {a_s.make_euro_string_from_int(a_s.get_balance(user.user_id))}", + f"Kontostand: {AccountingService.make_euro_string_from_int(self.balance)}", style=TextStyle( fill=self.session.theme.background_color, font_size=1.2 diff --git a/src/ez_lan_manager/pages/BasePage.py b/src/ez_lan_manager/pages/BasePage.py index a89efb4..f179ad3 100644 --- a/src/ez_lan_manager/pages/BasePage.py +++ b/src/ez_lan_manager/pages/BasePage.py @@ -4,7 +4,7 @@ from typing import * # type: ignore from rio import Component, event, Spacer, Card, Container, Column, Row, TextStyle, Color, Text -from src.ez_lan_manager import ConfigurationService, DatabaseService +from src.ez_lan_manager import ConfigurationService from src.ez_lan_manager.components.DesktopNavigation import DesktopNavigation class BasePage(Component): @@ -14,11 +14,6 @@ class BasePage(Component): async def on_window_size_change(self): await self.force_refresh() - @event.on_populate - async def check_db_connection(self): - if not self.session[DatabaseService].is_connected: - self.session.navigate_to("./db-error") - def build(self) -> Component: if self.content is None: content = Spacer() diff --git a/src/ez_lan_manager/pages/CateringPage.py b/src/ez_lan_manager/pages/CateringPage.py index 09af45f..ee87e56 100644 --- a/src/ez_lan_manager/pages/CateringPage.py +++ b/src/ez_lan_manager/pages/CateringPage.py @@ -1,16 +1,19 @@ -from rio import Column, Component, event, TextStyle, Text, Spacer, Revealer, SwitcherBar, SwitcherBarChangeEvent +from typing import Optional + +from rio import Column, Component, event, TextStyle, Text, Spacer, Revealer, SwitcherBar, SwitcherBarChangeEvent, ProgressCircle from src.ez_lan_manager import ConfigurationService, CateringService from src.ez_lan_manager.components.CateringSelectionItem import CateringSelectionItem from src.ez_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ez_lan_manager.components.ShoppingCartAndOrders import ShoppingCartAndOrders from src.ez_lan_manager.pages import BasePage -from src.ez_lan_manager.types.CateringMenuItem import CateringMenuItemCategory +from src.ez_lan_manager.types.CateringMenuItem import CateringMenuItemCategory, CateringMenuItem from src.ez_lan_manager.types.SessionStorage import SessionStorage class CateringPage(Component): show_cart = True + all_menu_items: Optional[list[CateringMenuItem]] = None def __post_init__(self) -> None: self.session[SessionStorage].subscribe_to_logged_in_or_out_event(self.__class__.__name__, self.on_user_logged_in_status_changed) @@ -18,6 +21,8 @@ class CateringPage(Component): @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Catering") + self.all_menu_items = await self.session[CateringService].get_menu() + async def on_user_logged_in_status_changed(self) -> None: await self.force_refresh() @@ -25,9 +30,13 @@ class CateringPage(Component): async def on_switcher_bar_changed(self, _: SwitcherBarChangeEvent) -> None: await self.shopping_cart_and_orders.switch() + @staticmethod + def get_menu_items_by_category(all_menu_items: list[CateringMenuItem], category: Optional[CateringMenuItemCategory]) -> list[CateringMenuItem]: + return list(filter(lambda item: item.category == category, all_menu_items)) + + def build(self) -> Component: user_id = self.session[SessionStorage].user_id - catering_service = self.session[CateringService] self.shopping_cart_and_orders = ShoppingCartAndOrders() switcher_bar = SwitcherBar( values=["cart", "orders"], @@ -58,12 +67,14 @@ class CateringPage(Component): ) ) if user_id else Spacer() - return BasePage( - content=Column( - # SHOPPING CART - shopping_cart_and_orders_container, - # ITEM SELECTION - MainViewContentBox( + menu = [MainViewContentBox( + ProgressCircle( + color="secondary", + align_x=0.5, + margin_top=2, + margin_bottom=2 + ) + )] if not self.all_menu_items else [MainViewContentBox( Revealer( header="Snacks", content=Column( @@ -75,7 +86,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.SNACK))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.SNACK))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -97,7 +108,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.BREAKFAST))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.BREAKFAST))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -119,7 +130,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.MAIN_COURSE))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.MAIN_COURSE))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -141,7 +152,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.DESSERT))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.DESSERT))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -163,7 +174,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.BEVERAGE_NON_ALCOHOLIC))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.BEVERAGE_NON_ALCOHOLIC))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -185,7 +196,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.BEVERAGE_ALCOHOLIC))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.BEVERAGE_ALCOHOLIC))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -207,7 +218,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.BEVERAGE_COCKTAIL))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.BEVERAGE_COCKTAIL))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -229,7 +240,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.BEVERAGE_SHOT))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.BEVERAGE_SHOT))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -251,7 +262,7 @@ class CateringPage(Component): is_sensitive=(user_id is not None) and not catering_menu_item.is_disabled, additional_info=catering_menu_item.additional_info, is_grey=idx % 2 == 0 - ) for idx, catering_menu_item in enumerate(catering_service.get_menu(CateringMenuItemCategory.NON_FOOD))], + ) for idx, catering_menu_item in enumerate(self.get_menu_items_by_category(self.all_menu_items, CateringMenuItemCategory.NON_FOOD))], ), header_style=TextStyle( fill=self.session.theme.background_color, @@ -260,7 +271,14 @@ class CateringPage(Component): margin=1, align_y=0.5 ) - ), + )] + + return BasePage( + content=Column( + # SHOPPING CART + shopping_cart_and_orders_container, + # ITEM SELECTION + *menu, align_y=0 ) ) diff --git a/src/ez_lan_manager/pages/ContactPage.py b/src/ez_lan_manager/pages/ContactPage.py index 24b454d..6fdb7d9 100644 --- a/src/ez_lan_manager/pages/ContactPage.py +++ b/src/ez_lan_manager/pages/ContactPage.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Optional from rio import Text, Column, TextStyle, Component, event, TextInput, MultiLineTextInput, Row, Button @@ -7,6 +8,7 @@ from src.ez_lan_manager.components.AnimatedText import AnimatedText from src.ez_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ez_lan_manager.pages import BasePage from src.ez_lan_manager.types.SessionStorage import SessionStorage +from src.ez_lan_manager.types.User import User class ContactPage(Component): @@ -14,10 +16,15 @@ class ContactPage(Component): # Using list to bypass this behavior last_message_sent: list[datetime] = [datetime(day=1, month=1, year=2000)] display_printing: list[bool] = [False] + user: Optional[User] = None @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Kontakt") + if self.session[SessionStorage].user_id is not None: + self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) + else: + self.user = None async def on_send_pressed(self) -> None: error_msg = "" @@ -51,11 +58,6 @@ class ContactPage(Component): await self.animated_text.display_text(True, "Nachricht erfolgreich gesendet!") def build(self) -> Component: - if self.session[SessionStorage].user_id is not None: - user = self.session[UserService].get_user(self.session[SessionStorage].user_id) - else: - user = None - self.animated_text = AnimatedText( margin_top = 2, margin_bottom = 1, @@ -64,7 +66,7 @@ class ContactPage(Component): self.email_input = TextInput( label="E-Mail Adresse", - text="" if not user else user.user_mail, + text="" if not self.user else self.user.user_mail, margin_left=1, margin_right=1, margin_bottom=1, diff --git a/src/ez_lan_manager/pages/DbErrorPage.py b/src/ez_lan_manager/pages/DbErrorPage.py index 97c50f0..f3da2a3 100644 --- a/src/ez_lan_manager/pages/DbErrorPage.py +++ b/src/ez_lan_manager/pages/DbErrorPage.py @@ -15,12 +15,12 @@ class DbErrorPage(Component): async def on_window_size_change(self) -> None: await self.force_refresh() - @event.on_mount - async def retry_db_connect(self) -> None: - await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Fehler") - while not self.session[DatabaseService].is_connected: - await sleep(2) - self.session.navigate_to("./") + # @event.on_mount + # async def retry_db_connect(self) -> None: + # await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Fehler") + # while not self.session[DatabaseService].is_connected: + # await sleep(2) + # self.session.navigate_to("./") def build(self) -> Component: content = Card( diff --git a/src/ez_lan_manager/pages/EditProfile.py b/src/ez_lan_manager/pages/EditProfile.py index 774a7be..4543337 100644 --- a/src/ez_lan_manager/pages/EditProfile.py +++ b/src/ez_lan_manager/pages/EditProfile.py @@ -3,7 +3,8 @@ from hashlib import sha256 from typing import Optional from from_root import from_root -from rio import Column, Component, event, Text, TextStyle, Button, Color, Row, TextInput, Image, TextInputChangeEvent, NoFileSelectedError +from rio import Column, Component, event, Text, TextStyle, Button, Color, Row, TextInput, Image, TextInputChangeEvent, NoFileSelectedError, \ + ProgressCircle from email_validator import validate_email, EmailNotValidError from src.ez_lan_manager import ConfigurationService, UserService @@ -15,6 +16,8 @@ from src.ez_lan_manager.types.User import User class EditProfilePage(Component): + user: Optional[User] = None + pfp: Optional[bytes] = None @staticmethod def optional_date_to_str(d: Optional[date]) -> str: if not d: @@ -24,6 +27,8 @@ class EditProfilePage(Component): @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Profil bearbeiten") + self.user = await self.session[UserService].get_user(self.session[SessionStorage].user_id) + self.pfp = await self.session[UserService].get_profile_picture(self.user.user_id) def on_email_changed(self, change_event: TextInputChangeEvent) -> None: try: @@ -58,7 +63,7 @@ class EditProfilePage(Component): return image_data = await new_pfp.read_bytes() - self.session[UserService].set_profile_picture(self.session[SessionStorage].user_id, image_data) + await self.session[UserService].set_profile_picture(self.session[SessionStorage].user_id, image_data) self.pfp_image_container.image = image_data await self.animated_text.display_text(True, "Gespeichert!") @@ -72,7 +77,7 @@ class EditProfilePage(Component): await self.animated_text.display_text(False, "Passwörter nicht gleich!") return - user: User = self.session[UserService].get_user(self.session[SessionStorage].user_id) + user: User = await self.session[UserService].get_user(self.session[SessionStorage].user_id) user.user_mail = self.email_input.text if len(self.birthday_input.text) == 0: @@ -86,12 +91,24 @@ class EditProfilePage(Component): if len(self.new_pw_1_input.text.strip()) > 0: user.user_password = sha256(self.new_pw_1_input.text.encode(encoding="utf-8")).hexdigest() - self.session[UserService].update_user(user) + await self.session[UserService].update_user(user) await self.animated_text.display_text(True, "Gespeichert!") def build(self) -> Component: - user = self.session[UserService].get_user(self.session[SessionStorage].user_id) - pfp = self.session[UserService].get_profile_picture(user.user_id) + if not self.user: + return BasePage( + content=Column( + MainViewContentBox( + ProgressCircle( + color="secondary", + align_x=0.5, + margin_top=2, + margin_bottom=2 + ) + ), + align_y = 0 + ) + ) self.animated_text = AnimatedText( margin_top=2, @@ -101,7 +118,7 @@ class EditProfilePage(Component): self.email_input = TextInput( label="E-Mail Adresse", - text=user.user_mail, + text=self.user.user_mail, margin_left=1, margin_right=1, margin_bottom=1, @@ -110,20 +127,20 @@ class EditProfilePage(Component): ) self.first_name_input = TextInput( label="Vorname", - text=user.user_first_name, + text=self.user.user_first_name, margin_left=1, margin_right=1, grow_x=True ) self.last_name_input = TextInput( label="Nachname", - text=user.user_last_name, + text=self.user.user_last_name, margin_right=1, grow_x=True ) self.birthday_input = TextInput( label="Geburtstag (TT.MM.JJJJ)", - text=self.optional_date_to_str(user.user_birth_day), + text=self.optional_date_to_str(self.user.user_birth_day), margin_left=1, margin_right=1, margin_bottom=1, @@ -150,7 +167,7 @@ class EditProfilePage(Component): ) self.pfp_image_container = Image( - from_root("src/ez_lan_manager/assets/img/anon_pfp.png") if pfp is None else pfp, + from_root("src/ez_lan_manager/assets/img/anon_pfp.png") if self.pfp is None else self.pfp, align_x=0.5, min_width=10, min_height=10, @@ -176,8 +193,8 @@ class EditProfilePage(Component): on_press=self.upload_new_pfp ), Row( - TextInput(label="Deine User-ID", text=user.user_id, is_sensitive=False, margin_left=1, grow_x=False), - TextInput(label="Dein Nickname", text=user.user_name, is_sensitive=False, margin_left=1, margin_right=1, grow_x=True), + TextInput(label="Deine User-ID", text=self.user.user_id, is_sensitive=False, margin_left=1, grow_x=False), + TextInput(label="Dein Nickname", text=self.user.user_name, is_sensitive=False, margin_left=1, margin_right=1, grow_x=True), margin_bottom=1 ), self.email_input, diff --git a/src/ez_lan_manager/pages/ForgotPassword.py b/src/ez_lan_manager/pages/ForgotPassword.py index 7dc0212..eafba10 100644 --- a/src/ez_lan_manager/pages/ForgotPassword.py +++ b/src/ez_lan_manager/pages/ForgotPassword.py @@ -24,11 +24,11 @@ class ForgotPasswordPage(Component): lan_info = self.session[ConfigurationService].get_lan_info() user_service = self.session[UserService] mailing_service = self.session[MailingService] - user = user_service.get_user(self.email_input.text.strip()) + user = await user_service.get_user(self.email_input.text.strip()) if user is not None: new_password = "".join(choices(user_service.ALLOWED_USER_NAME_SYMBOLS, k=16)) user.user_password = sha256(new_password.encode(encoding="utf-8")).hexdigest() - user_service.update_user(user) + await user_service.update_user(user) await mailing_service.send_email( subject=f"Dein neues Passwort für {lan_info.name}", body=f"Du hast für den EZ-LAN Manager der {lan_info.name} ein neues Passwort angefragt. " diff --git a/src/ez_lan_manager/pages/GuestsPage.py b/src/ez_lan_manager/pages/GuestsPage.py index 363a556..bb35ce1 100644 --- a/src/ez_lan_manager/pages/GuestsPage.py +++ b/src/ez_lan_manager/pages/GuestsPage.py @@ -5,37 +5,50 @@ from rio import Column, Component, event, TextStyle, Text, Button, Row, TextInpu from src.ez_lan_manager import ConfigurationService, UserService, TicketingService, SeatingService from src.ez_lan_manager.components.MainViewContentBox import MainViewContentBox from src.ez_lan_manager.pages import BasePage +from src.ez_lan_manager.types.Seat import Seat from src.ez_lan_manager.types.User import User class GuestsPage(Component): table_elements: list[Button] = [] users_with_tickets: list[User] = [] + users_with_seats: dict[User, Seat] = {} user_filter: Optional[str] = None - - def __post_init__(self) -> None: - user_service = self.session[UserService] - all_users = user_service.get_all_users() - ticketing_service = self.session[TicketingService] - self.users_with_tickets = list(filter(lambda user: ticketing_service.get_user_ticket(user.user_id) is not None, all_users)) - @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Teilnehmer") + user_service = self.session[UserService] + all_users = await user_service.get_all_users() + ticketing_service = self.session[TicketingService] + seating_service = self.session[SeatingService] + u_w_t = [] + u_w_s = {} + for user in all_users: + ticket = await ticketing_service.get_user_ticket(user.user_id) + seat = await seating_service.get_user_seat(user.user_id) + if ticket is not None: + u_w_t.append(user) + if seat is not None: + u_w_s[user] = seat + + self.users_with_tickets = u_w_t + self.users_with_seats = u_w_s def on_searchbar_content_change(self, change_event: TextInputChangeEvent) -> None: self.user_filter = change_event.text def build(self) -> Component: - seating_service = self.session[SeatingService] if self.user_filter: users = [user for user in self.users_with_tickets if self.user_filter.lower() in user.user_name or self.user_filter.lower() in str(user.user_id)] else: users = self.users_with_tickets self.table_elements.clear() for idx, user in enumerate(users): - seat = seating_service.get_user_seat(user.user_id) + try: + seat = self.users_with_seats[user] + except KeyError: + seat = None self.table_elements.append( Button( content=Row(Text(text=f"{user.user_id:0>4}", align_x=0, margin_right=1), Text(text=user.user_name, grow_x=True, wrap="ellipsize"), Text(text="-" if seat is None else seat.seat_id, align_x=1)), diff --git a/src/ez_lan_manager/pages/NewsPage.py b/src/ez_lan_manager/pages/NewsPage.py index d8c73d5..1d83807 100644 --- a/src/ez_lan_manager/pages/NewsPage.py +++ b/src/ez_lan_manager/pages/NewsPage.py @@ -12,7 +12,7 @@ class NewsPage(Component): @event.on_populate async def on_populate(self) -> None: await self.session.set_title(f"{self.session[ConfigurationService].get_lan_info().name} - Neuigkeiten") - self.news_posts = self.session[NewsService].get_news()[:8] + self.news_posts = (await self.session[NewsService].get_news())[:8] def build(self) -> Component: posts = [NewsPost( diff --git a/src/ez_lan_manager/pages/RegisterPage.py b/src/ez_lan_manager/pages/RegisterPage.py index d03e459..1ee6d66 100644 --- a/src/ez_lan_manager/pages/RegisterPage.py +++ b/src/ez_lan_manager/pages/RegisterPage.py @@ -62,13 +62,13 @@ class RegisterPage(Component): mailing_service = self.session[MailingService] lan_info = self.session[ConfigurationService].get_lan_info() - if user_service.get_user(self.email_input.text) is not None or user_service.get_user(self.user_name_input.text) is not None: + if await user_service.get_user(self.email_input.text) is not None or await user_service.get_user(self.user_name_input.text) is not None: await self.animated_text.display_text(False, "Benutzername oder E-Mail bereits regestriert!") self.submit_button.is_loading = False return try: - new_user = user_service.create_user(self.user_name_input.text, self.email_input.text, self.pw_1.text) + new_user = await user_service.create_user(self.user_name_input.text, self.email_input.text, self.pw_1.text) if not new_user: raise RuntimeError("User could not be created") except Exception as e: diff --git a/src/ez_lan_manager/services/AccountingService.py b/src/ez_lan_manager/services/AccountingService.py index e3bec1e..120ac85 100644 --- a/src/ez_lan_manager/services/AccountingService.py +++ b/src/ez_lan_manager/services/AccountingService.py @@ -13,8 +13,8 @@ class AccountingService: def __init__(self, db_service: DatabaseService) -> None: self._db_service = db_service - def add_balance(self, user_id: int, balance_to_add: int, reference: str) -> int: - self._db_service.add_transaction(Transaction( + async def add_balance(self, user_id: int, balance_to_add: int, reference: str) -> int: + await self._db_service.add_transaction(Transaction( user_id=user_id, value=balance_to_add, is_debit=False, @@ -22,13 +22,13 @@ class AccountingService: transaction_date=datetime.now() )) logger.debug(f"Added balance of {self.make_euro_string_from_int(balance_to_add)} to user with ID {user_id}") - return self.get_balance(user_id) + return await self.get_balance(user_id) - def remove_balance(self, user_id: int, balance_to_remove: int, reference: str) -> int: - current_balance = self.get_balance(user_id) + async def remove_balance(self, user_id: int, balance_to_remove: int, reference: str) -> int: + current_balance = await self.get_balance(user_id) if (current_balance - balance_to_remove) < 0: raise InsufficientFundsError - self._db_service.add_transaction(Transaction( + await self._db_service.add_transaction(Transaction( user_id=user_id, value=balance_to_remove, is_debit=True, @@ -36,19 +36,19 @@ class AccountingService: transaction_date=datetime.now() )) logger.debug(f"Removed balance of {self.make_euro_string_from_int(balance_to_remove)} to user with ID {user_id}") - return self.get_balance(user_id) + return await self.get_balance(user_id) - def get_balance(self, user_id: int) -> int: + async def get_balance(self, user_id: int) -> int: balance_buffer = 0 - for transaction in self._db_service.get_all_transactions_for_user(user_id): + for transaction in await self._db_service.get_all_transactions_for_user(user_id): if transaction.is_debit: balance_buffer -= transaction.value else: balance_buffer += transaction.value return balance_buffer - def get_transaction_history(self, user_id: int) -> list[Transaction]: - return self._db_service.get_all_transactions_for_user(user_id) + async def get_transaction_history(self, user_id: int) -> list[Transaction]: + return await self._db_service.get_all_transactions_for_user(user_id) @staticmethod def make_euro_string_from_int(cent_int: int) -> str: diff --git a/src/ez_lan_manager/services/CateringService.py b/src/ez_lan_manager/services/CateringService.py index 33effae..5c4853a 100644 --- a/src/ez_lan_manager/services/CateringService.py +++ b/src/ez_lan_manager/services/CateringService.py @@ -23,93 +23,93 @@ class CateringService: # ORDERS - def place_order(self, menu_items: CateringMenuItemsWithAmount, user_id: int, is_delivery: bool = True) -> CateringOrder: + async def place_order(self, menu_items: CateringMenuItemsWithAmount, user_id: int, is_delivery: bool = True) -> CateringOrder: for menu_item in menu_items: if menu_item.is_disabled: raise CateringError("Order includes disabled items") - user = self._user_service.get_user(user_id) + user = await self._user_service.get_user(user_id) if not user: raise CateringError("User does not exist") total_price = sum([item.price * quantity for item, quantity in menu_items.items()]) - if self._accounting_service.get_balance(user_id) < total_price: + if await self._accounting_service.get_balance(user_id) < total_price: raise CateringError("Insufficient funds") - order = self._db_service.add_new_order(menu_items, user_id, is_delivery) + order = await self._db_service.add_new_order(menu_items, user_id, is_delivery) if order: - self._accounting_service.remove_balance(user_id, total_price, f"CATERING - {order.order_id}") + await self._accounting_service.remove_balance(user_id, total_price, f"CATERING - {order.order_id}") logger.info(f"User '{order.customer.user_name}' (ID:{order.customer.user_id}) ordered from catering for {self._accounting_service.make_euro_string_from_int(total_price)}") return order - def update_order_status(self, order_id: int, new_status: CateringOrderStatus) -> bool: + async def update_order_status(self, order_id: int, new_status: CateringOrderStatus) -> bool: if new_status == CateringOrderStatus.CANCELED: # Cancelled orders need to be refunded raise CateringError("Orders cannot be canceled this way, use CateringService.cancel_order") - return self._db_service.change_order_status(order_id, new_status) + return await self._db_service.change_order_status(order_id, new_status) - def get_orders(self) -> list[CateringOrder]: - return self._db_service.get_orders() + async def get_orders(self) -> list[CateringOrder]: + return await self._db_service.get_orders() - def get_orders_for_user(self, user_id: int) -> list[CateringOrder]: - return self._db_service.get_orders(user_id=user_id) + async def get_orders_for_user(self, user_id: int) -> list[CateringOrder]: + return await self._db_service.get_orders(user_id=user_id) - def get_orders_by_status(self, status: CateringOrderStatus) -> list[CateringOrder]: - return self._db_service.get_orders(status=status) + async def get_orders_by_status(self, status: CateringOrderStatus) -> list[CateringOrder]: + return await self._db_service.get_orders(status=status) - def cancel_order(self, order: CateringOrder) -> bool: + async def cancel_order(self, order: CateringOrder) -> bool: if self._db_service.change_order_status(order.order_id, CateringOrderStatus.CANCELED): - self._accounting_service.add_balance(order.customer.user_id, order.price, f"CATERING REFUND - {order.order_id}") + await self._accounting_service.add_balance(order.customer.user_id, order.price, f"CATERING REFUND - {order.order_id}") return True return False # MENU ITEMS - def get_menu(self, category: Optional[CateringMenuItemCategory] = None) -> list[CateringMenuItem]: - items = self._db_service.get_menu_items() + async def get_menu(self, category: Optional[CateringMenuItemCategory] = None) -> list[CateringMenuItem]: + items = await self._db_service.get_menu_items() if not category: return items return list(filter(lambda item: item.category == category, items)) - def get_menu_item_by_id(self, menu_item_id: int) -> CateringMenuItem: - item = self._db_service.get_menu_item(menu_item_id) + async def get_menu_item_by_id(self, menu_item_id: int) -> CateringMenuItem: + item = await self._db_service.get_menu_item(menu_item_id) if not item: raise CateringError("Menu item not found") return item - def add_menu_item(self, name: str, info: str, price: int, category: CateringMenuItemCategory, is_disabled: bool = False) -> CateringMenuItem: - if new_item := self._db_service.add_menu_item(name, info, price, category, is_disabled): + async def add_menu_item(self, name: str, info: str, price: int, category: CateringMenuItemCategory, is_disabled: bool = False) -> CateringMenuItem: + if new_item := await self._db_service.add_menu_item(name, info, price, category, is_disabled): return new_item raise CateringError(f"Could not add item '{name}' to the menu.") - def remove_menu_item(self, menu_item_id: int) -> bool: - return self._db_service.delete_menu_item(menu_item_id) + async def remove_menu_item(self, menu_item_id: int) -> bool: + return await self._db_service.delete_menu_item(menu_item_id) - def change_menu_item(self, updated_item: CateringMenuItem) -> bool: - return self._db_service.update_menu_item(updated_item) + async def change_menu_item(self, updated_item: CateringMenuItem) -> bool: + return await self._db_service.update_menu_item(updated_item) - def disable_menu_item(self, menu_item_id: int) -> bool: + async def disable_menu_item(self, menu_item_id: int) -> bool: try: - item = self.get_menu_item_by_id(menu_item_id) + item = await self.get_menu_item_by_id(menu_item_id) except CateringError: return False item.is_disabled = True - return self._db_service.update_menu_item(item) + return await self._db_service.update_menu_item(item) - def enable_menu_item(self, menu_item_id: int) -> bool: + async def enable_menu_item(self, menu_item_id: int) -> bool: try: - item = self.get_menu_item_by_id(menu_item_id) + item = await self.get_menu_item_by_id(menu_item_id) except CateringError: return False item.is_disabled = False - return self._db_service.update_menu_item(item) + return await self._db_service.update_menu_item(item) - def disable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool: - items = self.get_menu(category=category) + async def disable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool: + items = await self.get_menu(category=category) return all([self.disable_menu_item(item.item_id) for item in items]) - def enable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool: - items = self.get_menu(category=category) + async def enable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool: + items = await self.get_menu(category=category) return all([self.enable_menu_item(item.item_id) for item in items]) # CART diff --git a/src/ez_lan_manager/services/DatabaseService.py b/src/ez_lan_manager/services/DatabaseService.py index bbcd4ba..3bca5c4 100644 --- a/src/ez_lan_manager/services/DatabaseService.py +++ b/src/ez_lan_manager/services/DatabaseService.py @@ -1,11 +1,9 @@ import logging -from time import sleep from datetime import date, datetime -from typing import Optional, Coroutine +from typing import Optional -import mariadb -from mariadb import Cursor +import aiomysql from src.ez_lan_manager.types.CateringOrder import CateringOrder from src.ez_lan_manager.types.CateringMenuItem import CateringMenuItem, CateringMenuItemCategory @@ -29,56 +27,27 @@ class DatabaseService: MAX_CONNECTION_RETRIES = 5 def __init__(self, database_config: DatabaseConfiguration) -> None: self._database_config = database_config + self._connection_pool: Optional[aiomysql.Pool] = None + + async def init_db_pool(self) -> bool: logger.info( f"Connecting to database '{self._database_config.db_name}' on " f"{self._database_config.db_user}@{self._database_config.db_host}:{self._database_config.db_port}" ) - self._connection: Optional[mariadb.Connection] = None - self._reestablishment_lock = False - self.establish_new_connection() - - @property - def is_connected(self) -> bool: try: - self._connection.ping() - except Exception: - try: - self.establish_new_connection() - return True - except NoDatabaseConnectionError: - return False + self._connection_pool = await aiomysql.create_pool( + host=self._database_config.db_host, + port=self._database_config.db_port, + user=self._database_config.db_user, + password=self._database_config.db_password, + db=self._database_config.db_name, + minsize=1, + maxsize=20 + ) + except aiomysql.OperationalError: + return False return True - def establish_new_connection(self) -> None: - if self._reestablishment_lock: - return - self._reestablishment_lock = True - - if isinstance(self._connection, mariadb.Connection): - self._connection.close() - self._connection = None - - for _ in range(DatabaseService.MAX_CONNECTION_RETRIES): - try: - self._connection = mariadb.connect( - user=self._database_config.db_user, - password=self._database_config.db_password, - host=self._database_config.db_host, - port=self._database_config.db_port, - database=self._database_config.db_name - ) - except mariadb.Error: - sleep(0.4) - continue - self._reestablishment_lock = False - return - self._reestablishment_lock = False - raise NoDatabaseConnectionError - - - def _get_cursor(self) -> Cursor: - return self._connection.cursor() - @staticmethod def _map_db_result_to_user(data: tuple) -> User: return User( @@ -96,554 +65,640 @@ class DatabaseService: last_updated_at=data[11] ) - def get_user_by_name(self, user_name: str) -> Optional[User]: - cursor = self._get_cursor() - cursor.execute("SELECT * FROM users WHERE user_name=?", (user_name,)) - self._connection.commit() - result = cursor.fetchone() - if not result: - return - return self._map_db_result_to_user(result) + async def get_user_by_name(self, user_name: str) -> Optional[User]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + await cursor.execute("SELECT * FROM users WHERE user_name=%s", (user_name,)) + result = await cursor.fetchone() + if not result: + return + return self._map_db_result_to_user(result) - def get_user_by_id(self, user_id: int) -> Optional[User]: - cursor = self._get_cursor() - cursor.execute("SELECT * FROM users WHERE user_id=?", (user_id,)) - self._connection.commit() - result = cursor.fetchone() - if not result: - return - return self._map_db_result_to_user(result) - def get_user_by_mail(self, user_mail: str) -> Optional[User]: - cursor = self._get_cursor() - cursor.execute("SELECT * FROM users WHERE user_mail=?", (user_mail.lower(),)) - self._connection.commit() - result = cursor.fetchone() - if not result: - return - return self._map_db_result_to_user(result) + async def get_user_by_id(self, user_id: int) -> Optional[User]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + await cursor.execute("SELECT * FROM users WHERE user_id=%s", (user_id,)) + result = await cursor.fetchone() + if not result: + return + return self._map_db_result_to_user(result) - def create_user(self, user_name: str, user_mail: str, password_hash: str) -> User: - cursor = self._get_cursor() - try: - cursor.execute( - "INSERT INTO users (user_name, user_mail, user_password) " - "VALUES (?, ?, ?)", (user_name, user_mail.lower(), password_hash) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.create_user(user_name, user_mail, password_hash) - except mariadb.IntegrityError as e: - logger.warning(f"Aborted duplication entry: {e}") - raise DuplicationError + async def get_user_by_mail(self, user_mail: str) -> Optional[User]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + await cursor.execute("SELECT * FROM users WHERE user_mail=%s", (user_mail.lower(),)) + result = await cursor.fetchone() + if not result: + return + return self._map_db_result_to_user(result) - return self.get_user_by_name(user_name) + async def create_user(self, user_name: str, user_mail: str, password_hash: str) -> User: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "INSERT INTO users (user_name, user_mail, user_password) " + "VALUES (%s, %s, %s)", (user_name, user_mail.lower(), password_hash) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.create_user(user_name, user_mail, password_hash) + except aiomysql.IntegrityError as e: + logger.warning(f"Aborted duplication entry: {e}") + raise DuplicationError - def update_user(self, user: User) -> User: - cursor = self._get_cursor() - try: - cursor.execute( - "UPDATE users SET user_name=?, user_mail=?, user_password=?, user_first_name=?, user_last_name=?, user_birth_date=?, " - "is_active=?, is_team_member=?, is_admin=? WHERE (user_id=?)", (user.user_name, user.user_mail.lower(), user.user_password, - user.user_first_name, user.user_last_name, user.user_birth_day, - user.is_active, user.is_team_member, user.is_admin, - user.user_id) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.update_user(user) - except mariadb.IntegrityError as e: - logger.warning(f"Aborted duplication entry: {e}") - raise DuplicationError - return user + return await self.get_user_by_name(user_name) + + - def add_transaction(self, transaction: Transaction) -> Optional[Transaction]: - cursor = self._get_cursor() - try: - cursor.execute( - "INSERT INTO transactions (user_id, value, is_debit, transaction_date, transaction_reference) " - "VALUES (?, ?, ?, ?, ?)", - (transaction.user_id, transaction.value, transaction.is_debit, transaction.transaction_date, transaction.reference) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.add_transaction(transaction) - except Exception as e: - logger.warning(f"Error adding Transaction: {e}") - return + async def update_user(self, user: User) -> User: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "UPDATE users SET user_name=%s, user_mail=%s, user_password=%s, user_first_name=%s, user_last_name=%s, user_birth_date=%s, " + "is_active=%s, is_team_member=%s, is_admin=%s WHERE (user_id=%s)", (user.user_name, user.user_mail.lower(), user.user_password, + user.user_first_name, user.user_last_name, user.user_birth_day, + user.is_active, user.is_team_member, user.is_admin, + user.user_id) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.update_user(user) + except aiomysql.IntegrityError as e: + logger.warning(f"Aborted duplication entry: {e}") + raise DuplicationError + return user - return transaction + async def add_transaction(self, transaction: Transaction) -> Optional[Transaction]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "INSERT INTO transactions (user_id, value, is_debit, transaction_date, transaction_reference) " + "VALUES (%s, %s, %s, %s, %s)", + (transaction.user_id, transaction.value, transaction.is_debit, transaction.transaction_date, transaction.reference) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.add_transaction(transaction) + except Exception as e: + logger.warning(f"Error adding Transaction: {e}") + return - def get_all_transactions_for_user(self, user_id: int) -> list[Transaction]: - transactions = [] + return transaction - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM transactions WHERE user_id=?", (user_id,)) - self._connection.commit() - result = cursor.fetchall() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_all_transactions_for_user(user_id) - except mariadb.Error as e: - logger.error(f"Error getting all transactions for user: {e}") - return [] + async def get_all_transactions_for_user(self, user_id: int) -> list[Transaction]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + transactions = [] + try: + await cursor.execute("SELECT * FROM transactions WHERE user_id=%s", (user_id,)) + await conn.commit() + result = await cursor.fetchall() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_all_transactions_for_user(user_id) + except aiomysql.Error as e: + logger.error(f"Error getting all transactions for user: {e}") + return [] - for transaction_raw in result: - transactions.append(Transaction( - user_id=user_id, - value=int(transaction_raw[2]), - is_debit=bool(transaction_raw[3]), - transaction_date=transaction_raw[4], - reference=transaction_raw[5] - )) - return transactions + for transaction_raw in result: + transactions.append(Transaction( + user_id=user_id, + value=int(transaction_raw[2]), + is_debit=bool(transaction_raw[3]), + transaction_date=transaction_raw[4], + reference=transaction_raw[5] + )) + return transactions - def add_news(self, news: News) -> None: - cursor = self._get_cursor() - try: - cursor.execute( - "INSERT INTO news (news_content, news_title, news_subtitle, news_author, news_date) " - "VALUES (?, ?, ?, ?, ?)", - (news.content, news.title, news.subtitle, news.author.user_id, news.news_date) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.add_news(news) - except Exception as e: - logger.warning(f"Error adding Transaction: {e}") - def get_news(self, dt_start: date, dt_end: date) -> list[News]: - results = [] - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM news INNER JOIN users ON news.news_author = users.user_id WHERE news_date BETWEEN ? AND ?;", (dt_start, dt_end)) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_news(dt_start, dt_end) - except Exception as e: - logger.warning(f"Error fetching news: {e}") - return [] + async def add_news(self, news: News) -> None: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "INSERT INTO news (news_content, news_title, news_subtitle, news_author, news_date) " + "VALUES (%s, %s, %s, %s, %s)", + (news.content, news.title, news.subtitle, news.author.user_id, news.news_date) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.add_news(news) + except Exception as e: + logger.warning(f"Error adding Transaction: {e}") - for news_raw in cursor.fetchall(): - user = self._map_db_result_to_user(news_raw[6:]) - results.append(News( - news_id=news_raw[0], - title=news_raw[2], - subtitle=news_raw[3], - author=user, - content=news_raw[1], - news_date=news_raw[5] - )) - return results + async def get_news(self, dt_start: date, dt_end: date) -> list[News]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + results = [] + try: + await cursor.execute("SELECT * FROM news INNER JOIN users ON news.news_author = users.user_id WHERE news_date BETWEEN %s AND %s;", (dt_start, dt_end)) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_news(dt_start, dt_end) + except Exception as e: + logger.warning(f"Error fetching news: {e}") + return [] - def get_tickets(self) -> list[Ticket]: - results = [] - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM tickets INNER JOIN users ON tickets.user = users.user_id;", ()) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_tickets() - except Exception as e: - logger.warning(f"Error fetching tickets: {e}") - return [] + for news_raw in await cursor.fetchall(): + user = self._map_db_result_to_user(news_raw[6:]) + results.append(News( + news_id=news_raw[0], + title=news_raw[2], + subtitle=news_raw[3], + author=user, + content=news_raw[1], + news_date=news_raw[5] + )) - for ticket_raw in cursor.fetchall(): - user = self._map_db_result_to_user(ticket_raw[3:]) - results.append(Ticket( - ticket_id=ticket_raw[0], - category=ticket_raw[1], - purchase_date=ticket_raw[3], - owner=user - )) + return results - return results - def get_ticket_for_user(self, user_id: int) -> Optional[Ticket]: - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM tickets INNER JOIN users ON tickets.user = users.user_id WHERE user_id=?;", (user_id, )) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_ticket_for_user(user_id) - except Exception as e: - logger.warning(f"Error fetching ticket for user: {e}") - return - result = cursor.fetchone() - if not result: - return + async def get_tickets(self) -> list[Ticket]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + results = [] + try: + await cursor.execute("SELECT * FROM tickets INNER JOIN users ON tickets.user = users.user_id;", ()) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_tickets() + except Exception as e: + logger.warning(f"Error fetching tickets: {e}") + return [] - user = self._map_db_result_to_user(result[3:]) - return Ticket( - ticket_id=result[0], - category=result[1], - purchase_date=result[3], - owner=user - ) + for ticket_raw in await cursor.fetchall(): + user = self._map_db_result_to_user(ticket_raw[3:]) + results.append(Ticket( + ticket_id=ticket_raw[0], + category=ticket_raw[1], + purchase_date=ticket_raw[3], + owner=user + )) - def generate_ticket_for_user(self, user_id: int, category: str) -> Optional[Ticket]: - cursor = self._get_cursor() - try: - cursor.execute("INSERT INTO tickets (ticket_category, user) VALUES (?, ?)", (category, user_id)) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.generate_ticket_for_user(user_id, category) - except Exception as e: - logger.warning(f"Error generating ticket for user: {e}") - return + return results - return self.get_ticket_for_user(user_id) - def change_ticket_owner(self, ticket_id: int, new_owner_id: int) -> bool: - cursor = self._get_cursor() - try: - cursor.execute("UPDATE tickets SET user = ? WHERE ticket_id = ?;", (new_owner_id, ticket_id)) - affected_rows = cursor.rowcount - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.change_ticket_owner(ticket_id, new_owner_id) - except Exception as e: - logger.warning(f"Error transferring ticket to user: {e}") - return False - return bool(affected_rows) + async def get_ticket_for_user(self, user_id: int) -> Optional[Ticket]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("SELECT * FROM tickets INNER JOIN users ON tickets.user = users.user_id WHERE user_id=%s;", (user_id, )) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_ticket_for_user(user_id) + except Exception as e: + logger.warning(f"Error fetching ticket for user: {e}") + return - def delete_ticket(self, ticket_id: int) -> bool: - cursor = self._get_cursor() - try: - cursor.execute("DELETE FROM tickets WHERE ticket_id = ?;", (ticket_id, )) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.change_ticket_owner(ticket_id) - except Exception as e: - logger.warning(f"Error deleting ticket: {e}") - return False - return True + result = await cursor.fetchone() + if not result: + return - def generate_fresh_seats_table(self, seats: list[tuple[str, str]]) -> None: + user = self._map_db_result_to_user(result[3:]) + return Ticket( + ticket_id=result[0], + category=result[1], + purchase_date=result[3], + owner=user + ) + + async def generate_ticket_for_user(self, user_id: int, category: str) -> Optional[Ticket]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("INSERT INTO tickets (ticket_category, user) VALUES (%s, %s)", (category, user_id)) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.generate_ticket_for_user(user_id, category) + except Exception as e: + logger.warning(f"Error generating ticket for user: {e}") + return + + return await self.get_ticket_for_user(user_id) + + + async def change_ticket_owner(self, ticket_id: int, new_owner_id: int) -> bool: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("UPDATE tickets SET user = %s WHERE ticket_id = %s;", (new_owner_id, ticket_id)) + affected_rows = cursor.rowcount + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.change_ticket_owner(ticket_id, new_owner_id) + except Exception as e: + logger.warning(f"Error transferring ticket to user: {e}") + return False + return affected_rows > 0 + + async def delete_ticket(self, ticket_id: int) -> bool: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("DELETE FROM tickets WHERE ticket_id = %s;", (ticket_id, )) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.change_ticket_owner(ticket_id) + except Exception as e: + logger.warning(f"Error deleting ticket: {e}") + return False + return True + + async def generate_fresh_seats_table(self, seats: list[tuple[str, str]]) -> None: """ WARNING: THIS WILL DELETE ALL EXISTING DATA! DO NOT USE ON PRODUCTION DATABASE! """ - cursor = self._get_cursor() - try: - cursor.execute("TRUNCATE seats;") - for seat in seats: - cursor.execute("INSERT INTO seats (seat_id, seat_category) VALUES (?, ?);", (seat[0], seat[1])) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.generate_fresh_seats_table(seats) - except Exception as e: - logger.warning(f"Error generating fresh seats table: {e}") - return + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("TRUNCATE seats;") + for seat in seats: + await cursor.execute("INSERT INTO seats (seat_id, seat_category) VALUES (%s, %s);", (seat[0], seat[1])) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.generate_fresh_seats_table(seats) + except Exception as e: + logger.warning(f"Error generating fresh seats table: {e}") + return - def get_seating_info(self) -> list[Seat]: - results = [] - cursor = self._get_cursor() - try: - cursor.execute("SELECT seats.*, users.* FROM seats LEFT JOIN users ON seats.user = users.user_id;") - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_seating_info() - except Exception as e: - logger.warning(f"Error getting seats table: {e}") - return results + async def get_seating_info(self) -> list[Seat]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + results = [] + try: + await cursor.execute("SELECT seats.*, users.* FROM seats LEFT JOIN users ON seats.user = users.user_id;") + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_seating_info() + except Exception as e: + logger.warning(f"Error getting seats table: {e}") + return results + for seat_raw in await cursor.fetchall(): + if seat_raw[3] is None: # Empty seat + results.append(Seat(seat_raw[0], bool(seat_raw[1]), seat_raw[2], None)) + else: + user = self._map_db_result_to_user(seat_raw[4:]) + results.append(Seat(seat_raw[0], bool(seat_raw[1]), seat_raw[2], user)) - for seat_raw in cursor.fetchall(): - if seat_raw[3] is None: # Empty seat - results.append(Seat(seat_raw[0], bool(seat_raw[1]), seat_raw[2], None)) - else: - user = self._map_db_result_to_user(seat_raw[4:]) - results.append(Seat(seat_raw[0], bool(seat_raw[1]), seat_raw[2], user)) + return results - return results + async def seat_user(self, seat_id: str, user_id: int) -> bool: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("UPDATE seats SET user = %s WHERE seat_id = %s;", (user_id, seat_id)) + affected_rows = cursor.rowcount + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.seat_user(seat_id, user_id) + except Exception as e: + logger.warning(f"Error seating user: {e}") + return False + return affected_rows > 0 - def seat_user(self, seat_id: str, user_id: int) -> bool: - cursor = self._get_cursor() - try: - cursor.execute("UPDATE seats SET user = ? WHERE seat_id = ?;", (user_id, seat_id)) - affected_rows = cursor.rowcount - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.seat_user(seat_id, user_id) - except Exception as e: - logger.warning(f"Error seating user: {e}") - return False - return bool(affected_rows) + async def get_menu_items(self) -> list[CateringMenuItem]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + results = [] + try: + await cursor.execute("SELECT * FROM catering_menu_items;") + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_menu_items() + except Exception as e: + logger.warning(f"Error fetching menu items: {e}") + return results - def get_menu_items(self) -> list[CateringMenuItem]: - results = [] - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM catering_menu_items;") - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_menu_items() - except Exception as e: - logger.warning(f"Error fetching menu items: {e}") - return results + for menu_item_raw in await cursor.fetchall(): + results.append(CateringMenuItem( + item_id=menu_item_raw[0], + name=menu_item_raw[1], + additional_info=menu_item_raw[2], + price=menu_item_raw[3], + category=CateringMenuItemCategory(menu_item_raw[4]), + is_disabled=bool(menu_item_raw[5]) + )) - for menu_item_raw in cursor.fetchall(): - results.append(CateringMenuItem( - item_id=menu_item_raw[0], - name=menu_item_raw[1], - additional_info=menu_item_raw[2], - price=menu_item_raw[3], - category=CateringMenuItemCategory(menu_item_raw[4]), - is_disabled=bool(menu_item_raw[5]) - )) + return results - return results + async def get_menu_item(self, menu_item_id) -> Optional[CateringMenuItem]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("SELECT * FROM catering_menu_items WHERE catering_menu_item_id = %s;", (menu_item_id, )) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_menu_item(menu_item_id) + except Exception as e: + logger.warning(f"Error fetching menu items: {e}") + return - def get_menu_item(self, menu_item_id) -> Optional[CateringMenuItem]: - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM catering_menu_items WHERE catering_menu_item_id = ?;", (menu_item_id, )) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_menu_item(menu_item_id) - except Exception as e: - logger.warning(f"Error fetching menu items: {e}") - return + raw_data = await cursor.fetchone() + if raw_data is None: + return + return CateringMenuItem( + item_id=raw_data[0], + name=raw_data[1], + additional_info=raw_data[2], + price=raw_data[3], + category=CateringMenuItemCategory(raw_data[4]), + is_disabled=bool(raw_data[5]) + ) - raw_data = cursor.fetchone() - if raw_data is None: - return - return CateringMenuItem( - item_id=raw_data[0], - name=raw_data[1], - additional_info=raw_data[2], - price=raw_data[3], - category=CateringMenuItemCategory(raw_data[4]), - is_disabled=bool(raw_data[5]) - ) + async def add_menu_item(self, name: str, info: str, price: int, category: CateringMenuItemCategory, is_disabled: bool = False) -> Optional[CateringMenuItem]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "INSERT INTO catering_menu_items (name, additional_info, price, category, is_disabled) VALUES (%s, %s, %s, %s, %s);", + (name, info, price, category.value, is_disabled) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.add_menu_item(name, info, price, category, is_disabled) + except Exception as e: + logger.warning(f"Error adding menu item: {e}") + return - def add_menu_item(self, name: str, info: str, price: int, category: CateringMenuItemCategory, is_disabled: bool = False) -> Optional[CateringMenuItem]: - cursor = self._get_cursor() - try: - cursor.execute( - "INSERT INTO catering_menu_items (name, additional_info, price, category, is_disabled) VALUES (?, ?, ?, ?, ?);", - (name, info, price, category.value, is_disabled) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.add_menu_item(name, info, price, category, is_disabled) - except Exception as e: - logger.warning(f"Error adding menu item: {e}") - return - - return CateringMenuItem( - item_id=cursor.lastrowid, - name=name, - additional_info=info, - price=price, - category=category, - is_disabled=is_disabled - ) - - def delete_menu_item(self, menu_item_id: int) -> bool: - cursor = self._get_cursor() - try: - cursor.execute("DELETE FROM catering_menu_items WHERE catering_menu_item_id = ?;", (menu_item_id,)) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.delete_menu_item(menu_item_id) - except Exception as e: - logger.warning(f"Error deleting menu item: {e}") - return False - return bool(cursor.affected_rows) - - def update_menu_item(self, updated_item: CateringMenuItem) -> bool: - cursor = self._get_cursor() - try: - cursor.execute( - "UPDATE catering_menu_items SET name = ?, additional_info = ?, price = ?, category = ?, is_disabled = ? WHERE catering_menu_item_id = ?;", - (updated_item.name, updated_item.additional_info, updated_item.price, updated_item.category.value, updated_item.is_disabled, updated_item.item_id) - ) - affected_rows = cursor.rowcount - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.update_menu_item(updated_item) - except Exception as e: - logger.warning(f"Error updating menu item: {e}") - return False - return bool(affected_rows) - - def add_new_order(self, menu_items: CateringMenuItemsWithAmount, user_id: int, is_delivery: bool) -> Optional[CateringOrder]: - now = datetime.now() - cursor = self._get_cursor() - try: - cursor.execute( - "INSERT INTO orders (status, user, is_delivery, order_date) VALUES (?, ?, ?, ?);", - (CateringOrderStatus.RECEIVED.value, user_id, is_delivery, now) - ) - order_id = cursor.lastrowid - for menu_item, quantity in menu_items.items(): - cursor.execute( - "INSERT INTO order_catering_menu_item (order_id, catering_menu_item_id, quantity) VALUES (?, ?, ?);", - (order_id, menu_item.item_id, quantity) + return CateringMenuItem( + item_id=cursor.lastrowid, + name=name, + additional_info=info, + price=price, + category=category, + is_disabled=is_disabled ) - self._connection.commit() - return CateringOrder( - order_id=order_id, - order_date=now, - status=CateringOrderStatus.RECEIVED, - items=menu_items, - customer=self.get_user_by_id(user_id), - is_delivery=is_delivery - ) - except mariadb.InterfaceError: - self.establish_new_connection() - return self.add_new_order(menu_items, user_id, is_delivery) - except Exception as e: - logger.warning(f"Error placing order: {e}") - return - def change_order_status(self, order_id: int, status: CateringOrderStatus) -> bool: - cursor = self._get_cursor() - try: - cursor.execute( - "UPDATE orders SET status = ? WHERE order_id = ?;", - (status.value, order_id) - ) - affected_rows = cursor.rowcount - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.change_order_status(order_id, status) - except Exception as e: - logger.warning(f"Error updating menu item: {e}") - return False - return bool(affected_rows) + async def delete_menu_item(self, menu_item_id: int) -> bool: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("DELETE FROM catering_menu_items WHERE catering_menu_item_id = %s;", (menu_item_id,)) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.delete_menu_item(menu_item_id) + except Exception as e: + logger.warning(f"Error deleting menu item: {e}") + return False + return cursor.affected_rows > 0 - def get_orders(self, user_id: Optional[int] = None, status: Optional[CateringOrderStatus] = None) -> list[CateringOrder]: - fetched_orders = [] - query = "SELECT * FROM orders LEFT JOIN users ON orders.user = users.user_id" - if user_id is not None and status is None: - query += f" WHERE user = {user_id};" - elif status is not None and user_id is None: - query += f" WHERE status = '{status.value}';" - elif status is not None and user_id is not None: - query += f" WHERE user = {user_id} AND status = '{status.value}';" - else: - query += ";" - cursor = self._get_cursor() - try: - cursor.execute(query) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_orders(user_id, status) - except Exception as e: - logger.warning(f"Error getting orders: {e}") - return fetched_orders + async def update_menu_item(self, updated_item: CateringMenuItem) -> bool: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "UPDATE catering_menu_items SET name = %s, additional_info = %s, price = %s, category = %s, is_disabled = %s WHERE catering_menu_item_id = %s;", + (updated_item.name, updated_item.additional_info, updated_item.price, updated_item.category.value, updated_item.is_disabled, updated_item.item_id) + ) + affected_rows = cursor.rowcount + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.update_menu_item(updated_item) + except Exception as e: + logger.warning(f"Error updating menu item: {e}") + return False + return affected_rows > 0 - for raw_order in cursor.fetchall(): - fetched_orders.append( - CateringOrder( - order_id=raw_order[0], - status=CateringOrderStatus(raw_order[1]), - customer=self._map_db_result_to_user(raw_order[5:]), - items=self.get_menu_items_for_order(raw_order[0]), - is_delivery=bool(raw_order[4]), - order_date=raw_order[3], - ) - ) + async def add_new_order(self, menu_items: CateringMenuItemsWithAmount, user_id: int, is_delivery: bool) -> Optional[CateringOrder]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + now = datetime.now() + try: + await cursor.execute( + "INSERT INTO orders (status, user, is_delivery, order_date) VALUES (%s, %s, %s, %s);", + (CateringOrderStatus.RECEIVED.value, user_id, is_delivery, now) + ) + order_id = cursor.lastrowid + for menu_item, quantity in menu_items.items(): + await cursor.execute( + "INSERT INTO order_catering_menu_item (order_id, catering_menu_item_id, quantity) VALUES (%s, %s, %s);", + (order_id, menu_item.item_id, quantity) + ) + await conn.commit() + return CateringOrder( + order_id=order_id, + order_date=now, + status=CateringOrderStatus.RECEIVED, + items=menu_items, + customer=await self.get_user_by_id(user_id), + is_delivery=is_delivery + ) + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.add_new_order(menu_items, user_id, is_delivery) + except Exception as e: + logger.warning(f"Error placing order: {e}") + return - return fetched_orders + async def change_order_status(self, order_id: int, status: CateringOrderStatus) -> bool: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "UPDATE orders SET status = %s WHERE order_id = %s;", + (status.value, order_id) + ) + affected_rows = cursor.rowcount + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.change_order_status(order_id, status) + except Exception as e: + logger.warning(f"Error updating menu item: {e}") + return False + return affected_rows > 0 - def get_menu_items_for_order(self, order_id: int) -> CateringMenuItemsWithAmount: - cursor = self._get_cursor() - result = {} - try: - cursor.execute( - "SELECT * FROM order_catering_menu_item " - "LEFT JOIN catering_menu_items ON order_catering_menu_item.catering_menu_item_id = catering_menu_items.catering_menu_item_id " - "WHERE order_id = ?;", - (order_id, ) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_menu_items_for_order(order_id) - except Exception as e: - logger.warning(f"Error getting order items: {e}") - return result + async def get_orders(self, user_id: Optional[int] = None, status: Optional[CateringOrderStatus] = None) -> list[CateringOrder]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + fetched_orders = [] + query = "SELECT * FROM orders LEFT JOIN users ON orders.user = users.user_id" + if user_id is not None and status is None: + query += f" WHERE user = {user_id};" + elif status is not None and user_id is None: + query += f" WHERE status = '{status.value}';" + elif status is not None and user_id is not None: + query += f" WHERE user = {user_id} AND status = '{status.value}';" + else: + query += ";" + try: + await cursor.execute(query) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_orders(user_id, status) + except Exception as e: + logger.warning(f"Error getting orders: {e}") + return fetched_orders - for order_catering_menu_item_raw in cursor.fetchall(): - result[CateringMenuItem( - item_id=order_catering_menu_item_raw[1], - name=order_catering_menu_item_raw[4], - additional_info=order_catering_menu_item_raw[5], - price=order_catering_menu_item_raw[6], - category=CateringMenuItemCategory(order_catering_menu_item_raw[7]), - is_disabled=bool(order_catering_menu_item_raw[8]) - )] = order_catering_menu_item_raw[2] + for raw_order in await cursor.fetchall(): + fetched_orders.append( + CateringOrder( + order_id=raw_order[0], + status=CateringOrderStatus(raw_order[1]), + customer=self._map_db_result_to_user(raw_order[5:]), + items=await self.get_menu_items_for_order(raw_order[0]), + is_delivery=bool(raw_order[4]), + order_date=raw_order[3], + ) + ) - return result + return fetched_orders - def set_user_profile_picture(self, user_id: int, picture_data: bytes) -> None: - cursor = self._get_cursor() - try: - cursor.execute( - "INSERT INTO user_profile_picture (user_id, picture) VALUES (?, ?) ON DUPLICATE KEY UPDATE picture = VALUES(picture)", - (user_id, picture_data) - ) - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.set_user_profile_picture(user_id, picture_data) - except Exception as e: - logger.warning(f"Error setting user profile picture: {e}") + async def get_menu_items_for_order(self, order_id: int) -> CateringMenuItemsWithAmount: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + result = {} + try: + await cursor.execute( + "SELECT * FROM order_catering_menu_item " + "LEFT JOIN catering_menu_items ON order_catering_menu_item.catering_menu_item_id = catering_menu_items.catering_menu_item_id " + "WHERE order_id = %s;", + (order_id, ) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_menu_items_for_order(order_id) + except Exception as e: + logger.warning(f"Error getting order items: {e}") + return result - def get_user_profile_picture(self, user_id: int) -> Optional[bytes]: - cursor = self._get_cursor() - try: - cursor.execute("SELECT (picture) FROM user_profile_picture WHERE user_id = ?", (user_id, )) - self._connection.commit() - r = cursor.fetchone() - if r is None: - return - return r[0] - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_user_profile_picture(user_id) - except Exception as e: - logger.warning(f"Error setting user profile picture: {e}") - return None + for order_catering_menu_item_raw in await cursor.fetchall(): + result[CateringMenuItem( + item_id=order_catering_menu_item_raw[1], + name=order_catering_menu_item_raw[4], + additional_info=order_catering_menu_item_raw[5], + price=order_catering_menu_item_raw[6], + category=CateringMenuItemCategory(order_catering_menu_item_raw[7]), + is_disabled=bool(order_catering_menu_item_raw[8]) + )] = order_catering_menu_item_raw[2] - def get_all_users(self) -> list[User]: - results = [] - cursor = self._get_cursor() - try: - cursor.execute("SELECT * FROM users;") - self._connection.commit() - except mariadb.InterfaceError: - self.establish_new_connection() - return self.get_all_users() - except Exception as e: - logger.warning(f"Error getting all users: {e}") - return results + return result - for user_raw in cursor.fetchall(): - results.append(self._map_db_result_to_user(user_raw)) + async def set_user_profile_picture(self, user_id: int, picture_data: bytes) -> None: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute( + "INSERT INTO user_profile_picture (user_id, picture) VALUES (%s, %s) ON DUPLICATE KEY UPDATE picture = VALUES(picture)", + (user_id, picture_data) + ) + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.set_user_profile_picture(user_id, picture_data) + except Exception as e: + logger.warning(f"Error setting user profile picture: {e}") + + async def get_user_profile_picture(self, user_id: int) -> Optional[bytes]: + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + try: + await cursor.execute("SELECT (picture) FROM user_profile_picture WHERE user_id = %s", (user_id, )) + await conn.commit() + r = await cursor.fetchone() + if r is None: + return + return r[0] + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_user_profile_picture(user_id) + except Exception as e: + logger.warning(f"Error setting user profile picture: {e}") + return None + + async def get_all_users(self) -> list[User]: + + async with self._connection_pool.acquire() as conn: + async with conn.cursor(aiomysql.Cursor) as cursor: + results = [] + try: + await cursor.execute("SELECT * FROM users;") + await conn.commit() + except aiomysql.InterfaceError: + pool_init_result = await self.init_db_pool() + if not pool_init_result: + raise NoDatabaseConnectionError + return await self.get_all_users() + except Exception as e: + logger.warning(f"Error getting all users: {e}") + return results + + for user_raw in await cursor.fetchall(): + results.append(self._map_db_result_to_user(user_raw)) return results diff --git a/src/ez_lan_manager/services/NewsService.py b/src/ez_lan_manager/services/NewsService.py index 14ab77b..6c829a5 100644 --- a/src/ez_lan_manager/services/NewsService.py +++ b/src/ez_lan_manager/services/NewsService.py @@ -1,5 +1,5 @@ import logging -from datetime import date, datetime +from datetime import date from typing import Optional from src.ez_lan_manager.services.DatabaseService import DatabaseService @@ -11,21 +11,22 @@ class NewsService: def __init__(self, db_service: DatabaseService) -> None: self._db_service = db_service - def add_news(self, news: News) -> None: + async def add_news(self, news: News) -> None: if news.news_id is not None: logger.warning("Can not add news with ID, ignoring...") return - self._db_service.add_news(news) + await self._db_service.add_news(news) - def get_news(self, dt_start: Optional[date] = None, dt_end: Optional[date] = None) -> list[News]: + async def get_news(self, dt_start: Optional[date] = None, dt_end: Optional[date] = None) -> list[News]: if not dt_end: dt_end = date.today() if not dt_start: dt_start = date(1900, 1, 1) - return self._db_service.get_news(dt_start, dt_end) + return await self._db_service.get_news(dt_start, dt_end) - def get_latest_news(self) -> Optional[News]: + async def get_latest_news(self) -> Optional[News]: try: - return self.get_news(None, date.today())[0] + all_news = await self.get_news(None, date.today()) + return all_news[0] except IndexError: logger.debug("There are no news to fetch") diff --git a/src/ez_lan_manager/services/SeatingService.py b/src/ez_lan_manager/services/SeatingService.py index 6279df0..ee80b3e 100644 --- a/src/ez_lan_manager/services/SeatingService.py +++ b/src/ez_lan_manager/services/SeatingService.py @@ -36,27 +36,27 @@ class SeatingService: ElementTree.parse(self._seating_configuration.base_svg_path).write(self._seating_plan, encoding="unicode") - def get_seating(self) -> list[Seat]: - return self._db_service.get_seating_info() + async def get_seating(self) -> list[Seat]: + return await self._db_service.get_seating_info() - def get_seat(self, seat_id: str, cached_data: Optional[list[Seat]] = None) -> Optional[Seat]: - all_seats = self.get_seating() if not cached_data else cached_data + async def get_seat(self, seat_id: str, cached_data: Optional[list[Seat]] = None) -> Optional[Seat]: + all_seats = await self.get_seating() if not cached_data else cached_data for seat in all_seats: if seat.seat_id == seat_id: return seat - def get_user_seat(self, user_id: int) -> Optional[Seat]: - all_seats = self.get_seating() + async def get_user_seat(self, user_id: int) -> Optional[Seat]: + all_seats = await self.get_seating() for seat in all_seats: if seat.user and seat.user.user_id == user_id: return seat - def seat_user(self, user_id: int, seat_id: str) -> None: - user_ticket = self._ticketing_service.get_user_ticket(user_id) + async def seat_user(self, user_id: int, seat_id: str) -> None: + user_ticket = await self._ticketing_service.get_user_ticket(user_id) if not user_ticket: raise NoTicketError - seat = self.get_seat(seat_id) + seat = await self.get_seat(seat_id) if not seat: raise SeatNotFoundError @@ -66,10 +66,10 @@ class SeatingService: if seat.user is not None: raise SeatAlreadyTakenError - self._db_service.seat_user(seat_id, user_id) - self.update_svg_with_seating_status() + await self._db_service.seat_user(seat_id, user_id) + await self.update_svg_with_seating_status() - def generate_new_seating_table(self, seating_plan_fp: Path, no_confirm: bool = False) -> None: + async def generate_new_seating_table(self, seating_plan_fp: Path, no_confirm: bool = False) -> None: if not no_confirm: confirm = input("WARNING: THIS ACTION WILL DELETE ALL SEATING DATA! TYPE 'AGREE' TO CONTINUE: ") if confirm != "AGREE": @@ -95,10 +95,10 @@ class SeatingService: except TypeError: continue - self._db_service.generate_fresh_seats_table(sorted(seat_ids, key=lambda sd: sd[0])) - self.update_svg_with_seating_status() + await self._db_service.generate_fresh_seats_table(sorted(seat_ids, key=lambda sd: sd[0])) + await self.update_svg_with_seating_status() - def update_svg_with_seating_status(self) -> None: + async def update_svg_with_seating_status(self) -> None: et = ElementTree.parse(self._seating_configuration.base_svg_path) root = et.getroot() namespace = {'svg': root.tag.split('}')[0].strip('{')} if '}' in root.tag else {} @@ -113,13 +113,13 @@ class SeatingService: rect_g_pairs.append((last_rect, elem)) last_rect = None - all_seats = self.get_seating() + all_seats = await self.get_seating() for rect, g in rect_g_pairs: seat_id = self.get_seat_id_from_element(g, namespace) if not seat_id: continue - seat = self.get_seat(seat_id, cached_data=all_seats) + seat = await self.get_seat(seat_id, cached_data=all_seats) if not seat.is_blocked and seat.user is None: rect.set("fill", "rgb(102, 255, 51)") elif not seat.is_blocked and seat.user is not None: diff --git a/src/ez_lan_manager/services/TicketingService.py b/src/ez_lan_manager/services/TicketingService.py index 13975cb..3502ff2 100644 --- a/src/ez_lan_manager/services/TicketingService.py +++ b/src/ez_lan_manager/services/TicketingService.py @@ -21,30 +21,30 @@ class TicketingService: self._db_service = db_service self._accounting_service = accounting_service - def get_total_tickets(self) -> int: + async def get_total_tickets(self) -> int: return sum([self._lan_info.ticket_info.get_available_tickets(c) for c in self._lan_info.ticket_info.categories]) - def get_available_tickets(self) -> dict[str, int]: + async def get_available_tickets(self) -> dict[str, int]: result = self._lan_info.ticket_info.total_available_tickets - all_tickets = self._db_service.get_tickets() + all_tickets = await self._db_service.get_tickets() for ticket in all_tickets: result[ticket.category] -= 1 return result - def purchase_ticket(self, user_id: int, category: str) -> Ticket: - if category not in self._lan_info.ticket_info.categories or self.get_available_tickets()[category] < 1: + async def purchase_ticket(self, user_id: int, category: str) -> Ticket: + if category not in self._lan_info.ticket_info.categories or (await self.get_available_tickets())[category] < 1: raise TicketNotAvailableError(category) - user_balance = self._accounting_service.get_balance(user_id) + user_balance = await self._accounting_service.get_balance(user_id) if self._lan_info.ticket_info.get_price(category) > user_balance: raise InsufficientFundsError if self.get_user_ticket(user_id): raise UserAlreadyHasTicketError - if new_ticket := self._db_service.generate_ticket_for_user(user_id, category): - self._accounting_service.remove_balance( + if new_ticket := await self._db_service.generate_ticket_for_user(user_id, category): + await self._accounting_service.remove_balance( user_id, self._lan_info.ticket_info.get_price(new_ticket.category), f"TICKET {new_ticket.ticket_id}" @@ -54,20 +54,20 @@ class TicketingService: raise RuntimeError("An unknown error occurred while purchasing ticket") - def refund_ticket(self, user_id: int) -> bool: - user_ticket = self.get_user_ticket(user_id) + async def refund_ticket(self, user_id: int) -> bool: + user_ticket = await self.get_user_ticket(user_id) if not user_ticket: return False if self._db_service.delete_ticket(user_ticket.ticket_id): - self._accounting_service.add_balance(user_id, self._lan_info.ticket_info.get_price(user_ticket.category), f"TICKET REFUND {user_ticket.ticket_id}") + await self._accounting_service.add_balance(user_id, self._lan_info.ticket_info.get_price(user_ticket.category), f"TICKET REFUND {user_ticket.ticket_id}") logger.debug(f"User {user_id} refunded ticket {user_ticket.ticket_id}") return True return False - def transfer_ticket(self, ticket_id: int, user_id: int) -> bool: - return self._db_service.change_ticket_owner(ticket_id, user_id) + async def transfer_ticket(self, ticket_id: int, user_id: int) -> bool: + return await self._db_service.change_ticket_owner(ticket_id, user_id) - def get_user_ticket(self, user_id: int) -> Optional[Ticket]: - return self._db_service.get_ticket_for_user(user_id) + async def get_user_ticket(self, user_id: int) -> Optional[Ticket]: + return await self._db_service.get_ticket_for_user(user_id) diff --git a/src/ez_lan_manager/services/UserService.py b/src/ez_lan_manager/services/UserService.py index d24f213..2120ae4 100644 --- a/src/ez_lan_manager/services/UserService.py +++ b/src/ez_lan_manager/services/UserService.py @@ -17,26 +17,26 @@ class UserService: def __init__(self, db_service: DatabaseService) -> None: self._db_service = db_service - def get_all_users(self) -> list[User]: - return self._db_service.get_all_users() + async def get_all_users(self) -> list[User]: + return await self._db_service.get_all_users() - def get_user(self, accessor: Optional[Union[str, int]]) -> Optional[User]: + async def get_user(self, accessor: Optional[Union[str, int]]) -> Optional[User]: if accessor is None: return if isinstance(accessor, int): - return self._db_service.get_user_by_id(accessor) + return await self._db_service.get_user_by_id(accessor) accessor = accessor.lower() if "@" in accessor: - return self._db_service.get_user_by_mail(accessor) - return self._db_service.get_user_by_name(accessor) + return await self._db_service.get_user_by_mail(accessor) + return await self._db_service.get_user_by_name(accessor) - def set_profile_picture(self, user_id: int, picture: bytes) -> None: - self._db_service.set_user_profile_picture(user_id, picture) + async def set_profile_picture(self, user_id: int, picture: bytes) -> None: + await self._db_service.set_user_profile_picture(user_id, picture) - def get_profile_picture(self, user_id: int) -> bytes: - return self._db_service.get_user_profile_picture(user_id) + async def get_profile_picture(self, user_id: int) -> bytes: + return await self._db_service.get_user_profile_picture(user_id) - def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User: + async def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User: disallowed_char = self._check_for_disallowed_char(user_name) if disallowed_char: raise NameNotAllowedError(disallowed_char) @@ -44,17 +44,17 @@ class UserService: user_name = user_name.lower() hashed_pw = sha256(password_clear_text.encode(encoding="utf-8")).hexdigest() - return self._db_service.create_user(user_name, user_mail, hashed_pw) + return await self._db_service.create_user(user_name, user_mail, hashed_pw) - def update_user(self, user: User) -> User: + async def update_user(self, user: User) -> User: disallowed_char = self._check_for_disallowed_char(user.user_name) if disallowed_char: raise NameNotAllowedError(disallowed_char) user.user_name = user.user_name.lower() - return self._db_service.update_user(user) + return await self._db_service.update_user(user) - def is_login_valid(self, user_name_or_mail: str, password_clear_text: str) -> bool: - user = self.get_user(user_name_or_mail) + async def is_login_valid(self, user_name_or_mail: str, password_clear_text: str) -> bool: + user = await self.get_user(user_name_or_mail) if not user: return False return user.user_password == sha256(password_clear_text.encode(encoding="utf-8")).hexdigest() diff --git a/src/ez_lan_manager/types/User.py b/src/ez_lan_manager/types/User.py index 164abac..a397962 100644 --- a/src/ez_lan_manager/types/User.py +++ b/src/ez_lan_manager/types/User.py @@ -17,3 +17,6 @@ class User: is_admin: bool created_at: datetime last_updated_at: datetime + + def __hash__(self) -> int: + return hash(f"{self.user_id}{self.user_name}{self.user_mail}")