/* * Copyright (c) 2018-2020, Andreas Kling * * SPDX-License-Identifier: BSD-2-Clause */ #pragma once #include #include #include #include #include #include #include namespace Kernel { class TCPSocket final : public IPv4Socket { public: static void for_each(Function); static KResultOr> create(int protocol, NonnullOwnPtr receive_buffer); virtual ~TCPSocket() override; enum class Direction { Unspecified, Outgoing, Incoming, Passive, }; static StringView to_string(Direction direction) { switch (direction) { case Direction::Unspecified: return "Unspecified"sv; case Direction::Outgoing: return "Outgoing"sv; case Direction::Incoming: return "Incoming"sv; case Direction::Passive: return "Passive"sv; default: return "None"sv; } } enum class State { Closed, Listen, SynSent, SynReceived, Established, CloseWait, LastAck, FinWait1, FinWait2, Closing, TimeWait, }; static StringView to_string(State state) { switch (state) { case State::Closed: return "Closed"sv; case State::Listen: return "Listen"sv; case State::SynSent: return "SynSent"sv; case State::SynReceived: return "SynReceived"sv; case State::Established: return "Established"sv; case State::CloseWait: return "CloseWait"sv; case State::LastAck: return "LastAck"sv; case State::FinWait1: return "FinWait1"sv; case State::FinWait2: return "FinWait2"sv; case State::Closing: return "Closing"sv; case State::TimeWait: return "TimeWait"sv; default: return "None"; } } enum class Error { None, FINDuringConnect, RSTDuringConnect, UnexpectedFlagsDuringConnect, RetransmitTimeout, }; static StringView to_string(Error error) { switch (error) { case Error::None: return "None"sv; case Error::FINDuringConnect: return "FINDuringConnect"sv; case Error::RSTDuringConnect: return "RSTDuringConnect"sv; case Error::UnexpectedFlagsDuringConnect: return "UnexpectedFlagsDuringConnect"sv; default: return "Invalid"sv; } } State state() const { return m_state; } void set_state(State state); Direction direction() const { return m_direction; } bool has_error() const { return m_error != Error::None; } Error error() const { return m_error; } void set_error(Error error) { m_error = error; } void set_ack_number(u32 n) { m_ack_number = n; } void set_sequence_number(u32 n) { m_sequence_number = n; } u32 ack_number() const { return m_ack_number; } u32 sequence_number() const { return m_sequence_number; } u32 packets_in() const { return m_packets_in; } u32 bytes_in() const { return m_bytes_in; } u32 packets_out() const { return m_packets_out; } u32 bytes_out() const { return m_bytes_out; } // FIXME: Make this configurable? static constexpr u32 maximum_duplicate_acks = 5; void set_duplicate_acks(u32 acks) { m_duplicate_acks = acks; } u32 duplicate_acks() const { return m_duplicate_acks; } KResult send_ack(bool allow_duplicate = false); KResult send_tcp_packet(u16 flags, const UserOrKernelBuffer* = nullptr, size_t = 0, RoutingDecision* = nullptr); void receive_tcp_packet(const TCPPacket&, u16 size); bool should_delay_next_ack() const; static MutexProtected>& sockets_by_tuple(); static RefPtr from_tuple(const IPv4SocketTuple& tuple); static MutexProtected>>& closing_sockets(); RefPtr create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); void set_originator(TCPSocket& originator) { m_originator = originator; } bool has_originator() { return !!m_originator; } void release_to_originator(); void release_for_accept(RefPtr); void retransmit_packets(); virtual KResult close() override; virtual bool can_write(const FileDescription&, size_t) const override; static NetworkOrdered compute_tcp_checksum(IPv4Address const& source, IPv4Address const& destination, TCPPacket const&, u16 payload_size); protected: void set_direction(Direction direction) { m_direction = direction; } private: explicit TCPSocket(int protocol, NonnullOwnPtr receive_buffer, OwnPtr scratch_buffer); virtual StringView class_name() const override { return "TCPSocket"; } virtual void shut_down_for_writing() override; virtual KResultOr protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override; virtual KResultOr protocol_send(const UserOrKernelBuffer&, size_t) override; virtual KResult protocol_connect(FileDescription&, ShouldBlock) override; virtual KResultOr protocol_allocate_local_port() override; virtual bool protocol_is_disconnected() const override; virtual KResult protocol_bind() override; virtual KResult protocol_listen(bool did_allocate_port) override; void enqueue_for_retransmit(); void dequeue_for_retransmit(); WeakPtr m_originator; HashMap> m_pending_release_for_accept; Direction m_direction { Direction::Unspecified }; Error m_error { Error::None }; RefPtr m_adapter; u32 m_sequence_number { 0 }; u32 m_ack_number { 0 }; State m_state { State::Closed }; u32 m_packets_in { 0 }; u32 m_bytes_in { 0 }; u32 m_packets_out { 0 }; u32 m_bytes_out { 0 }; struct OutgoingPacket { u32 ack_number { 0 }; RefPtr buffer; size_t ipv4_payload_offset; WeakPtr adapter; int tx_counter { 0 }; }; struct UnackedPackets { SinglyLinkedList packets; size_t size { 0 }; }; MutexProtected m_unacked_packets; u32 m_duplicate_acks { 0 }; u32 m_last_ack_number_sent { 0 }; Time m_last_ack_sent_time; // FIXME: Make this configurable (sysctl) static constexpr u32 maximum_retransmits = 5; Time m_last_retransmit_time; u32 m_retransmit_attempts { 0 }; // FIXME: Parse window size TCP option from the peer u32 m_send_window_size { 64 * KiB }; IntrusiveListNode m_retransmit_list_node; public: using RetransmitList = IntrusiveList, &TCPSocket::m_retransmit_list_node>; static MutexProtected& sockets_for_retransmit(); }; }