Receive partner as param in ProtonCallbackHandler

This commit is contained in:
Carlos Quintana 2022-05-23 16:10:24 +02:00
parent 0dfc6c0b0d
commit ed9d2ed816
No known key found for this signature in database
GPG key ID: 15E73DCC410679F8
3 changed files with 72 additions and 46 deletions

View file

@ -14,7 +14,11 @@ from app.config import (
URL, URL,
) )
from app.proton.proton_client import HttpProtonClient, convert_access_token from app.proton.proton_client import HttpProtonClient, convert_access_token
from app.proton.proton_callback_handler import ProtonCallbackHandler, Action from app.proton.proton_callback_handler import (
ProtonCallbackHandler,
Action,
get_proton_partner,
)
from app.utils import sanitize_next_url from app.utils import sanitize_next_url
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
@ -100,11 +104,12 @@ def proton_callback():
PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS
) )
handler = ProtonCallbackHandler(proton_client) handler = ProtonCallbackHandler(proton_client)
proton_partner = get_proton_partner()
if action == Action.Login: if action == Action.Login:
res = handler.handle_login() res = handler.handle_login(proton_partner)
elif action == Action.Link: elif action == Action.Link:
res = handler.handle_link(current_user) res = handler.handle_link(current_user, proton_partner)
else: else:
raise Exception(f"Unknown Action: {action.name}") raise Exception(f"Unknown Action: {action.name}")

View file

