diff --git a/AK/IPv4Address.h b/AK/IPv4Address.h index 7c29704be34..e4356d492ac 100644 --- a/AK/IPv4Address.h +++ b/AK/IPv4Address.h @@ -67,6 +67,7 @@ public: } in_addr_t to_in_addr_t() const { return m_data_as_u32; } + u32 to_u32() const { return m_data_as_u32; } bool operator==(const IPv4Address& other) const { return m_data_as_u32 == other.m_data_as_u32; } bool operator!=(const IPv4Address& other) const { return m_data_as_u32 != other.m_data_as_u32; } diff --git a/Kernel/FileSystem/ProcFS.cpp b/Kernel/FileSystem/ProcFS.cpp index 3398921ccfc..fe10c7659b3 100644 --- a/Kernel/FileSystem/ProcFS.cpp +++ b/Kernel/FileSystem/ProcFS.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -46,6 +47,7 @@ enum ProcFileType { FI_Root_uptime, FI_Root_cmdline, FI_Root_netadapters, + FI_Root_net_tcp, FI_Root_self, // symlink FI_Root_sys, // directory __FI_Root_End, @@ -278,6 +280,23 @@ Optional procfs$netadapters(InodeIdentifier) return builder.to_byte_buffer(); } +Optional procfs$net_tcp(InodeIdentifier) +{ + JsonArray json; + TCPSocket::for_each([&json](auto& socket) { + JsonObject obj; + obj.set("local_address", socket->local_address().to_string()); + obj.set("local_port", socket->local_port()); + obj.set("peer_address", socket->peer_address().to_string()); + obj.set("peer_port", socket->peer_port()); + obj.set("state", TCPSocket::to_string(socket->state())); + obj.set("ack_number", socket->ack_number()); + obj.set("sequence_number", socket->sequence_number()); + json.append(obj); + }); + return json.serialized().to_byte_buffer(); +} + Optional procfs$pid_vmo(InodeIdentifier identifier) { auto handle = ProcessInspectionHandle::from_pid(to_pid(identifier)); @@ -1077,6 +1096,7 @@ ProcFS::ProcFS() m_entries[FI_Root_uptime] = { "uptime", FI_Root_uptime, procfs$uptime }; m_entries[FI_Root_cmdline] = { "cmdline", FI_Root_cmdline, procfs$cmdline }; m_entries[FI_Root_netadapters] = { "netadapters", FI_Root_netadapters, procfs$netadapters }; + m_entries[FI_Root_net_tcp] = { "net_tcp", FI_Root_net_tcp, procfs$net_tcp }; m_entries[FI_Root_sys] = { "sys", FI_Root_sys }; m_entries[FI_PID_vm] = { "vm", FI_PID_vm, procfs$pid_vm }; diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index a8a319cd113..a97005bb11c 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -89,6 +89,22 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) return protocol_bind(); } +KResult IPv4Socket::listen(int backlog) +{ + int rc = allocate_local_port_if_needed(); + if (rc < 0) + return KResult(-EADDRINUSE); + + if (m_local_address.to_u32() == 0) + return KResult(-EADDRINUSE); + + set_backlog(backlog); + + kprintf("IPv4Socket{%p} listening with backlog=%d\n", this, backlog); + + return protocol_listen(); +} + KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block) { if (address_size != sizeof(sockaddr_in)) @@ -157,6 +173,9 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt if (!adapter) return -EHOSTUNREACH; + if (m_local_address.to_u32() == 0) + m_local_address = adapter->ipv4_address(); + int rc = allocate_local_port_if_needed(); if (rc < 0) return rc; diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index 2a12c6ffb48..81667b99c9f 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -2,10 +2,11 @@ #include #include -#include #include +#include #include #include +#include #include class IPv4SocketHandle; @@ -23,6 +24,7 @@ public: virtual KResult bind(const sockaddr*, socklen_t) override; virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; + virtual KResult listen(int) override; virtual bool get_local_address(sockaddr*, socklen_t*) override; virtual bool get_peer_address(sockaddr*, socklen_t*) override; virtual void attach(FileDescription&) override; @@ -34,7 +36,7 @@ public: void did_receive(const IPv4Address& peer_address, u16 peer_port, KBuffer&&); - const IPv4Address& local_address() const; + const IPv4Address& local_address() const { return m_local_address; } u16 local_port() const { return m_local_port; } void set_local_port(u16 port) { m_local_port = port; } @@ -42,6 +44,8 @@ public: u16 peer_port() const { return m_peer_port; } void set_peer_port(u16 port) { m_peer_port = port; } + IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); } + protected: IPv4Socket(int type, int protocol); virtual const char* class_name() const override { return "IPv4Socket"; } @@ -49,12 +53,16 @@ protected: int allocate_local_port_if_needed(); virtual KResult protocol_bind() { return KSuccess; } + virtual KResult protocol_listen() { return KSuccess; } virtual int protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; } virtual int protocol_send(const void*, int) { return -ENOTIMPL; } virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; } virtual int protocol_allocate_local_port() { return 0; } virtual bool protocol_is_disconnected() const { return false; } + void set_local_address(IPv4Address address) { m_local_address = address; } + void set_peer_address(IPv4Address address) { m_peer_address = address; } + private: virtual bool is_ipv4() const override { return true; } diff --git a/Kernel/Net/IPv4SocketTuple.h b/Kernel/Net/IPv4SocketTuple.h new file mode 100644 index 00000000000..8e512ba3df7 --- /dev/null +++ b/Kernel/Net/IPv4SocketTuple.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +class IPv4SocketTuple { +public: + IPv4SocketTuple(IPv4Address local_address, u16 local_port, IPv4Address peer_address, u16 peer_port) + : m_local_address(local_address) + , m_local_port(local_port) + , m_peer_address(peer_address) + , m_peer_port(peer_port) {}; + + IPv4Address local_address() const { return m_local_address; }; + u16 local_port() const { return m_local_port; }; + IPv4Address peer_address() const { return m_peer_address; }; + u16 peer_port() const { return m_peer_port; }; + + bool operator==(const IPv4SocketTuple other) const + { + return other.local_address() == m_local_address && other.local_port() == m_local_port && other.peer_address() == m_peer_address && other.peer_port() == m_peer_port; + }; + + String to_string() const + { + return String::format( + "%s:%d -> %s:%d", + m_local_address.to_string().characters(), + m_local_port, + m_peer_address.to_string().characters(), + m_peer_port); + } + +private: + IPv4Address m_local_address; + u16 m_local_port { 0 }; + IPv4Address m_peer_address; + u16 m_peer_port { 0 }; +}; + +namespace AK { + +template<> +struct Traits : public GenericTraits { + static unsigned hash(const IPv4SocketTuple& tuple) + { + auto h1 = pair_int_hash(tuple.local_address().to_u32(), tuple.local_port()); + auto h2 = pair_int_hash(tuple.peer_address().to_u32(), tuple.peer_port()); + return pair_int_hash(h1, h2); + } + + static void dump(const IPv4SocketTuple& tuple) + { + kprintf("%s", tuple.to_string().characters()); + } +}; + +} diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index 4eacfe00597..dde13cd2a1b 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -114,6 +114,16 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre return KSuccess; } +KResult LocalSocket::listen(int backlog) +{ + LOCKER(lock()); + if (type() != SOCK_STREAM) + return KResult(-EOPNOTSUPP); + set_backlog(backlog); + kprintf("LocalSocket{%p} listening with backlog=%d\n", this, backlog); + return KSuccess; +} + void LocalSocket::attach(FileDescription& description) { switch (description.socket_role()) { diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index 9a651931744..ae82b4f0056 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -13,6 +13,7 @@ public: // ^Socket virtual KResult bind(const sockaddr*, socklen_t) override; virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; + virtual KResult listen(int) override; virtual bool get_local_address(sockaddr*, socklen_t*) override; virtual bool get_peer_address(sockaddr*, socklen_t*) override; virtual void attach(FileDescription&) override; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index 00e0f86918c..8b8325c2219 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -14,6 +14,7 @@ #include //#define ETHERNET_DEBUG +//#define ETHERNET_VERY_DEBUG //#define IPV4_DEBUG //#define ICMP_DEBUG //#define UDP_DEBUG @@ -84,6 +85,28 @@ void NetworkTask_main() packet.size()); #endif +#ifdef ETHERNET_VERY_DEBUG + u8* data = packet.data(); + + for (size_t i = 0; i < packet.size(); i++) { + kprintf("%b", data[i]); + + switch (i % 16) { + case 7: + kprintf(" "); + break; + case 15: + kprintf("\n"); + break; + default: + kprintf(" "); + break; + } + } + + kprintf("\n"); +#endif + switch (eth.ether_type()) { case EtherType::ARP: handle_arp(eth, packet.size()); @@ -279,7 +302,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size(); #ifdef TCP_DEBUG - kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s %s), window_size=%u, payload_size=%u\n", + kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s%s%s%s), window_size=%u, payload_size=%u\n", ipv4_packet.source().to_string().characters(), tcp_packet.source_port(), ipv4_packet.destination().to_string().characters(), @@ -287,15 +310,19 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) tcp_packet.sequence_number(), tcp_packet.ack_number(), tcp_packet.flags(), - tcp_packet.has_syn() ? "SYN" : "", - tcp_packet.has_ack() ? "ACK" : "", + tcp_packet.has_syn() ? "SYN " : "", + tcp_packet.has_ack() ? "ACK " : "", + tcp_packet.has_fin() ? "FIN " : "", + tcp_packet.has_rst() ? "RST " : "", tcp_packet.window_size(), payload_size); #endif - auto socket = TCPSocket::from_port(tcp_packet.destination_port()); + IPv4SocketTuple tuple(ipv4_packet.destination(), tcp_packet.destination_port(), ipv4_packet.source(), tcp_packet.source_port()); + + auto socket = TCPSocket::from_tuple(tuple); if (!socket) { - kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port()); + kprintf("handle_tcp: No TCP socket for tuple %s\n", tuple.to_string().characters()); return; } @@ -307,39 +334,168 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) return; } - if (tcp_packet.has_syn() && tcp_packet.has_ack()) { - socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); - socket->send_tcp_packet(TCPFlags::ACK); - socket->set_connected(true); - kprintf("handle_tcp: Connection established!\n"); - socket->set_state(TCPSocket::State::Connected); - return; - } +#ifdef TCP_DEBUG + kprintf("handle_tcp: state=%s\n", TCPSocket::to_string(socket->state())); +#endif - if (tcp_packet.has_fin()) { - kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size); + switch (socket->state()) { + case TCPSocket::State::Closed: + kprintf("handle_tcp: unexpected flags in Closed state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: Closed -> Closed\n"); + return; + case TCPSocket::State::TimeWait: + kprintf("handle_tcp: unexpected flags in TimeWait state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: TimeWait -> Closed\n"); + return; + case TCPSocket::State::Listen: + switch (tcp_packet.flags()) { + case TCPFlags::SYN: + kprintf("handle_tcp: incoming connections not supported\n"); + // socket->send_tcp_packet(TCPFlags::RST); + return; + default: + kprintf("handle_tcp: unexpected flags in Listen state\n"); + // socket->send_tcp_packet(TCPFlags::RST); + return; + } + case TCPSocket::State::SynSent: + switch (tcp_packet.flags()) { + case TCPFlags::SYN: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->send_tcp_packet(TCPFlags::ACK); + socket->set_state(TCPSocket::State::SynReceived); + kprintf("handle_tcp: SynSent -> SynReceived\n"); + return; + case TCPFlags::SYN | TCPFlags::ACK: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->send_tcp_packet(TCPFlags::ACK); + socket->set_state(TCPSocket::State::Established); + socket->set_connected(true); + kprintf("handle_tcp: SynSent -> Established\n"); + return; + default: + kprintf("handle_tcp: unexpected flags in SynSent state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: SynSent -> Closed\n"); + return; + } + case TCPSocket::State::SynReceived: + switch (tcp_packet.flags()) { + case TCPFlags::ACK: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->set_state(TCPSocket::State::Established); + socket->set_connected(true); + kprintf("handle_tcp: SynReceived -> Established\n"); + return; + default: + kprintf("handle_tcp: unexpected flags in SynReceived state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: SynReceived -> Closed\n"); + return; + } + case TCPSocket::State::CloseWait: + switch (tcp_packet.flags()) { + default: + kprintf("handle_tcp: unexpected flags in CloseWait state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: CloseWait -> Closed\n"); + return; + } + case TCPSocket::State::LastAck: + switch (tcp_packet.flags()) { + case TCPFlags::ACK: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: LastAck -> Closed\n"); + return; + default: + kprintf("handle_tcp: unexpected flags in LastAck state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: LastAck -> Closed\n"); + return; + } + case TCPSocket::State::FinWait1: + switch (tcp_packet.flags()) { + case TCPFlags::ACK: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->set_state(TCPSocket::State::FinWait2); + kprintf("handle_tcp: FinWait1 -> FinWait2\n"); + return; + case TCPFlags::FIN: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->set_state(TCPSocket::State::Closing); + kprintf("handle_tcp: FinWait1 -> Closing\n"); + return; + default: + kprintf("handle_tcp: unexpected flags in FinWait1 state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: FinWait1 -> Closed\n"); + return; + } + case TCPSocket::State::FinWait2: + switch (tcp_packet.flags()) { + case TCPFlags::FIN: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->set_state(TCPSocket::State::TimeWait); + kprintf("handle_tcp: FinWait2 -> TimeWait\n"); + return; + default: + kprintf("handle_tcp: unexpected flags in FinWait2 state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: FinWait2 -> Closed\n"); + return; + } + case TCPSocket::State::Closing: + switch (tcp_packet.flags()) { + case TCPFlags::ACK: + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->set_state(TCPSocket::State::TimeWait); + kprintf("handle_tcp: Closing -> TimeWait\n"); + return; + default: + kprintf("handle_tcp: unexpected flags in Closing state\n"); + socket->send_tcp_packet(TCPFlags::RST); + socket->set_state(TCPSocket::State::Closed); + kprintf("handle_tcp: Closing -> Closed\n"); + return; + } + case TCPSocket::State::Established: + if (tcp_packet.has_fin()) { + if (payload_size != 0) + socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); + + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->send_tcp_packet(TCPFlags::ACK); + socket->set_state(TCPSocket::State::CloseWait); + socket->set_connected(false); + kprintf("handle_tcp: Established -> CloseWait\n"); + return; + } + + socket->set_ack_number(tcp_packet.sequence_number() + payload_size); + +#ifdef TCP_DEBUG + kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", + tcp_packet.ack_number(), + tcp_packet.sequence_number(), + payload_size, + socket->ack_number(), + socket->sequence_number()); +#endif + + socket->send_tcp_packet(TCPFlags::ACK); if (payload_size != 0) socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); - - socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); - socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK); - socket->set_state(TCPSocket::State::Disconnecting); - socket->set_connected(false); - return; } - - socket->set_ack_number(tcp_packet.sequence_number() + payload_size); -#ifdef TCP_DEBUG - kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", - tcp_packet.ack_number(), - tcp_packet.sequence_number(), - payload_size, - socket->ack_number(), - socket->sequence_number()); -#endif - socket->send_tcp_packet(TCPFlags::ACK); - - if (payload_size != 0) - socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); } diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index bbfef8e90b7..9094f7dd63e 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include #include #include @@ -35,10 +35,10 @@ public: bool can_accept() const { return !m_pending.is_empty(); } RefPtr accept(); bool is_connected() const { return m_connected; } - KResult listen(int backlog); virtual KResult bind(const sockaddr*, socklen_t) = 0; virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0; + virtual KResult listen(int) = 0; virtual bool get_local_address(sockaddr*, socklen_t*) = 0; virtual bool get_peer_address(sockaddr*, socklen_t*) = 0; virtual bool is_local() const { return false; } @@ -73,6 +73,9 @@ protected: void load_receive_deadline(); void load_send_deadline(); + int backlog() const { return m_backlog; } + void set_backlog(int backlog) { m_backlog = backlog; } + virtual const char* class_name() const override { return "Socket"; } private: diff --git a/Kernel/Net/TCP.h b/Kernel/Net/TCP.h index b994b9d47fd..20e0fc97237 100644 --- a/Kernel/Net/TCP.h +++ b/Kernel/Net/TCP.h @@ -39,6 +39,7 @@ public: bool has_syn() const { return flags() & TCPFlags::SYN; } bool has_ack() const { return flags() & TCPFlags::ACK; } bool has_fin() const { return flags() & TCPFlags::FIN; } + bool has_rst() const { return flags() & TCPFlags::RST; } u8 data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; } void set_data_offset(u16 data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; } diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index a50215c9549..557cbcb4d01 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -1,28 +1,35 @@ #include +#include #include #include #include #include -#include #include //#define TCP_SOCKET_DEBUG -Lockable>& TCPSocket::sockets_by_port() +void TCPSocket::for_each(Function callback) { - static Lockable>* s_map; + LOCKER(sockets_by_tuple().lock()); + for (auto& it : sockets_by_tuple().resource()) + callback(it.value); +} + +Lockable>& TCPSocket::sockets_by_tuple() +{ + static Lockable>* s_map; if (!s_map) - s_map = new Lockable>; + s_map = new Lockable>; return *s_map; } -TCPSocketHandle TCPSocket::from_port(u16 port) +TCPSocketHandle TCPSocket::from_tuple(const IPv4SocketTuple& tuple) { RefPtr socket; { - LOCKER(sockets_by_port().lock()); - auto it = sockets_by_port().resource().find(port); - if (it == sockets_by_port().resource().end()) + LOCKER(sockets_by_tuple().lock()); + auto it = sockets_by_tuple().resource().find(tuple); + if (it == sockets_by_tuple().resource().end()) return {}; socket = (*it).value; ASSERT(socket); @@ -30,6 +37,11 @@ TCPSocketHandle TCPSocket::from_port(u16 port) return { move(socket) }; } +TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port) +{ + return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port)); +} + TCPSocket::TCPSocket(int protocol) : IPv4Socket(SOCK_STREAM, protocol) { @@ -37,8 +49,8 @@ TCPSocket::TCPSocket(int protocol) TCPSocket::~TCPSocket() { - LOCKER(sockets_by_port().lock()); - sockets_by_port().resource().remove(local_port()); + LOCKER(sockets_by_tuple().lock()); + sockets_by_tuple().resource().remove(tuple()); } NonnullRefPtr TCPSocket::create(int protocol) @@ -62,18 +74,13 @@ int TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size int TCPSocket::protocol_send(const void* data, int data_length) { - auto* adapter = adapter_for_route_to(peer_address()); - if (!adapter) - return -EHOSTUNREACH; send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length); return data_length; } void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size) { - // FIXME: Maybe the socket should be bound to an adapter instead of looking it up every time? - auto* adapter = adapter_for_route_to(peer_address()); - ASSERT(adapter); + ASSERT(m_adapter); auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); @@ -95,19 +102,21 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size } memcpy(tcp_packet.payload(), payload, payload_size); - tcp_packet.set_checksum(compute_tcp_checksum(adapter->ipv4_address(), peer_address(), tcp_packet, payload_size)); + tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size)); #ifdef TCP_SOCKET_DEBUG - kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n", - adapter->ipv4_address().to_string().characters(), + kprintf("sending tcp packet from %s:%u to %s:%u with (%s%s%s%s) seq_no=%u, ack_no=%u\n", + local_address().to_string().characters(), local_port(), peer_address().to_string().characters(), peer_port(), tcp_packet.has_syn() ? "SYN" : "", tcp_packet.has_ack() ? "ACK" : "", + tcp_packet.has_fin() ? "FIN" : "", + tcp_packet.has_rst() ? "RST" : "", tcp_packet.sequence_number(), tcp_packet.ack_number()); #endif - adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size()); + m_adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size()); } NetworkOrdered TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size) @@ -152,11 +161,36 @@ NetworkOrdered TCPSocket::compute_tcp_checksum(const IPv4Address& source, c return ~(checksum & 0xffff); } +KResult TCPSocket::protocol_bind() +{ + if (!m_adapter) { + m_adapter = NetworkAdapter::from_ipv4_address(local_address()); + if (!m_adapter) + return KResult(-EADDRNOTAVAIL); + } + + return KSuccess; +} + +KResult TCPSocket::protocol_listen() +{ + LOCKER(sockets_by_tuple().lock()); + if (sockets_by_tuple().resource().contains(tuple())) + return KResult(-EADDRINUSE); + sockets_by_tuple().resource().set(tuple(), this); + set_state(State::Listen); + return KSuccess; +} + KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block) { - auto* adapter = adapter_for_route_to(peer_address()); - if (!adapter) - return KResult(-EHOSTUNREACH); + if (!m_adapter) { + m_adapter = adapter_for_route_to(peer_address()); + if (!m_adapter) + return KResult(-EHOSTUNREACH); + + set_local_address(m_adapter->ipv4_address()); + } allocate_local_port_if_needed(); @@ -164,7 +198,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh m_ack_number = 0; send_tcp_packet(TCPFlags::SYN); - m_state = State::Connecting; + m_state = State::SynSent; if (should_block == ShouldBlock::Yes) { if (current->block(description) == Thread::BlockResult::InterruptedBySignal) @@ -183,12 +217,14 @@ int TCPSocket::protocol_allocate_local_port() static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; u16 first_scan_port = first_ephemeral_port + RandomDevice::random_value() % ephemeral_port_range_size; - LOCKER(sockets_by_port().lock()); + LOCKER(sockets_by_tuple().lock()); for (u16 port = first_scan_port;;) { - auto it = sockets_by_port().resource().find(port); - if (it == sockets_by_port().resource().end()) { + IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port()); + + auto it = sockets_by_tuple().resource().find(proposed_tuple); + if (it == sockets_by_tuple().resource().end()) { set_local_port(port); - sockets_by_port().resource().set(port, this); + sockets_by_tuple().resource().set(proposed_tuple, this); return port; } ++port; @@ -202,14 +238,16 @@ int TCPSocket::protocol_allocate_local_port() bool TCPSocket::protocol_is_disconnected() const { - return m_state == State::Disconnecting || m_state == State::Disconnected; -} - -KResult TCPSocket::protocol_bind() -{ - LOCKER(sockets_by_port().lock()); - if (sockets_by_port().resource().contains(local_port())) - return KResult(-EADDRINUSE); - sockets_by_port().resource().set(local_port(), this); - return KSuccess; + switch (m_state) { + case State::Closed: + case State::CloseWait: + case State::LastAck: + case State::FinWait1: + case State::FinWait2: + case State::Closing: + case State::TimeWait: + return true; + default: + return false; + } } diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 4a5775bbc6c..f310fbb01eb 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -1,19 +1,58 @@ #pragma once +#include #include class TCPSocket final : public IPv4Socket { public: + static void for_each(Function); static NonnullRefPtr create(int protocol); virtual ~TCPSocket() override; enum class State { - Disconnected, - Connecting, - Connected, - Disconnecting, + Closed, + Listen, + SynSent, + SynReceived, + Established, + CloseWait, + LastAck, + FinWait1, + FinWait2, + Closing, + TimeWait, }; + static const char* to_string(State state) + { + switch (state) { + case State::Closed: + return "Closed"; + case State::Listen: + return "Listen"; + case State::SynSent: + return "SynSent"; + case State::SynReceived: + return "SynReceived"; + case State::Established: + return "Established"; + case State::CloseWait: + return "CloseWait"; + case State::LastAck: + return "LastAck"; + case State::FinWait1: + return "FinWait1"; + case State::FinWait2: + return "FinWait2"; + case State::Closing: + return "Closing"; + case State::TimeWait: + return "TimeWait"; + default: + return "None"; + } + } + State state() const { return m_state; } void set_state(State state) { m_state = state; } @@ -24,8 +63,9 @@ public: void send_tcp_packet(u16 flags, const void* = nullptr, int = 0); - static Lockable>& sockets_by_port(); - static TCPSocketHandle from_port(u16); + static Lockable>& sockets_by_tuple(); + static TCPSocketHandle from_tuple(const IPv4SocketTuple& tuple); + static TCPSocketHandle from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port); private: explicit TCPSocket(int protocol); @@ -39,10 +79,12 @@ private: virtual int protocol_allocate_local_port() override; virtual bool protocol_is_disconnected() const override; virtual KResult protocol_bind() override; + virtual KResult protocol_listen() override; + NetworkAdapter* m_adapter { nullptr }; u32 m_sequence_number { 0 }; u32 m_ack_number { 0 }; - State m_state { State::Disconnected }; + State m_state { State::Closed }; }; class TCPSocketHandle : public SocketHandle {