diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py index e8983797..2b4e4334 100644 --- a/app/auth/views/proton.py +++ b/app/auth/views/proton.py @@ -14,7 +14,11 @@ from app.config import ( URL, ) from app.proton.proton_client import HttpProtonClient, convert_access_token -from app.proton.proton_callback_handler import ProtonCallbackHandler, Action +from app.proton.proton_callback_handler import ( + ProtonCallbackHandler, + Action, + get_proton_partner, +) from app.utils import sanitize_next_url _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" @@ -100,11 +104,12 @@ def proton_callback(): PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS ) handler = ProtonCallbackHandler(proton_client) + proton_partner = get_proton_partner() if action == Action.Login: - res = handler.handle_login() + res = handler.handle_login(proton_partner) elif action == Action.Link: - res = handler.handle_link(current_user) + res = handler.handle_link(current_user, proton_partner) else: raise Exception(f"Unknown Action: {action.name}") diff --git a/app/proton/proton_callback_handler.py b/app/proton/proton_callback_handler.py index e444ed16..c12130d8 100644 --- a/app/proton/proton_callback_handler.py +++ b/app/proton/proton_callback_handler.py @@ -12,14 +12,22 @@ from app.utils import random_string PROTON_PARTNER_NAME = "Proton" _PROTON_PARTNER_ID: Optional[int] = None +_PROTON_PARTNER: Optional[Partner] = None + + +def get_proton_partner() -> Partner: + global _PROTON_PARTNER + if _PROTON_PARTNER is None: + partner = Partner.get_by(name=PROTON_PARTNER_NAME) + if partner is None: + raise ProtonPartnerNotSetUp + return _PROTON_PARTNER def get_proton_partner_id() -> int: global _PROTON_PARTNER_ID if _PROTON_PARTNER_ID is None: - partner = Partner.get_by(name=PROTON_PARTNER_NAME) - if partner is None: - raise ProtonPartnerNotSetUp + partner = get_proton_partner() _PROTON_PARTNER_ID = partner.id return _PROTON_PARTNER_ID @@ -39,23 +47,27 @@ class ProtonCallbackResult: user: Optional[User] -def ensure_partner_user_exists(proton_user: ProtonUser, sl_user: User): - proton_partner_id = get_proton_partner_id() - if not PartnerUser.get_by(user_id=sl_user.id, partner_id=proton_partner_id): +def ensure_partner_user_exists( + proton_user: ProtonUser, sl_user: User, partner: Partner +): + if not PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id): PartnerUser.create( user_id=sl_user.id, - partner_id=proton_partner_id, + partner_id=partner.id, partner_email=proton_user.email, ) Session.commit() class ClientMergeStrategy(ABC): - def __init__(self, proton_user: ProtonUser, sl_user: Optional[User]): + def __init__( + self, proton_user: ProtonUser, sl_user: Optional[User], partner: Partner + ): if self.__class__ == ClientMergeStrategy: raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy") self.proton_user = proton_user self.sl_user = sl_user + self.partner = partner @abstractmethod def process(self) -> ProtonCallbackResult: @@ -65,17 +77,16 @@ class ClientMergeStrategy(ABC): class UnexistantSlClientStrategy(ClientMergeStrategy): def process(self) -> ProtonCallbackResult: # Will create a new SL User with a random password - proton_partner_id = get_proton_partner_id() new_user = User.create( email=self.proton_user.email, name=self.proton_user.name, partner_user_id=self.proton_user.id, - partner_id=proton_partner_id, + partner_id=self.partner.id, password=random_string(20), ) PartnerUser.create( user_id=new_user.id, - partner_id=proton_partner_id, + partner_id=self.partner.id, partner_email=self.proton_user.email, ) # TODO: Adjust plans @@ -92,7 +103,7 @@ class UnexistantSlClientStrategy(ClientMergeStrategy): class ExistingSlClientStrategy(ClientMergeStrategy): def process(self) -> ProtonCallbackResult: - ensure_partner_user_exists(self.proton_user, self.sl_user) + ensure_partner_user_exists(self.proton_user, self.sl_user, self.partner) # TODO: Adjust plans return ProtonCallbackResult( @@ -127,52 +138,54 @@ class AlreadyLinkedUserStrategy(ClientMergeStrategy): def get_login_strategy( - proton_user: ProtonUser, sl_user: Optional[User] + proton_user: ProtonUser, sl_user: Optional[User], partner: Partner ) -> ClientMergeStrategy: if sl_user is None: # We couldn't find any SimpleLogin user with the requested e-mail - return UnexistantSlClientStrategy(proton_user, sl_user) + return UnexistantSlClientStrategy(proton_user, sl_user, partner) # There is a SimpleLogin user with the proton_user's e-mail # Try to find if it has been registered via a partner if sl_user.partner_id is None: # It has not been registered via a Partner - return ExistingSlClientStrategy(proton_user, sl_user) + return ExistingSlClientStrategy(proton_user, sl_user, partner) # It has been registered via a partner # Check if the partner_user_id matches if sl_user.partner_user_id != proton_user.id: # It doesn't match. That means that the SimpleLogin user has a different Proton account linked return ExistingSlUserLinkedWithDifferentProtonAccountStrategy( - proton_user, sl_user + proton_user, sl_user, partner ) # This case means that the sl_user is already linked, so nothing to do - return AlreadyLinkedUserStrategy(proton_user, sl_user) + return AlreadyLinkedUserStrategy(proton_user, sl_user, partner) -def process_login_case(proton_user: ProtonUser) -> ProtonCallbackResult: +def process_login_case( + proton_user: ProtonUser, partner: Partner +) -> ProtonCallbackResult: # Try to find a SimpleLogin user registered with that proton user id - proton_partner_id = get_proton_partner_id() sl_user_with_external_id = User.get_by( - partner_id=proton_partner_id, partner_user_id=proton_user.id + partner_id=partner.id, partner_user_id=proton_user.id ) if sl_user_with_external_id is None: # We didn't find any SimpleLogin user registered with that proton user id # Try to find it using the proton's e-mail address sl_user = User.get_by(email=proton_user.email) - return get_login_strategy(proton_user, sl_user).process() + return get_login_strategy(proton_user, sl_user, partner).process() else: # We found the SL user registered with that proton user id # We're done return AlreadyLinkedUserStrategy( - proton_user, sl_user_with_external_id + proton_user, sl_user_with_external_id, partner ).process() -def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResult: - proton_partner_id = get_proton_partner_id() +def link_user( + proton_user: ProtonUser, current_user: User, partner: Partner +) -> ProtonCallbackResult: current_user.partner_user_id = proton_user.id - current_user.partner_id = proton_partner_id + current_user.partner_id = partner.id - ensure_partner_user_exists(proton_user, current_user) + ensure_partner_user_exists(proton_user, current_user, partner) Session.commit() return ProtonCallbackResult( @@ -185,16 +198,17 @@ def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResu def process_link_case( - proton_user: ProtonUser, current_user: User + proton_user: ProtonUser, + current_user: User, + partner: Partner, ) -> ProtonCallbackResult: # Try to find a SimpleLogin user linked with this Proton account - proton_partner_id = get_proton_partner_id() sl_user_linked_to_proton_account = User.get_by( - partner_id=proton_partner_id, partner_user_id=proton_user.id + partner_id=partner.id, partner_user_id=proton_user.id ) if sl_user_linked_to_proton_account is None: # There is no SL user linked with the proton email. Proceed with linking - return link_user(proton_user, current_user) + return link_user(proton_user, current_user, partner) else: # There is a SL user registered with the proton email. Check if is the current one if sl_user_linked_to_proton_account.id == current_user.id: @@ -212,25 +226,27 @@ def process_link_case( sl_user_linked_to_proton_account.partner_user_id = None other_partner_user = PartnerUser.get_by( user_id=sl_user_linked_to_proton_account.id, - partner_id=proton_partner_id, + partner_id=partner.id, ) if other_partner_user is not None: PartnerUser.delete(other_partner_user.id) - return link_user(proton_user, current_user) + return link_user(proton_user, current_user, partner) class ProtonCallbackHandler: def __init__(self, proton_client: ProtonClient): self.proton_client = proton_client - def handle_login(self) -> ProtonCallbackResult: - return process_login_case(self.__get_proton_user()) + def handle_login(self, partner: Partner) -> ProtonCallbackResult: + return process_login_case(self.__get_proton_user(), partner) - def handle_link(self, current_user: Optional[User]) -> ProtonCallbackResult: + def handle_link( + self, current_user: Optional[User], partner: Partner + ) -> ProtonCallbackResult: if current_user is None: raise Exception("Cannot link account with current_user being None") - return process_link_case(self.__get_proton_user(), current_user) + return process_link_case(self.__get_proton_user(), current_user, partner) def __get_proton_user(self) -> ProtonUser: user = self.proton_client.get_user() diff --git a/tests/proton/test_proton_callback_handler.py b/tests/proton/test_proton_callback_handler.py index a5278b83..98e3fc66 100644 --- a/tests/proton/test_proton_callback_handler.py +++ b/tests/proton/test_proton_callback_handler.py @@ -4,6 +4,7 @@ from app.db import Session from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan from app.proton.proton_callback_handler import ( ProtonCallbackHandler, + get_proton_partner, get_proton_partner_id, get_login_strategy, process_link_case, @@ -84,7 +85,7 @@ def test_proton_callback_handler_unexistant_sl_user(): user=user, plan=ProtonPlan.Professional, organization={} ) handler = ProtonCallbackHandler(mock_client) - res = handler.handle_login() + res = handler.handle_login(get_proton_partner()) assert res.user is not None assert res.user.email == email @@ -102,7 +103,7 @@ def test_proton_callback_handler_existant_sl_user(): user=user, plan=ProtonPlan.Professional, organization={} ) handler = ProtonCallbackHandler(mock_client) - res = handler.handle_login() + res = handler.handle_login(get_proton_partner()) assert res.user is not None assert res.user.id == sl_user.id @@ -116,6 +117,7 @@ def test_get_strategy_unexistant_sl_user(): strategy = get_login_strategy( proton_user=random_proton_user(), sl_user=None, + partner=get_proton_partner(), ) assert isinstance(strategy, UnexistantSlClientStrategy) @@ -126,6 +128,7 @@ def test_get_strategy_existing_sl_user(): strategy = get_login_strategy( proton_user=random_proton_user(email=email), sl_user=sl_user, + partner=get_proton_partner(), ) assert isinstance(strategy, ExistingSlClientStrategy) @@ -137,6 +140,7 @@ def test_get_strategy_already_linked_user(): strategy = get_login_strategy( proton_user=random_proton_user(user_id=proton_user_id, email=email), sl_user=sl_user, + partner=get_proton_partner(), ) assert isinstance(strategy, AlreadyLinkedUserStrategy) @@ -159,6 +163,7 @@ def test_get_strategy_existing_sl_user_linked_with_different_proton_account(): strategy = get_login_strategy( proton_user=proton_user_1, sl_user=sl_user, + partner=get_proton_partner(), ) assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy) @@ -179,7 +184,7 @@ def test_link_account_with_proton_account_same_address(flask_client): proton_user = random_proton_user(user_id=proton_user_id, email=email) sl_user = create_user(email) - res = process_link_case(proton_user, sl_user) + res = process_link_case(proton_user, sl_user, get_proton_partner()) assert res.redirect_to_login is False assert res.redirect is not None assert res.flash_category == "success" @@ -199,7 +204,7 @@ def test_link_account_with_proton_account_different_address(flask_client): proton_user = random_proton_user(user_id=proton_user_id, email=random_email()) sl_user = create_user() - res = process_link_case(proton_user, sl_user) + res = process_link_case(proton_user, sl_user, get_proton_partner()) assert res.redirect_to_login is False assert res.redirect is not None assert res.flash_category == "success" @@ -226,7 +231,7 @@ def test_link_account_with_proton_account_same_address_but_linked_to_other_user( proton_user_id, email=random_email() ) # User already linked with the proton account - res = process_link_case(proton_user, sl_user_1) + res = process_link_case(proton_user, sl_user_1, get_proton_partner()) assert res.redirect_to_login is False assert res.redirect is not None assert res.flash_category == "success" @@ -256,7 +261,7 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_ proton_user_id, email=random_email() ) # User already linked with the proton account - res = process_link_case(proton_user, sl_user_1) + res = process_link_case(proton_user, sl_user_1, get_proton_partner()) assert res.redirect_to_login is False assert res.redirect is not None assert res.flash_category == "success" @@ -282,4 +287,4 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_ def test_cannot_create_instance_of_base_strategy(): with pytest.raises(Exception): - ClientMergeStrategy(random_proton_user(), None) + ClientMergeStrategy(random_proton_user(), None, get_proton_partner())