Merge pull request #894 from cquintana92/feature/add-login-with-proton

Add login with proton
This commit is contained in:
Adrià Casajús 2022-05-05 12:29:00 +02:00 committed by GitHub
commit a92981c52d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 1142 additions and 31 deletions

View file

@ -9,6 +9,7 @@ from .views import (
github,
google,
facebook,
proton,
change_email,
mfa,
fido,

View file

@ -5,6 +5,7 @@ from wtforms import StringField, validators
from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login
from app.config import CONNECT_WITH_PROTON
from app.events.auth_event import LoginEvent
from app.extensions import limiter
from app.log import LOG
@ -67,4 +68,5 @@ def login():
form=form,
next_url=next_url,
show_resend_activation=show_resend_activation,
connect_with_proton=CONNECT_WITH_PROTON,
)

114
app/auth/views/proton.py Normal file
View file

@ -0,0 +1,114 @@
from flask import request, session, redirect, flash, url_for
from flask_limiter.util import get_remote_address
from flask_login import current_user
from requests_oauthlib import OAuth2Session
from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login
from app.config import (
PROTON_BASE_URL,
PROTON_CLIENT_ID,
PROTON_CLIENT_SECRET,
PROTON_VALIDATE_CERTS,
URL,
)
from app.proton.proton_client import HttpProtonClient, convert_access_token
from app.proton.proton_callback_handler import ProtonCallbackHandler, Action
from app.utils import encode_url, sanitize_next_url
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
_token_url = PROTON_BASE_URL + "/oauth/token"
# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri
# when served behind nginx, the redirect_uri is localhost... and not the real url
_redirect_uri = URL + "/auth/proton/callback"
def extract_action() -> Action:
action = request.args.get("action")
if action is not None:
if action == "link":
return Action.Link
else:
raise Exception(f"Unknown action: {action}")
return Action.Login
def get_action_from_state() -> Action:
oauth_action = session["oauth_action"]
if oauth_action == Action.Login.value:
return Action.Login
elif oauth_action == Action.Link.value:
return Action.Link
raise Exception(f"Unknown action in state: {oauth_action}")
@auth_bp.route("/proton/login")
def proton_login():
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next"))
if next_url:
redirect_uri = _redirect_uri + "?next=" + encode_url(next_url)
else:
redirect_uri = _redirect_uri
proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=redirect_uri)
authorization_url, state = proton.authorization_url(_authorization_base_url)
# State is used to prevent CSRF, keep this for later.
session["oauth_state"] = state
session["oauth_action"] = extract_action().value
return redirect(authorization_url)
@auth_bp.route("/proton/callback")
def proton_callback():
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
return redirect(url_for("auth.login"))
# user clicks on cancel
if "error" in request.args:
flash("Please use another sign in method then", "warning")
return redirect("/")
proton = OAuth2Session(
PROTON_CLIENT_ID,
state=session["oauth_state"],
redirect_uri=_redirect_uri,
)
token = proton.fetch_token(
_token_url,
client_secret=PROTON_CLIENT_SECRET,
authorization_response=request.url,
verify=PROTON_VALIDATE_CERTS,
method="GET",
include_client_id=True,
)
credentials = convert_access_token(token["access_token"])
action = get_action_from_state()
proton_client = HttpProtonClient(
PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS
)
handler = ProtonCallbackHandler(proton_client)
if action == Action.Login:
res = handler.handle_login()
elif action == Action.Link:
res = handler.handle_link(current_user)
else:
raise Exception(f"Unknown Action: {action.name}")
if res.flash_message is not None:
flash(res.flash_message, res.flash_category)
if res.redirect_to_login:
return redirect(url_for("auth.login"))
if res.redirect:
return redirect(res.redirect)
next_url = request.args.get("next") if request.args else None
return after_login(res.user, next_url)

View file

@ -6,6 +6,7 @@ from wtforms import StringField, validators
from app import email_utils, config
from app.auth.base import auth_bp
from app.config import CONNECT_WITH_PROTON
from app.auth.views.login_utils import get_referral
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
from app.db import Session
@ -102,6 +103,7 @@ def register():
form=form,
next_url=next_url,
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
connect_with_proton=CONNECT_WITH_PROTON,
)

View file