@ -12,14 +12,22 @@ from app.utils import random_string
PROTON_PARTNER_NAME = "Proton" PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER_ID: Optional[int] = None _PROTON_PARTNER_ID: Optional[int] = None
_PROTON_PARTNER: Optional[Partner] = None
def get_proton_partner() -> Partner:
global _PROTON_PARTNER
if _PROTON_PARTNER is None:
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
if partner is None:
raise ProtonPartnerNotSetUp
return _PROTON_PARTNER
def get_proton_partner_id() -> int: def get_proton_partner_id() -> int:
global _PROTON_PARTNER_ID global _PROTON_PARTNER_ID
if _PROTON_PARTNER_ID is None: if _PROTON_PARTNER_ID is None:
partner = Partner.get_by(name=PROTON_PARTNER_NAME) partner = get_proton_partner()
if partner is None:
raise ProtonPartnerNotSetUp
_PROTON_PARTNER_ID = partner.id _PROTON_PARTNER_ID = partner.id
return _PROTON_PARTNER_ID return _PROTON_PARTNER_ID
@ -39,23 +47,27 @@ class ProtonCallbackResult:
user: Optional[User] user: Optional[User]
def ensure_partner_user_exists(proton_user: ProtonUser, sl_user: User): def ensure_partner_user_exists(
proton_partner_id = get_proton_partner_id() proton_user: ProtonUser, sl_user: User, partner: Partner
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=proton_partner_id): ):
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id):
PartnerUser.create( PartnerUser.create(
user_id=sl_user.id, user_id=sl_user.id,
partner_id=proton_partner_id, partner_id=partner.id,
partner_email=proton_user.email, partner_email=proton_user.email,
) )
Session.commit() Session.commit()
class ClientMergeStrategy(ABC): class ClientMergeStrategy(ABC):
def __init__(self, proton_user: ProtonUser, sl_user: Optional[User]): def __init__(
self, proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
):
if self.__class__ == ClientMergeStrategy: if self.__class__ == ClientMergeStrategy:
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy") raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
self.proton_user = proton_user self.proton_user = proton_user
self.sl_user = sl_user self.sl_user = sl_user
self.partner = partner
@abstractmethod @abstractmethod
def process(self) -> ProtonCallbackResult: def process(self) -> ProtonCallbackResult:
@ -65,17 +77,16 @@ class ClientMergeStrategy(ABC):
class UnexistantSlClientStrategy(ClientMergeStrategy): class UnexistantSlClientStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult: def process(self) -> ProtonCallbackResult:
# Will create a new SL User with a random password # Will create a new SL User with a random password
proton_partner_id = get_proton_partner_id()
new_user = User.create( new_user = User.create(
email=self.proton_user.email, email=self.proton_user.email,
name=self.proton_user.name, name=self.proton_user.name,
partner_user_id=self.proton_user.id, partner_user_id=self.proton_user.id,
partner_id=proton_partner_id, partner_id=self.partner.id,
password=random_string(20), password=random_string(20),
) )
PartnerUser.create( PartnerUser.create(
user_id=new_user.id, user_id=new_user.id,
partner_id=proton_partner_id, partner_id=self.partner.id,
partner_email=self.proton_user.email, partner_email=self.proton_user.email,
) )
# TODO: Adjust plans # TODO: Adjust plans
@ -92,7 +103,7 @@ class UnexistantSlClientStrategy(ClientMergeStrategy):
class ExistingSlClientStrategy(ClientMergeStrategy): class ExistingSlClientStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult: def process(self) -> ProtonCallbackResult:
ensure_partner_user_exists(self.proton_user, self.sl_user) ensure_partner_user_exists(self.proton_user, self.sl_user, self.partner)
# TODO: Adjust plans # TODO: Adjust plans
return ProtonCallbackResult( return ProtonCallbackResult(
@ -127,52 +138,54 @@ class AlreadyLinkedUserStrategy(ClientMergeStrategy):
def get_login_strategy( def get_login_strategy(
proton_user: ProtonUser, sl_user: Optional[User] proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
) -> ClientMergeStrategy: ) -> ClientMergeStrategy:
if sl_user is None: if sl_user is None:
# We couldn't find any SimpleLogin user with the requested e-mail # We couldn't find any SimpleLogin user with the requested e-mail
return UnexistantSlClientStrategy(proton_user, sl_user) return UnexistantSlClientStrategy(proton_user, sl_user, partner)
# There is a SimpleLogin user with the proton_user's e-mail # There is a SimpleLogin user with the proton_user's e-mail
# Try to find if it has been registered via a partner # Try to find if it has been registered via a partner
if sl_user.partner_id is None: if sl_user.partner_id is None:
# It has not been registered via a Partner # It has not been registered via a Partner
return ExistingSlClientStrategy(proton_user, sl_user) return ExistingSlClientStrategy(proton_user, sl_user, partner)
# It has been registered via a partner # It has been registered via a partner
# Check if the partner_user_id matches # Check if the partner_user_id matches
if sl_user.partner_user_id != proton_user.id: if sl_user.partner_user_id != proton_user.id:
# It doesn't match. That means that the SimpleLogin user has a different Proton account linked # It doesn't match. That means that the SimpleLogin user has a different Proton account linked
return ExistingSlUserLinkedWithDifferentProtonAccountStrategy( return ExistingSlUserLinkedWithDifferentProtonAccountStrategy(
proton_user, sl_user proton_user, sl_user, partner
) )
# This case means that the sl_user is already linked, so nothing to do # This case means that the sl_user is already linked, so nothing to do
return AlreadyLinkedUserStrategy(proton_user, sl_user) return AlreadyLinkedUserStrategy(proton_user, sl_user, partner)
def process_login_case(proton_user: ProtonUser) -> ProtonCallbackResult: def process_login_case(
proton_user: ProtonUser, partner: Partner
) -> ProtonCallbackResult:
# Try to find a SimpleLogin user registered with that proton user id # Try to find a SimpleLogin user registered with that proton user id
proton_partner_id = get_proton_partner_id()
sl_user_with_external_id = User.get_by( sl_user_with_external_id = User.get_by(
partner_id=proton_partner_id, partner_user_id=proton_user.id partner_id=partner.id, partner_user_id=proton_user.id
) )
if sl_user_with_external_id is None: if sl_user_with_external_id is None:
# We didn't find any SimpleLogin user registered with that proton user id # We didn't find any SimpleLogin user registered with that proton user id
# Try to find it using the proton's e-mail address # Try to find it using the proton's e-mail address
sl_user = User.get_by(email=proton_user.email) sl_user = User.get_by(email=proton_user.email)
return get_login_strategy(proton_user, sl_user).process() return get_login_strategy(proton_user, sl_user, partner).process()
else: else:
# We found the SL user registered with that proton user id # We found the SL user registered with that proton user id
# We're done # We're done
return AlreadyLinkedUserStrategy( return AlreadyLinkedUserStrategy(
proton_user, sl_user_with_external_id proton_user, sl_user_with_external_id, partner
).process() ).process()
def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResult: def link_user(
proton_partner_id = get_proton_partner_id() proton_user: ProtonUser, current_user: User, partner: Partner
) -> ProtonCallbackResult:
current_user.partner_user_id = proton_user.id current_user.partner_user_id = proton_user.id
current_user.partner_id = proton_partner_id current_user.partner_id = partner.id
ensure_partner_user_exists(proton_user, current_user) ensure_partner_user_exists(proton_user, current_user, partner)
Session.commit() Session.commit()
return ProtonCallbackResult( return ProtonCallbackResult(
@ -185,16 +198,17 @@ def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResu
def process_link_case( def process_link_case(
proton_user: ProtonUser, current_user: User proton_user: ProtonUser,
current_user: User,
partner: Partner,
) -> ProtonCallbackResult: ) -> ProtonCallbackResult:
# Try to find a SimpleLogin user linked with this Proton account # Try to find a SimpleLogin user linked with this Proton account
proton_partner_id = get_proton_partner_id()
sl_user_linked_to_proton_account = User.get_by( sl_user_linked_to_proton_account = User.get_by(
partner_id=proton_partner_id, partner_user_id=proton_user.id partner_id=partner.id, partner_user_id=proton_user.id
) )
if sl_user_linked_to_proton_account is None: if sl_user_linked_to_proton_account is None:
# There is no SL user linked with the proton email. Proceed with linking # There is no SL user linked with the proton email. Proceed with linking
return link_user(proton_user, current_user) return link_user(proton_user, current_user, partner)
else: else:
# There is a SL user registered with the proton email. Check if is the current one # There is a SL user registered with the proton email. Check if is the current one
if sl_user_linked_to_proton_account.id == current_user.id: if sl_user_linked_to_proton_account.id == current_user.id:
@ -212,25 +226,27 @@ def process_link_case(
sl_user_linked_to_proton_account.partner_user_id = None sl_user_linked_to_proton_account.partner_user_id = None
other_partner_user = PartnerUser.get_by( other_partner_user = PartnerUser.get_by(
user_id=sl_user_linked_to_proton_account.id, user_id=sl_user_linked_to_proton_account.id,
partner_id=proton_partner_id, partner_id=partner.id,
) )
if other_partner_user is not None: if other_partner_user is not None:
PartnerUser.delete(other_partner_user.id) PartnerUser.delete(other_partner_user.id)
return link_user(proton_user, current_user) return link_user(proton_user, current_user, partner)
class ProtonCallbackHandler: class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient): def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client self.proton_client = proton_client
def handle_login(self) -> ProtonCallbackResult: def handle_login(self, partner: Partner) -> ProtonCallbackResult:
return process_login_case(self.__get_proton_user()) return process_login_case(self.__get_proton_user(), partner)
def handle_link(self, current_user: Optional[User]) -> ProtonCallbackResult: def handle_link(
self, current_user: Optional[User], partner: Partner
) -> ProtonCallbackResult:
if current_user is None: if current_user is None:
raise Exception("Cannot link account with current_user being None") raise Exception("Cannot link account with current_user being None")
return process_link_case(self.__get_proton_user(), current_user) return process_link_case(self.__get_proton_user(), current_user, partner)
def __get_proton_user(self) -> ProtonUser: def __get_proton_user(self) -> ProtonUser:
user = self.proton_client.get_user() user = self.proton_client.get_user()

View file

@ -4,6 +4,7 @@ from app.db import Session
from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan
from app.proton.proton_callback_handler import ( from app.proton.proton_callback_handler import (
ProtonCallbackHandler, ProtonCallbackHandler,
get_proton_partner,
get_proton_partner_id, get_proton_partner_id,
get_login_strategy, get_login_strategy,
process_link_case, process_link_case,
@ -84,7 +85,7 @@ def test_proton_callback_handler_unexistant_sl_user():
user=user, plan=ProtonPlan.Professional, organization={} user=user, plan=ProtonPlan.Professional, organization={}
) )
handler = ProtonCallbackHandler(mock_client) handler = ProtonCallbackHandler(mock_client)
res = handler.handle_login() res = handler.handle_login(get_proton_partner())
assert res.user is not None assert res.user is not None
assert res.user.email == email assert res.user.email == email
@ -102,7 +103,7 @@ def test_proton_callback_handler_existant_sl_user():
user=user, plan=ProtonPlan.Professional, organization={} user=user, plan=ProtonPlan.Professional, organization={}
) )
handler = ProtonCallbackHandler(mock_client) handler = ProtonCallbackHandler(mock_client)
res = handler.handle_login() res = handler.handle_login(get_proton_partner())
assert res.user is not None assert res.user is not None
assert res.user.id == sl_user.id assert res.user.id == sl_user.id
@ -116,6 +117,7 @@ def test_get_strategy_unexistant_sl_user():
strategy = get_login_strategy( strategy = get_login_strategy(
proton_user=random_proton_user(), proton_user=random_proton_user(),
sl_user=None, sl_user=None,
partner=get_proton_partner(),
) )
assert isinstance(strategy, UnexistantSlClientStrategy) assert isinstance(strategy, UnexistantSlClientStrategy)
@ -126,6 +128,7 @@ def test_get_strategy_existing_sl_user():
strategy = get_login_strategy( strategy = get_login_strategy(
proton_user=random_proton_user(email=email), proton_user=random_proton_user(email=email),
sl_user=sl_user, sl_user=sl_user,
partner=get_proton_partner(),
) )
assert isinstance(strategy, ExistingSlClientStrategy) assert isinstance(strategy, ExistingSlClientStrategy)
@ -137,6 +140,7 @@ def test_get_strategy_already_linked_user():
strategy = get_login_strategy( strategy = get_login_strategy(
proton_user=random_proton_user(user_id=proton_user_id, email=email), proton_user=random_proton_user(user_id=proton_user_id, email=email),
sl_user=sl_user, sl_user=sl_user,
partner=get_proton_partner(),
) )
assert isinstance(strategy, AlreadyLinkedUserStrategy) assert isinstance(strategy, AlreadyLinkedUserStrategy)
@ -159,6 +163,7 @@ def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
strategy = get_login_strategy( strategy = get_login_strategy(
proton_user=proton_user_1, proton_user=proton_user_1,
sl_user=sl_user, sl_user=sl_user,
partner=get_proton_partner(),
) )
assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy) assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy)
@ -179,7 +184,7 @@ def test_link_account_with_proton_account_same_address(flask_client):
proton_user = random_proton_user(user_id=proton_user_id, email=email) proton_user = random_proton_user(user_id=proton_user_id, email=email)
sl_user = create_user(email) sl_user = create_user(email)
res = process_link_case(proton_user, sl_user) res = process_link_case(proton_user, sl_user, get_proton_partner())
assert res.redirect_to_login is False assert res.redirect_to_login is False
assert res.redirect is not None assert res.redirect is not None
assert res.flash_category == "success" assert res.flash_category == "success"
@ -199,7 +204,7 @@ def test_link_account_with_proton_account_different_address(flask_client):
proton_user = random_proton_user(user_id=proton_user_id, email=random_email()) proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
sl_user = create_user() sl_user = create_user()
res = process_link_case(proton_user, sl_user) res = process_link_case(proton_user, sl_user, get_proton_partner())
assert res.redirect_to_login is False assert res.redirect_to_login is False
assert res.redirect is not None assert res.redirect is not None
assert res.flash_category == "success" assert res.flash_category == "success"
@ -226,7 +231,7 @@ def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
proton_user_id, email=random_email() proton_user_id, email=random_email()
) # User already linked with the proton account ) # User already linked with the proton account
res = process_link_case(proton_user, sl_user_1) res = process_link_case(proton_user, sl_user_1, get_proton_partner())
assert res.redirect_to_login is False assert res.redirect_to_login is False
assert res.redirect is not None assert res.redirect is not None
assert res.flash_category == "success" assert res.flash_category == "success"
@ -256,7 +261,7 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
proton_user_id, email=random_email() proton_user_id, email=random_email()
) # User already linked with the proton account ) # User already linked with the proton account
res = process_link_case(proton_user, sl_user_1) res = process_link_case(proton_user, sl_user_1, get_proton_partner())
assert res.redirect_to_login is False assert res.redirect_to_login is False
assert res.redirect is not None assert res.redirect is not None
assert res.flash_category == "success" assert res.flash_category == "success"
@ -282,4 +287,4 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
def test_cannot_create_instance_of_base_strategy(): def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception): with pytest.raises(Exception):
ClientMergeStrategy(random_proton_user(), None) ClientMergeStrategy(random_proton_user(), None, get_proton_partner())