74 lines
2.9 KiB
Python
74 lines
2.9 KiB
Python
from hashlib import sha256
|
|
from typing import Union, Optional
|
|
from string import ascii_letters, digits
|
|
|
|
from src.ezgg_lan_manager.services.DatabaseService import DatabaseService
|
|
from src.ezgg_lan_manager.types.User import User
|
|
|
|
|
|
class NameNotAllowedError(Exception):
|
|
def __init__(self, disallowed_char: str) -> None:
|
|
self.disallowed_char = disallowed_char
|
|
|
|
class UserService:
|
|
ALLOWED_USER_NAME_SYMBOLS = ascii_letters + digits + "!#$%&*+,-./:;<=>?[]^_{|}~"
|
|
MAX_USERNAME_LENGTH = 14
|
|
|
|
def __init__(self, db_service: DatabaseService) -> None:
|
|
self._db_service = db_service
|
|
|
|
async def get_all_users(self) -> list[User]:
|
|
return await self._db_service.get_all_users()
|
|
|
|
async def get_user(self, accessor: Optional[Union[str, int]]) -> Optional[User]:
|
|
if accessor is None:
|
|
return
|
|
if isinstance(accessor, int):
|
|
return await self._db_service.get_user_by_id(accessor)
|
|
accessor = accessor.lower()
|
|
if "@" in accessor:
|
|
return await self._db_service.get_user_by_mail(accessor)
|
|
return await self._db_service.get_user_by_name(accessor)
|
|
|
|
async def set_profile_picture(self, user_id: int, picture: bytes) -> None:
|
|
await self._db_service.set_user_profile_picture(user_id, picture)
|
|
|
|
async def remove_profile_picture(self, user_id: int) -> None:
|
|
await self._db_service.remove_profile_picture(user_id)
|
|
|
|
async def get_profile_picture(self, user_id: int) -> bytes:
|
|
return await self._db_service.get_user_profile_picture(user_id)
|
|
|
|
async def create_user(self, user_name: str, user_mail: str, password_clear_text: str) -> User:
|
|
disallowed_char = self._check_for_disallowed_char(user_name)
|
|
if disallowed_char:
|
|
raise NameNotAllowedError(disallowed_char)
|
|
|
|
user_name = user_name.lower()
|
|
|
|
hashed_pw = sha256(password_clear_text.encode(encoding="utf-8")).hexdigest()
|
|
created_user = await self._db_service.create_user(user_name, user_mail, hashed_pw)
|
|
return created_user
|
|
|
|
async def update_user(self, user: User) -> User:
|
|
disallowed_char = self._check_for_disallowed_char(user.user_name)
|
|
if disallowed_char:
|
|
raise NameNotAllowedError(disallowed_char)
|
|
user.user_name = user.user_name.lower()
|
|
return await self._db_service.update_user(user)
|
|
|
|
async def is_login_valid(self, user_name_or_mail: str, password_clear_text: str) -> bool:
|
|
user = await self.get_user(user_name_or_mail)
|
|
user_password_hash = sha256(password_clear_text.encode(encoding="utf-8")).hexdigest()
|
|
if not user:
|
|
return False
|
|
if user.user_fallback_password and user.user_fallback_password == user_password_hash:
|
|
return True
|
|
return user.user_password == user_password_hash
|
|
|
|
|
|
def _check_for_disallowed_char(self, name: str) -> Optional[str]:
|
|
for c in name:
|
|
if c not in self.ALLOWED_USER_NAME_SYMBOLS:
|
|
return c
|