@ -84,7 +84,6 @@ BOUNCE_PREFIX_FOR_REPLY_PHASE = (
os.environ.get("BOUNCE_PREFIX_FOR_REPLY_PHASE") or "bounce_reply"
)
# VERP for transactional email: mail_from set to BOUNCE_PREFIX + email_log.id + BOUNCE_SUFFIX
TRANSACTIONAL_BOUNCE_PREFIX = (
os.environ.get("TRANSACTIONAL_BOUNCE_PREFIX") or "transactional+"
@ -159,7 +158,6 @@ if "DKIM_PRIVATE_KEY_PATH" in os.environ:
with open(DKIM_PRIVATE_KEY_PATH) as f:
DKIM_PRIVATE_KEY = f.read()
# Database
DB_URI = os.environ["DB_URI"]
@ -240,6 +238,14 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET")
FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID")
PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET")
PROTON_BASE_URL = os.environ.get(
"PROTON_BASE_URL", "https://account.protonmail.com/api"
)
PROTON_VALIDATE_CERTS = "PROTON_VALIDATE_CERTS" in os.environ
CONNECT_WITH_PROTON = "CONNECT_WITH_PROTON" in os.environ
# in seconds
AVATAR_URL_EXPIRATION = 3600 * 24 * 7 # 1h*24h/d*7d=1week
@ -287,7 +293,6 @@ STATUS_PAGE_URL = os.environ.get("STATUS_PAGE_URL") or "https://status.simplelog
# Loading PGP keys when mail_handler runs. To be used locally when init_app is not called.
LOAD_PGP_EMAIL_HANDLER = "LOAD_PGP_EMAIL_HANDLER" in os.environ
# Used when querying info on Apple API
# for iOS App
APPLE_API_SECRET = os.environ.get("APPLE_API_SECRET")

View file

@ -11,6 +11,7 @@ from flask import (
from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from flask_wtf.file import FileField
from typing import Optional
from wtforms import StringField, validators
from wtforms.fields.html5 import EmailField
@ -19,6 +20,7 @@ from app.config import (
URL,
FIRST_ALIAS_DOMAIN,
ALIAS_RANDOM_SUFFIX_LENGTH,
CONNECT_WITH_PROTON,
)
from app.dashboard.base import dashboard_bp
from app.db import Session
@ -43,7 +45,9 @@ from app.models import (
SLDomain,
CoinbaseSubscription,
AppleSubscription,
PartnerUser,
)
from app.proton.proton_callback_handler import get_proton_partner_id
from app.utils import random_string, sanitize_email
@ -62,6 +66,21 @@ class PromoCodeForm(FlaskForm):
code = StringField("Name", validators=[validators.DataRequired()])
def get_proton_linked_account() -> Optional[str]:
# Check if the current user has a partner_id
proton_partner_id = get_proton_partner_id()
if current_user.partner_id != proton_partner_id:
return None
# It has. Retrieve the information for the PartnerUser
proton_linked_account = PartnerUser.get_by(
user_id=current_user.id, partner_id=proton_partner_id
)
if proton_linked_account is None:
return None
return proton_linked_account.partner_email
@dashboard_bp.route("/setting", methods=["GET", "POST"])
@login_required
def setting():
@ -332,6 +351,7 @@ def setting():
manual_sub = ManualSubscription.get_by(user_id=current_user.id)
apple_sub = AppleSubscription.get_by(user_id=current_user.id)
coinbase_sub = CoinbaseSubscription.get_by(user_id=current_user.id)
proton_linked_account = get_proton_linked_account()
return render_template(
"dashboard/setting.html",
@ -348,6 +368,8 @@ def setting():
coinbase_sub=coinbase_sub,
FIRST_ALIAS_DOMAIN=FIRST_ALIAS_DOMAIN,
ALIAS_RAND_SUFFIX_LENGTH=ALIAS_RANDOM_SUFFIX_LENGTH,
connect_with_proton=CONNECT_WITH_PROTON,
proton_linked_account=proton_linked_account,
)
@ -409,3 +431,18 @@ def cancel_email_change():
"You have no pending email change. Redirect back to Setting page", "warning"
)
return redirect(url_for("dashboard.setting"))
@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
@login_required
def unlink_proton_account():
current_user.partner_id = None
current_user.partner_user_id = None
partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=get_proton_partner_id()
)
if partner_user is not None:
PartnerUser.delete(partner_user.id)
Session.commit()
flash("Your Proton account has been unlinked", "success")
return redirect(url_for("dashboard.setting"))

View file

@ -479,6 +479,15 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
sa.Boolean, default=False, nullable=False, server_default="1"
)
partner_id = sa.Column(sa.BigInteger, unique=False, nullable=True)
partner_user_id = sa.Column(sa.String(128), unique=False, nullable=True)
__table_args__ = (
sa.UniqueConstraint(
"partner_id", "partner_user_id", name="uq_partner_id_partner_user_id"
),
)
@property
def directory_quota(self):
return min(
@ -3037,3 +3046,42 @@ class ProviderComplaint(Base, ModelMixin):
user = orm.relationship(User, foreign_keys=[user_id])
refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id])
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
class PartnerApiToken(Base, ModelMixin):
__tablename__ = "partner_api_token"
token = sa.Column(sa.String(32), unique=True, nullable=False, index=True)
partner_id = sa.Column(
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
)
expiration_time = sa.Column(ArrowType, unique=False, nullable=True)
class PartnerUser(Base, ModelMixin):
__tablename__ = "partner_user"
user_id = sa.Column(
sa.ForeignKey("users.id", ondelete="cascade"),
unique=False,
nullable=False,
index=True,
)
partner_id = sa.Column(
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
)
partner_email = sa.Column(sa.String(255), unique=False, nullable=True)
__table_args__ = (
sa.UniqueConstraint("user_id", "partner_id", name="uq_user_id_partner_id"),
)
# endregion

