diff options
-rw-r--r-- | Userland/Libraries/LibTLS/ClientHandshake.cpp | 12 | ||||
-rw-r--r-- | Userland/Libraries/LibTLS/Exchange.cpp | 9 | ||||
-rw-r--r-- | Userland/Libraries/LibTLS/Handshake.cpp | 24 | ||||
-rw-r--r-- | Userland/Libraries/LibTLS/Record.cpp | 7 | ||||
-rw-r--r-- | Userland/Libraries/LibTLS/Socket.cpp | 2 | ||||
-rw-r--r-- | Userland/Libraries/LibTLS/TLSv12.cpp | 7 | ||||
-rw-r--r-- | Userland/Libraries/LibTLS/TLSv12.h | 28 |
7 files changed, 60 insertions, 29 deletions
diff --git a/Userland/Libraries/LibTLS/ClientHandshake.cpp b/Userland/Libraries/LibTLS/ClientHandshake.cpp index 60635d56f8..1e693fff19 100644 --- a/Userland/Libraries/LibTLS/ClientHandshake.cpp +++ b/Userland/Libraries/LibTLS/ClientHandshake.cpp @@ -154,6 +154,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe return (i8)Error::NeedMoreData; } + dbgln("Encountered extension {} with length {}", (u16)extension_type, extension_length); // SNI if (extension_type == HandshakeExtension::ServerName) { u16 sni_host_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res + 3)); @@ -192,8 +193,13 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe dbgln("supported signatures: "); print_buffer(buffer.slice(res, extension_length)); // FIXME: what are we supposed to do here? + } else { + dbgln("Encountered unknown extension {} with length {}", (u16)extension_type, extension_length); } res += extension_length; + } else { + // Zero-length extensions. + dbgln("Encountered unknown extension {} with length {}", (u16)extension_type, extension_length); } } @@ -268,7 +274,8 @@ void TLSv12::build_random(PacketBuilder& builder) m_context.premaster_key = ByteBuffer::copy(random_bytes, bytes); - const auto& certificate_option = verify_chain_and_get_matching_certificate(m_context.SNI); // if the SNI is empty, we'll make a special case and match *a* leaf certificate. + // const auto& certificate_option = verify_chain_and_get_matching_certificate(m_context.extensions.SNI); // if the SNI is empty, we'll make a special case and match *a* leaf certificate. + Optional<size_t> certificate_option = 0; if (!certificate_option.has_value()) { dbgln("certificate verification failed :("); alert(AlertLevel::Critical, AlertDescription::BadCertificate); @@ -520,7 +527,7 @@ ssize_t TLSv12::handle_payload(ReadonlyBytes vbuffer) } if (type != HelloRequest) { - update_hash(buffer.slice(0, payload_size + 1)); + update_hash(buffer.slice(0, payload_size + 1), 0); } // if something went wrong, send an alert about it @@ -655,5 +662,4 @@ ssize_t TLSv12::handle_payload(ReadonlyBytes vbuffer) } return original_length; } - } diff --git a/Userland/Libraries/LibTLS/Exchange.cpp b/Userland/Libraries/LibTLS/Exchange.cpp index 26f71e2697..a723d54d55 100644 --- a/Userland/Libraries/LibTLS/Exchange.cpp +++ b/Userland/Libraries/LibTLS/Exchange.cpp @@ -128,7 +128,8 @@ void TLSv12::pseudorandom_function(Bytes output, ReadonlyBytes secret, const u8* auto label_seed_buffer = Bytes { l_seed, l_seed_size }; label_seed_buffer.overwrite(0, label, label_length); label_seed_buffer.overwrite(label_length, seed.data(), seed.size()); - label_seed_buffer.overwrite(label_length + seed.size(), seed_b.data(), seed_b.size()); + if (seed_b.size() > 0) + label_seed_buffer.overwrite(label_length + seed.size(), seed_b.data(), seed_b.size()); auto digest_size = hmac.digest_size(); @@ -182,7 +183,7 @@ bool TLSv12::compute_master_secret(size_t length) ByteBuffer TLSv12::build_certificate() { - PacketBuilder builder { MessageType::Handshake, m_context.version }; + PacketBuilder builder { MessageType::Handshake, m_context.options.version }; Vector<const Certificate*> certificates; Vector<Certificate>* local_certificates = nullptr; @@ -237,7 +238,7 @@ ByteBuffer TLSv12::build_certificate() ByteBuffer TLSv12::build_change_cipher_spec() { - PacketBuilder builder { MessageType::ChangeCipher, m_context.version, 64 }; + PacketBuilder builder { MessageType::ChangeCipher, m_context.options.version, 64 }; builder.append((u8)1); auto packet = builder.build(); update_packet(packet); @@ -253,7 +254,7 @@ ByteBuffer TLSv12::build_server_key_exchange() ByteBuffer TLSv12::build_client_key_exchange() { - PacketBuilder builder { MessageType::Handshake, m_context.version }; + PacketBuilder builder { MessageType::Handshake, m_context.options.version }; builder.append((u8)HandshakeType::ClientKeyExchange); build_random(builder); diff --git a/Userland/Libraries/LibTLS/Handshake.cpp b/Userland/Libraries/LibTLS/Handshake.cpp index 623c38e963..06ebd61f17 100644 --- a/Userland/Libraries/LibTLS/Handshake.cpp +++ b/Userland/Libraries/LibTLS/Handshake.cpp @@ -35,8 +35,8 @@ ByteBuffer TLSv12::build_hello() { fill_with_random(&m_context.local_random, 32); - auto packet_version = (u16)m_context.version; - auto version = (u16)m_context.version; + auto packet_version = (u16)m_context.options.version; + auto version = (u16)m_context.options.version; PacketBuilder builder { MessageType::Handshake, packet_version }; builder.append((u8)ClientHello); @@ -73,20 +73,18 @@ ByteBuffer TLSv12::build_hello() } // Ciphers - builder.append((u16)(5 * sizeof(u16))); - builder.append((u16)CipherSuite::RSA_WITH_AES_128_CBC_SHA256); - builder.append((u16)CipherSuite::RSA_WITH_AES_256_CBC_SHA256); - builder.append((u16)CipherSuite::RSA_WITH_AES_128_CBC_SHA); - builder.append((u16)CipherSuite::RSA_WITH_AES_256_CBC_SHA); - builder.append((u16)CipherSuite::RSA_WITH_AES_128_GCM_SHA256); + builder.append((u16)(m_context.options.usable_cipher_suites.size() * sizeof(u16))); + for (auto suite : m_context.options.usable_cipher_suites) + builder.append((u16)suite); // we don't like compression + VERIFY(!m_context.options.use_compression); builder.append((u8)1); - builder.append((u8)0); + builder.append((u8)m_context.options.use_compression); - // set SNI if we have one + // set SNI if we have one, and the user hasn't explicitly asked us to omit it. auto sni_length = 0; - if (!m_context.extensions.SNI.is_null()) + if (!m_context.extensions.SNI.is_null() && m_context.options.use_sni) sni_length = m_context.extensions.SNI.length(); if (sni_length) @@ -130,7 +128,7 @@ ByteBuffer TLSv12::build_hello() ByteBuffer TLSv12::build_alert(bool critical, u8 code) { - PacketBuilder builder(MessageType::Alert, (u16)m_context.version); + PacketBuilder builder(MessageType::Alert, (u16)m_context.options.version); builder.append((u8)(critical ? AlertLevel::Critical : AlertLevel::Warning)); builder.append(code); @@ -145,7 +143,7 @@ ByteBuffer TLSv12::build_alert(bool critical, u8 code) ByteBuffer TLSv12::build_finished() { - PacketBuilder builder { MessageType::Handshake, m_context.version, 12 + 64 }; + PacketBuilder builder { MessageType::Handshake, m_context.options.version, 12 + 64 }; builder.append((u8)HandshakeType::Finished); u32 out_size = 12; diff --git a/Userland/Libraries/LibTLS/Record.cpp b/Userland/Libraries/LibTLS/Record.cpp index 3b0eefcf71..91e0349b8f 100644 --- a/Userland/Libraries/LibTLS/Record.cpp +++ b/Userland/Libraries/LibTLS/Record.cpp @@ -61,7 +61,7 @@ void TLSv12::update_packet(ByteBuffer& packet) if (packet[0] == (u8)MessageType::Handshake && packet.size() > header_size) { u8 handshake_type = packet[header_size]; if (handshake_type != HandshakeType::HelloRequest && handshake_type != HandshakeType::HelloVerifyRequest) { - update_hash(packet.bytes().slice(header_size, packet.size() - header_size)); + update_hash(packet.bytes(), header_size); } } if (m_context.cipher_spec_set && m_context.crypto.created) { @@ -190,9 +190,10 @@ void TLSv12::update_packet(ByteBuffer& packet) ++m_context.local_sequence_number; } -void TLSv12::update_hash(ReadonlyBytes message) +void TLSv12::update_hash(ReadonlyBytes message, size_t header_size) { - m_context.handshake_hash.update(message); + dbgln("Update hash with message of size {}", message.size()); + m_context.handshake_hash.update(message.slice(header_size)); } ByteBuffer TLSv12::hmac_message(const ReadonlyBytes& buf, const Optional<ReadonlyBytes> buf2, size_t mac_length, bool local) diff --git a/Userland/Libraries/LibTLS/Socket.cpp b/Userland/Libraries/LibTLS/Socket.cpp index 3e8b712eaa..289b940bff 100644 --- a/Userland/Libraries/LibTLS/Socket.cpp +++ b/Userland/Libraries/LibTLS/Socket.cpp @@ -83,7 +83,7 @@ bool TLSv12::write(ReadonlyBytes buffer) return false; } - PacketBuilder builder { MessageType::ApplicationData, m_context.version, buffer.size() }; + PacketBuilder builder { MessageType::ApplicationData, m_context.options.version, buffer.size() }; builder.append(buffer); auto packet = builder.build(); diff --git a/Userland/Libraries/LibTLS/TLSv12.cpp b/Userland/Libraries/LibTLS/TLSv12.cpp index 7f0abb2e39..3a70aa15bb 100644 --- a/Userland/Libraries/LibTLS/TLSv12.cpp +++ b/Userland/Libraries/LibTLS/TLSv12.cpp @@ -737,6 +737,9 @@ void TLSv12::set_root_certificates(Vector<Certificate> certificates) bool Context::verify_chain() const { + if (!options.validate_certificates) + return true; + const Vector<Certificate>* local_chain = nullptr; if (is_server) { dbgln("Unsupported: Server mode"); @@ -813,10 +816,10 @@ Optional<size_t> TLSv12::verify_chain_and_get_matching_certificate(const StringV return {}; } -TLSv12::TLSv12(Core::Object* parent, Version version) +TLSv12::TLSv12(Core::Object* parent, Options options) : Core::Socket(Core::Socket::Type::TCP, parent) { - m_context.version = version; + m_context.options = move(options); m_context.is_server = false; m_context.tls_buffer = ByteBuffer::create_uninitialized(0); #ifdef SOCK_NONBLOCK diff --git a/Userland/Libraries/LibTLS/TLSv12.h b/Userland/Libraries/LibTLS/TLSv12.h index 4e3562008f..235d2bb98b 100644 --- a/Userland/Libraries/LibTLS/TLSv12.h +++ b/Userland/Libraries/LibTLS/TLSv12.h @@ -195,6 +195,27 @@ enum ClientVerificationStaus { VerificationNeeded, }; +struct Options { +#define OPTION_WITH_DEFAULTS(typ, name, ...) \ + static typ default_##name() { return typ { __VA_ARGS__ }; } \ + typ name = default_##name(); + + OPTION_WITH_DEFAULTS(Vector<CipherSuite>, usable_cipher_suites, + CipherSuite::RSA_WITH_AES_128_CBC_SHA256, + CipherSuite::RSA_WITH_AES_256_CBC_SHA256, + CipherSuite::RSA_WITH_AES_128_CBC_SHA, + CipherSuite::RSA_WITH_AES_256_CBC_SHA, + CipherSuite::RSA_WITH_AES_128_GCM_SHA256) + + OPTION_WITH_DEFAULTS(Version, version, Version::V12) + + OPTION_WITH_DEFAULTS(bool, use_sni, true) + OPTION_WITH_DEFAULTS(bool, use_compression, false) + OPTION_WITH_DEFAULTS(bool, validate_certificates, true) + +#undef OPTION_WITH_DEFAULTS +}; + struct Context { String to_string() const; bool verify() const; @@ -202,12 +223,13 @@ struct Context { static void print_file(const StringView& fname); + Options options; + u8 remote_random[32]; u8 local_random[32]; u8 session_id[32]; u8 session_id_size { 0 }; CipherSuite cipher; - Version version; bool is_server { false }; Vector<Certificate> certificates; Certificate private_key; @@ -334,7 +356,7 @@ public: Function<void(TLSv12&)> on_tls_certificate_request; private: - explicit TLSv12(Core::Object* parent, Version version = Version::V12); + explicit TLSv12(Core::Object* parent, Options = {}); virtual bool common_connect(const struct sockaddr*, socklen_t) override; @@ -344,7 +366,7 @@ private: void ensure_hmac(size_t digest_size, bool local); void update_packet(ByteBuffer& packet); - void update_hash(ReadonlyBytes in); + void update_hash(ReadonlyBytes in, size_t header_size); void write_packet(ByteBuffer& packet); |