LibWasm+LibWeb+test-wasm: Refcount Wasm::Module for function references

Prior to funcref, a partial chunk of an invalid module was never needed,
but funcref allows a partially instantiated module to modify imported
tables with references to its own functions, which means we need to keep
the second module alive while that function reference is present within
the imported table.
This was tested by the spectests, but very rarely caught as our GC does
not behave particularly predictably, making it so the offending module
remains in memory just long enough to let the tests pass.

This commit makes it so all function references keep their respective
modules alive.
This commit is contained in:
Ali Mohammad Pur 2024-08-22 01:13:37 +02:00 committed by Andreas Kling
parent 5606ce412e
commit a60ecea16a
Notes: github-actions[bot] 2024-08-22 07:37:30 +00:00
9 changed files with 54 additions and 37 deletions

View file

@ -53,7 +53,7 @@ public:
Wasm::Module& module() { return *m_module; }
Wasm::ModuleInstance& module_instance() { return *m_module_instance; }
static JS::ThrowCompletionOr<WebAssemblyModule*> create(JS::Realm& realm, Wasm::Module module, HashMap<Wasm::Linker::Name, Wasm::ExternValue> const& imports)
static JS::ThrowCompletionOr<WebAssemblyModule*> create(JS::Realm& realm, NonnullRefPtr<Wasm::Module> module, HashMap<Wasm::Linker::Name, Wasm::ExternValue> const& imports)
{
auto& vm = realm.vm();
auto instance = realm.heap().allocate<WebAssemblyModule>(realm, realm.intrinsics().object_prototype());
@ -148,7 +148,7 @@ private:
static HashMap<Wasm::Linker::Name, Wasm::ExternValue> s_spec_test_namespace;
static Wasm::AbstractMachine m_machine;
Optional<Wasm::Module> m_module;
RefPtr<Wasm::Module> m_module;
OwnPtr<Wasm::ModuleInstance> m_module_instance;
};
@ -379,13 +379,15 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
arguments.append(Wasm::Value(bits));
break;
}
case Wasm::ValueType::Kind::FunctionReference:
case Wasm::ValueType::Kind::FunctionReference: {
if (argument.is_null()) {
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) } }));
break;
}
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { static_cast<u64>(double_value) } }));
Wasm::FunctionAddress addr = static_cast<u64>(double_value);
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { addr, machine().store().get_module_for(addr) } }));
break;
}
case Wasm::ValueType::Kind::ExternReference:
if (argument.is_null()) {
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) } }));

View file