0
app/proton/__init__.py Normal file
View file

View file

@ -0,0 +1,237 @@
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from flask import url_for
from typing import Optional
from app.db import Session
from app.models import User, PartnerUser, Partner
from app.proton.proton_client import ProtonClient, ProtonUser
from app.utils import random_string
PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER_ID: Optional[int] = None
def get_proton_partner_id() -> int:
global _PROTON_PARTNER_ID
if _PROTON_PARTNER_ID is None:
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
if partner is None:
raise Exception("Could not find Proton Partner instance")
_PROTON_PARTNER_ID = partner.id
return _PROTON_PARTNER_ID
class Action(enum.Enum):
Login = 1
Link = 2
@dataclass
class ProtonCallbackResult:
redirect_to_login: bool
flash_message: Optional[str]
flash_category: Optional[str]
redirect: Optional[str]
user: Optional[User]
def ensure_partner_user_exists(proton_user: ProtonUser, sl_user: User):
proton_partner_id = get_proton_partner_id()
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=proton_partner_id):
PartnerUser.create(
user_id=sl_user.id,
partner_id=proton_partner_id,
partner_email=proton_user.email,
)
Session.commit()
class ClientMergeStrategy(ABC):
def __init__(self, proton_user: ProtonUser, sl_user: Optional[User]):
if self.__class__ == ClientMergeStrategy:
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
self.proton_user = proton_user
self.sl_user = sl_user
@abstractmethod
def process(self) -> ProtonCallbackResult:
pass
class UnexistantSlClientStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
# Will create a new SL User with a random password
proton_partner_id = get_proton_partner_id()
new_user = User.create(
email=self.proton_user.email,
name=self.proton_user.name,
partner_user_id=self.proton_user.id,
partner_id=proton_partner_id,
password=random_string(20),
)
PartnerUser.create(
user_id=new_user.id,
partner_id=proton_partner_id,
partner_email=self.proton_user.email,
)
# TODO: Adjust plans
Session.commit()
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=new_user,
)
class ExistingSlClientStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
ensure_partner_user_exists(self.proton_user, self.sl_user)
# TODO: Adjust plans
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=self.sl_user,
)
class ExistingSlUserLinkedWithDifferentProtonAccountStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
return ProtonCallbackResult(
redirect_to_login=True,
flash_message="This Proton account is already linked to another account",
flash_category="error",
user=None,
redirect=None,
)
class AlreadyLinkedUserStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=self.sl_user,
)
def get_login_strategy(
proton_user: ProtonUser, sl_user: Optional[User]
) -> ClientMergeStrategy:
if sl_user is None:
# We couldn't find any SimpleLogin user with the requested e-mail
return UnexistantSlClientStrategy(proton_user, sl_user)
# There is a SimpleLogin user with the proton_user's e-mail
# Try to find if it has been registered via a partner
if sl_user.partner_id is None:
# It has not been registered via a Partner
return ExistingSlClientStrategy(proton_user, sl_user)
# It has been registered via a partner
# Check if the partner_user_id matches
if sl_user.partner_user_id != proton_user.id:
# It doesn't match. That means that the SimpleLogin user has a different Proton account linked
return ExistingSlUserLinkedWithDifferentProtonAccountStrategy(
proton_user, sl_user
)
# This case means that the sl_user is already linked, so nothing to do
return AlreadyLinkedUserStrategy(proton_user, sl_user)
def process_login_case(proton_user: ProtonUser) -> ProtonCallbackResult:
# Try to find a SimpleLogin user registered with that proton user id
proton_partner_id = get_proton_partner_id()
sl_user_with_external_id = User.get_by(
partner_id=proton_partner_id, partner_user_id=proton_user.id
)
if sl_user_with_external_id is None:
# We didn't find any SimpleLogin user registered with that proton user id
# Try to find it using the proton's e-mail address
sl_user = User.get_by(email=proton_user.email)
return get_login_strategy(proton_user, sl_user).process()
else:
# We found the SL user registered with that proton user id
# We're done
return AlreadyLinkedUserStrategy(
proton_user, sl_user_with_external_id
).process()
def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResult:
proton_partner_id = get_proton_partner_id()
current_user.partner_user_id = proton_user.id
current_user.partner_id = proton_partner_id
ensure_partner_user_exists(proton_user, current_user)
Session.commit()
return ProtonCallbackResult(
redirect_to_login=False,
redirect=url_for("dashboard.setting"),
flash_category="success",
flash_message="Account successfully linked",
user=current_user,
)
def process_link_case(
proton_user: ProtonUser, current_user: User
) -> ProtonCallbackResult:
# Try to find a SimpleLogin user linked with this Proton account
proton_partner_id = get_proton_partner_id()
sl_user_linked_to_proton_account = User.get_by(
partner_id=proton_partner_id, partner_user_id=proton_user.id
)
if sl_user_linked_to_proton_account is None:
# There is no SL user linked with the proton email. Proceed with linking
return link_user(proton_user, current_user)
else:
# There is a SL user registered with the proton email. Check if is the current one
if sl_user_linked_to_proton_account.id == current_user.id:
# It's the same user. No need to do anything
return ProtonCallbackResult(
redirect_to_login=False,
redirect=url_for("dashboard.setting"),
flash_category="success",
flash_message="Account successfully linked",
user=current_user,
)
else:
# It's a different user. Unlink the other account and link the current one
sl_user_linked_to_proton_account.partner_id = None
sl_user_linked_to_proton_account.partner_user_id = None
other_partner_user = PartnerUser.get_by(
user_id=sl_user_linked_to_proton_account.id,
partner_id=proton_partner_id,
)
if other_partner_user is not None:
PartnerUser.delete(other_partner_user.id)
return link_user(proton_user, current_user)
class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client
def handle_login(self) -> ProtonCallbackResult:
return process_login_case(self.__get_proton_user())
def handle_link(self, current_user: Optional[User]) -> ProtonCallbackResult:
if current_user is None:
raise Exception("Cannot link account with current_user being None")
return process_link_case(self.__get_proton_user(), current_user)
def __get_proton_user(self) -> ProtonUser:
user = self.proton_client.get_user()
plan = self.proton_client.get_plan()
return ProtonUser(email=user.email, plan=plan, name=user.name, id=user.id)

