LibIPC+LibC: Add and use a helper to encode/decoder container sizes

While refactoring the IPC encoders and decoders for fallibility, the
inconsistency in which we transfer container sizes was a frequent thing
to trip over. We currently transfer sizes as any of i32, u32, and u64.
This adds a helper to transfer sizes in one consistent way.

Two special cases here are DeprecatedString and Vector, whose encoding
is depended upon by netdb, so that is also updated here.
This commit is contained in:
Timothy Flynn 2023-01-04 07:08:29 -05:00 committed by Linus Groh
parent 40165f5846
commit 7c6b5ed161
Notes: sideshowbarker 2024-07-17 02:11:37 +09:00
5 changed files with 48 additions and 35 deletions

View file

@ -122,20 +122,19 @@ hostent* gethostbyname(char const* name)
close(fd);
});
size_t unsigned_name_length = strlen(name);
VERIFY(unsigned_name_length <= NumericLimits<i32>::max());
i32 name_length = static_cast<i32>(unsigned_name_length);
auto name_length = strlen(name);
VERIFY(name_length <= NumericLimits<i32>::max());
struct [[gnu::packed]] {
u32 message_size;
u32 endpoint_magic;
i32 message_id;
i32 name_length;
u32 name_length;
} request_header = {
(u32)(sizeof(request_header) - sizeof(request_header.message_size) + name_length),
lookup_server_endpoint_magic,
1,
name_length,
static_cast<u32>(name_length),
};
if (auto nsent = write(fd, &request_header, sizeof(request_header)); nsent < 0) {
h_errno = TRY_AGAIN;
@ -148,7 +147,7 @@ hostent* gethostbyname(char const* name)
if (auto nsent = write(fd, name, name_length); nsent < 0) {
h_errno = TRY_AGAIN;
return nullptr;
} else if (nsent != name_length) {
} else if (static_cast<size_t>(nsent) != name_length) {
h_errno = NO_RECOVERY;
return nullptr;
}
@ -158,7 +157,7 @@ hostent* gethostbyname(char const* name)
u32 endpoint_magic;
i32 message_id;
i32 code;
u64 addresses_count;
u32 addresses_count;
} response_header;
if (auto nreceived = read(fd, &response_header, sizeof(response_header)); nreceived < 0) {
@ -270,7 +269,7 @@ hostent* gethostbyaddr(void const* addr, socklen_t addr_size, int type)
u32 endpoint_magic;
i32 message_id;
i32 code;
i32 name_length;
u32 name_length;
} response_header;
if (auto nreceived = read(fd, &response_header, sizeof(response_header)); nreceived < 0) {
@ -293,7 +292,7 @@ hostent* gethostbyaddr(void const* addr, socklen_t addr_size, int type)
if (auto nreceived = read(fd, buffer, response_header.name_length); nreceived < 0) {
h_errno = TRY_AGAIN;
return nullptr;
} else if (nreceived != response_header.name_length) {
} else if (static_cast<u32>(nreceived) != response_header.name_length) {
h_errno = NO_RECOVERY;
return nullptr;
}

View file

@ -5,6 +5,7 @@
*/
#include <AK/JsonValue.h>
#include <AK/NumericLimits.h>
#include <AK/URL.h>
#include <LibCore/AnonymousBuffer.h>
#include <LibCore/DateTime.h>
@ -16,19 +17,24 @@
namespace IPC {
ErrorOr<size_t> Decoder::decode_size()
{
return static_cast<size_t>(TRY(decode<u32>()));
}
template<>
ErrorOr<DeprecatedString> decode(Decoder& decoder)
{
auto length = TRY(decoder.decode<i32>());
if (length < 0)
auto length = TRY(decoder.decode_size());
if (length == NumericLimits<u32>::max())
return DeprecatedString {};
if (length == 0)
return DeprecatedString::empty();
char* text_buffer = nullptr;
auto text_impl = StringImpl::create_uninitialized(static_cast<size_t>(length), text_buffer);
auto text_impl = StringImpl::create_uninitialized(length, text_buffer);
Bytes bytes { text_buffer, static_cast<size_t>(length) };
Bytes bytes { text_buffer, length };
TRY(decoder.decode_into(bytes));
return DeprecatedString { *text_impl };
@ -37,8 +43,8 @@ ErrorOr<DeprecatedString> decode(Decoder& decoder)
template<>
ErrorOr<ByteBuffer> decode(Decoder& decoder)
{
auto length = TRY(decoder.decode<i32>());
if (length <= 0)
auto length = TRY(decoder.decode_size());
if (length == 0)
return ByteBuffer {};
auto buffer = TRY(ByteBuffer::create_uninitialized(length));
@ -65,10 +71,7 @@ ErrorOr<URL> decode(Decoder& decoder)
template<>
ErrorOr<Dictionary> decode(Decoder& decoder)
{
auto size = TRY(decoder.decode<u64>());
if (size >= NumericLimits<i32>::max())
VERIFY_NOT_REACHED();
auto size = TRY(decoder.decode_size());
Dictionary dictionary {};
for (size_t i = 0; i < size; ++i) {
@ -99,7 +102,7 @@ ErrorOr<Core::AnonymousBuffer> decode(Decoder& decoder)
if (auto valid = TRY(decoder.decode<bool>()); !valid)
return Core::AnonymousBuffer {};
auto size = TRY(decoder.decode<u32>());
auto size = TRY(decoder.decode_size());
auto anon_file = TRY(decoder.decode<IPC::File>());
return Core::AnonymousBuffer::create_from_anon_fd(anon_file.take_fd(), size);

View file

@ -50,6 +50,8 @@ public:
return {};
}
ErrorOr<size_t> decode_size();
Core::Stream::LocalSocket& socket() { return m_socket; }
private:
@ -96,11 +98,9 @@ ErrorOr<Empty> decode(Decoder&);
template<Concepts::Vector T>
ErrorOr<T> decode(Decoder& decoder)
{
auto size = TRY(decoder.decode<u64>());
if (size > NumericLimits<i32>::max())
return Error::from_string_literal("IPC: Invalid Vector size");
T vector;
auto size = TRY(decoder.decode_size());
TRY(vector.try_ensure_capacity(size));
for (size_t i = 0; i < size; ++i) {
@ -114,12 +114,11 @@ ErrorOr<T> decode(Decoder& decoder)
template<Concepts::HashMap T>
ErrorOr<T> decode(Decoder& decoder)
{
auto size = TRY(decoder.decode<u32>());
if (size > NumericLimits<i32>::max())
return Error::from_string_literal("IPC: Invalid HashMap size");
T hashmap;
auto size = TRY(decoder.decode_size());
TRY(hashmap.try_ensure_capacity(size));
for (size_t i = 0; i < size; ++i) {
auto key = TRY(decoder.decode<typename T::KeyType>());
auto value = TRY(decoder.decode<typename T::ValueType>());

View file

@ -10,6 +10,7 @@
#include <AK/DeprecatedString.h>
#include <AK/JsonObject.h>
#include <AK/JsonValue.h>
#include <AK/NumericLimits.h>
#include <AK/URL.h>
#include <LibCore/AnonymousBuffer.h>
#include <LibCore/DateTime.h>
@ -21,6 +22,13 @@
namespace IPC {
ErrorOr<void> Encoder::encode_size(size_t size)
{
if (static_cast<u64>(size) > static_cast<u64>(NumericLimits<u32>::max()))
return Error::from_string_literal("Container exceeds the maximum allowed size");
return encode(static_cast<u32>(size));
}
template<>
ErrorOr<void> encode(Encoder& encoder, float const& value)
{
@ -43,10 +51,11 @@ ErrorOr<void> encode(Encoder& encoder, StringView const& value)
template<>
ErrorOr<void> encode(Encoder& encoder, DeprecatedString const& value)
{
// NOTE: Do not change this encoding without also updating LibC/netdb.cpp.
if (value.is_null())
return encoder.encode(-1);
return encoder.encode(NumericLimits<u32>::max());
TRY(encoder.encode(static_cast<i32>(value.length())));
TRY(encoder.encode_size(value.length()));
TRY(encoder.encode(value.view()));
return {};
}
@ -54,7 +63,7 @@ ErrorOr<void> encode(Encoder& encoder, DeprecatedString const& value)
template<>
ErrorOr<void> encode(Encoder& encoder, ByteBuffer const& value)
{
TRY(encoder.encode(static_cast<i32>(value.size())));
TRY(encoder.encode_size(value.size()));
TRY(encoder.append(value.data(), value.size()));
return {};
}
@ -74,7 +83,7 @@ ErrorOr<void> encode(Encoder& encoder, URL const& value)
template<>
ErrorOr<void> encode(Encoder& encoder, Dictionary const& dictionary)
{
TRY(encoder.encode(static_cast<u64>(dictionary.size())));
TRY(encoder.encode_size(dictionary.size()));
TRY(dictionary.try_for_each_entry([&](auto const& key, auto const& value) -> ErrorOr<void> {
TRY(encoder.encode(key));
@ -109,7 +118,7 @@ ErrorOr<void> encode(Encoder& encoder, Core::AnonymousBuffer const& buffer)
TRY(encoder.encode(buffer.is_valid()));
if (buffer.is_valid()) {
TRY(encoder.encode(static_cast<u32>(buffer.size())));
TRY(encoder.encode_size(buffer.size()));
TRY(encoder.encode(IPC::File { buffer.fd() }));
}

View file

@ -57,6 +57,8 @@ public:
return m_buffer.fds.try_append(move(auto_fd));
}
ErrorOr<void> encode_size(size_t size);
private:
void encode_u32(u32);
void encode_u64(u64);
@ -134,7 +136,8 @@ ErrorOr<void> encode(Encoder&, Empty const&);
template<Concepts::Vector T>
ErrorOr<void> encode(Encoder& encoder, T const& vector)
{
TRY(encoder.encode(static_cast<u64>(vector.size())));
// NOTE: Do not change this encoding without also updating LibC/netdb.cpp.
TRY(encoder.encode_size(vector.size()));
for (auto const& value : vector)
TRY(encoder.encode(value));
@ -145,7 +148,7 @@ ErrorOr<void> encode(Encoder& encoder, T const& vector)
template<Concepts::HashMap T>
ErrorOr<void> encode(Encoder& encoder, T const& hashmap)
{
TRY(encoder.encode(static_cast<u32>(hashmap.size())));
TRY(encoder.encode_size(hashmap.size()));
for (auto it : hashmap) {
TRY(encoder.encode(it.key));