ezgg-lan-manager/src/ez_lan_manager/services/DatabaseService.py
2024-08-19 11:29:00 +02:00

145 lines
5.6 KiB
Python

import logging
import sys
from typing import Optional
import mariadb
from mariadb import Cursor
from src.ez_lan_manager.types.ConfigurationTypes import DatabaseConfiguration
from src.ez_lan_manager.types.Transaction import Transaction
from src.ez_lan_manager.types.User import User
logger = logging.getLogger(__name__.split(".")[-1])
class DuplicationError(Exception):
pass
class DatabaseService:
def __init__(self, database_config: DatabaseConfiguration) -> None:
self._database_config = database_config
try:
logger.info(
f"Connecting to database '{self._database_config.db_name}' on "
f"{self._database_config.db_user}@{self._database_config.db_host}:{self._database_config.db_port}"
)
self._connection = mariadb.connect(
user=self._database_config.db_user,
password=self._database_config.db_password,
host=self._database_config.db_host,
port=self._database_config.db_port,
database=self._database_config.db_name
)
except mariadb.Error as e:
logger.fatal(f"Error connecting to database: {e}")
sys.exit(1)
def _get_cursor(self) -> Cursor:
return self._connection.cursor()
@staticmethod
def _map_db_result_to_user(data: tuple) -> User:
return User(
user_id=data[0],
user_name=data[1],
user_mail=data[2],
user_password=data[3],
user_first_name=data[4],
user_last_name=data[5],
user_birth_day=data[6],
is_active=bool(data[7]),
is_team_member=bool(data[8]),
is_admin=bool(data[9]),
created_at=data[10],
last_updated_at=data[11]
)
def get_user_by_name(self, user_name: str) -> Optional[User]:
cursor = self._get_cursor()
cursor.execute("SELECT * FROM users WHERE user_name=?", (user_name,))
result = cursor.fetchone()
if not result:
return
return self._map_db_result_to_user(result)
def get_user_by_id(self, user_id: int) -> Optional[User]:
cursor = self._get_cursor()
cursor.execute("SELECT * FROM users WHERE user_id=?", (user_id,))
result = cursor.fetchone()
if not result:
return
return self._map_db_result_to_user(result)
def get_user_by_mail(self, user_mail: str) -> Optional[User]:
cursor = self._get_cursor()
cursor.execute("SELECT * FROM users WHERE user_mail=?", (user_mail.lower(),))
result = cursor.fetchone()
if not result:
return
return self._map_db_result_to_user(result)
def create_user(self, user_name: str, user_mail: str, password_hash: str) -> User:
cursor = self._get_cursor()
try:
cursor.execute(
"INSERT INTO users (user_name, user_mail, user_password) "
"VALUES (?, ?, ?)", (user_name, user_mail.lower(), password_hash)
)
self._connection.commit()
except mariadb.IntegrityError as e:
logger.warning(f"Aborted duplication entry: {e}")
raise DuplicationError
return self.get_user_by_name(user_name)
def update_user(self, user: User) -> User:
cursor = self._get_cursor()
try:
cursor.execute(
"UPDATE users SET user_name=?, user_mail=?, user_password=?, user_first_name=?, user_last_name=?, user_birth_date=?, "
"is_active=?, is_team_member=?, is_admin=? WHERE (user_id=?)", (user.user_name, user.user_mail.lower(), user.user_password,
user.user_first_name, user.user_last_name, user.user_birth_day,
user.is_active, user.is_team_member, user.is_admin,
user.user_id)
)
self._connection.commit()
except mariadb.IntegrityError as e:
logger.warning(f"Aborted duplication entry: {e}")
raise DuplicationError
return user
def add_transaction(self, transaction: Transaction) -> Optional[Transaction]:
cursor = self._get_cursor()
try:
cursor.execute(
"INSERT INTO transactions (user_id, value, is_debit, transaction_date, transaction_reference) "
"VALUES (?, ?, ?, ?, ?)",
(transaction.user_id, transaction.value, transaction.is_debit, transaction.transaction_date, transaction.reference)
)
self._connection.commit()
except Exception as e:
logger.warning(f"Error adding Transaction: {e}")
return
return transaction
def get_all_transactions_for_user(self, user_id: int) -> list[Transaction]:
transactions = []
cursor = self._get_cursor()
try:
cursor.execute("SELECT * FROM transactions WHERE user_id=?", (user_id,))
result = cursor.fetchall()
except mariadb.Error as e:
logger.error(f"Error getting all transactions for user: {e}")
return []
for transaction_raw in result:
transactions.append(Transaction(
user_id=user_id,
value=int(transaction_raw[2]),
is_debit=bool(transaction_raw[3]),
transaction_date=transaction_raw[4],
reference=transaction_raw[5]
))
return transactions