from hashlib import sha256 from typing import Union, Optional from string import ascii_letters, digits from src.ez_lan_manager.services.DatabaseService import DatabaseService from src.ez_lan_manager.types.User import User class NameNotAllowedError(Exception): def __init__(self, disallowed_char: str) -> None: self.disallowed_char = disallowed_char class UserService: ALLOWED_USER_NAME_SYMBOLS = ascii_letters + digits + "!#$%&*+,-./:;<=>?[]^_{|}~" def __init__(self, db_service: DatabaseService) -> None: self._db_service = db_service def get_user(self, accessor: Optional[Union[str, int]]) -> Optional[User]: if accessor is None: return if isinstance(accessor, int): return self._db_service.get_user_by_id(accessor) accessor = accessor.lower() if "@" in accessor: return self._db_service.get_user_by_mail(accessor) return self._db_service.get_user_by_name(accessor) def set_profile_picture(self, user_id: int, picture: bytes) -> None: self._db_service.set_user_profile_picture(user_id, picture) def get_profile_picture(self, user_id: int) -> bytes: return self._db_service.get_user_profile_picture(user_id) def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User: disallowed_char = self._check_for_disallowed_char(user_name) if disallowed_char: raise NameNotAllowedError(disallowed_char) user_name = user_name.lower() hashed_pw = sha256(password_clear_text.encode(encoding="utf-8")).hexdigest() return self._db_service.create_user(user_name, user_mail, hashed_pw) def update_user(self, user: User) -> User: disallowed_char = self._check_for_disallowed_char(user.user_name) if disallowed_char: raise NameNotAllowedError(disallowed_char) user.user_name = user.user_name.lower() return self._db_service.update_user(user) def is_login_valid(self, user_name_or_mail: str, password_clear_text: str) -> bool: user = self.get_user(user_name_or_mail) if not user: return False return user.user_password == sha256(password_clear_text.encode(encoding="utf-8")).hexdigest() def _check_for_disallowed_char(self, name: str) -> Optional[str]: for c in name: if c not in self.ALLOWED_USER_NAME_SYMBOLS: return c