175
app/proton/proton_client.py Normal file
View file

@ -0,0 +1,175 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from http import HTTPStatus
from requests import Response, Session
from typing import Optional
_APP_VERSION = "OauthClient_1.0.0"
PROTON_ERROR_CODE_NOT_EXISTS = 2501
class ProtonPlan(Enum):
Free = 0
Professional = 1
Visionary = 2
def name(self):
if self == self.Free:
return "Free"
elif self == self.Professional:
return "Professional"
elif self == self.Visionary:
return "Visionary"
else:
raise Exception("Unknown plan")
def plan_from_name(name: str) -> ProtonPlan:
name_lower = name.lower()
if name_lower == "free":
return ProtonPlan.Free
elif name_lower == "professional":
return ProtonPlan.Professional
elif name_lower == "visionary":
return ProtonPlan.Visionary
else:
raise Exception(f"Unknown plan [{name}]")
@dataclasses.dataclass
class UserInformation:
email: str
name: str
id: str
class AuthorizeResponse:
def __init__(self, code: str, has_accepted: bool):
self.code = code
self.has_accepted = has_accepted
def __str__(self):
return f"[code={self.code}] [has_accepted={self.has_accepted}]"
@dataclasses.dataclass
class SessionResponse:
state: str
expires_in: int
token_type: str
refresh_token: str
access_token: str
session_id: str
@dataclasses.dataclass
class ProtonUser:
id: str
name: str
email: str
plan: ProtonPlan
@dataclasses.dataclass
class AccessCredentials:
access_token: str
session_id: str
def convert_access_token(access_token_response: str) -> AccessCredentials:
"""
The Access token response contains both the Proton Session ID and the Access Token.
The Session ID is necessary in order to use the Proton API. However, the OAuth response does not allow us to return
extra content.
This method takes the Access token response and extracts the session ID and the access token.
"""
parts = access_token_response.split("-")
if len(parts) != 3:
raise Exception("Invalid access token response")
if parts[0] != "pt":
raise Exception("Invalid access token response format")
return AccessCredentials(
session_id=parts[1],
access_token=parts[2],
)
class ProtonClient(ABC):
@abstractmethod
def get_user(self) -> UserInformation:
pass
@abstractmethod
def get_organization(self) -> dict:
pass
@abstractmethod
def get_plan(self) -> ProtonPlan:
pass
class HttpProtonClient(ProtonClient):
def __init__(
self,
base_url: str,
credentials: AccessCredentials,
original_ip: Optional[str],
verify: bool = True,
):
self.base_url = base_url
self.access_token = credentials.access_token
client = Session()
client.verify = verify
headers = {
"x-pm-appversion": _APP_VERSION,
"x-pm-apiversion": "3",
"x-pm-uid": credentials.session_id,
"authorization": f"Bearer {credentials.access_token}",
"accept": "application/vnd.protonmail.v1+json",
"user-agent": "ProtonOauthClient",
}
if original_ip is not None:
headers["x-forwarded-for"] = original_ip
client.headers.update(headers)
self.client = client
def get_user(self) -> UserInformation:
info = self.__get("/users")["User"]
return UserInformation(
email=info.get("Email"), name=info.get("Name"), id=info.get("ID")
)
def get_organization(self) -> dict:
return self.__get("/code/v4/organizations")["Organization"]
def get_plan(self) -> ProtonPlan:
url = f"{self.base_url}/core/v4/organizations"
res = self.client.get(url)
status = res.status_code
if status == HTTPStatus.UNPROCESSABLE_ENTITY:
as_json = res.json()
error_code = as_json.get("Code")
if error_code == PROTON_ERROR_CODE_NOT_EXISTS:
return ProtonPlan.Free
org = self.__validate_response(res).get("Organization")
if org is None:
return ProtonPlan.Free
return plan_from_name(org["PlanName"])
def __get(self, route: str) -> dict:
url = f"{self.base_url}{route}"
res = self.client.get(url)
return self.__validate_response(res)
@staticmethod
def __validate_response(res: Response) -> dict:
status = res.status_code
if status != HTTPStatus.OK:
raise Exception(
f"Unexpected status code. Wanted 200 and got {status}: " + res.text
)
return res.json()

