diff options
Diffstat (limited to 'third_party/rlbox_wasm2c_sandbox/include')
4 files changed, 1088 insertions, 0 deletions
diff --git a/third_party/rlbox_wasm2c_sandbox/include/rlbox_wasm2c_sandbox.hpp b/third_party/rlbox_wasm2c_sandbox/include/rlbox_wasm2c_sandbox.hpp new file mode 100644 index 0000000000..ed6975e27e --- /dev/null +++ b/third_party/rlbox_wasm2c_sandbox/include/rlbox_wasm2c_sandbox.hpp @@ -0,0 +1,971 @@ +#pragma once + +#include "rlbox_wasm2c_tls.hpp" +#include "wasm-rt.h" +#include "wasm2c_rt_mem.h" +#include "wasm2c_rt_minwasi.h" + +// Pull the helper header from the main repo for dynamic_check and scope_exit +#include "rlbox_helpers.hpp" + +#include <cstdint> +#include <iostream> +#include <limits> +#include <map> +#include <memory> +#include <mutex> +// RLBox allows applications to provide a custom shared lock implementation +#ifndef RLBOX_USE_CUSTOM_SHARED_LOCK +# include <shared_mutex> +#endif +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +#if defined(_WIN32) +// Ensure the min/max macro in the header doesn't collide with functions in +// std:: +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include <windows.h> +#else +# include <dlfcn.h> +#endif + +#define RLBOX_WASM2C_UNUSED(...) (void)__VA_ARGS__ + +// Use the same convention as rlbox to allow applications to customize the +// shared lock +#ifndef RLBOX_USE_CUSTOM_SHARED_LOCK +# define RLBOX_SHARED_LOCK(name) std::shared_timed_mutex name +# define RLBOX_ACQUIRE_SHARED_GUARD(name, ...) \ + std::shared_lock<std::shared_timed_mutex> name(__VA_ARGS__) +# define RLBOX_ACQUIRE_UNIQUE_GUARD(name, ...) \ + std::unique_lock<std::shared_timed_mutex> name(__VA_ARGS__) +#else +# if !defined(RLBOX_SHARED_LOCK) || !defined(RLBOX_ACQUIRE_SHARED_GUARD) || \ + !defined(RLBOX_ACQUIRE_UNIQUE_GUARD) +# error \ + "RLBOX_USE_CUSTOM_SHARED_LOCK defined but missing definitions for RLBOX_SHARED_LOCK, RLBOX_ACQUIRE_SHARED_GUARD, RLBOX_ACQUIRE_UNIQUE_GUARD" +# endif +#endif + +#define DEFINE_RLBOX_WASM2C_MODULE_TYPE(modname) \ + struct rlbox_wasm2c_module_type_##modname \ + { \ + using instance_t = w2c_##modname; \ + \ + using create_instance_t = void (*)(instance_t*, \ + struct w2c_env*, \ + struct w2c_wasi__snapshot__preview1*); \ + static constexpr create_instance_t create_instance = \ + &wasm2c_##modname##_instantiate; \ + \ + using free_instance_t = void (*)(instance_t*); \ + static constexpr free_instance_t free_instance = &wasm2c_##modname##_free; \ + \ + using get_func_type_t = wasm_rt_func_type_t (*)(uint32_t, uint32_t, ...); \ + static constexpr get_func_type_t get_func_type = \ + &wasm2c_##modname##_get_func_type; \ + \ + static constexpr const uint64_t* initial_memory_pages = \ + &wasm2c_##modname##_min_env_memory; \ + static constexpr const uint8_t* is_memory_64 = \ + &wasm2c_##modname##_is64_env_memory; \ + static constexpr const uint32_t* initial_func_elements = \ + &wasm2c_##modname##_min_env_0x5F_indirect_function_table; \ + \ + static constexpr const char* prefix = #modname; \ + \ + /* A function that returns the address of the func specified as a \ + * constexpr string */ \ + /* Unfortunately, there is no way to implement the below in C++. */ \ + /* Implement this to fully support multiple static modules. */ \ + /* static constexpr void* dlsym_in_w2c_module(const char* func_name) { */ \ + /* return &w2c_##modname##_%func%; */ \ + /* } */ \ + \ + static constexpr auto malloc_address = &w2c_##modname##_malloc; \ + static constexpr auto free_address = &w2c_##modname##_free; \ + } + +// wasm_module_name module name used when compiling with wasm2c +#ifndef RLBOX_WASM2C_MODULE_NAME +# error "Expected definition for RLBOX_WASM2C_MODULE_NAME" +#endif + +// Need an extra macro to expand RLBOX_WASM2C_MODULE_NAME +#define INVOKE_DEFINE_RLBOX_WASM2C_MODULE_TYPE(modname) \ + DEFINE_RLBOX_WASM2C_MODULE_TYPE(modname) + +INVOKE_DEFINE_RLBOX_WASM2C_MODULE_TYPE(RLBOX_WASM2C_MODULE_NAME); + +// Concat after macro expansion +#define RLBOX_WASM2C_CONCAT2(x, y) x##y +#define RLBOX_WASM2C_CONCAT(x, y) RLBOX_WASM2C_CONCAT2(x, y) + +#define RLBOX_WASM_MODULE_TYPE_CURR \ + RLBOX_WASM2C_CONCAT(rlbox_wasm2c_module_type_, RLBOX_WASM2C_MODULE_NAME) + +#define RLBOX_WASM2C_STRINGIFY(x) RLBOX_WASM2C_STRINGIFY2(x) +#define RLBOX_WASM2C_STRINGIFY2(x) #x + +#define RLBOX_WASM2C_MODULE_NAME_STR \ + RLBOX_WASM2C_STRINGIFY(RLBOX_WASM2C_MODULE_NAME) + +#define RLBOX_WASM2C_MODULE_FUNC_HELPER2(part1, part2, part3) \ + part1##part2##part3 +#define RLBOX_WASM2C_MODULE_FUNC_HELPER(part1, part2, part3) \ + RLBOX_WASM2C_MODULE_FUNC_HELPER2(part1, part2, part3) +#define RLBOX_WASM2C_MODULE_FUNC(name) \ + RLBOX_WASM2C_MODULE_FUNC_HELPER(w2c_, RLBOX_WASM2C_MODULE_NAME, name) + +namespace rlbox { + +namespace wasm2c_detail { + + template<typename T> + constexpr bool false_v = false; + + // https://stackoverflow.com/questions/6512019/can-we-get-the-type-of-a-lambda-argument + namespace return_argument_detail { + template<typename Ret, typename... Rest> + Ret helper(Ret (*)(Rest...)); + + template<typename Ret, typename F, typename... Rest> + Ret helper(Ret (F::*)(Rest...)); + + template<typename Ret, typename F, typename... Rest> + Ret helper(Ret (F::*)(Rest...) const); + + template<typename F> + decltype(helper(&F::operator())) helper(F); + } // namespace return_argument_detail + + template<typename T> + using return_argument = + decltype(return_argument_detail::helper(std::declval<T>())); + + /////////////////////////////////////////////////////////////// + + // https://stackoverflow.com/questions/37602057/why-isnt-a-for-loop-a-compile-time-expression + namespace compile_time_for_detail { + template<std::size_t N> + struct num + { + static const constexpr auto value = N; + }; + + template<class F, std::size_t... Is> + inline void compile_time_for_helper(F func, std::index_sequence<Is...>) + { + (func(num<Is>{}), ...); + } + } // namespace compile_time_for_detail + + template<std::size_t N, typename F> + inline void compile_time_for(F func) + { + compile_time_for_detail::compile_time_for_helper( + func, std::make_index_sequence<N>()); + } + + /////////////////////////////////////////////////////////////// + + template<typename T, typename = void> + struct convert_type_to_wasm_type + { + static_assert(std::is_void_v<T>, "Missing specialization"); + using type = void; + // wasm2c has no void type so use i32 for now + static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I32; + }; + + template<typename T> + struct convert_type_to_wasm_type< + T, + std::enable_if_t<(std::is_integral_v<T> || std::is_enum_v<T>)&&sizeof(T) <= + sizeof(uint32_t)>> + { + using type = uint32_t; + static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I32; + }; + + template<typename T> + struct convert_type_to_wasm_type< + T, + std::enable_if_t<(std::is_integral_v<T> || + std::is_enum_v<T>)&&sizeof(uint32_t) < sizeof(T) && + sizeof(T) <= sizeof(uint64_t)>> + { + using type = uint64_t; + static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I64; + }; + + template<typename T> + struct convert_type_to_wasm_type<T, + std::enable_if_t<std::is_same_v<T, float>>> + { + using type = T; + static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_F32; + }; + + template<typename T> + struct convert_type_to_wasm_type<T, + std::enable_if_t<std::is_same_v<T, double>>> + { + using type = T; + static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_F64; + }; + + template<typename T> + struct convert_type_to_wasm_type< + T, + std::enable_if_t<std::is_pointer_v<T> || std::is_class_v<T>>> + { + // pointers are 32 bit indexes in wasm + // class paramters are passed as a pointer to an object in the stack or heap + using type = uint32_t; + static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I32; + }; + + /////////////////////////////////////////////////////////////// + + namespace prepend_arg_type_detail { + template<typename T, typename T_ArgNew> + struct helper; + + template<typename T_ArgNew, typename T_Ret, typename... T_Args> + struct helper<T_Ret(T_Args...), T_ArgNew> + { + using type = T_Ret(T_ArgNew, T_Args...); + }; + } + + template<typename T_Func, typename T_ArgNew> + using prepend_arg_type = + typename prepend_arg_type_detail::helper<T_Func, T_ArgNew>::type; + + /////////////////////////////////////////////////////////////// + + namespace change_return_type_detail { + template<typename T, typename T_RetNew> + struct helper; + + template<typename T_RetNew, typename T_Ret, typename... T_Args> + struct helper<T_Ret(T_Args...), T_RetNew> + { + using type = T_RetNew(T_Args...); + }; + } + + template<typename T_Func, typename T_RetNew> + using change_return_type = + typename change_return_type_detail::helper<T_Func, T_RetNew>::type; + + /////////////////////////////////////////////////////////////// + + namespace change_class_arg_types_detail { + template<typename T, typename T_ArgNew> + struct helper; + + template<typename T_ArgNew, typename T_Ret, typename... T_Args> + struct helper<T_Ret(T_Args...), T_ArgNew> + { + using type = + T_Ret(std::conditional_t<std::is_class_v<T_Args>, T_ArgNew, T_Args>...); + }; + } + + template<typename T_Func, typename T_ArgNew> + using change_class_arg_types = + typename change_class_arg_types_detail::helper<T_Func, T_ArgNew>::type; + +} // namespace wasm2c_detail + +// declare the static symbol with weak linkage to keep this header only +#if defined(_MSC_VER) +__declspec(selectany) +#else +__attribute__((weak)) +#endif + std::once_flag rlbox_wasm2c_initialized; + +class rlbox_wasm2c_sandbox +{ +public: + using T_LongLongType = int64_t; + using T_LongType = int32_t; + using T_IntType = int32_t; + using T_PointerType = uint32_t; + using T_ShortType = int16_t; + +private: + mutable typename RLBOX_WASM_MODULE_TYPE_CURR::instance_t wasm2c_instance{ 0 }; + struct w2c_env sandbox_memory_env; + struct w2c_wasi__snapshot__preview1 wasi_env; + bool instance_initialized = false; + wasm_rt_memory_t sandbox_memory_info; + mutable wasm_rt_funcref_table_t sandbox_callback_table; + uintptr_t heap_base; + size_t return_slot_size = 0; + T_PointerType return_slot = 0; + mutable std::vector<T_PointerType> callback_free_list; + + static const size_t MAX_CALLBACKS = 128; + mutable RLBOX_SHARED_LOCK(callback_mutex); + void* callback_unique_keys[MAX_CALLBACKS]{ 0 }; + void* callbacks[MAX_CALLBACKS]{ 0 }; + uint32_t callback_slot_assignment[MAX_CALLBACKS]{ 0 }; + mutable std::map<const void*, uint32_t> internal_callbacks; + mutable std::map<uint32_t, const void*> slot_assignments; + +#ifndef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES + thread_local static inline rlbox_wasm2c_sandbox_thread_data thread_data{ 0, + 0 }; +#endif + + template<typename T_FormalRet, typename T_ActualRet> + inline auto serialize_to_sandbox(T_ActualRet arg) + { + if constexpr (std::is_class_v<T_FormalRet>) { + // structs returned as pointers into wasm memory/wasm stack + auto ptr = reinterpret_cast<T_FormalRet*>( + impl_get_unsandboxed_pointer<T_FormalRet*>(arg)); + T_FormalRet ret = *ptr; + return ret; + } else { + return arg; + } + } + + template<uint32_t N, typename T_Ret, typename... T_Args> + static typename wasm2c_detail::convert_type_to_wasm_type<T_Ret>::type + callback_interceptor( + void* /* vmContext */, + typename wasm2c_detail::convert_type_to_wasm_type<T_Args>::type... params) + { +#ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES + auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data(); +#endif + thread_data.last_callback_invoked = N; + using T_Func = T_Ret (*)(T_Args...); + T_Func func; + { +#ifndef RLBOX_SINGLE_THREADED_INVOCATIONS + RLBOX_ACQUIRE_SHARED_GUARD(lock, thread_data.sandbox->callback_mutex); +#endif + func = reinterpret_cast<T_Func>(thread_data.sandbox->callbacks[N]); + } + // Callbacks are invoked through function pointers, cannot use std::forward + // as we don't have caller context for T_Args, which means they are all + // effectively passed by value + return func( + thread_data.sandbox->template serialize_to_sandbox<T_Args>(params)...); + } + + template<uint32_t N, typename T_Ret, typename... T_Args> + static void callback_interceptor_promoted( + void* /* vmContext */, + typename wasm2c_detail::convert_type_to_wasm_type<T_Ret>::type ret, + typename wasm2c_detail::convert_type_to_wasm_type<T_Args>::type... params) + { +#ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES + auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data(); +#endif + thread_data.last_callback_invoked = N; + using T_Func = T_Ret (*)(T_Args...); + T_Func func; + { +#ifndef RLBOX_SINGLE_THREADED_INVOCATIONS + RLBOX_ACQUIRE_SHARED_GUARD(lock, thread_data.sandbox->callback_mutex); +#endif + func = reinterpret_cast<T_Func>(thread_data.sandbox->callbacks[N]); + } + // Callbacks are invoked through function pointers, cannot use std::forward + // as we don't have caller context for T_Args, which means they are all + // effectively passed by value + auto ret_val = func( + thread_data.sandbox->template serialize_to_sandbox<T_Args>(params)...); + // Copy the return value back + auto ret_ptr = reinterpret_cast<T_Ret*>( + thread_data.sandbox->template impl_get_unsandboxed_pointer<T_Ret*>(ret)); + *ret_ptr = ret_val; + } + + template<typename T_Ret, typename... T_Args> + inline wasm_rt_func_type_t get_wasm2c_func_index( + // dummy for template inference + T_Ret (*)(T_Args...) = nullptr) const + { + // Class return types as promoted to args + constexpr bool promoted = std::is_class_v<T_Ret>; + constexpr uint32_t param_count = + promoted ? (sizeof...(T_Args) + 1) : (sizeof...(T_Args)); + constexpr uint32_t ret_count = + promoted ? 0 : (std::is_void_v<T_Ret> ? 0 : 1); + + wasm_rt_func_type_t ret = nullptr; + if constexpr (ret_count == 0) { + ret = RLBOX_WASM_MODULE_TYPE_CURR::get_func_type( + param_count, + ret_count, + wasm2c_detail::convert_type_to_wasm_type<T_Args>::wasm2c_type...); + } else { + ret = RLBOX_WASM_MODULE_TYPE_CURR::get_func_type( + param_count, + ret_count, + wasm2c_detail::convert_type_to_wasm_type<T_Args>::wasm2c_type..., + wasm2c_detail::convert_type_to_wasm_type<T_Ret>::wasm2c_type); + } + + return ret; + } + + void ensure_return_slot_size(size_t size) + { + if (size > return_slot_size) { + if (return_slot_size) { + impl_free_in_sandbox(return_slot); + } + return_slot = impl_malloc_in_sandbox(size); + detail::dynamic_check( + return_slot != 0, + "Error initializing return slot. Sandbox may be out of memory!"); + return_slot_size = size; + } + } + + // function takes a 32-bit value and returns the next power of 2 + // return is a 64-bit value as large 32-bit values will return 2^32 + static inline uint64_t next_power_of_two(uint32_t value) + { + uint64_t power = 1; + while (power < value) { + power *= 2; + } + return power; + } + +protected: +#define rlbox_wasm2c_sandbox_lookup_symbol(func_name) \ + reinterpret_cast<void*>(&RLBOX_WASM2C_MODULE_FUNC(_##func_name)) /* NOLINT \ + */ + + // adding a template so that we can use static_assert to fire only if this + // function is invoked + template<typename T = void> + void* impl_lookup_symbol(const char* func_name) + { + constexpr bool fail = std::is_same_v<T, void>; + static_assert( + !fail, + "The wasm2c_sandbox uses static calls and thus developers should add\n\n" + "#define RLBOX_USE_STATIC_CALLS() rlbox_wasm2c_sandbox_lookup_symbol\n\n" + "to their code, to ensure that static calls are handled correctly."); + return nullptr; + } + +public: +#define FALLIBLE_DYNAMIC_CHECK(infallible, cond, msg) \ + if (infallible) { \ + detail::dynamic_check(cond, msg); \ + } else if (!(cond)) { \ + impl_destroy_sandbox(); \ + return false; \ + } + + /** + * @brief creates the Wasm sandbox from the given shared library + * + * @param infallible if set to true, the sandbox aborts on failure. If false, + * the sandbox returns creation status as a return value + * @param custom_capacity allows optionally overriding the platform-specified + * maximum size of the wasm heap allowed for this sandbox instance. + * @return true when sandbox is successfully created. false when infallible is + * set to false and sandbox was not successfully created. If infallible is set + * to true, this function will never return false. + */ + inline bool impl_create_sandbox( + bool infallible = true, + const w2c_mem_capacity* custom_capacity = nullptr) + { + FALLIBLE_DYNAMIC_CHECK( + infallible, instance_initialized == false, "Sandbox already initialized"); + + bool minwasi_init_succeeded = true; + + std::call_once(rlbox_wasm2c_initialized, [&]() { + wasm_rt_init(); + minwasi_init_succeeded = minwasi_init(); + }); + + FALLIBLE_DYNAMIC_CHECK( + infallible, minwasi_init_succeeded, "Could not initialize min wasi"); + + const bool minwasi_init_inst_succeeded = minwasi_init_instance(&wasi_env); + FALLIBLE_DYNAMIC_CHECK( + infallible, minwasi_init_inst_succeeded, "Could not initialize min wasi instance"); + + if (custom_capacity) { + FALLIBLE_DYNAMIC_CHECK( + infallible, custom_capacity->is_valid, "Invalid capacity"); + } + + sandbox_memory_info = create_wasm2c_memory( + *RLBOX_WASM_MODULE_TYPE_CURR::initial_memory_pages, custom_capacity); + FALLIBLE_DYNAMIC_CHECK(infallible, + sandbox_memory_info.data != nullptr, + "Could not allocate a heap for the wasm2c sandbox"); + + FALLIBLE_DYNAMIC_CHECK(infallible, + *RLBOX_WASM_MODULE_TYPE_CURR::is_memory_64 == 0, + "Does not support Wasm with memory64"); + + const uint32_t max_table_size = 0xffffffffu; /* this means unlimited */ + wasm_rt_allocate_funcref_table( + &sandbox_callback_table, + *RLBOX_WASM_MODULE_TYPE_CURR::initial_func_elements, + max_table_size); + + sandbox_memory_env.sandbox_memory_info = &sandbox_memory_info; + sandbox_memory_env.sandbox_callback_table = &sandbox_callback_table; + wasi_env.instance_memory = &sandbox_memory_info; + RLBOX_WASM_MODULE_TYPE_CURR::create_instance( + &wasm2c_instance, &sandbox_memory_env, &wasi_env); + + heap_base = reinterpret_cast<uintptr_t>(impl_get_memory_location()); + + if constexpr (sizeof(uintptr_t) != sizeof(uint32_t)) { + // On larger platforms, check that the heap is aligned to the pointer size + // i.e. 32-bit pointer => aligned to 4GB. The implementations of + // impl_get_unsandboxed_pointer_no_ctx and + // impl_get_sandboxed_pointer_no_ctx below rely on this. + uintptr_t heap_offset_mask = std::numeric_limits<T_PointerType>::max(); + FALLIBLE_DYNAMIC_CHECK(infallible, + (heap_base & heap_offset_mask) == 0, + "Sandbox heap not aligned to 4GB"); + } + + instance_initialized = true; + + return true; + } + +#undef FALLIBLE_DYNAMIC_CHECK + + inline void impl_destroy_sandbox() + { + if (return_slot_size) { + impl_free_in_sandbox(return_slot); + } + + if (instance_initialized) { + instance_initialized = false; + RLBOX_WASM_MODULE_TYPE_CURR::free_instance(&wasm2c_instance); + } + + destroy_wasm2c_memory(&sandbox_memory_info); + wasm_rt_free_funcref_table(&sandbox_callback_table); + minwasi_cleanup_instance(&wasi_env); + } + + template<typename T> + inline void* impl_get_unsandboxed_pointer(T_PointerType p) const + { + if constexpr (std::is_function_v<std::remove_pointer_t<T>>) { + RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex); + auto found = slot_assignments.find(p); + if (found != slot_assignments.end()) { + auto ret = found->second; + return const_cast<void*>(ret); + } else { + return nullptr; + } + } else { + return reinterpret_cast<void*>(heap_base + p); + } + } + + template<typename T> + inline T_PointerType impl_get_sandboxed_pointer(const void* p) const + { + if constexpr (std::is_function_v<std::remove_pointer_t<T>>) { + RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex); + + uint32_t slot_number = 0; + auto found = internal_callbacks.find(p); + if (found != internal_callbacks.end()) { + slot_number = found->second; + } else { + + slot_number = new_callback_slot(); + wasm_rt_funcref_t func_val; + func_val.func_type = get_wasm2c_func_index(static_cast<T>(nullptr)); + func_val.func = + reinterpret_cast<wasm_rt_function_ptr_t>(const_cast<void*>(p)); + func_val.module_instance = &wasm2c_instance; + + sandbox_callback_table.data[slot_number] = func_val; + internal_callbacks[p] = slot_number; + slot_assignments[slot_number] = p; + } + return static_cast<T_PointerType>(slot_number); + } else { + if constexpr (sizeof(uintptr_t) == sizeof(uint32_t)) { + return static_cast<T_PointerType>(reinterpret_cast<uintptr_t>(p) - + heap_base); + } else { + return static_cast<T_PointerType>(reinterpret_cast<uintptr_t>(p)); + } + } + } + + template<typename T> + static inline void* impl_get_unsandboxed_pointer_no_ctx( + T_PointerType p, + const void* example_unsandboxed_ptr, + rlbox_wasm2c_sandbox* (*expensive_sandbox_finder)( + const void* example_unsandboxed_ptr)) + { + // on 32-bit platforms we don't assume the heap is aligned + if constexpr (sizeof(uintptr_t) == sizeof(uint32_t)) { + auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr); + return sandbox->template impl_get_unsandboxed_pointer<T>(p); + } else { + if constexpr (std::is_function_v<std::remove_pointer_t<T>>) { + // swizzling function pointers needs access to the function pointer + // tables and thus cannot be done without context + auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr); + return sandbox->template impl_get_unsandboxed_pointer<T>(p); + } else { + // grab the memory base from the example_unsandboxed_ptr + uintptr_t heap_base_mask = + std::numeric_limits<uintptr_t>::max() & + ~(static_cast<uintptr_t>(std::numeric_limits<T_PointerType>::max())); + uintptr_t computed_heap_base = + reinterpret_cast<uintptr_t>(example_unsandboxed_ptr) & heap_base_mask; + uintptr_t ret = computed_heap_base | p; + return reinterpret_cast<void*>(ret); + } + } + } + + template<typename T> + static inline T_PointerType impl_get_sandboxed_pointer_no_ctx( + const void* p, + const void* example_unsandboxed_ptr, + rlbox_wasm2c_sandbox* (*expensive_sandbox_finder)( + const void* example_unsandboxed_ptr)) + { + // on 32-bit platforms we don't assume the heap is aligned + if constexpr (sizeof(uintptr_t) == sizeof(uint32_t)) { + auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr); + return sandbox->template impl_get_sandboxed_pointer<T>(p); + } else { + if constexpr (std::is_function_v<std::remove_pointer_t<T>>) { + // swizzling function pointers needs access to the function pointer + // tables and thus cannot be done without context + auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr); + return sandbox->template impl_get_sandboxed_pointer<T>(p); + } else { + // Just clear the memory base to leave the offset + RLBOX_WASM2C_UNUSED(example_unsandboxed_ptr); + uintptr_t ret = reinterpret_cast<uintptr_t>(p) & + std::numeric_limits<T_PointerType>::max(); + return static_cast<T_PointerType>(ret); + } + } + } + + static inline bool impl_is_in_same_sandbox(const void* p1, const void* p2) + { + uintptr_t heap_base_mask = std::numeric_limits<uintptr_t>::max() & + ~(std::numeric_limits<T_PointerType>::max()); + return (reinterpret_cast<uintptr_t>(p1) & heap_base_mask) == + (reinterpret_cast<uintptr_t>(p2) & heap_base_mask); + } + + inline bool impl_is_pointer_in_sandbox_memory(const void* p) + { + size_t length = impl_get_total_memory(); + uintptr_t p_val = reinterpret_cast<uintptr_t>(p); + return p_val >= heap_base && p_val < (heap_base + length); + } + + inline bool impl_is_pointer_in_app_memory(const void* p) + { + return !(impl_is_pointer_in_sandbox_memory(p)); + } + + inline size_t impl_get_total_memory() { return sandbox_memory_info.size; } + + inline void* impl_get_memory_location() const + { + return sandbox_memory_info.data; + } + + template<typename T, typename T_Converted, typename... T_Args> + auto impl_invoke_with_func_ptr(T_Converted* func_ptr, T_Args&&... params) + { +#ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES + auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data(); +#endif + auto old_sandbox = thread_data.sandbox; + thread_data.sandbox = this; + auto on_exit = + detail::make_scope_exit([&] { thread_data.sandbox = old_sandbox; }); + + // WASM functions are mangled in the following manner + // 1. All primitive types are left as is and follow an LP32 machine model + // (as opposed to the possibly 64-bit application) + // 2. All pointers are changed to u32 types + // 3. Returned class are returned as an out parameter before the actual + // function parameters + // 4. All class parameters are passed as pointers (u32 types) + // 5. The heap address is passed in as the first argument to the function + // + // RLBox accounts for the first 2 differences in T_Converted type, but we + // need to handle the rest + + // Handle point 3 + using T_Ret = wasm2c_detail::return_argument<T_Converted>; + if constexpr (std::is_class_v<T_Ret>) { + using T_Conv1 = wasm2c_detail::change_return_type<T_Converted, void>; + using T_Conv2 = wasm2c_detail::prepend_arg_type<T_Conv1, T_PointerType>; + auto func_ptr_conv = + reinterpret_cast<T_Conv2*>(reinterpret_cast<uintptr_t>(func_ptr)); + ensure_return_slot_size(sizeof(T_Ret)); + impl_invoke_with_func_ptr<T>(func_ptr_conv, return_slot, params...); + + auto ptr = reinterpret_cast<T_Ret*>( + impl_get_unsandboxed_pointer<T_Ret*>(return_slot)); + T_Ret ret = *ptr; + return ret; + } + + // Handle point 4 + constexpr size_t alloc_length = [&] { + if constexpr (sizeof...(params) > 0) { + return ((std::is_class_v<T_Args> ? 1 : 0) + ...); + } else { + return 0; + } + }(); + + // 0 arg functions create 0 length arrays which is not allowed + T_PointerType allocations_buff[alloc_length == 0 ? 1 : alloc_length]; + T_PointerType* allocations = allocations_buff; + + auto serialize_class_arg = + [&](auto arg) -> std::conditional_t<std::is_class_v<decltype(arg)>, + T_PointerType, + decltype(arg)> { + using T_Arg = decltype(arg); + if constexpr (std::is_class_v<T_Arg>) { + auto slot = impl_malloc_in_sandbox(sizeof(T_Arg)); + auto ptr = + reinterpret_cast<T_Arg*>(impl_get_unsandboxed_pointer<T_Arg*>(slot)); + *ptr = arg; + allocations[0] = slot; + allocations++; + return slot; + } else { + return arg; + } + }; + + // 0 arg functions don't use serialize + RLBOX_WASM2C_UNUSED(serialize_class_arg); + + using T_ConvNoClass = + wasm2c_detail::change_class_arg_types<T_Converted, T_PointerType>; + + // Handle Point 5 + using T_ConvHeap = wasm2c_detail::prepend_arg_type< + T_ConvNoClass, + typename RLBOX_WASM_MODULE_TYPE_CURR::instance_t*>; + + // Function invocation + auto func_ptr_conv = + reinterpret_cast<T_ConvHeap*>(reinterpret_cast<uintptr_t>(func_ptr)); + + using T_NoVoidRet = + std::conditional_t<std::is_void_v<T_Ret>, uint32_t, T_Ret>; + T_NoVoidRet ret; + + if constexpr (std::is_void_v<T_Ret>) { + RLBOX_WASM2C_UNUSED(ret); + func_ptr_conv(&wasm2c_instance, serialize_class_arg(params)...); + } else { + ret = func_ptr_conv(&wasm2c_instance, serialize_class_arg(params)...); + } + + for (size_t i = 0; i < alloc_length; i++) { + impl_free_in_sandbox(allocations_buff[i]); + } + + if constexpr (!std::is_void_v<T_Ret>) { + return ret; + } + } + + inline T_PointerType impl_malloc_in_sandbox(size_t size) + { + if constexpr (sizeof(size) > sizeof(uint32_t)) { + detail::dynamic_check(size <= std::numeric_limits<uint32_t>::max(), + "Attempting to malloc more than the heap size"); + } + using T_Func = void*(size_t); + using T_Converted = T_PointerType(uint32_t); + T_PointerType ret = impl_invoke_with_func_ptr<T_Func, T_Converted>( + reinterpret_cast<T_Converted*>( + RLBOX_WASM_MODULE_TYPE_CURR::malloc_address), + static_cast<uint32_t>(size)); + return ret; + } + + inline void impl_free_in_sandbox(T_PointerType p) + { + using T_Func = void(void*); + using T_Converted = void(T_PointerType); + impl_invoke_with_func_ptr<T_Func, T_Converted>( + reinterpret_cast<T_Converted*>(RLBOX_WASM_MODULE_TYPE_CURR::free_address), + p); + } + +private: + // Should be called with callback_mutex held + uint32_t new_callback_slot() const + { + if (callback_free_list.size() > 0) { + uint32_t ret = callback_free_list.back(); + callback_free_list.pop_back(); + return ret; + } + + const uint32_t curr_size = sandbox_callback_table.size; + + detail::dynamic_check( + curr_size < sandbox_callback_table.max_size, + "Could not find an empty row in Wasm instance table. This would " + "happen if you have registered too many callbacks, or unsandboxed " + "too many function pointers."); + + wasm_rt_funcref_t func_val{ 0 }; + // on success, this returns the previous number of elements in the table + const uint32_t ret = + wasm_rt_grow_funcref_table(&sandbox_callback_table, 1, func_val); + + detail::dynamic_check( + ret != 0 && ret != (uint32_t)-1, + "Adding a new callback slot to the wasm instance failed."); + + // We have expanded the number of slots + // Previous slots size: ret + // New slot is at index: ret + const uint32_t slot_number = ret; + return slot_number; + } + + void free_callback_slot(uint32_t slot) const + { + callback_free_list.push_back(slot); + } + +public: + template<typename T_Ret, typename... T_Args> + inline T_PointerType impl_register_callback(void* key, void* callback) + { + bool found = false; + uint32_t found_loc = 0; + wasm_rt_function_ptr_t chosen_interceptor = nullptr; + + RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex); + + // need a compile time for loop as we we need I to be a compile time value + // this is because we are setting the I'th callback ineterceptor + wasm2c_detail::compile_time_for<MAX_CALLBACKS>([&](auto I) { + constexpr auto i = I.value; + if (!found && callbacks[i] == nullptr) { + found = true; + found_loc = i; + + if constexpr (std::is_class_v<T_Ret>) { + chosen_interceptor = (wasm_rt_function_ptr_t)( + callback_interceptor_promoted<i, T_Ret, T_Args...>); + } else { + chosen_interceptor = + (wasm_rt_function_ptr_t)(callback_interceptor<i, T_Ret, T_Args...>); + } + } + }); + + detail::dynamic_check( + found, + "Could not find an empty slot in sandbox function table. This would " + "happen if you have registered too many callbacks, or unsandboxed " + "too many function pointers. You can file a bug if you want to " + "increase the maximum allowed callbacks or unsadnboxed functions " + "pointers"); + + wasm_rt_funcref_t func_val; + func_val.func_type = get_wasm2c_func_index<T_Ret, T_Args...>(); + func_val.func = chosen_interceptor; + func_val.module_instance = &wasm2c_instance; + + const uint32_t slot_number = new_callback_slot(); + sandbox_callback_table.data[slot_number] = func_val; + + callback_unique_keys[found_loc] = key; + callbacks[found_loc] = callback; + callback_slot_assignment[found_loc] = slot_number; + slot_assignments[slot_number] = callback; + + return static_cast<T_PointerType>(slot_number); + } + + static inline std::pair<rlbox_wasm2c_sandbox*, void*> + impl_get_executed_callback_sandbox_and_key() + { +#ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES + auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data(); +#endif + auto sandbox = thread_data.sandbox; + auto callback_num = thread_data.last_callback_invoked; + void* key = sandbox->callback_unique_keys[callback_num]; + return std::make_pair(sandbox, key); + } + + template<typename T_Ret, typename... T_Args> + inline void impl_unregister_callback(void* key) + { + bool found = false; + uint32_t i = 0; + { + RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex); + for (; i < MAX_CALLBACKS; i++) { + if (callback_unique_keys[i] == key) { + const uint32_t slot_number = callback_slot_assignment[i]; + wasm_rt_funcref_t func_val{ 0 }; + sandbox_callback_table.data[slot_number] = func_val; + + callback_unique_keys[i] = nullptr; + callbacks[i] = nullptr; + callback_slot_assignment[i] = 0; + found = true; + break; + } + } + } + + detail::dynamic_check( + found, "Internal error: Could not find callback to unregister"); + + return; + } +}; + +} // namespace rlbox diff --git a/third_party/rlbox_wasm2c_sandbox/include/rlbox_wasm2c_tls.hpp b/third_party/rlbox_wasm2c_sandbox/include/rlbox_wasm2c_tls.hpp new file mode 100644 index 0000000000..b87ac7bf5f --- /dev/null +++ b/third_party/rlbox_wasm2c_sandbox/include/rlbox_wasm2c_tls.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include <stdint.h> + +namespace rlbox { + +class rlbox_wasm2c_sandbox; + +struct rlbox_wasm2c_sandbox_thread_data +{ + rlbox_wasm2c_sandbox* sandbox; + uint32_t last_callback_invoked; +}; + +#ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES + +rlbox_wasm2c_sandbox_thread_data* get_rlbox_wasm2c_sandbox_thread_data(); + +# define RLBOX_WASM2C_SANDBOX_STATIC_VARIABLES() \ + thread_local rlbox::rlbox_wasm2c_sandbox_thread_data \ + rlbox_wasm2c_sandbox_thread_info{ 0, 0 }; \ + \ + namespace rlbox { \ + rlbox_wasm2c_sandbox_thread_data* get_rlbox_wasm2c_sandbox_thread_data() \ + { \ + return &rlbox_wasm2c_sandbox_thread_info; \ + } \ + } \ + static_assert(true, "Enforce semi-colon") + +#endif + +} // namespace rlbox diff --git a/third_party/rlbox_wasm2c_sandbox/include/wasm2c_rt_mem.h b/third_party/rlbox_wasm2c_sandbox/include/wasm2c_rt_mem.h new file mode 100644 index 0000000000..e61b9fa3eb --- /dev/null +++ b/third_party/rlbox_wasm2c_sandbox/include/wasm2c_rt_mem.h @@ -0,0 +1,46 @@ +#ifndef WASM_RT_OS_H_ +#define WASM_RT_OS_H_ + +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> +#include <time.h> + +#include "wasm-rt.h" + +#ifdef __cplusplus +extern "C" +{ +#endif + + typedef struct w2c_env + { + wasm_rt_memory_t* sandbox_memory_info; + wasm_rt_funcref_table_t* sandbox_callback_table; + } w2c_env; + + wasm_rt_memory_t* w2c_env_memory(struct w2c_env* instance); + wasm_rt_funcref_table_t* w2c_env_0x5F_indirect_function_table( + struct w2c_env*); + + typedef struct w2c_mem_capacity + { + bool is_valid; + bool is_mem_32; + uint64_t max_pages; + uint64_t max_size; + } w2c_mem_capacity; + + w2c_mem_capacity get_valid_wasm2c_memory_capacity(uint64_t min_capacity, + bool is_mem_32); + + wasm_rt_memory_t create_wasm2c_memory( + uint32_t initial_pages, + const w2c_mem_capacity* custom_capacity); + void destroy_wasm2c_memory(wasm_rt_memory_t* memory); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/third_party/rlbox_wasm2c_sandbox/include/wasm2c_rt_minwasi.h b/third_party/rlbox_wasm2c_sandbox/include/wasm2c_rt_minwasi.h new file mode 100644 index 0000000000..51542b5d1e --- /dev/null +++ b/third_party/rlbox_wasm2c_sandbox/include/wasm2c_rt_minwasi.h @@ -0,0 +1,38 @@ +#ifndef WASM_RT_MINWASI_H_ +#define WASM_RT_MINWASI_H_ + +/* A minimum wasi implementation supporting only stdin, stdout, stderr, argv + * (upto 100 args) and clock functions. */ + +#include <stdbool.h> +#include <stdint.h> + +#include "wasm-rt.h" + +#ifdef __cplusplus +extern "C" +{ +#endif + + typedef struct w2c_wasi__snapshot__preview1 + { + wasm_rt_memory_t* instance_memory; + + uint32_t main_argc; + const char** main_argv; + + uint32_t env_count; + const char** env; + + void* clock_data; + } w2c_wasi__snapshot__preview1; + + bool minwasi_init(); + bool minwasi_init_instance(w2c_wasi__snapshot__preview1* wasi_data); + void minwasi_cleanup_instance(w2c_wasi__snapshot__preview1* wasi_data); + +#ifdef __cplusplus +} +#endif + +#endif |