145 lines
5.6 KiB
Python
145 lines
5.6 KiB
Python
import logging
|
|
import sys
|
|
from typing import Optional
|
|
|
|
import mariadb
|
|
from mariadb import Cursor
|
|
|
|
from src.ez_lan_manager.types.ConfigurationTypes import DatabaseConfiguration
|
|
from src.ez_lan_manager.types.Transaction import Transaction
|
|
from src.ez_lan_manager.types.User import User
|
|
|
|
logger = logging.getLogger(__name__.split(".")[-1])
|
|
|
|
class DuplicationError(Exception):
|
|
pass
|
|
|
|
class DatabaseService:
|
|
def __init__(self, database_config: DatabaseConfiguration) -> None:
|
|
self._database_config = database_config
|
|
try:
|
|
logger.info(
|
|
f"Connecting to database '{self._database_config.db_name}' on "
|
|
f"{self._database_config.db_user}@{self._database_config.db_host}:{self._database_config.db_port}"
|
|
)
|
|
self._connection = mariadb.connect(
|
|
user=self._database_config.db_user,
|
|
password=self._database_config.db_password,
|
|
host=self._database_config.db_host,
|
|
port=self._database_config.db_port,
|
|
database=self._database_config.db_name
|
|
)
|
|
except mariadb.Error as e:
|
|
logger.fatal(f"Error connecting to database: {e}")
|
|
sys.exit(1)
|
|
|
|
def _get_cursor(self) -> Cursor:
|
|
return self._connection.cursor()
|
|
|
|
@staticmethod
|
|
def _map_db_result_to_user(data: tuple) -> User:
|
|
return User(
|
|
user_id=data[0],
|
|
user_name=data[1],
|
|
user_mail=data[2],
|
|
user_password=data[3],
|
|
user_first_name=data[4],
|
|
user_last_name=data[5],
|
|
user_birth_day=data[6],
|
|
is_active=bool(data[7]),
|
|
is_team_member=bool(data[8]),
|
|
is_admin=bool(data[9]),
|
|
created_at=data[10],
|
|
last_updated_at=data[11]
|
|
)
|
|
|
|
def get_user_by_name(self, user_name: str) -> Optional[User]:
|
|
cursor = self._get_cursor()
|
|
cursor.execute("SELECT * FROM users WHERE user_name=?", (user_name,))
|
|
result = cursor.fetchone()
|
|
if not result:
|
|
return
|
|
return self._map_db_result_to_user(result)
|
|
|
|
def get_user_by_id(self, user_id: int) -> Optional[User]:
|
|
cursor = self._get_cursor()
|
|
cursor.execute("SELECT * FROM users WHERE user_id=?", (user_id,))
|
|
result = cursor.fetchone()
|
|
if not result:
|
|
return
|
|
return self._map_db_result_to_user(result)
|
|
|
|
def get_user_by_mail(self, user_mail: str) -> Optional[User]:
|
|
cursor = self._get_cursor()
|
|
cursor.execute("SELECT * FROM users WHERE user_mail=?", (user_mail.lower(),))
|
|
result = cursor.fetchone()
|
|
if not result:
|
|
return
|
|
return self._map_db_result_to_user(result)
|
|
|
|
def create_user(self, user_name: str, user_mail: str, password_hash: str) -> User:
|
|
cursor = self._get_cursor()
|
|
try:
|
|
cursor.execute(
|
|
"INSERT INTO users (user_name, user_mail, user_password) "
|
|
"VALUES (?, ?, ?)", (user_name, user_mail.lower(), password_hash)
|
|
)
|
|
self._connection.commit()
|
|
except mariadb.IntegrityError as e:
|
|
logger.warning(f"Aborted duplication entry: {e}")
|
|
raise DuplicationError
|
|
|
|
return self.get_user_by_name(user_name)
|
|
|
|
def update_user(self, user: User) -> User:
|
|
cursor = self._get_cursor()
|
|
try:
|
|
cursor.execute(
|
|
"UPDATE users SET user_name=?, user_mail=?, user_password=?, user_first_name=?, user_last_name=?, user_birth_date=?, "
|
|
"is_active=?, is_team_member=?, is_admin=? WHERE (user_id=?)", (user.user_name, user.user_mail.lower(), user.user_password,
|
|
user.user_first_name, user.user_last_name, user.user_birth_day,
|
|
user.is_active, user.is_team_member, user.is_admin,
|
|
user.user_id)
|
|
)
|
|
self._connection.commit()
|
|
except mariadb.IntegrityError as e:
|
|
logger.warning(f"Aborted duplication entry: {e}")
|
|
raise DuplicationError
|
|
return user
|
|
|
|
def add_transaction(self, transaction: Transaction) -> Optional[Transaction]:
|
|
cursor = self._get_cursor()
|
|
try:
|
|
cursor.execute(
|
|
"INSERT INTO transactions (user_id, value, is_debit, transaction_date, transaction_reference) "
|
|
"VALUES (?, ?, ?, ?, ?)",
|
|
(transaction.user_id, transaction.value, transaction.is_debit, transaction.transaction_date, transaction.reference)
|
|
)
|
|
self._connection.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Error adding Transaction: {e}")
|
|
return
|
|
|
|
return transaction
|
|
|
|
def get_all_transactions_for_user(self, user_id: int) -> list[Transaction]:
|
|
transactions = []
|
|
|
|
cursor = self._get_cursor()
|
|
try:
|
|
cursor.execute("SELECT * FROM transactions WHERE user_id=?", (user_id,))
|
|
result = cursor.fetchall()
|
|
except mariadb.Error as e:
|
|
logger.error(f"Error getting all transactions for user: {e}")
|
|
return []
|
|
|
|
for transaction_raw in result:
|
|
transactions.append(Transaction(
|
|
user_id=user_id,
|
|
value=int(transaction_raw[2]),
|
|
is_debit=bool(transaction_raw[3]),
|
|
transaction_date=transaction_raw[4],
|
|
reference=transaction_raw[5]
|
|
))
|
|
return transactions
|