aiomysql refactor

This commit is contained in:
David Rodenkirchen
2024-09-03 14:30:32 +02:00
parent a9597b5c4f
commit 30b32a4c02
24 changed files with 901 additions and 755 deletions
@@ -13,8 +13,8 @@ class AccountingService:
def __init__(self, db_service: DatabaseService) -> None:
self._db_service = db_service
def add_balance(self, user_id: int, balance_to_add: int, reference: str) -> int:
self._db_service.add_transaction(Transaction(
async def add_balance(self, user_id: int, balance_to_add: int, reference: str) -> int:
await self._db_service.add_transaction(Transaction(
user_id=user_id,
value=balance_to_add,
is_debit=False,
@@ -22,13 +22,13 @@ class AccountingService:
transaction_date=datetime.now()
))
logger.debug(f"Added balance of {self.make_euro_string_from_int(balance_to_add)} to user with ID {user_id}")
return self.get_balance(user_id)
return await self.get_balance(user_id)
def remove_balance(self, user_id: int, balance_to_remove: int, reference: str) -> int:
current_balance = self.get_balance(user_id)
async def remove_balance(self, user_id: int, balance_to_remove: int, reference: str) -> int:
current_balance = await self.get_balance(user_id)
if (current_balance - balance_to_remove) < 0:
raise InsufficientFundsError
self._db_service.add_transaction(Transaction(
await self._db_service.add_transaction(Transaction(
user_id=user_id,
value=balance_to_remove,
is_debit=True,
@@ -36,19 +36,19 @@ class AccountingService:
transaction_date=datetime.now()
))
logger.debug(f"Removed balance of {self.make_euro_string_from_int(balance_to_remove)} to user with ID {user_id}")
return self.get_balance(user_id)
return await self.get_balance(user_id)
def get_balance(self, user_id: int) -> int:
async def get_balance(self, user_id: int) -> int:
balance_buffer = 0
for transaction in self._db_service.get_all_transactions_for_user(user_id):
for transaction in await self._db_service.get_all_transactions_for_user(user_id):
if transaction.is_debit:
balance_buffer -= transaction.value
else:
balance_buffer += transaction.value
return balance_buffer
def get_transaction_history(self, user_id: int) -> list[Transaction]:
return self._db_service.get_all_transactions_for_user(user_id)
async def get_transaction_history(self, user_id: int) -> list[Transaction]:
return await self._db_service.get_all_transactions_for_user(user_id)
@staticmethod
def make_euro_string_from_int(cent_int: int) -> str:
+35 -35
View File
@@ -23,93 +23,93 @@ class CateringService:
# ORDERS
def place_order(self, menu_items: CateringMenuItemsWithAmount, user_id: int, is_delivery: bool = True) -> CateringOrder:
async def place_order(self, menu_items: CateringMenuItemsWithAmount, user_id: int, is_delivery: bool = True) -> CateringOrder:
for menu_item in menu_items:
if menu_item.is_disabled:
raise CateringError("Order includes disabled items")
user = self._user_service.get_user(user_id)
user = await self._user_service.get_user(user_id)
if not user:
raise CateringError("User does not exist")
total_price = sum([item.price * quantity for item, quantity in menu_items.items()])
if self._accounting_service.get_balance(user_id) < total_price:
if await self._accounting_service.get_balance(user_id) < total_price:
raise CateringError("Insufficient funds")
order = self._db_service.add_new_order(menu_items, user_id, is_delivery)
order = await self._db_service.add_new_order(menu_items, user_id, is_delivery)
if order:
self._accounting_service.remove_balance(user_id, total_price, f"CATERING - {order.order_id}")
await self._accounting_service.remove_balance(user_id, total_price, f"CATERING - {order.order_id}")
logger.info(f"User '{order.customer.user_name}' (ID:{order.customer.user_id}) ordered from catering for {self._accounting_service.make_euro_string_from_int(total_price)}")
return order
def update_order_status(self, order_id: int, new_status: CateringOrderStatus) -> bool:
async def update_order_status(self, order_id: int, new_status: CateringOrderStatus) -> bool:
if new_status == CateringOrderStatus.CANCELED:
# Cancelled orders need to be refunded
raise CateringError("Orders cannot be canceled this way, use CateringService.cancel_order")
return self._db_service.change_order_status(order_id, new_status)
return await self._db_service.change_order_status(order_id, new_status)
def get_orders(self) -> list[CateringOrder]:
return self._db_service.get_orders()
async def get_orders(self) -> list[CateringOrder]:
return await self._db_service.get_orders()
def get_orders_for_user(self, user_id: int) -> list[CateringOrder]:
return self._db_service.get_orders(user_id=user_id)
async def get_orders_for_user(self, user_id: int) -> list[CateringOrder]:
return await self._db_service.get_orders(user_id=user_id)
def get_orders_by_status(self, status: CateringOrderStatus) -> list[CateringOrder]:
return self._db_service.get_orders(status=status)
async def get_orders_by_status(self, status: CateringOrderStatus) -> list[CateringOrder]:
return await self._db_service.get_orders(status=status)
def cancel_order(self, order: CateringOrder) -> bool:
async def cancel_order(self, order: CateringOrder) -> bool:
if self._db_service.change_order_status(order.order_id, CateringOrderStatus.CANCELED):
self._accounting_service.add_balance(order.customer.user_id, order.price, f"CATERING REFUND - {order.order_id}")
await self._accounting_service.add_balance(order.customer.user_id, order.price, f"CATERING REFUND - {order.order_id}")
return True
return False
# MENU ITEMS
def get_menu(self, category: Optional[CateringMenuItemCategory] = None) -> list[CateringMenuItem]:
items = self._db_service.get_menu_items()
async def get_menu(self, category: Optional[CateringMenuItemCategory] = None) -> list[CateringMenuItem]:
items = await self._db_service.get_menu_items()
if not category:
return items
return list(filter(lambda item: item.category == category, items))
def get_menu_item_by_id(self, menu_item_id: int) -> CateringMenuItem:
item = self._db_service.get_menu_item(menu_item_id)
async def get_menu_item_by_id(self, menu_item_id: int) -> CateringMenuItem:
item = await self._db_service.get_menu_item(menu_item_id)
if not item:
raise CateringError("Menu item not found")
return item
def add_menu_item(self, name: str, info: str, price: int, category: CateringMenuItemCategory, is_disabled: bool = False) -> CateringMenuItem:
if new_item := self._db_service.add_menu_item(name, info, price, category, is_disabled):
async def add_menu_item(self, name: str, info: str, price: int, category: CateringMenuItemCategory, is_disabled: bool = False) -> CateringMenuItem:
if new_item := await self._db_service.add_menu_item(name, info, price, category, is_disabled):
return new_item
raise CateringError(f"Could not add item '{name}' to the menu.")
def remove_menu_item(self, menu_item_id: int) -> bool:
return self._db_service.delete_menu_item(menu_item_id)
async def remove_menu_item(self, menu_item_id: int) -> bool:
return await self._db_service.delete_menu_item(menu_item_id)
def change_menu_item(self, updated_item: CateringMenuItem) -> bool:
return self._db_service.update_menu_item(updated_item)
async def change_menu_item(self, updated_item: CateringMenuItem) -> bool:
return await self._db_service.update_menu_item(updated_item)
def disable_menu_item(self, menu_item_id: int) -> bool:
async def disable_menu_item(self, menu_item_id: int) -> bool:
try:
item = self.get_menu_item_by_id(menu_item_id)
item = await self.get_menu_item_by_id(menu_item_id)
except CateringError:
return False
item.is_disabled = True
return self._db_service.update_menu_item(item)
return await self._db_service.update_menu_item(item)
def enable_menu_item(self, menu_item_id: int) -> bool:
async def enable_menu_item(self, menu_item_id: int) -> bool:
try:
item = self.get_menu_item_by_id(menu_item_id)
item = await self.get_menu_item_by_id(menu_item_id)
except CateringError:
return False
item.is_disabled = False
return self._db_service.update_menu_item(item)
return await self._db_service.update_menu_item(item)
def disable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool:
items = self.get_menu(category=category)
async def disable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool:
items = await self.get_menu(category=category)
return all([self.disable_menu_item(item.item_id) for item in items])
def enable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool:
items = self.get_menu(category=category)
async def enable_menu_items_by_category(self, category: CateringMenuItemCategory) -> bool:
items = await self.get_menu(category=category)
return all([self.enable_menu_item(item.item_id) for item in items])
# CART
File diff suppressed because it is too large Load Diff
+8 -7
View File
@@ -1,5 +1,5 @@
import logging
from datetime import date, datetime
from datetime import date
from typing import Optional
from src.ez_lan_manager.services.DatabaseService import DatabaseService
@@ -11,21 +11,22 @@ class NewsService:
def __init__(self, db_service: DatabaseService) -> None:
self._db_service = db_service
def add_news(self, news: News) -> None:
async def add_news(self, news: News) -> None:
if news.news_id is not None:
logger.warning("Can not add news with ID, ignoring...")
return
self._db_service.add_news(news)
await self._db_service.add_news(news)
def get_news(self, dt_start: Optional[date] = None, dt_end: Optional[date] = None) -> list[News]:
async def get_news(self, dt_start: Optional[date] = None, dt_end: Optional[date] = None) -> list[News]:
if not dt_end:
dt_end = date.today()
if not dt_start:
dt_start = date(1900, 1, 1)
return self._db_service.get_news(dt_start, dt_end)
return await self._db_service.get_news(dt_start, dt_end)
def get_latest_news(self) -> Optional[News]:
async def get_latest_news(self) -> Optional[News]:
try:
return self.get_news(None, date.today())[0]
all_news = await self.get_news(None, date.today())
return all_news[0]
except IndexError:
logger.debug("There are no news to fetch")
+17 -17
View File
@@ -36,27 +36,27 @@ class SeatingService:
ElementTree.parse(self._seating_configuration.base_svg_path).write(self._seating_plan, encoding="unicode")
def get_seating(self) -> list[Seat]:
return self._db_service.get_seating_info()
async def get_seating(self) -> list[Seat]:
return await self._db_service.get_seating_info()
def get_seat(self, seat_id: str, cached_data: Optional[list[Seat]] = None) -> Optional[Seat]:
all_seats = self.get_seating() if not cached_data else cached_data
async def get_seat(self, seat_id: str, cached_data: Optional[list[Seat]] = None) -> Optional[Seat]:
all_seats = await self.get_seating() if not cached_data else cached_data
for seat in all_seats:
if seat.seat_id == seat_id:
return seat
def get_user_seat(self, user_id: int) -> Optional[Seat]:
all_seats = self.get_seating()
async def get_user_seat(self, user_id: int) -> Optional[Seat]:
all_seats = await self.get_seating()
for seat in all_seats:
if seat.user and seat.user.user_id == user_id:
return seat
def seat_user(self, user_id: int, seat_id: str) -> None:
user_ticket = self._ticketing_service.get_user_ticket(user_id)
async def seat_user(self, user_id: int, seat_id: str) -> None:
user_ticket = await self._ticketing_service.get_user_ticket(user_id)
if not user_ticket:
raise NoTicketError
seat = self.get_seat(seat_id)
seat = await self.get_seat(seat_id)
if not seat:
raise SeatNotFoundError
@@ -66,10 +66,10 @@ class SeatingService:
if seat.user is not None:
raise SeatAlreadyTakenError
self._db_service.seat_user(seat_id, user_id)
self.update_svg_with_seating_status()
await self._db_service.seat_user(seat_id, user_id)
await self.update_svg_with_seating_status()
def generate_new_seating_table(self, seating_plan_fp: Path, no_confirm: bool = False) -> None:
async def generate_new_seating_table(self, seating_plan_fp: Path, no_confirm: bool = False) -> None:
if not no_confirm:
confirm = input("WARNING: THIS ACTION WILL DELETE ALL SEATING DATA! TYPE 'AGREE' TO CONTINUE: ")
if confirm != "AGREE":
@@ -95,10 +95,10 @@ class SeatingService:
except TypeError:
continue
self._db_service.generate_fresh_seats_table(sorted(seat_ids, key=lambda sd: sd[0]))
self.update_svg_with_seating_status()
await self._db_service.generate_fresh_seats_table(sorted(seat_ids, key=lambda sd: sd[0]))
await self.update_svg_with_seating_status()
def update_svg_with_seating_status(self) -> None:
async def update_svg_with_seating_status(self) -> None:
et = ElementTree.parse(self._seating_configuration.base_svg_path)
root = et.getroot()
namespace = {'svg': root.tag.split('}')[0].strip('{')} if '}' in root.tag else {}
@@ -113,13 +113,13 @@ class SeatingService:
rect_g_pairs.append((last_rect, elem))
last_rect = None
all_seats = self.get_seating()
all_seats = await self.get_seating()
for rect, g in rect_g_pairs:
seat_id = self.get_seat_id_from_element(g, namespace)
if not seat_id:
continue
seat = self.get_seat(seat_id, cached_data=all_seats)
seat = await self.get_seat(seat_id, cached_data=all_seats)
if not seat.is_blocked and seat.user is None:
rect.set("fill", "rgb(102, 255, 51)")
elif not seat.is_blocked and seat.user is not None:
+15 -15
View File
@@ -21,30 +21,30 @@ class TicketingService:
self._db_service = db_service
self._accounting_service = accounting_service
def get_total_tickets(self) -> int:
async def get_total_tickets(self) -> int:
return sum([self._lan_info.ticket_info.get_available_tickets(c) for c in self._lan_info.ticket_info.categories])
def get_available_tickets(self) -> dict[str, int]:
async def get_available_tickets(self) -> dict[str, int]:
result = self._lan_info.ticket_info.total_available_tickets
all_tickets = self._db_service.get_tickets()
all_tickets = await self._db_service.get_tickets()
for ticket in all_tickets:
result[ticket.category] -= 1
return result
def purchase_ticket(self, user_id: int, category: str) -> Ticket:
if category not in self._lan_info.ticket_info.categories or self.get_available_tickets()[category] < 1:
async def purchase_ticket(self, user_id: int, category: str) -> Ticket:
if category not in self._lan_info.ticket_info.categories or (await self.get_available_tickets())[category] < 1:
raise TicketNotAvailableError(category)
user_balance = self._accounting_service.get_balance(user_id)
user_balance = await self._accounting_service.get_balance(user_id)
if self._lan_info.ticket_info.get_price(category) > user_balance:
raise InsufficientFundsError
if self.get_user_ticket(user_id):
raise UserAlreadyHasTicketError
if new_ticket := self._db_service.generate_ticket_for_user(user_id, category):
self._accounting_service.remove_balance(
if new_ticket := await self._db_service.generate_ticket_for_user(user_id, category):
await self._accounting_service.remove_balance(
user_id,
self._lan_info.ticket_info.get_price(new_ticket.category),
f"TICKET {new_ticket.ticket_id}"
@@ -54,20 +54,20 @@ class TicketingService:
raise RuntimeError("An unknown error occurred while purchasing ticket")
def refund_ticket(self, user_id: int) -> bool:
user_ticket = self.get_user_ticket(user_id)
async def refund_ticket(self, user_id: int) -> bool:
user_ticket = await self.get_user_ticket(user_id)
if not user_ticket:
return False
if self._db_service.delete_ticket(user_ticket.ticket_id):
self._accounting_service.add_balance(user_id, self._lan_info.ticket_info.get_price(user_ticket.category), f"TICKET REFUND {user_ticket.ticket_id}")
await self._accounting_service.add_balance(user_id, self._lan_info.ticket_info.get_price(user_ticket.category), f"TICKET REFUND {user_ticket.ticket_id}")
logger.debug(f"User {user_id} refunded ticket {user_ticket.ticket_id}")
return True
return False
def transfer_ticket(self, ticket_id: int, user_id: int) -> bool:
return self._db_service.change_ticket_owner(ticket_id, user_id)
async def transfer_ticket(self, ticket_id: int, user_id: int) -> bool:
return await self._db_service.change_ticket_owner(ticket_id, user_id)
def get_user_ticket(self, user_id: int) -> Optional[Ticket]:
return self._db_service.get_ticket_for_user(user_id)
async def get_user_ticket(self, user_id: int) -> Optional[Ticket]:
return await self._db_service.get_ticket_for_user(user_id)
+16 -16
View File
@@ -17,26 +17,26 @@ class UserService:
def __init__(self, db_service: DatabaseService) -> None:
self._db_service = db_service
def get_all_users(self) -> list[User]:
return self._db_service.get_all_users()
async def get_all_users(self) -> list[User]:
return await self._db_service.get_all_users()
def get_user(self, accessor: Optional[Union[str, int]]) -> Optional[User]:
async def get_user(self, accessor: Optional[Union[str, int]]) -> Optional[User]:
if accessor is None:
return
if isinstance(accessor, int):
return self._db_service.get_user_by_id(accessor)
return await self._db_service.get_user_by_id(accessor)
accessor = accessor.lower()
if "@" in accessor:
return self._db_service.get_user_by_mail(accessor)
return self._db_service.get_user_by_name(accessor)
return await self._db_service.get_user_by_mail(accessor)
return await self._db_service.get_user_by_name(accessor)
def set_profile_picture(self, user_id: int, picture: bytes) -> None:
self._db_service.set_user_profile_picture(user_id, picture)
async def set_profile_picture(self, user_id: int, picture: bytes) -> None:
await self._db_service.set_user_profile_picture(user_id, picture)
def get_profile_picture(self, user_id: int) -> bytes:
return self._db_service.get_user_profile_picture(user_id)
async def get_profile_picture(self, user_id: int) -> bytes:
return await self._db_service.get_user_profile_picture(user_id)
def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User:
async def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User:
disallowed_char = self._check_for_disallowed_char(user_name)
if disallowed_char:
raise NameNotAllowedError(disallowed_char)
@@ -44,17 +44,17 @@ class UserService:
user_name = user_name.lower()
hashed_pw = sha256(password_clear_text.encode(encoding="utf-8")).hexdigest()
return self._db_service.create_user(user_name, user_mail, hashed_pw)
return await self._db_service.create_user(user_name, user_mail, hashed_pw)
def update_user(self, user: User) -> User:
async def update_user(self, user: User) -> User:
disallowed_char = self._check_for_disallowed_char(user.user_name)
if disallowed_char:
raise NameNotAllowedError(disallowed_char)
user.user_name = user.user_name.lower()
return self._db_service.update_user(user)
return await self._db_service.update_user(user)
def is_login_valid(self, user_name_or_mail: str, password_clear_text: str) -> bool:
user = self.get_user(user_name_or_mail)
async def is_login_valid(self, user_name_or_mail: str, password_clear_text: str) -> bool:
user = await self.get_user(user_name_or_mail)
if not user:
return False
return user.user_password == sha256(password_clear_text.encode(encoding="utf-8")).hexdigest()