View file

@ -108,6 +108,13 @@ WORDS_FILE_PATH=local_data/test_words.txt
# FACEBOOK_CLIENT_ID=to_fill
# FACEBOOK_CLIENT_SECRET=to_fill
# Login with Proton
# PROTON_CLIENT_ID=to_fill
# PROTON_CLIENT_SECRET=to_fill
# PROTON_BASE_URL=to_fill
# PROTON_VALIDATE_CERTS=true
# CONNECT_WITH_PROTON=true
# Flask profiler
# FLASK_PROFILER_PATH=/tmp/flask-profiler.sql
# FLASK_PROFILER_PASSWORD=password

View file

@ -4,8 +4,9 @@ from app.config import (
)
from app.db import Session
from app.log import LOG
from app.models import Mailbox, Contact, SLDomain
from app.models import Mailbox, Contact, SLDomain, Partner
from app.pgp_utils import load_public_key
from app.proton.proton_callback_handler import PROTON_PARTNER_NAME
from server import create_light_app
@ -53,6 +54,16 @@ def add_sl_domains():
Session.commit()
def add_proton_partner():
proton_partner = Partner.get_by(name=PROTON_PARTNER_NAME)
if not proton_partner:
Partner.create(
name=PROTON_PARTNER_NAME,
contact_email="simplelogin@protonmail.com",
)
Session.commit()
if __name__ == "__main__":
# wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc
with create_light_app().app_context():

View file