@ -14,14 +14,14 @@
namespace Wasm {
Optional<FunctionAddress> Store::allocate(ModuleInstance& module, CodeSection::Code const& code, TypeIndex type_index)
Optional<FunctionAddress> Store::allocate(ModuleInstance& instance, Module const& module, CodeSection::Code const& code, TypeIndex type_index)
{
FunctionAddress address { m_functions.size() };
if (type_index.value() > module.types().size())
if (type_index.value() > instance.types().size())
return {};
auto& type = module.types()[type_index.value()];
m_functions.empend(WasmFunction { type, module, code });
auto& type = instance.types()[type_index.value()];
m_functions.empend(WasmFunction { type, instance, module, code });
return address;
}
@ -84,6 +84,14 @@ FunctionInstance* Store::get(FunctionAddress address)
return &m_functions[value];
}
Module const* Store::get_module_for(Wasm::FunctionAddress address)
{
auto* function = get(address);
if (!function || function->has<HostFunction>())
return nullptr;
return function->get<WasmFunction>().module_ref().ptr();
}
TableInstance* Store::get(TableAddress address)
{
auto value = address.value();
@ -223,7 +231,7 @@ InstantiationResult AbstractMachine::instantiate(Module const& module, Vector<Ex
size_t i = 0;
for (auto& code : module.code_section().functions()) {
auto type_index = module.function_section().types()[i];
auto address = m_store.allocate(main_module_instance, code, type_index);
auto address = m_store.allocate(main_module_instance, module, code, type_index);
VERIFY(address.has_value());
auxiliary_instance.functions().append(*address);
module_functions.append(*address);

View file

@ -51,6 +51,7 @@ public:
};
struct Func {
FunctionAddress address;
RefPtr<Module> source_module; // null if host function.
};
struct Extern {
ExternAddress address;
@ -139,7 +140,7 @@ public:
// 2: null funcref
// 3: null externref
ref.ref().visit(
[&](Reference::Func const& func) { m_value = u128(bit_cast<u64>(func.address), 0); },
[&](Reference::Func const& func) { m_value = u128(bit_cast<u64>(func.address), bit_cast<u64>(func.source_module.ptr())); },
[&](Reference::Extern const& func) { m_value = u128(bit_cast<u64>(func.address), 1); },
[&](Reference::Null const& null) { m_value = u128(0, null.type.kind() == ValueType::Kind::FunctionReference ? 2 : 3); });
}
@ -177,17 +178,15 @@ public:
return bit_cast<f64>(m_value.low());
}
if constexpr (IsSame<T, Reference>) {
switch (m_value.high()) {
switch (m_value.high() & 3) {
case 0:
return Reference { Reference::Func { bit_cast<FunctionAddress>(m_value.low()) } };
return Reference { Reference::Func { bit_cast<FunctionAddress>(m_value.low()), bit_cast<Wasm::Module*>(m_value.high()) } };
case 1:
return Reference { Reference::Extern { bit_cast<ExternAddress>(m_value.low()) } };
case 2:
return Reference { Reference::Null { ValueType(ValueType::Kind::FunctionReference) } };
case 3:
return Reference { Reference::Null { ValueType(ValueType::Kind::ExternReference) } };
default:
VERIFY_NOT_REACHED();
}
}
VERIFY_NOT_REACHED();
@ -341,20 +340,23 @@ private:
class WasmFunction {
public:
explicit WasmFunction(FunctionType const& type, ModuleInstance const& module, CodeSection::Code const& code)
explicit WasmFunction(FunctionType const& type, ModuleInstance const& instance, Module const& module, CodeSection::Code const& code)
: m_type(type)
, m_module(module)
, m_module(module.make_weak_ptr())
, m_module_instance(instance)
, m_code(code)
{
}
auto& type() const { return m_type; }
auto& module() const { return m_module; }
auto& module() const { return m_module_instance; }
auto& code() const { return m_code; }
RefPtr<Module const> module_ref() const { return m_module.strong_ref(); }
private:
FunctionType m_type;
ModuleInstance const& m_module;
WeakPtr<Module const> m_module;
ModuleInstance const& m_module_instance;
CodeSection::Code const& m_code;
};
@ -554,7 +556,7 @@ class Store {
public:
Store() = default;
Optional<FunctionAddress> allocate(ModuleInstance&, CodeSection::Code const&, TypeIndex);
Optional<FunctionAddress> allocate(ModuleInstance&, Module const&, CodeSection::Code const&, TypeIndex);
Optional<FunctionAddress> allocate(HostFunction&&);
Optional<TableAddress> allocate(TableType const&);
Optional<MemoryAddress> allocate(MemoryType const&);
@ -562,6 +564,7 @@ public:
Optional<GlobalAddress> allocate(GlobalType const&, Value);
Optional<ElementAddress> allocate(ValueType const&, Vector<Reference>);
Module const* get_module_for(FunctionAddress);
FunctionInstance* get(FunctionAddress);
TableInstance* get(TableAddress);
MemoryInstance* get(MemoryAddress);

View file

@ -864,7 +864,7 @@ ALWAYS_INLINE void BytecodeInterpreter::interpret_instruction(Configuration& con
auto index = instruction.arguments().get<FunctionIndex>().value();
auto& functions = configuration.frame().module().functions();
auto address = functions[index];
configuration.value_stack().append(Value(address.value()));
configuration.value_stack().append(Value(Reference { Reference::Func { address, configuration.store().get_module_for(address) } }));
return;
}
case Instructions::ref_is_null.value(): {

View file

@ -1248,7 +1248,7 @@ ParseResult<SectionId> SectionId::parse(Stream& stream)
}
}
ParseResult<Module> Module::parse(Stream& stream)
ParseResult<NonnullRefPtr<Module>> Module::parse(Stream& stream)
{
ScopeLogger<WASM_BINPARSER_DEBUG> logger("Module"sv);
u8 buf[4];
@ -1263,7 +1263,9 @@ ParseResult<Module> Module::parse(Stream& stream)
return with_eof_check(stream, ParseError::InvalidModuleVersion);
auto last_section_id = SectionId::SectionIdKind::Custom;
Module module;
auto module_ptr = make_ref_counted<Module>();
auto& module = *module_ptr;
while (!stream.is_eof()) {
auto section_id = TRY(SectionId::parse(stream));
size_t section_size = TRY_READ(stream, LEB128<u32>, ParseError::ExpectedSize);
@ -1324,7 +1326,7 @@ ParseResult<Module> Module::parse(Stream& stream)
return ParseError::SectionSizeMismatch;
}
return module;
return module_ptr;
}
ByteString parse_error_to_byte_string(ParseError error)

View file

@ -14,6 +14,7 @@
#include <AK/String.h>
#include <AK/UFixedBigInt.h>
#include <AK/Variant.h>
#include <AK/WeakPtr.h>
#include <LibWasm/Constants.h>
#include <LibWasm/Forward.h>
#include <LibWasm/Opcode.h>
@ -982,7 +983,8 @@ private:
Optional<u32> m_count;
};
class Module {
class Module : public RefCounted<Module>
, public Weakable<Module> {
public:
enum class ValidationStatus {
Unchecked,
@ -1027,7 +1029,7 @@ public:
StringView validation_error() const { return *m_validation_error; }
void set_validation_error(ByteString error) { m_validation_error = move(error); }
static ParseResult<Module> parse(Stream& stream);
static ParseResult<NonnullRefPtr<Module>> parse(Stream& stream);
private:
void set_validation_status(ValidationStatus status) { m_validation_status = status; }

View file

@ -430,7 +430,7 @@ JS::ThrowCompletionOr<Wasm::Value> to_webassembly_value(JS::VM& vm, JS::Value va
auto& cache = get_cache(*vm.current_realm());
for (auto& entry : cache.function_instances()) {
if (entry.value == &function)
return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key } } };
return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key, cache.abstract_machine().store().get_module_for(entry.key) } } };
}
}

