From 8b87d78d5d06420f9e2f51331c8af42f985fde1a Mon Sep 17 00:00:00 2001 From: David Rodenkirchen Date: Mon, 19 Aug 2024 10:38:57 +0200 Subject: [PATCH] implement UserService --- src/EzLanManager.py | 14 +++++- .../services/ConfigurationService.py | 2 +- .../services/DatabaseService.py | 24 +++++++++-- src/ez_lan_manager/services/UserService.py | 43 +++++++++++++++++++ 4 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 src/ez_lan_manager/services/UserService.py diff --git a/src/EzLanManager.py b/src/EzLanManager.py index fb285e1..fa14888 100644 --- a/src/EzLanManager.py +++ b/src/EzLanManager.py @@ -1,4 +1,5 @@ import logging +from datetime import datetime from from_root import from_root @@ -7,6 +8,9 @@ from src.ez_lan_manager.services.DatabaseService import DatabaseService from random import randint +from src.ez_lan_manager.services.UserService import UserService +from src.ez_lan_manager.types.User import User + logger = logging.getLogger(__name__.split(".")[-1]) if __name__ == "__main__": @@ -14,4 +18,12 @@ if __name__ == "__main__": configuration_service = ConfigurationService(from_root("config.toml")) db_config = configuration_service.get_database_configuration() db_service = DatabaseService(db_config) - print(db_service.create_user(f"TestUser{randint(0, 9999)}", f"TestMail{randint(0, 9999)}", "pw123")) + user_service = UserService(db_service) + user_service.create_user("Mamfred", "Peter@peterson.com", "MamaHalloDoo") + # print(db_service.create_user(f"TestUser{randint(0, 9999)}", f"TestMail{randint(0, 9999)}", "pw123")) + # print(db_service.update_user( + # User(user_id=19, user_name='TestUser838', user_mail='TestMail3142', user_password='pw123', user_first_name=None, user_last_name=None, + # user_birth_day=None, is_active=False, is_team_member=False, is_admin=False, created_at=datetime(2024, 8, 19, 10, 10, 39), + # last_updated_at=datetime(2024, 8, 19, 10, 10, 39), balance=0) + # + # )) \ No newline at end of file diff --git a/src/ez_lan_manager/services/ConfigurationService.py b/src/ez_lan_manager/services/ConfigurationService.py index bb261a3..44a6788 100644 --- a/src/ez_lan_manager/services/ConfigurationService.py +++ b/src/ez_lan_manager/services/ConfigurationService.py @@ -8,7 +8,7 @@ from src.ez_lan_manager.types.ConfigurationTypes import DatabaseConfiguration logger = logging.getLogger(__name__.split(".")[-1]) class ConfigurationService: - def __init__(self, config_file_path: Path): + def __init__(self, config_file_path: Path) -> None: try: with open(config_file_path, "rb") as config_file: self._config = tomllib.load(config_file) diff --git a/src/ez_lan_manager/services/DatabaseService.py b/src/ez_lan_manager/services/DatabaseService.py index c58fb00..bfd12e5 100644 --- a/src/ez_lan_manager/services/DatabaseService.py +++ b/src/ez_lan_manager/services/DatabaseService.py @@ -14,7 +14,7 @@ class DuplicationError(Exception): pass class DatabaseService: - def __init__(self, database_config: DatabaseConfiguration): + def __init__(self, database_config: DatabaseConfiguration) -> None: self._database_config = database_config try: logger.info( @@ -69,9 +69,9 @@ class DatabaseService: return return self._map_db_result_to_user(result) - def get_user_by_main(self, user_mail: str) -> Optional[User]: + def get_user_by_mail(self, user_mail: str) -> Optional[User]: cursor = self._get_cursor() - cursor.execute("SELECT * FROM users WHERE user_mail=?", (user_mail,)) + cursor.execute("SELECT * FROM users WHERE user_mail=?", (user_mail.lower(),)) result = cursor.fetchone() if not result: return @@ -82,7 +82,7 @@ class DatabaseService: try: cursor.execute( "INSERT INTO users (user_name, user_mail, user_password) " - "VALUES (?, ?, ?)", (user_name, user_mail, password_hash) + "VALUES (?, ?, ?)", (user_name, user_mail.lower(), password_hash) ) self._connection.commit() except mariadb.IntegrityError as e: @@ -90,3 +90,19 @@ class DatabaseService: raise DuplicationError return self.get_user_by_name(user_name) + + def update_user(self, user: User) -> User: + cursor = self._get_cursor() + try: + cursor.execute( + "UPDATE users SET user_name=?, user_mail=?, user_password=?, user_first_name=?, user_last_name=?, user_birth_date=?, " + "is_active=?, is_team_member=?, is_admin=?, balance=? WHERE (user_id=?)", (user.user_name, user.user_mail.lower(), user.user_password, + user.user_first_name, user.user_last_name, user.user_birth_day, + user.is_active, user.is_team_member, user.is_admin, + user.balance, user.user_id) + ) + self._connection.commit() + except mariadb.IntegrityError as e: + logger.warning(f"Aborted duplication entry: {e}") + raise DuplicationError + return user diff --git a/src/ez_lan_manager/services/UserService.py b/src/ez_lan_manager/services/UserService.py new file mode 100644 index 0000000..c1e056f --- /dev/null +++ b/src/ez_lan_manager/services/UserService.py @@ -0,0 +1,43 @@ +from hashlib import sha256 +from typing import Union, Optional +from string import ascii_letters, digits + +from src.ez_lan_manager.services.DatabaseService import DatabaseService +from src.ez_lan_manager.types.User import User + + +class NameNotAllowedError(Exception): + def __init__(self, disallowed_char: str) -> None: + self.disallowed_char = disallowed_char + +class UserService: + ALLOWED_USER_NAME_SYMBOLS = ascii_letters + digits + "!#$%&*+,-./:;<=>?[]^_{|}~" + + def __init__(self, db_service: DatabaseService) -> None: + self._db_service = db_service + + def get_user(self, accessor: Union[str, int]) -> User: + if isinstance(accessor, int): + return self._db_service.get_user_by_id(accessor) + if "@" in accessor: + return self._db_service.get_user_by_mail(accessor) + return self._db_service.get_user_by_name(accessor) + + def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User: + disallowed_char = self._check_for_disallowed_char(user_name) + if disallowed_char: + raise NameNotAllowedError(disallowed_char) + + hashed_pw = sha256(password_clear_text.encode(encoding="utf-8")).hexdigest() + return self._db_service.create_user(user_name, user_mail, hashed_pw) + + def update_user(self, user: User) -> User: + disallowed_char = self._check_for_disallowed_char(user.user_name) + if disallowed_char: + raise NameNotAllowedError(disallowed_char) + return self._db_service.update_user(user) + + def _check_for_disallowed_char(self, name: str) -> Optional[str]: + for c in name: + if c not in self.ALLOWED_USER_NAME_SYMBOLS: + return c