@ -0,0 +1,76 @@
"""Add partner tables
Revision ID: e866ad0e78e1
Revises: 0aaad1740797
Create Date: 2022-05-05 12:10:01.229457
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e866ad0e78e1'
down_revision = '0aaad1740797'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('partner',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('name', sa.String(length=128), nullable=False),
sa.Column('contact_email', sa.String(length=128), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('contact_email'),
sa.UniqueConstraint('name')
)
op.create_table('partner_api_token',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('token', sa.String(length=32), nullable=False),
sa.Column('partner_id', sa.Integer(), nullable=False),
sa.Column('expiration_time', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.ForeignKeyConstraint(['partner_id'], ['partner.id'], ondelete='cascade'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_partner_api_token_partner_id'), 'partner_api_token', ['partner_id'], unique=False)
op.create_index(op.f('ix_partner_api_token_token'), 'partner_api_token', ['token'], unique=True)
op.create_table('partner_user',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('partner_id', sa.Integer(), nullable=False),
sa.Column('partner_email', sa.String(length=255), nullable=True),
sa.ForeignKeyConstraint(['partner_id'], ['partner.id'], ondelete='cascade'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='cascade'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id', 'partner_id', name='uq_user_id_partner_id')
)
op.create_index(op.f('ix_partner_user_partner_id'), 'partner_user', ['partner_id'], unique=False)
op.create_index(op.f('ix_partner_user_user_id'), 'partner_user', ['user_id'], unique=False)
op.add_column('users', sa.Column('partner_id', sa.BigInteger(), nullable=True))
op.add_column('users', sa.Column('partner_user_id', sa.String(length=128), nullable=True))
op.create_unique_constraint('uq_partner_id_partner_user_id', 'users', ['partner_id', 'partner_user_id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('uq_partner_id_partner_user_id', 'users', type_='unique')
op.drop_column('users', 'partner_user_id')
op.drop_column('users', 'partner_id')
op.drop_index(op.f('ix_partner_user_user_id'), table_name='partner_user')
op.drop_index(op.f('ix_partner_user_partner_id'), table_name='partner_user')
op.drop_table('partner_user')
op.drop_index(op.f('ix_partner_api_token_token'), table_name='partner_api_token')
op.drop_index(op.f('ix_partner_api_token_partner_id'), table_name='partner_api_token')
op.drop_table('partner_api_token')
op.drop_table('partner')
# ### end Alembic commands ###

8
static/style.css vendored
View file

@ -186,4 +186,10 @@ textarea.parsley-error {
#help-menu-item {
display: none;
}
}
}
.proton-button {
border-color:#6d4aff;
background-color:white;
color:#6d4aff;
}

View file

@ -13,39 +13,46 @@
</div>
{% endif %}
<form class="card" style="border-radius: 2%" method="post">
{{ form.csrf_token }}
<div class="card" style="border-radius: 2%">
<div class="card-body p-6">
<h1 class="card-title">Welcome back!</h1>
<div class="form-group">
<label class="form-label">Email address</label>
{{ form.email(class="form-control", type="email", autofocus="true") }}
{{ render_field_errors(form.email) }}
</div>
<form method="post">
{{ form.csrf_token }}
<div class="form-group">
<label class="form-label">
Password
</label>
{{ form.password(class="form-control", type="password") }}
{{ render_field_errors(form.password) }}
<div class="text-muted">
<a href="{{ url_for('auth.forgot_password') }}" class="small">
I forgot my password
</a>
<div class="form-group">
<label class="form-label">Email address</label>
{{ form.email(class="form-control", type="email", autofocus="true") }}
{{ render_field_errors(form.email) }}
</div>
</div>
<div class="form-footer">
<button type="submit" class="btn btn-primary btn-block">Log in</button>
</div>
<div class="form-group">
<label class="form-label">
Password
</label>
{{ form.password(class="form-control", type="password") }}
{{ render_field_errors(form.password) }}
<div class="text-muted">
<a href="{{ url_for('auth.forgot_password') }}" class="small">
I forgot my password
</a>
</div>
</div>
<div class="form-footer">
<button type="submit" class="btn btn-primary btn-block">Log in</button>
</div>
</form>
{% if connect_with_proton %}
<div class="text-center my-2 text-gray"><span>or</span></div>
<a class="btn btn-primary btn-block mt-2 proton-button" href="{{ url_for("auth.proton_login") }}">Log in with Proton</a>
{% endif %}
</div>
</form>
</div>
<div class="text-center text-muted mt-2">
Don't have an account yet? <a href="{{ url_for('auth.register') }}">Sign up</a>
</div>
{% endblock %}
{% endblock %}

View file

@ -48,6 +48,11 @@
<div class="mt-2">
<button type="submit" class="btn btn-primary btn-block">Create Account</button>
</div>
{% if connect_with_proton %}
<div class="text-center my-2 text-gray"><span>or</span></div>
<a class="btn btn-primary btn-block mt-2 proton-button" href="{{ url_for("auth.proton_login") }}">Sign up with Proton</a>
{% endif %}
</div>
</form>
<div class="text-center text-muted mb-6">

View file

@ -208,6 +208,32 @@
</div>
<!-- END Change email -->
<!-- Connect with Proton -->
{% if connect_with_proton %}
<div class="card">
<div class="card-body">
<div class="card-title">
Connect with Proton
</div>
{% if proton_linked_account != None %}
<div class="mb-3">
You have linked your Proton account: {{ proton_linked_account }} <br>
</div>
<a
class="btn btn-primary mt-2 proton-button"
href="{{ url_for('dashboard.unlink_proton_account') }}"
>Unlink account</a>
{% else %}
<div class="mb-3">
You can connect your Proton account with your SimpleLogin one. <br>
</div>
<a class="btn btn-primary mt-2 proton-button" href="{{ url_for("auth.proton_login", action="link") }}">Connect with Proton</a>
{% endif %}
</div>
</div>
{% endif %}
<!-- END Connect with Proton -->
<!-- Change password -->
<div class="card" id="change_password">
<div class="card-body">
@ -539,7 +565,7 @@
<div class="form-check">
<input type="checkbox" id="include-sender-header" name="enable"
{% if current_user.include_header_email_header %} checked {% endif %} class="form-check-input">
{% if current_user.include_header_email_header %} checked {% endif %} class="form-check-input">
<label for="include-sender-header">Include sender address in email headers</label>
</div>

23
tests/auth/test_proton.py Normal file
View file

@ -0,0 +1,23 @@
from flask import url_for
from urllib.parse import parse_qs
from urllib3.util import parse_url
from app.config import URL, PROTON_CLIENT_ID
def test_login_with_proton(flask_client):
r = flask_client.get(
url_for("auth.proton_login"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
query = parse_qs(parsed.query)
expected_redirect_url = f"{URL}/auth/proton/callback"
assert "code" == query["response_type"][0]
assert PROTON_CLIENT_ID == query["client_id"][0]
assert expected_redirect_url == query["redirect_uri"][0]

View file

@ -16,7 +16,7 @@ from psycopg2.errorcodes import DEPENDENT_OBJECTS_STILL_EXIST
import pytest
from server import create_app
from init_app import add_sl_domains
from init_app import add_sl_domains, add_proton_partner
app = create_app()
app.config["TESTING"] = True
@ -34,6 +34,7 @@ with engine.connect() as conn:
conn.execute("Rollback")
add_sl_domains()
add_proton_partner()
@pytest.fixture

0
tests/proton/__init__.py Normal file
View file

View file

@ -0,0 +1,285 @@
import pytest
from app.db import Session
from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan
from app.proton.proton_callback_handler import (
ProtonCallbackHandler,
get_proton_partner_id,
get_login_strategy,
process_link_case,
ProtonUser,
UnexistantSlClientStrategy,
ExistingSlClientStrategy,
AlreadyLinkedUserStrategy,
ExistingSlUserLinkedWithDifferentProtonAccountStrategy,
ClientMergeStrategy,
)
from app.models import User, PartnerUser
from app.utils import random_string
class MockProtonClient(ProtonClient):
def __init__(self, user: UserInformation, plan: ProtonPlan, organization: dict):
self.user = user
self.plan = plan
self.organization = organization
def get_organization(self) -> dict:
return self.organization
def get_user(self) -> UserInformation:
return self.user
def get_plan(self) -> ProtonPlan:
return self.plan
def random_email() -> str:
return "{rand}@{rand}.com".format(rand=random_string(20))
def random_proton_user(
user_id: str = None,
name: str = None,
email: str = None,
plan: ProtonPlan = None,
) -> ProtonUser:
user_id = user_id if user_id is not None else random_string()
name = name if name is not None else random_string()
email = (
email
if email is not None
else "{rand}@{rand}.com".format(rand=random_string(20))
)
plan = plan if plan is not None else ProtonPlan.Free
return ProtonUser(id=user_id, name=name, email=email, plan=plan)
def create_user(email: str = None) -> User:
email = email if email is not None else random_email()
user = User.create(email=email)
Session.commit()
return user
def create_user_for_partner(partner_user_id: str, email: str = None) -> User:
email = email if email is not None else random_email()
user = User.create(email=email)
user.partner_id = get_proton_partner_id()
user.partner_user_id = partner_user_id
PartnerUser.create(
user_id=user.id, partner_id=get_proton_partner_id(), partner_email=email
)
Session.commit()
return user
def test_proton_callback_handler_unexistant_sl_user():
email = random_email()
name = random_string()
external_id = random_string()
user = UserInformation(email=email, name=name, id=external_id)
mock_client = MockProtonClient(
user=user, plan=ProtonPlan.Professional, organization={}
)
handler = ProtonCallbackHandler(mock_client)
res = handler.handle_login()
assert res.user is not None
assert res.user.email == email
assert res.user.name == name
assert res.user.partner_user_id == external_id
def test_proton_callback_handler_existant_sl_user():
email = random_email()
sl_user = User.create(email, commit=True)
external_id = random_string()
user = UserInformation(email=email, name=random_string(), id=external_id)
mock_client = MockProtonClient(
user=user, plan=ProtonPlan.Professional, organization={}
)
handler = ProtonCallbackHandler(mock_client)
res = handler.handle_login()
assert res.user is not None
assert res.user.id == sl_user.id
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner_id())
assert sa is not None
assert sa.partner_email == user.email
def test_get_strategy_unexistant_sl_user():
strategy = get_login_strategy(
proton_user=random_proton_user(),
sl_user=None,
)
assert isinstance(strategy, UnexistantSlClientStrategy)
def test_get_strategy_existing_sl_user():
email = random_email()
sl_user = User.create(email, commit=True)
strategy = get_login_strategy(
proton_user=random_proton_user(email=email),
sl_user=sl_user,
)
assert isinstance(strategy, ExistingSlClientStrategy)
def test_get_strategy_already_linked_user():
email = random_email()
proton_user_id = random_string()
sl_user = create_user_for_partner(proton_user_id, email=email)
strategy = get_login_strategy(
proton_user=random_proton_user(user_id=proton_user_id, email=email),
sl_user=sl_user,
)
assert isinstance(strategy, AlreadyLinkedUserStrategy)
def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
# In this scenario we have
# - ProtonUser1 (ID1, email1@proton)
# - ProtonUser2 (ID2, email2@proton)
# - SimpleLoginUser1 registered with email1@proton, but linked to account ID2
# We will try to log in with email1@proton
email1 = random_email()
email2 = random_email()
proton_user_id_1 = random_string()
proton_user_id_2 = random_string()
proton_user_1 = random_proton_user(user_id=proton_user_id_1, email=email1)
proton_user_2 = random_proton_user(user_id=proton_user_id_2, email=email2)
sl_user = create_user_for_partner(proton_user_2.id, email=proton_user_1.email)
strategy = get_login_strategy(
proton_user=proton_user_1,
sl_user=sl_user,
)
assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy)
##
# LINK
def test_link_account_with_proton_account_same_address(flask_client):
# This is the most basic scenario
# In this scenario we have:
# - ProtonUser (email1@proton)
# - SimpleLoginUser registered with email1@proton
# We will try to link both accounts
email = random_email()
proton_user_id = random_string()
proton_user = random_proton_user(user_id=proton_user_id, email=email)
sl_user = create_user(email)
res = process_link_case(proton_user, sl_user)
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user = User.get(sl_user.id)
assert updated_user.partner_id == get_proton_partner_id()
assert updated_user.partner_user_id == proton_user_id
def test_link_account_with_proton_account_different_address(flask_client):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser (bar@somethingelse)
# We will try to link both accounts
proton_user_id = random_string()
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
sl_user = create_user()
res = process_link_case(proton_user, sl_user)
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user = User.get(sl_user.id)
assert updated_user.partner_id == get_proton_partner_id()
assert updated_user.partner_user_id == proton_user_id
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
flask_client,
):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser1 (foo@proton)
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
proton_user_id = random_string()
proton_email = random_email()
proton_user = random_proton_user(user_id=proton_user_id, email=proton_email)
sl_user_1 = create_user(proton_email)
sl_user_2 = create_user_for_partner(
proton_user_id, email=random_email()
) # User already linked with the proton account
res = process_link_case(proton_user, sl_user_1)
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user_1 = User.get(sl_user_1.id)
assert updated_user_1.partner_id == get_proton_partner_id()
assert updated_user_1.partner_user_id == proton_user_id
updated_user_2 = User.get(sl_user_2.id)
assert updated_user_2.partner_id is None
assert updated_user_2.partner_user_id is None
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
flask_client,
):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser1 (bar@somethingelse)
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
proton_user_id = random_string()
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
sl_user_1 = create_user(random_email())
sl_user_2 = create_user_for_partner(
proton_user_id, email=random_email()
) # User already linked with the proton account
res = process_link_case(proton_user, sl_user_1)
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user_1 = User.get(sl_user_1.id)
assert updated_user_1.partner_id == get_proton_partner_id()
assert updated_user_1.partner_user_id == proton_user_id
partner_user_1 = PartnerUser.get_by(
user_id=sl_user_1.id, partner_id=get_proton_partner_id()
)
assert partner_user_1 is not None
assert partner_user_1.partner_email == proton_user.email
updated_user_2 = User.get(sl_user_2.id)
assert updated_user_2.partner_id is None
assert updated_user_2.partner_user_id is None
partner_user_2 = PartnerUser.get_by(
user_id=sl_user_2.id, partner_id=get_proton_partner_id()
)
assert partner_user_2 is None
def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception):
ClientMergeStrategy(random_proton_user(), None)

View file

@ -0,0 +1,21 @@
import pytest
from app.proton import proton_client
def test_convert_access_token_valid():
res = proton_client.convert_access_token("pt-abc-123")
assert res.session_id == "abc"
assert res.access_token == "123"
def test_convert_access_token_not_containing_pt():
with pytest.raises(Exception):
proton_client.convert_access_token("pb-abc-123")
def test_convert_access_token_not_containing_invalid_length():
cases = ["pt-abc-too-long", "pt-short"]
for case in cases:
with pytest.raises(Exception):
proton_client.convert_access_token(case)

View file

@ -54,4 +54,9 @@ PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc
ALIAS_AUTOMATIC_DISABLE=true
ALLOWED_REDIRECT_DOMAINS=["test.simplelogin.local"]
DMARC_CHECK_ENABLED=true
DMARC_CHECK_ENABLED=true
PROTON_CLIENT_ID=to_fill
PROTON_CLIENT_SECRET=to_fill
PROTON_BASE_URL=https://localhost/api

View file

@ -1,4 +1,5 @@
from typing import List
from urllib.parse import parse_qs
import pytest
@ -40,3 +41,19 @@ def generate_sanitize_url_cases() -> List:
def test_sanitize_url(url, expected):
sanitized = sanitize_next_url(url)
assert expected == sanitized
def test_parse_querystring():
cases = [
{"input": "", "expected": {}},
{"input": "a=b", "expected": {"a": ["b"]}},
{"input": "a=b&c=d", "expected": {"a": ["b"], "c": ["d"]}},
{"input": "a=b&a=c", "expected": {"a": ["b", "c"]}},
]
for case in cases:
expected = case["expected"]
res = parse_qs(case["input"])
assert len(res) == len(expected)
for k, v in expected.items():
assert res[k] == v