View file

@ -29,12 +29,12 @@ WebIDL::ExceptionOr<JS::Value> instantiate(JS::VM&, Module const& module_object,
namespace Detail {
struct CompiledWebAssemblyModule : public RefCounted<CompiledWebAssemblyModule> {
explicit CompiledWebAssemblyModule(Wasm::Module&& module)
explicit CompiledWebAssemblyModule(NonnullRefPtr<Wasm::Module> module)
: module(move(module))
{
}
Wasm::Module module;
NonnullRefPtr<Wasm::Module> module;
};
class WebAssemblyCache {

View file

@ -491,7 +491,7 @@ static bool pre_interpret_hook(Wasm::Configuration& config, Wasm::InstructionPoi
}
}
static Optional<Wasm::Module> parse(StringView filename)
static RefPtr<Wasm::Module> parse(StringView filename)
{
auto result = Core::MappedFile::map(filename);
if (result.is_error()) {
@ -603,7 +603,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
attempt_instantiate = true;
auto parse_result = parse(filename);
if (!parse_result.has_value())
if (parse_result.is_null())
return 1;
g_stdout = TRY(Core::File::standard_output());
@ -611,7 +611,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
if (print && !attempt_instantiate) {
Wasm::Printer printer(*g_stdout);
printer.print(parse_result.value());
printer.print(*parse_result);
}
if (attempt_instantiate) {
@ -653,14 +653,14 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
// First, resolve the linked modules
Vector<NonnullOwnPtr<Wasm::ModuleInstance>> linked_instances;
Vector<Wasm::Module> linked_modules;
Vector<NonnullRefPtr<Wasm::Module>> linked_modules;
for (auto& name : modules_to_link_in) {
auto parse_result = parse(name);
if (!parse_result.has_value()) {
if (parse_result.is_null()) {
warnln("Failed to parse linked module '{}'", name);
return 1;
}
linked_modules.append(parse_result.release_value());
linked_modules.append(parse_result.release_nonnull());
Wasm::Linker linker { linked_modules.last() };
for (auto& instance : linked_instances)
linker.link(*instance);
@ -678,7 +678,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
linked_instances.append(instantiation_result.release_value());
}
Wasm::Linker linker { parse_result.value() };
Wasm::Linker linker { *parse_result };
for (auto& instance : linked_instances)
linker.link(*instance);
@ -704,7 +704,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
for (auto& entry : linker.unresolved_imports()) {
if (!entry.type.has<Wasm::TypeIndex>())
continue;
auto type = parse_result.value().type_section().types()[entry.type.get<Wasm::TypeIndex>().value()];
auto type = parse_result->type_section().types()[entry.type.get<Wasm::TypeIndex>().value()];
auto address = machine.store().allocate(Wasm::HostFunction(
[name = entry.name, type = type](auto&, auto& arguments) -> Wasm::Result {
StringBuilder argument_builder;
@ -744,7 +744,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
print_link_error(link_result.error());
return 1;
}
auto result = machine.instantiate(parse_result.value(), link_result.release_value());
auto result = machine.instantiate(*parse_result, link_result.release_value());
if (result.is_error()) {
warnln("Module instantiation failed: {}", result.error().error);
return 1;