From a60ecea16abe62aae988ba877fdb98466d2919d3 Mon Sep 17 00:00:00 2001 From: Ali Mohammad Pur Date: Thu, 22 Aug 2024 01:13:37 +0200 Subject: [PATCH] LibWasm+LibWeb+test-wasm: Refcount Wasm::Module for function references Prior to funcref, a partial chunk of an invalid module was never needed, but funcref allows a partially instantiated module to modify imported tables with references to its own functions, which means we need to keep the second module alive while that function reference is present within the imported table. This was tested by the spectests, but very rarely caught as our GC does not behave particularly predictably, making it so the offending module remains in memory just long enough to let the tests pass. This commit makes it so all function references keep their respective modules alive. --- Tests/LibWasm/test-wasm.cpp | 10 ++++---- .../AbstractMachine/AbstractMachine.cpp | 18 +++++++++++---- .../LibWasm/AbstractMachine/AbstractMachine.h | 23 +++++++++++-------- .../AbstractMachine/BytecodeInterpreter.cpp | 2 +- Userland/Libraries/LibWasm/Parser/Parser.cpp | 8 ++++--- Userland/Libraries/LibWasm/Types.h | 6 +++-- .../LibWeb/WebAssembly/WebAssembly.cpp | 2 +- .../LibWeb/WebAssembly/WebAssembly.h | 4 ++-- Userland/Utilities/wasm.cpp | 18 +++++++-------- 9 files changed, 54 insertions(+), 37 deletions(-) diff --git a/Tests/LibWasm/test-wasm.cpp b/Tests/LibWasm/test-wasm.cpp index 7560cac1155..6a601fdedfd 100644 --- a/Tests/LibWasm/test-wasm.cpp +++ b/Tests/LibWasm/test-wasm.cpp @@ -53,7 +53,7 @@ public: Wasm::Module& module() { return *m_module; } Wasm::ModuleInstance& module_instance() { return *m_module_instance; } - static JS::ThrowCompletionOr create(JS::Realm& realm, Wasm::Module module, HashMap const& imports) + static JS::ThrowCompletionOr create(JS::Realm& realm, NonnullRefPtr module, HashMap const& imports) { auto& vm = realm.vm(); auto instance = realm.heap().allocate(realm, realm.intrinsics().object_prototype()); @@ -148,7 +148,7 @@ private: static HashMap s_spec_test_namespace; static Wasm::AbstractMachine m_machine; - Optional m_module; + RefPtr m_module; OwnPtr m_module_instance; }; @@ -379,13 +379,15 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke) arguments.append(Wasm::Value(bits)); break; } - case Wasm::ValueType::Kind::FunctionReference: + case Wasm::ValueType::Kind::FunctionReference: { if (argument.is_null()) { arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) } })); break; } - arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { static_cast(double_value) } })); + Wasm::FunctionAddress addr = static_cast(double_value); + arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { addr, machine().store().get_module_for(addr) } })); break; + } case Wasm::ValueType::Kind::ExternReference: if (argument.is_null()) { arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) } })); diff --git a/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp b/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp index cecbce36246..142e8859454 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp +++ b/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.cpp @@ -14,14 +14,14 @@ namespace Wasm { -Optional Store::allocate(ModuleInstance& module, CodeSection::Code const& code, TypeIndex type_index) +Optional Store::allocate(ModuleInstance& instance, Module const& module, CodeSection::Code const& code, TypeIndex type_index) { FunctionAddress address { m_functions.size() }; - if (type_index.value() > module.types().size()) + if (type_index.value() > instance.types().size()) return {}; - auto& type = module.types()[type_index.value()]; - m_functions.empend(WasmFunction { type, module, code }); + auto& type = instance.types()[type_index.value()]; + m_functions.empend(WasmFunction { type, instance, module, code }); return address; } @@ -84,6 +84,14 @@ FunctionInstance* Store::get(FunctionAddress address) return &m_functions[value]; } +Module const* Store::get_module_for(Wasm::FunctionAddress address) +{ + auto* function = get(address); + if (!function || function->has()) + return nullptr; + return function->get().module_ref().ptr(); +} + TableInstance* Store::get(TableAddress address) { auto value = address.value(); @@ -223,7 +231,7 @@ InstantiationResult AbstractMachine::instantiate(Module const& module, Vector source_module; // null if host function. }; struct Extern { ExternAddress address; @@ -139,7 +140,7 @@ public: // 2: null funcref // 3: null externref ref.ref().visit( - [&](Reference::Func const& func) { m_value = u128(bit_cast(func.address), 0); }, + [&](Reference::Func const& func) { m_value = u128(bit_cast(func.address), bit_cast(func.source_module.ptr())); }, [&](Reference::Extern const& func) { m_value = u128(bit_cast(func.address), 1); }, [&](Reference::Null const& null) { m_value = u128(0, null.type.kind() == ValueType::Kind::FunctionReference ? 2 : 3); }); } @@ -177,17 +178,15 @@ public: return bit_cast(m_value.low()); } if constexpr (IsSame) { - switch (m_value.high()) { + switch (m_value.high() & 3) { case 0: - return Reference { Reference::Func { bit_cast(m_value.low()) } }; + return Reference { Reference::Func { bit_cast(m_value.low()), bit_cast(m_value.high()) } }; case 1: return Reference { Reference::Extern { bit_cast(m_value.low()) } }; case 2: return Reference { Reference::Null { ValueType(ValueType::Kind::FunctionReference) } }; case 3: return Reference { Reference::Null { ValueType(ValueType::Kind::ExternReference) } }; - default: - VERIFY_NOT_REACHED(); } } VERIFY_NOT_REACHED(); @@ -341,20 +340,23 @@ private: class WasmFunction { public: - explicit WasmFunction(FunctionType const& type, ModuleInstance const& module, CodeSection::Code const& code) + explicit WasmFunction(FunctionType const& type, ModuleInstance const& instance, Module const& module, CodeSection::Code const& code) : m_type(type) - , m_module(module) + , m_module(module.make_weak_ptr()) + , m_module_instance(instance) , m_code(code) { } auto& type() const { return m_type; } - auto& module() const { return m_module; } + auto& module() const { return m_module_instance; } auto& code() const { return m_code; } + RefPtr module_ref() const { return m_module.strong_ref(); } private: FunctionType m_type; - ModuleInstance const& m_module; + WeakPtr m_module; + ModuleInstance const& m_module_instance; CodeSection::Code const& m_code; }; @@ -554,7 +556,7 @@ class Store { public: Store() = default; - Optional allocate(ModuleInstance&, CodeSection::Code const&, TypeIndex); + Optional allocate(ModuleInstance&, Module const&, CodeSection::Code const&, TypeIndex); Optional allocate(HostFunction&&); Optional allocate(TableType const&); Optional allocate(MemoryType const&); @@ -562,6 +564,7 @@ public: Optional allocate(GlobalType const&, Value); Optional allocate(ValueType const&, Vector); + Module const* get_module_for(FunctionAddress); FunctionInstance* get(FunctionAddress); TableInstance* get(TableAddress); MemoryInstance* get(MemoryAddress); diff --git a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp index 9e441cb0f94..7e61accaaac 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp +++ b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp @@ -864,7 +864,7 @@ ALWAYS_INLINE void BytecodeInterpreter::interpret_instruction(Configuration& con auto index = instruction.arguments().get().value(); auto& functions = configuration.frame().module().functions(); auto address = functions[index]; - configuration.value_stack().append(Value(address.value())); + configuration.value_stack().append(Value(Reference { Reference::Func { address, configuration.store().get_module_for(address) } })); return; } case Instructions::ref_is_null.value(): { diff --git a/Userland/Libraries/LibWasm/Parser/Parser.cpp b/Userland/Libraries/LibWasm/Parser/Parser.cpp index 0a7831618ef..6a1d30fbe81 100644 --- a/Userland/Libraries/LibWasm/Parser/Parser.cpp +++ b/Userland/Libraries/LibWasm/Parser/Parser.cpp @@ -1248,7 +1248,7 @@ ParseResult SectionId::parse(Stream& stream) } } -ParseResult Module::parse(Stream& stream) +ParseResult> Module::parse(Stream& stream) { ScopeLogger logger("Module"sv); u8 buf[4]; @@ -1263,7 +1263,9 @@ ParseResult Module::parse(Stream& stream) return with_eof_check(stream, ParseError::InvalidModuleVersion); auto last_section_id = SectionId::SectionIdKind::Custom; - Module module; + auto module_ptr = make_ref_counted(); + auto& module = *module_ptr; + while (!stream.is_eof()) { auto section_id = TRY(SectionId::parse(stream)); size_t section_size = TRY_READ(stream, LEB128, ParseError::ExpectedSize); @@ -1324,7 +1326,7 @@ ParseResult Module::parse(Stream& stream) return ParseError::SectionSizeMismatch; } - return module; + return module_ptr; } ByteString parse_error_to_byte_string(ParseError error) diff --git a/Userland/Libraries/LibWasm/Types.h b/Userland/Libraries/LibWasm/Types.h index 3aa21516458..fd557ba5967 100644 --- a/Userland/Libraries/LibWasm/Types.h +++ b/Userland/Libraries/LibWasm/Types.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -982,7 +983,8 @@ private: Optional m_count; }; -class Module { +class Module : public RefCounted + , public Weakable { public: enum class ValidationStatus { Unchecked, @@ -1027,7 +1029,7 @@ public: StringView validation_error() const { return *m_validation_error; } void set_validation_error(ByteString error) { m_validation_error = move(error); } - static ParseResult parse(Stream& stream); + static ParseResult> parse(Stream& stream); private: void set_validation_status(ValidationStatus status) { m_validation_status = status; } diff --git a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp index bfe89b0bbf6..717847601b5 100644 --- a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp +++ b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.cpp @@ -430,7 +430,7 @@ JS::ThrowCompletionOr to_webassembly_value(JS::VM& vm, JS::Value va auto& cache = get_cache(*vm.current_realm()); for (auto& entry : cache.function_instances()) { if (entry.value == &function) - return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key } } }; + return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key, cache.abstract_machine().store().get_module_for(entry.key) } } }; } } diff --git a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h index 41b696513b8..19397b50164 100644 --- a/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h +++ b/Userland/Libraries/LibWeb/WebAssembly/WebAssembly.h @@ -29,12 +29,12 @@ WebIDL::ExceptionOr instantiate(JS::VM&, Module const& module_object, namespace Detail { struct CompiledWebAssemblyModule : public RefCounted { - explicit CompiledWebAssemblyModule(Wasm::Module&& module) + explicit CompiledWebAssemblyModule(NonnullRefPtr module) : module(move(module)) { } - Wasm::Module module; + NonnullRefPtr module; }; class WebAssemblyCache { diff --git a/Userland/Utilities/wasm.cpp b/Userland/Utilities/wasm.cpp index 246051628e8..1f64387c72b 100644 --- a/Userland/Utilities/wasm.cpp +++ b/Userland/Utilities/wasm.cpp @@ -491,7 +491,7 @@ static bool pre_interpret_hook(Wasm::Configuration& config, Wasm::InstructionPoi } } -static Optional parse(StringView filename) +static RefPtr parse(StringView filename) { auto result = Core::MappedFile::map(filename); if (result.is_error()) { @@ -603,7 +603,7 @@ ErrorOr serenity_main(Main::Arguments arguments) attempt_instantiate = true; auto parse_result = parse(filename); - if (!parse_result.has_value()) + if (parse_result.is_null()) return 1; g_stdout = TRY(Core::File::standard_output()); @@ -611,7 +611,7 @@ ErrorOr serenity_main(Main::Arguments arguments) if (print && !attempt_instantiate) { Wasm::Printer printer(*g_stdout); - printer.print(parse_result.value()); + printer.print(*parse_result); } if (attempt_instantiate) { @@ -653,14 +653,14 @@ ErrorOr serenity_main(Main::Arguments arguments) // First, resolve the linked modules Vector> linked_instances; - Vector linked_modules; + Vector> linked_modules; for (auto& name : modules_to_link_in) { auto parse_result = parse(name); - if (!parse_result.has_value()) { + if (parse_result.is_null()) { warnln("Failed to parse linked module '{}'", name); return 1; } - linked_modules.append(parse_result.release_value()); + linked_modules.append(parse_result.release_nonnull()); Wasm::Linker linker { linked_modules.last() }; for (auto& instance : linked_instances) linker.link(*instance); @@ -678,7 +678,7 @@ ErrorOr serenity_main(Main::Arguments arguments) linked_instances.append(instantiation_result.release_value()); } - Wasm::Linker linker { parse_result.value() }; + Wasm::Linker linker { *parse_result }; for (auto& instance : linked_instances) linker.link(*instance); @@ -704,7 +704,7 @@ ErrorOr serenity_main(Main::Arguments arguments) for (auto& entry : linker.unresolved_imports()) { if (!entry.type.has()) continue; - auto type = parse_result.value().type_section().types()[entry.type.get().value()]; + auto type = parse_result->type_section().types()[entry.type.get().value()]; auto address = machine.store().allocate(Wasm::HostFunction( [name = entry.name, type = type](auto&, auto& arguments) -> Wasm::Result { StringBuilder argument_builder; @@ -744,7 +744,7 @@ ErrorOr serenity_main(Main::Arguments arguments) print_link_error(link_result.error()); return 1; } - auto result = machine.instantiate(parse_result.value(), link_result.release_value()); + auto result = machine.instantiate(*parse_result, link_result.release_value()); if (result.is_error()) { warnln("Module instantiation failed: {}", result.error().error); return 1;