diff --git a/Libraries/LibGemini/GeminiJob.cpp b/Libraries/LibGemini/GeminiJob.cpp index dac0237155f..91cb58cd2f0 100644 --- a/Libraries/LibGemini/GeminiJob.cpp +++ b/Libraries/LibGemini/GeminiJob.cpp @@ -63,6 +63,10 @@ void GeminiJob::start() m_socket->on_tls_finished = [this] { finish_up(); }; + m_socket->on_tls_certificate_request = [this](auto&) { + if (on_certificate_requested) + on_certificate_requested(*this); + }; bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port()); if (!success) { deferred_invoke([this](auto&) { @@ -89,6 +93,15 @@ void GeminiJob::read_while_data_available(Function read) } } +void GeminiJob::set_certificate(String certificate, String private_key) +{ + if (!m_socket->add_client_key(ByteBuffer::wrap(certificate.characters(), certificate.length()), ByteBuffer::wrap(private_key.characters(), private_key.length()))) { + dbg() << "LibGemini: Failed to set a client certificate"; + // FIXME: Do something about this failure + ASSERT_NOT_REACHED(); + } +} + void GeminiJob::register_on_ready_to_read(Function callback) { m_socket->on_tls_ready_to_read = [callback = move(callback)](auto&) { diff --git a/Libraries/LibGemini/GeminiJob.h b/Libraries/LibGemini/GeminiJob.h index 2a67557199f..79b33d14843 100644 --- a/Libraries/LibGemini/GeminiJob.h +++ b/Libraries/LibGemini/GeminiJob.h @@ -48,6 +48,9 @@ public: virtual void start() override; virtual void shutdown() override; + void set_certificate(String certificate, String key); + + Function on_certificate_requested; protected: virtual void register_on_ready_to_read(Function) override; diff --git a/Libraries/LibHTTP/HttpsJob.cpp b/Libraries/LibHTTP/HttpsJob.cpp index d5495b5de0f..8b5ae352b01 100644 --- a/Libraries/LibHTTP/HttpsJob.cpp +++ b/Libraries/LibHTTP/HttpsJob.cpp @@ -64,6 +64,10 @@ void HttpsJob::start() m_socket->on_tls_finished = [&] { finish_up(); }; + m_socket->on_tls_certificate_request = [this](auto&) { + if (on_certificate_requested) + on_certificate_requested(*this); + }; bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port()); if (!success) { deferred_invoke([this](auto&) { @@ -82,6 +86,15 @@ void HttpsJob::shutdown() m_socket = nullptr; } +void HttpsJob::set_certificate(String certificate, String private_key) +{ + if (!m_socket->add_client_key(ByteBuffer::wrap(certificate.characters(), certificate.length()), ByteBuffer::wrap(private_key.characters(), private_key.length()))) { + dbg() << "LibHTTP: Failed to set a client certificate"; + // FIXME: Do something about this failure + ASSERT_NOT_REACHED(); + } +} + void HttpsJob::read_while_data_available(Function read) { while (m_socket->can_read()) { diff --git a/Libraries/LibHTTP/HttpsJob.h b/Libraries/LibHTTP/HttpsJob.h index eb8dac4f63b..5eef2a773f4 100644 --- a/Libraries/LibHTTP/HttpsJob.h +++ b/Libraries/LibHTTP/HttpsJob.h @@ -49,6 +49,9 @@ public: virtual void start() override; virtual void shutdown() override; + void set_certificate(String certificate, String key); + + Function on_certificate_requested; protected: virtual void register_on_ready_to_read(Function) override; diff --git a/Libraries/LibProtocol/Client.cpp b/Libraries/LibProtocol/Client.cpp index f5e22979e42..3531bfaa767 100644 --- a/Libraries/LibProtocol/Client.cpp +++ b/Libraries/LibProtocol/Client.cpp @@ -68,6 +68,13 @@ bool Client::stop_download(Badge, Download& download) return send_sync(download.id())->success(); } +bool Client::set_certificate(Badge, Download& download, String certificate, String key) +{ + if (!m_downloads.contains(download.id())) + return false; + return send_sync(download.id(), move(certificate), move(key))->success(); +} + void Client::handle(const Messages::ProtocolClient::DownloadFinished& message) { RefPtr download; @@ -85,4 +92,13 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message) } } +OwnPtr Client::handle(const Messages::ProtocolClient::CertificateRequested& message) +{ + if (auto download = const_cast(m_downloads.get(message.download_id()).value_or(nullptr))) { + download->did_request_certificates({}); + } + + return make(); +} + } diff --git a/Libraries/LibProtocol/Client.h b/Libraries/LibProtocol/Client.h index 37351710fa7..6724a307799 100644 --- a/Libraries/LibProtocol/Client.h +++ b/Libraries/LibProtocol/Client.h @@ -46,14 +46,15 @@ public: bool is_supported_protocol(const String&); RefPtr start_download(const String& url, const HashMap& request_headers = {}); - bool stop_download(Badge, Download&); + bool set_certificate(Badge, Download&, String, String); private: Client(); virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override; virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override; + virtual OwnPtr handle(const Messages::ProtocolClient::CertificateRequested&) override; HashMap> m_downloads; }; diff --git a/Libraries/LibProtocol/Download.cpp b/Libraries/LibProtocol/Download.cpp index 0a2a29ed687..35c22fee042 100644 --- a/Libraries/LibProtocol/Download.cpp +++ b/Libraries/LibProtocol/Download.cpp @@ -67,4 +67,14 @@ void Download::did_progress(Badge, Optional total_size, u32 downloa if (on_progress) on_progress(total_size, downloaded_size); } + +void Download::did_request_certificates(Badge) +{ + if (on_certificate_requested) { + auto result = on_certificate_requested(); + if (!m_client->set_certificate({}, *this, result.certificate, result.key)) { + dbg() << "Download: set_certificate failed"; + } + } +} } diff --git a/Libraries/LibProtocol/Download.h b/Libraries/LibProtocol/Download.h index 12f7ce90528..9601ab06da8 100644 --- a/Libraries/LibProtocol/Download.h +++ b/Libraries/LibProtocol/Download.h @@ -40,6 +40,11 @@ class Client; class Download : public RefCounted { public: + struct CertificateAndKey { + String certificate; + String key; + }; + static NonnullRefPtr create_from_id(Badge, Client& client, i32 download_id) { return adopt(*new Download(client, download_id)); @@ -50,9 +55,11 @@ public: Function payload_storage, const HashMap& response_headers, Optional status_code)> on_finish; Function total_size, u32 downloaded_size)> on_progress; + Function on_certificate_requested; void did_finish(Badge, bool success, Optional status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers); void did_progress(Badge, Optional total_size, u32 downloaded_size); + void did_request_certificates(Badge); private: explicit Download(Client&, i32 download_id); diff --git a/Libraries/LibTLS/TLSv12.cpp b/Libraries/LibTLS/TLSv12.cpp index f87e73481bb..5e5e4018818 100644 --- a/Libraries/LibTLS/TLSv12.cpp +++ b/Libraries/LibTLS/TLSv12.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -721,4 +722,28 @@ TLSv12::TLSv12(Core::Object* parent, Version version) } } +bool TLSv12::add_client_key(const ByteBuffer& certificate_pem_buffer, const ByteBuffer& rsa_key) // FIXME: This should not be bound to RSA +{ + if (certificate_pem_buffer.is_empty() || rsa_key.is_empty()) { + return true; + } + auto decoded_certificate = decode_pem(certificate_pem_buffer.span(), 0); + if (decoded_certificate.is_empty()) { + dbg() << "Certificate not PEM"; + return false; + } + + auto maybe_certificate = parse_asn1(decoded_certificate); + if (!maybe_certificate.has_value()) { + dbg() << "Invalid certificate"; + return false; + } + + Crypto::PK::RSA rsa(rsa_key); + auto certificate = maybe_certificate.value(); + certificate.private_key = rsa.private_key(); + + return add_client_key(certificate); +} + } diff --git a/Libraries/LibTLS/TLSv12.h b/Libraries/LibTLS/TLSv12.h index ab9ca33ff5c..6e279ba5ed2 100644 --- a/Libraries/LibTLS/TLSv12.h +++ b/Libraries/LibTLS/TLSv12.h @@ -206,6 +206,7 @@ struct Certificate { CertificateKeyAlgorithm ec_algorithm; ByteBuffer exponent; Crypto::PK::RSAPublicKey public_key; + Crypto::PK::RSAPrivateKey private_key; String issuer_country; String issuer_state; String issuer_location; @@ -318,6 +319,13 @@ public: bool load_certificates(const ByteBuffer& pem_buffer); bool load_private_key(const ByteBuffer& pem_buffer); + bool add_client_key(const ByteBuffer& certificate_pem_buffer, const ByteBuffer& key_pem_buffer); + bool add_client_key(Certificate certificate) + { + m_context.client_certificates.append(move(certificate)); + return true; + } + ByteBuffer finish_build(); const StringView& alpn() const { return m_context.negotiated_alpn; } @@ -349,6 +357,7 @@ public: Function on_tls_error; Function on_tls_connected; Function on_tls_finished; + Function on_tls_certificate_request; private: explicit TLSv12(Core::Object* parent, Version version = Version::V12); diff --git a/Libraries/LibWeb/Loader/ResourceLoader.cpp b/Libraries/LibWeb/Loader/ResourceLoader.cpp index e5727cd871a..04c2349fc11 100644 --- a/Libraries/LibWeb/Loader/ResourceLoader.cpp +++ b/Libraries/LibWeb/Loader/ResourceLoader.cpp @@ -179,6 +179,9 @@ void ResourceLoader::load(const URL& url, Functionon_certificate_requested = []() -> Protocol::Download::CertificateAndKey { + return {}; + }; ++m_pending_loads; if (on_load_counter_change) on_load_counter_change(); diff --git a/Services/ProtocolServer/ClientConnection.cpp b/Services/ProtocolServer/ClientConnection.cpp index cbe21dd1ef7..a33baf29733 100644 --- a/Services/ProtocolServer/ClientConnection.cpp +++ b/Services/ProtocolServer/ClientConnection.cpp @@ -111,6 +111,11 @@ void ClientConnection::did_progress_download(Badge, Download& download post_message(Messages::ProtocolClient::DownloadProgress(download.id(), download.total_size(), download.downloaded_size())); } +void ClientConnection::did_request_certificates(Badge, Download& download) +{ + post_message(Messages::ProtocolClient::CertificateRequested(download.id())); +} + OwnPtr ClientConnection::handle(const Messages::ProtocolServer::Greet&) { return make(client_id()); @@ -122,4 +127,15 @@ OwnPtr ClientConnection::h return make(); } +OwnPtr ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message) +{ + auto* download = const_cast(m_downloads.get(message.download_id()).value_or(nullptr)); + bool success = false; + if (download) { + download->set_certificate(message.certificate(), message.key()); + success = true; + } + return make(success); +} + } diff --git a/Services/ProtocolServer/ClientConnection.h b/Services/ProtocolServer/ClientConnection.h index bedba4801f0..b592c5f29be 100644 --- a/Services/ProtocolServer/ClientConnection.h +++ b/Services/ProtocolServer/ClientConnection.h @@ -28,8 +28,8 @@ #include #include -#include #include +#include namespace ProtocolServer { @@ -46,6 +46,7 @@ public: void did_finish_download(Badge, Download&, bool success); void did_progress_download(Badge, Download&); + void did_request_certificates(Badge, Download&); private: virtual OwnPtr handle(const Messages::ProtocolServer::Greet&) override; @@ -53,6 +54,7 @@ private: virtual OwnPtr handle(const Messages::ProtocolServer::StartDownload&) override; virtual OwnPtr handle(const Messages::ProtocolServer::StopDownload&) override; virtual OwnPtr handle(const Messages::ProtocolServer::DisownSharedBuffer&) override; + virtual OwnPtr handle(const Messages::ProtocolServer::SetCertificate&); HashMap> m_downloads; HashMap> m_shared_buffers; diff --git a/Services/ProtocolServer/Download.cpp b/Services/ProtocolServer/Download.cpp index 32531894c6d..d0d9aa2ab8c 100644 --- a/Services/ProtocolServer/Download.cpp +++ b/Services/ProtocolServer/Download.cpp @@ -25,8 +25,8 @@ */ #include -#include #include +#include namespace ProtocolServer { @@ -59,6 +59,10 @@ void Download::set_response_headers(const HashMap total_size, u32 downloaded_size) m_client.did_progress_download({}, *this); } +void Download::did_request_certificates() +{ + m_client.did_request_certificates({}, *this); +} + } diff --git a/Services/ProtocolServer/Download.h b/Services/ProtocolServer/Download.h index 2cc0487a6d5..f0d0342006d 100644 --- a/Services/ProtocolServer/Download.h +++ b/Services/ProtocolServer/Download.h @@ -49,6 +49,7 @@ public: const HashMap& response_headers() const { return m_response_headers; } void stop(); + virtual void set_certificate(String, String); protected: explicit Download(ClientConnection&); @@ -56,6 +57,7 @@ protected: void did_finish(bool success); void did_progress(Optional total_size, u32 downloaded_size); void set_status_code(u32 status_code) { m_status_code = status_code; } + void did_request_certificates(); void set_payload(const ByteBuffer&); void set_response_headers(const HashMap&); diff --git a/Services/ProtocolServer/Forward.h b/Services/ProtocolServer/Forward.h index 13dc9f249fd..06eec6c7ab0 100644 --- a/Services/ProtocolServer/Forward.h +++ b/Services/ProtocolServer/Forward.h @@ -31,7 +31,9 @@ namespace ProtocolServer { class ClientConnection; class Download; class GeminiProtocol; +class HttpDownload; class HttpProtocol; +class HttpsDownload; class HttpsProtocol; class Protocol; diff --git a/Services/ProtocolServer/GeminiDownload.cpp b/Services/ProtocolServer/GeminiDownload.cpp index 2e9cff4d91d..79aba506c59 100644 --- a/Services/ProtocolServer/GeminiDownload.cpp +++ b/Services/ProtocolServer/GeminiDownload.cpp @@ -59,6 +59,14 @@ GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtron_progress = [this](Optional total, u32 current) { did_progress(total, current); }; + m_job->on_certificate_requested = [this](auto&) { + did_request_certificates(); + }; +} + +void GeminiDownload::set_certificate(String certificate, String key) +{ + m_job->set_certificate(move(certificate), move(key)); } GeminiDownload::~GeminiDownload() diff --git a/Services/ProtocolServer/GeminiDownload.h b/Services/ProtocolServer/GeminiDownload.h index b6e796e952a..c429bac7b1c 100644 --- a/Services/ProtocolServer/GeminiDownload.h +++ b/Services/ProtocolServer/GeminiDownload.h @@ -41,6 +41,8 @@ public: private: explicit GeminiDownload(ClientConnection&, NonnullRefPtr); + virtual void set_certificate(String certificate, String key) override; + NonnullRefPtr m_job; }; diff --git a/Services/ProtocolServer/HttpsDownload.cpp b/Services/ProtocolServer/HttpsDownload.cpp index 899c204019a..fe381d216ea 100644 --- a/Services/ProtocolServer/HttpsDownload.cpp +++ b/Services/ProtocolServer/HttpsDownload.cpp @@ -51,6 +51,14 @@ HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtron_progress = [this](Optional total, u32 current) { did_progress(total, current); }; + m_job->on_certificate_requested = [this](auto&) { + did_request_certificates(); + }; +} + +void HttpsDownload::set_certificate(String certificate, String key) +{ + m_job->set_certificate(move(certificate), move(key)); } HttpsDownload::~HttpsDownload() diff --git a/Services/ProtocolServer/HttpsDownload.h b/Services/ProtocolServer/HttpsDownload.h index 4b0ee573fc3..48f255b2fac 100644 --- a/Services/ProtocolServer/HttpsDownload.h +++ b/Services/ProtocolServer/HttpsDownload.h @@ -41,6 +41,8 @@ public: private: explicit HttpsDownload(ClientConnection&, NonnullRefPtr); + virtual void set_certificate(String certificate, String key) override; + NonnullRefPtr m_job; }; diff --git a/Services/ProtocolServer/ProtocolClient.ipc b/Services/ProtocolServer/ProtocolClient.ipc index e4f7ab7a5b3..ef00d760ced 100644 --- a/Services/ProtocolServer/ProtocolClient.ipc +++ b/Services/ProtocolServer/ProtocolClient.ipc @@ -3,4 +3,7 @@ endpoint ProtocolClient = 13 // Download notifications DownloadProgress(i32 download_id, Optional total_size, u32 downloaded_size) =| DownloadFinished(i32 download_id, bool success, Optional status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =| + + // Certificate requests + CertificateRequested(i32 download_id) => () } diff --git a/Services/ProtocolServer/ProtocolServer.ipc b/Services/ProtocolServer/ProtocolServer.ipc index 80097ebff44..e819eb3234b 100644 --- a/Services/ProtocolServer/ProtocolServer.ipc +++ b/Services/ProtocolServer/ProtocolServer.ipc @@ -12,4 +12,5 @@ endpoint ProtocolServer = 9 // Download API StartDownload(URL url, IPC::Dictionary request_headers) => (i32 download_id) StopDownload(i32 download_id) => (bool success) + SetCertificate(i32 download_id, String certificate, String key) => (bool success) }