diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-17 12:02:58 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-17 12:02:58 +0000 |
commit | 698f8c2f01ea549d77d7dc3338a12e04c11057b9 (patch) | |
tree | 173a775858bd501c378080a10dca74132f05bc50 /library/std/src/sys/sgx | |
parent | Initial commit. (diff) | |
download | rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.tar.xz rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.zip |
Adding upstream version 1.64.0+dfsg1.upstream/1.64.0+dfsg1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'library/std/src/sys/sgx')
36 files changed, 4650 insertions, 0 deletions
diff --git a/library/std/src/sys/sgx/abi/entry.S b/library/std/src/sys/sgx/abi/entry.S new file mode 100644 index 000000000..f61bcf06f --- /dev/null +++ b/library/std/src/sys/sgx/abi/entry.S @@ -0,0 +1,372 @@ +/* This symbol is used at runtime to figure out the virtual address that the */ +/* enclave is loaded at. */ +.section absolute +.global IMAGE_BASE +IMAGE_BASE: + +.section ".note.x86_64-fortanix-unknown-sgx", "", @note + .align 4 + .long 1f - 0f /* name length (not including padding) */ + .long 3f - 2f /* desc length (not including padding) */ + .long 1 /* type = NT_VERSION */ +0: .asciz "toolchain-version" /* name */ +1: .align 4 +2: .long 1 /* desc - toolchain version number, 32-bit LE */ +3: .align 4 + +.section .rodata +/* The XSAVE area needs to be a large chunk of readable memory, but since we are */ +/* going to restore everything to its initial state (XSTATE_BV=0), only certain */ +/* parts need to have a defined value. In particular: */ +/* */ +/* * MXCSR in the legacy area. This register is always restored if RFBM[1] or */ +/* RFBM[2] is set, regardless of the value of XSTATE_BV */ +/* * XSAVE header */ +.align 64 +.Lxsave_clear: +.org .+24 +.Lxsave_mxcsr: + .short 0x1f80 + +/* We can store a bunch of data in the gap between MXCSR and the XSAVE header */ + +/* The following symbols point at read-only data that will be filled in by the */ +/* post-linker. */ + +/* When using this macro, don't forget to adjust the linker version script! */ +.macro globvar name:req size:req + .global \name + .protected \name + .align \size + .size \name , \size + \name : + .org .+\size +.endm + /* The base address (relative to enclave start) of the heap area */ + globvar HEAP_BASE 8 + /* The heap size in bytes */ + globvar HEAP_SIZE 8 + /* Value of the RELA entry in the dynamic table */ + globvar RELA 8 + /* Value of the RELACOUNT entry in the dynamic table */ + globvar RELACOUNT 8 + /* The enclave size in bytes */ + globvar ENCLAVE_SIZE 8 + /* The base address (relative to enclave start) of the enclave configuration area */ + globvar CFGDATA_BASE 8 + /* Non-zero if debugging is enabled, zero otherwise */ + globvar DEBUG 1 + /* The base address (relative to enclave start) of the enclave text section */ + globvar TEXT_BASE 8 + /* The size in bytes of enclacve text section */ + globvar TEXT_SIZE 8 + /* The base address (relative to enclave start) of the enclave .eh_frame_hdr section */ + globvar EH_FRM_HDR_OFFSET 8 + /* The size in bytes of enclave .eh_frame_hdr section */ + globvar EH_FRM_HDR_LEN 8 + /* The base address (relative to enclave start) of the enclave .eh_frame section */ + globvar EH_FRM_OFFSET 8 + /* The size in bytes of enclacve .eh_frame section */ + globvar EH_FRM_LEN 8 + +.org .Lxsave_clear+512 +.Lxsave_header: + .int 0, 0 /* XSTATE_BV */ + .int 0, 0 /* XCOMP_BV */ + .org .+48 /* reserved bits */ + +.data +.Laborted: + .byte 0 + +/* TCS local storage section */ +.equ tcsls_tos, 0x00 /* initialized by loader to *offset* from image base to TOS */ +.equ tcsls_flags, 0x08 /* initialized by loader */ +.equ tcsls_flag_secondary, 0 /* initialized by loader; 0 = standard TCS, 1 = secondary TCS */ +.equ tcsls_flag_init_once, 1 /* initialized by loader to 0 */ +/* 14 unused bits */ +.equ tcsls_user_fcw, 0x0a +.equ tcsls_user_mxcsr, 0x0c +.equ tcsls_last_rsp, 0x10 /* initialized by loader to 0 */ +.equ tcsls_panic_last_rsp, 0x18 /* initialized by loader to 0 */ +.equ tcsls_debug_panic_buf_ptr, 0x20 /* initialized by loader to 0 */ +.equ tcsls_user_rsp, 0x28 +.equ tcsls_user_retip, 0x30 +.equ tcsls_user_rbp, 0x38 +.equ tcsls_user_r12, 0x40 +.equ tcsls_user_r13, 0x48 +.equ tcsls_user_r14, 0x50 +.equ tcsls_user_r15, 0x58 +.equ tcsls_tls_ptr, 0x60 +.equ tcsls_tcs_addr, 0x68 + +.macro load_tcsls_flag_secondary_bool reg:req comments:vararg + .ifne tcsls_flag_secondary /* to convert to a bool, must be the first bit */ + .abort + .endif + mov $(1<<tcsls_flag_secondary),%e\reg + and %gs:tcsls_flags,%\reg +.endm + +/* We place the ELF entry point in a separate section so it can be removed by + elf2sgxs */ +.section .text_no_sgx, "ax" +.Lelf_entry_error_msg: + .ascii "Error: This file is an SGX enclave which cannot be executed as a standard Linux binary.\nSee the installation guide at https://edp.fortanix.com/docs/installation/guide/ on how to use 'cargo run' or follow the steps at https://edp.fortanix.com/docs/tasks/deployment/ for manual deployment.\n" +.Lelf_entry_error_msg_end: + +.global elf_entry +.type elf_entry,function +elf_entry: +/* print error message */ + movq $2,%rdi /* write to stderr (fd 2) */ + lea .Lelf_entry_error_msg(%rip),%rsi + movq $.Lelf_entry_error_msg_end-.Lelf_entry_error_msg,%rdx +.Lelf_entry_call: + movq $1,%rax /* write() syscall */ + syscall + test %rax,%rax + jle .Lelf_exit /* exit on error */ + add %rax,%rsi + sub %rax,%rdx /* all chars written? */ + jnz .Lelf_entry_call + +.Lelf_exit: + movq $60,%rax /* exit() syscall */ + movq $1,%rdi /* exit code 1 */ + syscall + ud2 /* should not be reached */ +/* end elf_entry */ + +/* This code needs to be called *after* the enclave stack has been setup. */ +/* There are 3 places where this needs to happen, so this is put in a macro. */ +.macro entry_sanitize_final +/* Sanitize rflags received from user */ +/* - DF flag: x86-64 ABI requires DF to be unset at function entry/exit */ +/* - AC flag: AEX on misaligned memory accesses leaks side channel info */ + pushfq + andq $~0x40400, (%rsp) + popfq +/* check for abort */ + bt $0,.Laborted(%rip) + jc .Lreentry_panic +.endm + +.text +.global sgx_entry +.type sgx_entry,function +sgx_entry: +/* save user registers */ + mov %rcx,%gs:tcsls_user_retip + mov %rsp,%gs:tcsls_user_rsp + mov %rbp,%gs:tcsls_user_rbp + mov %r12,%gs:tcsls_user_r12 + mov %r13,%gs:tcsls_user_r13 + mov %r14,%gs:tcsls_user_r14 + mov %r15,%gs:tcsls_user_r15 + mov %rbx,%gs:tcsls_tcs_addr + stmxcsr %gs:tcsls_user_mxcsr + fnstcw %gs:tcsls_user_fcw + +/* check for debug buffer pointer */ + testb $0xff,DEBUG(%rip) + jz .Lskip_debug_init + mov %r10,%gs:tcsls_debug_panic_buf_ptr +.Lskip_debug_init: +/* reset cpu state */ + mov %rdx, %r10 + mov $-1, %rax + mov $-1, %rdx + xrstor .Lxsave_clear(%rip) + mov %r10, %rdx + +/* check if returning from usercall */ + mov %gs:tcsls_last_rsp,%r11 + test %r11,%r11 + jnz .Lusercall_ret +/* setup stack */ + mov %gs:tcsls_tos,%rsp /* initially, RSP is not set to the correct value */ + /* here. This is fixed below under "adjust stack". */ +/* check for thread init */ + bts $tcsls_flag_init_once,%gs:tcsls_flags + jc .Lskip_init +/* adjust stack */ + lea IMAGE_BASE(%rip),%rax + add %rax,%rsp + mov %rsp,%gs:tcsls_tos + entry_sanitize_final +/* call tcs_init */ +/* store caller-saved registers in callee-saved registers */ + mov %rdi,%rbx + mov %rsi,%r12 + mov %rdx,%r13 + mov %r8,%r14 + mov %r9,%r15 + load_tcsls_flag_secondary_bool di /* RDI = tcs_init() argument: secondary: bool */ + call tcs_init +/* reload caller-saved registers */ + mov %rbx,%rdi + mov %r12,%rsi + mov %r13,%rdx + mov %r14,%r8 + mov %r15,%r9 + jmp .Lafter_init +.Lskip_init: + entry_sanitize_final +.Lafter_init: +/* call into main entry point */ + load_tcsls_flag_secondary_bool cx /* RCX = entry() argument: secondary: bool */ + call entry /* RDI, RSI, RDX, R8, R9 passed in from userspace */ + mov %rax,%rsi /* RSI = return value */ + /* NOP: mov %rdx,%rdx */ /* RDX = return value */ + xor %rdi,%rdi /* RDI = normal exit */ +.Lexit: +/* clear general purpose register state */ + /* RAX overwritten by ENCLU */ + /* RBX set later */ + /* RCX overwritten by ENCLU */ + /* RDX contains return value */ + /* RSP set later */ + /* RBP set later */ + /* RDI contains exit mode */ + /* RSI contains return value */ + xor %r8,%r8 + xor %r9,%r9 + xor %r10,%r10 + xor %r11,%r11 + /* R12 ~ R15 set by sgx_exit */ +.Lsgx_exit: +/* clear extended register state */ + mov %rdx, %rcx /* save RDX */ + mov $-1, %rax + mov %rax, %rdx + xrstor .Lxsave_clear(%rip) + mov %rcx, %rdx /* restore RDX */ +/* clear flags */ + pushq $0 + popfq +/* restore user registers */ + mov %gs:tcsls_user_r12,%r12 + mov %gs:tcsls_user_r13,%r13 + mov %gs:tcsls_user_r14,%r14 + mov %gs:tcsls_user_r15,%r15 + mov %gs:tcsls_user_retip,%rbx + mov %gs:tcsls_user_rsp,%rsp + mov %gs:tcsls_user_rbp,%rbp + fldcw %gs:tcsls_user_fcw + ldmxcsr %gs:tcsls_user_mxcsr +/* exit enclave */ + mov $0x4,%eax /* EEXIT */ + enclu +/* end sgx_entry */ + +.Lreentry_panic: + orq $8,%rsp + jmp abort_reentry + +/* This *MUST* be called with 6 parameters, otherwise register information */ +/* might leak! */ +.global usercall +usercall: + test %rcx,%rcx /* check `abort` function argument */ + jnz .Lusercall_abort /* abort is set, jump to abort code (unlikely forward conditional) */ + jmp .Lusercall_save_state /* non-aborting usercall */ +.Lusercall_abort: +/* set aborted bit */ + movb $1,.Laborted(%rip) +/* save registers in DEBUG mode, so that debugger can reconstruct the stack */ + testb $0xff,DEBUG(%rip) + jz .Lusercall_noreturn +.Lusercall_save_state: +/* save callee-saved state */ + push %r15 + push %r14 + push %r13 + push %r12 + push %rbp + push %rbx + sub $8, %rsp + fstcw 4(%rsp) + stmxcsr (%rsp) + movq %rsp,%gs:tcsls_last_rsp +.Lusercall_noreturn: +/* clear general purpose register state */ + /* RAX overwritten by ENCLU */ + /* RBX set by sgx_exit */ + /* RCX overwritten by ENCLU */ + /* RDX contains parameter */ + /* RSP set by sgx_exit */ + /* RBP set by sgx_exit */ + /* RDI contains parameter */ + /* RSI contains parameter */ + /* R8 contains parameter */ + /* R9 contains parameter */ + xor %r10,%r10 + xor %r11,%r11 + /* R12 ~ R15 set by sgx_exit */ +/* extended registers/flags cleared by sgx_exit */ +/* exit */ + jmp .Lsgx_exit +.Lusercall_ret: + movq $0,%gs:tcsls_last_rsp +/* restore callee-saved state, cf. "save" above */ + mov %r11,%rsp + ldmxcsr (%rsp) + fldcw 4(%rsp) + add $8, %rsp + entry_sanitize_final + pop %rbx + pop %rbp + pop %r12 + pop %r13 + pop %r14 + pop %r15 +/* return */ + mov %rsi,%rax /* RAX = return value */ + /* NOP: mov %rdx,%rdx */ /* RDX = return value */ + pop %r11 + lfence + jmp *%r11 + +/* +The following functions need to be defined externally: +``` +// Called by entry code on re-entry after exit +extern "C" fn abort_reentry() -> !; + +// Called once when a TCS is first entered +extern "C" fn tcs_init(secondary: bool); + +// Standard TCS entrypoint +extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> (u64, u64); +``` +*/ + +.global get_tcs_addr +get_tcs_addr: + mov %gs:tcsls_tcs_addr,%rax + pop %r11 + lfence + jmp *%r11 + +.global get_tls_ptr +get_tls_ptr: + mov %gs:tcsls_tls_ptr,%rax + pop %r11 + lfence + jmp *%r11 + +.global set_tls_ptr +set_tls_ptr: + mov %rdi,%gs:tcsls_tls_ptr + pop %r11 + lfence + jmp *%r11 + +.global take_debug_panic_buf_ptr +take_debug_panic_buf_ptr: + xor %rax,%rax + xchg %gs:tcsls_debug_panic_buf_ptr,%rax + pop %r11 + lfence + jmp *%r11 diff --git a/library/std/src/sys/sgx/abi/mem.rs b/library/std/src/sys/sgx/abi/mem.rs new file mode 100644 index 000000000..18e6d5b3f --- /dev/null +++ b/library/std/src/sys/sgx/abi/mem.rs @@ -0,0 +1,93 @@ +use core::arch::asm; + +// Do not remove inline: will result in relocation failure +#[inline(always)] +pub(crate) unsafe fn rel_ptr<T>(offset: u64) -> *const T { + (image_base() + offset) as *const T +} + +// Do not remove inline: will result in relocation failure +#[inline(always)] +pub(crate) unsafe fn rel_ptr_mut<T>(offset: u64) -> *mut T { + (image_base() + offset) as *mut T +} + +extern "C" { + static ENCLAVE_SIZE: usize; + static HEAP_BASE: u64; + static HEAP_SIZE: usize; +} + +/// Returns the base memory address of the heap +pub(crate) fn heap_base() -> *const u8 { + unsafe { rel_ptr_mut(HEAP_BASE) } +} + +/// Returns the size of the heap +pub(crate) fn heap_size() -> usize { + unsafe { HEAP_SIZE } +} + +// Do not remove inline: will result in relocation failure +// For the same reason we use inline ASM here instead of an extern static to +// locate the base +/// Returns address at which current enclave is loaded. +#[inline(always)] +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn image_base() -> u64 { + let base: u64; + unsafe { + asm!( + "lea IMAGE_BASE(%rip), {}", + lateout(reg) base, + options(att_syntax, nostack, preserves_flags, nomem, pure), + ) + }; + base +} + +/// Returns `true` if the specified memory range is in the enclave. +/// +/// For safety, this function also checks whether the range given overflows, +/// returning `false` if so. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn is_enclave_range(p: *const u8, len: usize) -> bool { + let start = p as usize; + + // Subtract one from `len` when calculating `end` in case `p + len` is + // exactly at the end of addressable memory (`p + len` would overflow, but + // the range is still valid). + let end = if len == 0 { + start + } else if let Some(end) = start.checked_add(len - 1) { + end + } else { + return false; + }; + + let base = image_base() as usize; + start >= base && end <= base + (unsafe { ENCLAVE_SIZE } - 1) // unsafe ok: link-time constant +} + +/// Returns `true` if the specified memory range is in userspace. +/// +/// For safety, this function also checks whether the range given overflows, +/// returning `false` if so. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn is_user_range(p: *const u8, len: usize) -> bool { + let start = p as usize; + + // Subtract one from `len` when calculating `end` in case `p + len` is + // exactly at the end of addressable memory (`p + len` would overflow, but + // the range is still valid). + let end = if len == 0 { + start + } else if let Some(end) = start.checked_add(len - 1) { + end + } else { + return false; + }; + + let base = image_base() as usize; + end < base || start > base + (unsafe { ENCLAVE_SIZE } - 1) // unsafe ok: link-time constant +} diff --git a/library/std/src/sys/sgx/abi/mod.rs b/library/std/src/sys/sgx/abi/mod.rs new file mode 100644 index 000000000..9508c3874 --- /dev/null +++ b/library/std/src/sys/sgx/abi/mod.rs @@ -0,0 +1,108 @@ +#![cfg_attr(test, allow(unused))] // RT initialization logic is not compiled for test + +use crate::io::Write; +use core::arch::global_asm; +use core::sync::atomic::{AtomicUsize, Ordering}; + +// runtime features +pub(super) mod panic; +mod reloc; + +// library features +pub mod mem; +pub mod thread; +pub mod tls; +#[macro_use] +pub mod usercalls; + +#[cfg(not(test))] +global_asm!(include_str!("entry.S"), options(att_syntax)); + +#[repr(C)] +struct EntryReturn(u64, u64); + +#[cfg(not(test))] +#[no_mangle] +unsafe extern "C" fn tcs_init(secondary: bool) { + // Be very careful when changing this code: it runs before the binary has been + // relocated. Any indirect accesses to symbols will likely fail. + const UNINIT: usize = 0; + const BUSY: usize = 1; + const DONE: usize = 2; + // Three-state spin-lock + static RELOC_STATE: AtomicUsize = AtomicUsize::new(UNINIT); + + if secondary && RELOC_STATE.load(Ordering::Relaxed) != DONE { + rtabort!("Entered secondary TCS before main TCS!") + } + + // Try to atomically swap UNINIT with BUSY. The returned state can be: + match RELOC_STATE.compare_exchange(UNINIT, BUSY, Ordering::Acquire, Ordering::Acquire) { + // This thread just obtained the lock and other threads will observe BUSY + Ok(_) => { + reloc::relocate_elf_rela(); + RELOC_STATE.store(DONE, Ordering::Release); + } + // We need to wait until the initialization is done. + Err(BUSY) => { + while RELOC_STATE.load(Ordering::Acquire) == BUSY { + core::hint::spin_loop(); + } + } + // Initialization is done. + Err(DONE) => {} + _ => unreachable!(), + } +} + +// FIXME: this item should only exist if this is linked into an executable +// (main function exists). If this is a library, the crate author should be +// able to specify this +#[cfg(not(test))] +#[no_mangle] +extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn { + // FIXME: how to support TLS in library mode? + let tls = Box::new(tls::Tls::new()); + let tls_guard = unsafe { tls.activate() }; + + if secondary { + let join_notifier = super::thread::Thread::entry(); + drop(tls_guard); + drop(join_notifier); + + EntryReturn(0, 0) + } else { + extern "C" { + fn main(argc: isize, argv: *const *const u8) -> isize; + } + + // check entry is being called according to ABI + rtassert!(p3 == 0); + rtassert!(p4 == 0); + rtassert!(p5 == 0); + + unsafe { + // The actual types of these arguments are `p1: *const Arg, p2: + // usize`. We can't currently customize the argument list of Rust's + // main function, so we pass these in as the standard pointer-sized + // values in `argc` and `argv`. + let ret = main(p2 as _, p1 as _); + exit_with_code(ret) + } + } +} + +pub(super) fn exit_with_code(code: isize) -> ! { + if code != 0 { + if let Some(mut out) = panic::SgxPanicOutput::new() { + let _ = write!(out, "Exited with status code {code}"); + } + } + usercalls::exit(code != 0); +} + +#[cfg(not(test))] +#[no_mangle] +extern "C" fn abort_reentry() -> ! { + usercalls::exit(false) +} diff --git a/library/std/src/sys/sgx/abi/panic.rs b/library/std/src/sys/sgx/abi/panic.rs new file mode 100644 index 000000000..229b3b329 --- /dev/null +++ b/library/std/src/sys/sgx/abi/panic.rs @@ -0,0 +1,42 @@ +use super::usercalls::alloc::UserRef; +use crate::cmp; +use crate::io::{self, Write}; +use crate::mem; + +extern "C" { + fn take_debug_panic_buf_ptr() -> *mut u8; + static DEBUG: u8; +} + +pub(crate) struct SgxPanicOutput(Option<&'static mut UserRef<[u8]>>); + +fn empty_user_slice() -> &'static mut UserRef<[u8]> { + unsafe { UserRef::from_raw_parts_mut(1 as *mut u8, 0) } +} + +impl SgxPanicOutput { + pub(crate) fn new() -> Option<Self> { + if unsafe { DEBUG == 0 } { None } else { Some(SgxPanicOutput(None)) } + } + + fn init(&mut self) -> &mut &'static mut UserRef<[u8]> { + self.0.get_or_insert_with(|| unsafe { + let ptr = take_debug_panic_buf_ptr(); + if ptr.is_null() { empty_user_slice() } else { UserRef::from_raw_parts_mut(ptr, 1024) } + }) + } +} + +impl Write for SgxPanicOutput { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + let dst = mem::replace(self.init(), empty_user_slice()); + let written = cmp::min(src.len(), dst.len()); + dst[..written].copy_from_enclave(&src[..written]); + self.0 = Some(&mut dst[written..]); + Ok(written) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/library/std/src/sys/sgx/abi/reloc.rs b/library/std/src/sys/sgx/abi/reloc.rs new file mode 100644 index 000000000..02dff0ad2 --- /dev/null +++ b/library/std/src/sys/sgx/abi/reloc.rs @@ -0,0 +1,32 @@ +use super::mem; +use crate::slice::from_raw_parts; + +const R_X86_64_RELATIVE: u32 = 8; + +#[repr(packed)] +struct Rela<T> { + offset: T, + info: T, + addend: T, +} + +pub fn relocate_elf_rela() { + extern "C" { + static RELA: u64; + static RELACOUNT: usize; + } + + if unsafe { RELACOUNT } == 0 { + return; + } // unsafe ok: link-time constant + + let relas = unsafe { + from_raw_parts::<Rela<u64>>(mem::rel_ptr(RELA), RELACOUNT) // unsafe ok: link-time constant + }; + for rela in relas { + if rela.info != (/*0 << 32 |*/R_X86_64_RELATIVE as u64) { + rtabort!("Invalid relocation"); + } + unsafe { *mem::rel_ptr_mut::<*const ()>(rela.offset) = mem::rel_ptr(rela.addend) }; + } +} diff --git a/library/std/src/sys/sgx/abi/thread.rs b/library/std/src/sys/sgx/abi/thread.rs new file mode 100644 index 000000000..ef55b821a --- /dev/null +++ b/library/std/src/sys/sgx/abi/thread.rs @@ -0,0 +1,13 @@ +use fortanix_sgx_abi::Tcs; + +/// Gets the ID for the current thread. The ID is guaranteed to be unique among +/// all currently running threads in the enclave, and it is guaranteed to be +/// constant for the lifetime of the thread. More specifically for SGX, there +/// is a one-to-one correspondence of the ID to the address of the TCS. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn current() -> Tcs { + extern "C" { + fn get_tcs_addr() -> Tcs; + } + unsafe { get_tcs_addr() } +} diff --git a/library/std/src/sys/sgx/abi/tls/mod.rs b/library/std/src/sys/sgx/abi/tls/mod.rs new file mode 100644 index 000000000..13d96e9a6 --- /dev/null +++ b/library/std/src/sys/sgx/abi/tls/mod.rs @@ -0,0 +1,132 @@ +mod sync_bitset; + +use self::sync_bitset::*; +use crate::cell::Cell; +use crate::mem; +use crate::num::NonZeroUsize; +use crate::ptr; +use crate::sync::atomic::{AtomicUsize, Ordering}; + +#[cfg(target_pointer_width = "64")] +const USIZE_BITS: usize = 64; +const TLS_KEYS: usize = 128; // Same as POSIX minimum +const TLS_KEYS_BITSET_SIZE: usize = (TLS_KEYS + (USIZE_BITS - 1)) / USIZE_BITS; + +#[cfg_attr(test, linkage = "available_externally")] +#[export_name = "_ZN16__rust_internals3std3sys3sgx3abi3tls14TLS_KEY_IN_USEE"] +static TLS_KEY_IN_USE: SyncBitset = SYNC_BITSET_INIT; +macro_rules! dup { + ((* $($exp:tt)*) $($val:tt)*) => (dup!( ($($exp)*) $($val)* $($val)* )); + (() $($val:tt)*) => ([$($val),*]) +} +#[cfg_attr(test, linkage = "available_externally")] +#[export_name = "_ZN16__rust_internals3std3sys3sgx3abi3tls14TLS_DESTRUCTORE"] +static TLS_DESTRUCTOR: [AtomicUsize; TLS_KEYS] = dup!((* * * * * * *) (AtomicUsize::new(0))); + +extern "C" { + fn get_tls_ptr() -> *const u8; + fn set_tls_ptr(tls: *const u8); +} + +#[derive(Copy, Clone)] +#[repr(C)] +pub struct Key(NonZeroUsize); + +impl Key { + fn to_index(self) -> usize { + self.0.get() - 1 + } + + fn from_index(index: usize) -> Self { + Key(NonZeroUsize::new(index + 1).unwrap()) + } + + pub fn as_usize(self) -> usize { + self.0.get() + } + + pub fn from_usize(index: usize) -> Self { + Key(NonZeroUsize::new(index).unwrap()) + } +} + +#[repr(C)] +pub struct Tls { + data: [Cell<*mut u8>; TLS_KEYS], +} + +pub struct ActiveTls<'a> { + tls: &'a Tls, +} + +impl<'a> Drop for ActiveTls<'a> { + fn drop(&mut self) { + let value_with_destructor = |key: usize| { + let ptr = TLS_DESTRUCTOR[key].load(Ordering::Relaxed); + unsafe { mem::transmute::<_, Option<unsafe extern "C" fn(*mut u8)>>(ptr) } + .map(|dtor| (&self.tls.data[key], dtor)) + }; + + let mut any_non_null_dtor = true; + while any_non_null_dtor { + any_non_null_dtor = false; + for (value, dtor) in TLS_KEY_IN_USE.iter().filter_map(&value_with_destructor) { + let value = value.replace(ptr::null_mut()); + if !value.is_null() { + any_non_null_dtor = true; + unsafe { dtor(value) } + } + } + } + } +} + +impl Tls { + pub fn new() -> Tls { + Tls { data: dup!((* * * * * * *) (Cell::new(ptr::null_mut()))) } + } + + pub unsafe fn activate(&self) -> ActiveTls<'_> { + // FIXME: Needs safety information. See entry.S for `set_tls_ptr` definition. + unsafe { set_tls_ptr(self as *const Tls as _) }; + ActiveTls { tls: self } + } + + #[allow(unused)] + pub unsafe fn activate_persistent(self: Box<Self>) { + // FIXME: Needs safety information. See entry.S for `set_tls_ptr` definition. + unsafe { set_tls_ptr((&*self) as *const Tls as _) }; + mem::forget(self); + } + + unsafe fn current<'a>() -> &'a Tls { + // FIXME: Needs safety information. See entry.S for `set_tls_ptr` definition. + unsafe { &*(get_tls_ptr() as *const Tls) } + } + + pub fn create(dtor: Option<unsafe extern "C" fn(*mut u8)>) -> Key { + let index = if let Some(index) = TLS_KEY_IN_USE.set() { + index + } else { + rtabort!("TLS limit exceeded") + }; + TLS_DESTRUCTOR[index].store(dtor.map_or(0, |f| f as usize), Ordering::Relaxed); + Key::from_index(index) + } + + pub fn set(key: Key, value: *mut u8) { + let index = key.to_index(); + rtassert!(TLS_KEY_IN_USE.get(index)); + unsafe { Self::current() }.data[index].set(value); + } + + pub fn get(key: Key) -> *mut u8 { + let index = key.to_index(); + rtassert!(TLS_KEY_IN_USE.get(index)); + unsafe { Self::current() }.data[index].get() + } + + pub fn destroy(key: Key) { + TLS_KEY_IN_USE.clear(key.to_index()); + } +} diff --git a/library/std/src/sys/sgx/abi/tls/sync_bitset.rs b/library/std/src/sys/sgx/abi/tls/sync_bitset.rs new file mode 100644 index 000000000..4eeff8f6e --- /dev/null +++ b/library/std/src/sys/sgx/abi/tls/sync_bitset.rs @@ -0,0 +1,85 @@ +#[cfg(test)] +mod tests; + +use super::{TLS_KEYS_BITSET_SIZE, USIZE_BITS}; +use crate::iter::{Enumerate, Peekable}; +use crate::slice::Iter; +use crate::sync::atomic::{AtomicUsize, Ordering}; + +/// A bitset that can be used synchronously. +pub(super) struct SyncBitset([AtomicUsize; TLS_KEYS_BITSET_SIZE]); + +pub(super) const SYNC_BITSET_INIT: SyncBitset = + SyncBitset([AtomicUsize::new(0), AtomicUsize::new(0)]); + +impl SyncBitset { + pub fn get(&self, index: usize) -> bool { + let (hi, lo) = Self::split(index); + (self.0[hi].load(Ordering::Relaxed) & lo) != 0 + } + + /// Not atomic. + pub fn iter(&self) -> SyncBitsetIter<'_> { + SyncBitsetIter { iter: self.0.iter().enumerate().peekable(), elem_idx: 0 } + } + + pub fn clear(&self, index: usize) { + let (hi, lo) = Self::split(index); + self.0[hi].fetch_and(!lo, Ordering::Relaxed); + } + + /// Sets any unset bit. Not atomic. Returns `None` if all bits were + /// observed to be set. + pub fn set(&self) -> Option<usize> { + 'elems: for (idx, elem) in self.0.iter().enumerate() { + let mut current = elem.load(Ordering::Relaxed); + loop { + if 0 == !current { + continue 'elems; + } + let trailing_ones = (!current).trailing_zeros() as usize; + match elem.compare_exchange( + current, + current | (1 << trailing_ones), + Ordering::AcqRel, + Ordering::Relaxed, + ) { + Ok(_) => return Some(idx * USIZE_BITS + trailing_ones), + Err(previous) => current = previous, + } + } + } + None + } + + fn split(index: usize) -> (usize, usize) { + (index / USIZE_BITS, 1 << (index % USIZE_BITS)) + } +} + +pub(super) struct SyncBitsetIter<'a> { + iter: Peekable<Enumerate<Iter<'a, AtomicUsize>>>, + elem_idx: usize, +} + +impl<'a> Iterator for SyncBitsetIter<'a> { + type Item = usize; + + fn next(&mut self) -> Option<usize> { + self.iter.peek().cloned().and_then(|(idx, elem)| { + let elem = elem.load(Ordering::Relaxed); + let low_mask = (1 << self.elem_idx) - 1; + let next = elem & !low_mask; + let next_idx = next.trailing_zeros() as usize; + self.elem_idx = next_idx + 1; + if self.elem_idx >= 64 { + self.elem_idx = 0; + self.iter.next(); + } + match next_idx { + 64 => self.next(), + _ => Some(idx * USIZE_BITS + next_idx), + } + }) + } +} diff --git a/library/std/src/sys/sgx/abi/tls/sync_bitset/tests.rs b/library/std/src/sys/sgx/abi/tls/sync_bitset/tests.rs new file mode 100644 index 000000000..d7eb2e139 --- /dev/null +++ b/library/std/src/sys/sgx/abi/tls/sync_bitset/tests.rs @@ -0,0 +1,25 @@ +use super::*; + +fn test_data(bitset: [usize; 2], bit_indices: &[usize]) { + let set = SyncBitset([AtomicUsize::new(bitset[0]), AtomicUsize::new(bitset[1])]); + assert_eq!(set.iter().collect::<Vec<_>>(), bit_indices); + for &i in bit_indices { + assert!(set.get(i)); + } +} + +#[test] +fn iter() { + test_data([0b0110_1001, 0], &[0, 3, 5, 6]); + test_data([0x8000_0000_0000_0000, 0x8000_0000_0000_0001], &[63, 64, 127]); + test_data([0, 0], &[]); +} + +#[test] +fn set_get_clear() { + let set = SYNC_BITSET_INIT; + let key = set.set().unwrap(); + assert!(set.get(key)); + set.clear(key); + assert!(!set.get(key)); +} diff --git a/library/std/src/sys/sgx/abi/usercalls/alloc.rs b/library/std/src/sys/sgx/abi/usercalls/alloc.rs new file mode 100644 index 000000000..ea24fedd0 --- /dev/null +++ b/library/std/src/sys/sgx/abi/usercalls/alloc.rs @@ -0,0 +1,732 @@ +#![allow(unused)] + +use crate::arch::asm; +use crate::cell::UnsafeCell; +use crate::cmp; +use crate::convert::TryInto; +use crate::mem; +use crate::ops::{CoerceUnsized, Deref, DerefMut, Index, IndexMut}; +use crate::ptr::{self, NonNull}; +use crate::slice; +use crate::slice::SliceIndex; + +use super::super::mem::{is_enclave_range, is_user_range}; +use fortanix_sgx_abi::*; + +/// A type that can be safely read from or written to userspace. +/// +/// Non-exhaustive list of specific requirements for reading and writing: +/// * **Type is `Copy`** (and therefore also not `Drop`). Copies will be +/// created when copying from/to userspace. Destructors will not be called. +/// * **No references or Rust-style owned pointers** (`Vec`, `Arc`, etc.). When +/// reading from userspace, references into enclave memory must not be +/// created. Also, only enclave memory is considered managed by the Rust +/// compiler's static analysis. When reading from userspace, there can be no +/// guarantee that the value correctly adheres to the expectations of the +/// type. When writing to userspace, memory addresses of data in enclave +/// memory must not be leaked for confidentiality reasons. `User` and +/// `UserRef` are also not allowed for the same reasons. +/// * **No fat pointers.** When reading from userspace, the size or vtable +/// pointer could be automatically interpreted and used by the code. When +/// writing to userspace, memory addresses of data in enclave memory (such +/// as vtable pointers) must not be leaked for confidentiality reasons. +/// +/// Non-exhaustive list of specific requirements for reading from userspace: +/// * **Any bit pattern is valid** for this type (no `enum`s). There can be no +/// guarantee that the value correctly adheres to the expectations of the +/// type, so any value must be valid for this type. +/// +/// Non-exhaustive list of specific requirements for writing to userspace: +/// * **No pointers to enclave memory.** Memory addresses of data in enclave +/// memory must not be leaked for confidentiality reasons. +/// * **No internal padding.** Padding might contain previously-initialized +/// secret data stored at that memory location and must not be leaked for +/// confidentiality reasons. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub unsafe trait UserSafeSized: Copy + Sized {} + +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl UserSafeSized for u8 {} +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl<T> UserSafeSized for FifoDescriptor<T> {} +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl UserSafeSized for ByteBuffer {} +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl UserSafeSized for Usercall {} +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl UserSafeSized for Return {} +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl<T: UserSafeSized> UserSafeSized for [T; 2] {} + +/// A type that can be represented in memory as one or more `UserSafeSized`s. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub unsafe trait UserSafe { + /// Equivalent to `mem::align_of::<Self>`. + fn align_of() -> usize; + + /// Construct a pointer to `Self` given a memory range in user space. + /// + /// N.B., this takes a size, not a length! + /// + /// # Safety + /// + /// The caller must ensure the memory range is in user memory, is the + /// correct size and is correctly aligned and points to the right type. + unsafe fn from_raw_sized_unchecked(ptr: *mut u8, size: usize) -> *mut Self; + + /// Construct a pointer to `Self` given a memory range. + /// + /// N.B., this takes a size, not a length! + /// + /// # Safety + /// + /// The caller must ensure the memory range points to the correct type. + /// + /// # Panics + /// + /// This function panics if: + /// + /// * the pointer is not aligned. + /// * the pointer is null. + /// * the pointed-to range does not fit in the address space. + /// * the pointed-to range is not in user memory. + unsafe fn from_raw_sized(ptr: *mut u8, size: usize) -> NonNull<Self> { + assert!(ptr.wrapping_add(size) >= ptr); + // SAFETY: The caller has guaranteed the pointer is valid + let ret = unsafe { Self::from_raw_sized_unchecked(ptr, size) }; + unsafe { + Self::check_ptr(ret); + NonNull::new_unchecked(ret as _) + } + } + + /// Checks if a pointer may point to `Self` in user memory. + /// + /// # Safety + /// + /// The caller must ensure the memory range points to the correct type and + /// length (if this is a slice). + /// + /// # Panics + /// + /// This function panics if: + /// + /// * the pointer is not aligned. + /// * the pointer is null. + /// * the pointed-to range is not in user memory. + unsafe fn check_ptr(ptr: *const Self) { + let is_aligned = |p| -> bool { 0 == (p as usize) & (Self::align_of() - 1) }; + + assert!(is_aligned(ptr as *const u8)); + assert!(is_user_range(ptr as _, mem::size_of_val(unsafe { &*ptr }))); + assert!(!ptr.is_null()); + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl<T: UserSafeSized> UserSafe for T { + fn align_of() -> usize { + mem::align_of::<T>() + } + + unsafe fn from_raw_sized_unchecked(ptr: *mut u8, size: usize) -> *mut Self { + assert_eq!(size, mem::size_of::<T>()); + ptr as _ + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +unsafe impl<T: UserSafeSized> UserSafe for [T] { + fn align_of() -> usize { + mem::align_of::<T>() + } + + /// # Safety + /// Behavior is undefined if any of these conditions are violated: + /// * `ptr` must be [valid] for writes of `size` many bytes, and it must be + /// properly aligned. + /// + /// [valid]: core::ptr#safety + /// # Panics + /// + /// This function panics if: + /// + /// * the element size is not a factor of the size + unsafe fn from_raw_sized_unchecked(ptr: *mut u8, size: usize) -> *mut Self { + let elem_size = mem::size_of::<T>(); + assert_eq!(size % elem_size, 0); + let len = size / elem_size; + // SAFETY: The caller must uphold the safety contract for `from_raw_sized_unchecked` + unsafe { slice::from_raw_parts_mut(ptr as _, len) } + } +} + +/// A reference to some type in userspace memory. `&UserRef<T>` is equivalent +/// to `&T` in enclave memory. Access to the memory is only allowed by copying +/// to avoid TOCTTOU issues. After copying, code should make sure to completely +/// check the value before use. +/// +/// It is also possible to obtain a mutable reference `&mut UserRef<T>`. Unlike +/// regular mutable references, these are not exclusive. Userspace may always +/// write to the backing memory at any time, so it can't be assumed that there +/// the pointed-to memory is uniquely borrowed. The two different reference types +/// are used solely to indicate intent: a mutable reference is for writing to +/// user memory, an immutable reference for reading from user memory. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub struct UserRef<T: ?Sized>(UnsafeCell<T>); +/// An owned type in userspace memory. `User<T>` is equivalent to `Box<T>` in +/// enclave memory. Access to the memory is only allowed by copying to avoid +/// TOCTTOU issues. The user memory will be freed when the value is dropped. +/// After copying, code should make sure to completely check the value before +/// use. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub struct User<T: UserSafe + ?Sized>(NonNull<UserRef<T>>); + +trait NewUserRef<T: ?Sized> { + unsafe fn new_userref(v: T) -> Self; +} + +impl<T: ?Sized> NewUserRef<*mut T> for NonNull<UserRef<T>> { + unsafe fn new_userref(v: *mut T) -> Self { + // SAFETY: The caller has guaranteed the pointer is valid + unsafe { NonNull::new_unchecked(v as _) } + } +} + +impl<T: ?Sized> NewUserRef<NonNull<T>> for NonNull<UserRef<T>> { + unsafe fn new_userref(v: NonNull<T>) -> Self { + // SAFETY: The caller has guaranteed the pointer is valid + unsafe { NonNull::new_userref(v.as_ptr()) } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T: ?Sized> User<T> +where + T: UserSafe, +{ + // This function returns memory that is practically uninitialized, but is + // not considered "unspecified" or "undefined" for purposes of an + // optimizing compiler. This is achieved by returning a pointer from + // from outside as obtained by `super::alloc`. + fn new_uninit_bytes(size: usize) -> Self { + unsafe { + // Mustn't call alloc with size 0. + let ptr = if size > 0 { + // `copy_to_userspace` is more efficient when data is 8-byte aligned + let alignment = cmp::max(T::align_of(), 8); + rtunwrap!(Ok, super::alloc(size, alignment)) as _ + } else { + T::align_of() as _ // dangling pointer ok for size 0 + }; + if let Ok(v) = crate::panic::catch_unwind(|| T::from_raw_sized(ptr, size)) { + User(NonNull::new_userref(v)) + } else { + rtabort!("Got invalid pointer from alloc() usercall") + } + } + } + + /// Copies `val` into freshly allocated space in user memory. + pub fn new_from_enclave(val: &T) -> Self { + unsafe { + let mut user = Self::new_uninit_bytes(mem::size_of_val(val)); + user.copy_from_enclave(val); + user + } + } + + /// Creates an owned `User<T>` from a raw pointer. + /// + /// # Safety + /// The caller must ensure `ptr` points to `T`, is freeable with the `free` + /// usercall and the alignment of `T`, and is uniquely owned. + /// + /// # Panics + /// This function panics if: + /// + /// * The pointer is not aligned + /// * The pointer is null + /// * The pointed-to range is not in user memory + pub unsafe fn from_raw(ptr: *mut T) -> Self { + // SAFETY: the caller must uphold the safety contract for `from_raw`. + unsafe { T::check_ptr(ptr) }; + User(unsafe { NonNull::new_userref(ptr) }) + } + + /// Converts this value into a raw pointer. The value will no longer be + /// automatically freed. + pub fn into_raw(self) -> *mut T { + let ret = self.0; + mem::forget(self); + ret.as_ptr() as _ + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T> User<T> +where + T: UserSafe, +{ + /// Allocate space for `T` in user memory. + pub fn uninitialized() -> Self { + Self::new_uninit_bytes(mem::size_of::<T>()) + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T> User<[T]> +where + [T]: UserSafe, +{ + /// Allocate space for a `[T]` of `n` elements in user memory. + pub fn uninitialized(n: usize) -> Self { + Self::new_uninit_bytes(n * mem::size_of::<T>()) + } + + /// Creates an owned `User<[T]>` from a raw thin pointer and a slice length. + /// + /// # Safety + /// The caller must ensure `ptr` points to `len` elements of `T`, is + /// freeable with the `free` usercall and the alignment of `T`, and is + /// uniquely owned. + /// + /// # Panics + /// This function panics if: + /// + /// * The pointer is not aligned + /// * The pointer is null + /// * The pointed-to range does not fit in the address space + /// * The pointed-to range is not in user memory + pub unsafe fn from_raw_parts(ptr: *mut T, len: usize) -> Self { + User(unsafe { + NonNull::new_userref(<[T]>::from_raw_sized(ptr as _, len * mem::size_of::<T>())) + }) + } +} + +/// Copies `len` bytes of data from enclave pointer `src` to userspace `dst` +/// +/// This function mitigates stale data vulnerabilities by ensuring all writes to untrusted memory are either: +/// - preceded by the VERW instruction and followed by the MFENCE; LFENCE instruction sequence +/// - or are in multiples of 8 bytes, aligned to an 8-byte boundary +/// +/// # Panics +/// This function panics if: +/// +/// * The `src` pointer is null +/// * The `dst` pointer is null +/// * The `src` memory range is not in enclave memory +/// * The `dst` memory range is not in user memory +/// +/// # References +/// - https://www.intel.com/content/www/us/en/security-center/advisory/intel-sa-00615.html +/// - https://www.intel.com/content/www/us/en/developer/articles/technical/software-security-guidance/technical-documentation/processor-mmio-stale-data-vulnerabilities.html#inpage-nav-3-2-2 +pub(crate) unsafe fn copy_to_userspace(src: *const u8, dst: *mut u8, len: usize) { + unsafe fn copy_bytewise_to_userspace(src: *const u8, dst: *mut u8, len: usize) { + unsafe { + let mut seg_sel: u16 = 0; + for off in 0..len { + asm!(" + mov %ds, ({seg_sel}) + verw ({seg_sel}) + movb {val}, ({dst}) + mfence + lfence + ", + val = in(reg_byte) *src.offset(off as isize), + dst = in(reg) dst.offset(off as isize), + seg_sel = in(reg) &mut seg_sel, + options(nostack, att_syntax) + ); + } + } + } + + unsafe fn copy_aligned_quadwords_to_userspace(src: *const u8, dst: *mut u8, len: usize) { + unsafe { + asm!( + "rep movsq (%rsi), (%rdi)", + inout("rcx") len / 8 => _, + inout("rdi") dst => _, + inout("rsi") src => _, + options(att_syntax, nostack, preserves_flags) + ); + } + } + assert!(!src.is_null()); + assert!(!dst.is_null()); + assert!(is_enclave_range(src, len)); + assert!(is_user_range(dst, len)); + assert!(len < isize::MAX as usize); + assert!(!(src as usize).overflowing_add(len).1); + assert!(!(dst as usize).overflowing_add(len).1); + + if len < 8 { + // Can't align on 8 byte boundary: copy safely byte per byte + unsafe { + copy_bytewise_to_userspace(src, dst, len); + } + } else if len % 8 == 0 && dst as usize % 8 == 0 { + // Copying 8-byte aligned quadwords: copy quad word per quad word + unsafe { + copy_aligned_quadwords_to_userspace(src, dst, len); + } + } else { + // Split copies into three parts: + // +--------+ + // | small0 | Chunk smaller than 8 bytes + // +--------+ + // | big | Chunk 8-byte aligned, and size a multiple of 8 bytes + // +--------+ + // | small1 | Chunk smaller than 8 bytes + // +--------+ + + unsafe { + // Copy small0 + let small0_size = (8 - dst as usize % 8) as u8; + let small0_src = src; + let small0_dst = dst; + copy_bytewise_to_userspace(small0_src as _, small0_dst, small0_size as _); + + // Copy big + let small1_size = ((len - small0_size as usize) % 8) as u8; + let big_size = len - small0_size as usize - small1_size as usize; + let big_src = src.offset(small0_size as _); + let big_dst = dst.offset(small0_size as _); + copy_aligned_quadwords_to_userspace(big_src as _, big_dst, big_size); + + // Copy small1 + let small1_src = src.offset(big_size as isize + small0_size as isize); + let small1_dst = dst.offset(big_size as isize + small0_size as isize); + copy_bytewise_to_userspace(small1_src, small1_dst, small1_size as _); + } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T: ?Sized> UserRef<T> +where + T: UserSafe, +{ + /// Creates a `&UserRef<[T]>` from a raw pointer. + /// + /// # Safety + /// The caller must ensure `ptr` points to `T`. + /// + /// # Panics + /// This function panics if: + /// + /// * The pointer is not aligned + /// * The pointer is null + /// * The pointed-to range is not in user memory + pub unsafe fn from_ptr<'a>(ptr: *const T) -> &'a Self { + // SAFETY: The caller must uphold the safety contract for `from_ptr`. + unsafe { T::check_ptr(ptr) }; + unsafe { &*(ptr as *const Self) } + } + + /// Creates a `&mut UserRef<[T]>` from a raw pointer. See the struct + /// documentation for the nuances regarding a `&mut UserRef<T>`. + /// + /// # Safety + /// The caller must ensure `ptr` points to `T`. + /// + /// # Panics + /// This function panics if: + /// + /// * The pointer is not aligned + /// * The pointer is null + /// * The pointed-to range is not in user memory + pub unsafe fn from_mut_ptr<'a>(ptr: *mut T) -> &'a mut Self { + // SAFETY: The caller must uphold the safety contract for `from_mut_ptr`. + unsafe { T::check_ptr(ptr) }; + unsafe { &mut *(ptr as *mut Self) } + } + + /// Copies `val` into user memory. + /// + /// # Panics + /// This function panics if the destination doesn't have the same size as + /// the source. This can happen for dynamically-sized types such as slices. + pub fn copy_from_enclave(&mut self, val: &T) { + unsafe { + assert_eq!(mem::size_of_val(val), mem::size_of_val(&*self.0.get())); + copy_to_userspace( + val as *const T as *const u8, + self.0.get() as *mut T as *mut u8, + mem::size_of_val(val), + ); + } + } + + /// Copies the value from user memory and place it into `dest`. + /// + /// # Panics + /// This function panics if the destination doesn't have the same size as + /// the source. This can happen for dynamically-sized types such as slices. + pub fn copy_to_enclave(&self, dest: &mut T) { + unsafe { + assert_eq!(mem::size_of_val(dest), mem::size_of_val(&*self.0.get())); + ptr::copy( + self.0.get() as *const T as *const u8, + dest as *mut T as *mut u8, + mem::size_of_val(dest), + ); + } + } + + /// Obtain a raw pointer from this reference. + pub fn as_raw_ptr(&self) -> *const T { + self as *const _ as _ + } + + /// Obtain a raw pointer from this reference. + pub fn as_raw_mut_ptr(&mut self) -> *mut T { + self as *mut _ as _ + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T> UserRef<T> +where + T: UserSafe, +{ + /// Copies the value from user memory into enclave memory. + pub fn to_enclave(&self) -> T { + unsafe { ptr::read(self.0.get()) } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T> UserRef<[T]> +where + [T]: UserSafe, +{ + /// Creates a `&UserRef<[T]>` from a raw thin pointer and a slice length. + /// + /// # Safety + /// The caller must ensure `ptr` points to `n` elements of `T`. + /// + /// # Panics + /// This function panics if: + /// + /// * The pointer is not aligned + /// * The pointer is null + /// * The pointed-to range does not fit in the address space + /// * The pointed-to range is not in user memory + pub unsafe fn from_raw_parts<'a>(ptr: *const T, len: usize) -> &'a Self { + // SAFETY: The caller must uphold the safety contract for `from_raw_parts`. + unsafe { + &*(<[T]>::from_raw_sized(ptr as _, len * mem::size_of::<T>()).as_ptr() as *const Self) + } + } + + /// Creates a `&mut UserRef<[T]>` from a raw thin pointer and a slice length. + /// See the struct documentation for the nuances regarding a + /// `&mut UserRef<T>`. + /// + /// # Safety + /// The caller must ensure `ptr` points to `n` elements of `T`. + /// + /// # Panics + /// This function panics if: + /// + /// * The pointer is not aligned + /// * The pointer is null + /// * The pointed-to range does not fit in the address space + /// * The pointed-to range is not in user memory + pub unsafe fn from_raw_parts_mut<'a>(ptr: *mut T, len: usize) -> &'a mut Self { + // SAFETY: The caller must uphold the safety contract for `from_raw_parts_mut`. + unsafe { + &mut *(<[T]>::from_raw_sized(ptr as _, len * mem::size_of::<T>()).as_ptr() as *mut Self) + } + } + + /// Obtain a raw pointer to the first element of this user slice. + pub fn as_ptr(&self) -> *const T { + self.0.get() as _ + } + + /// Obtain a raw pointer to the first element of this user slice. + pub fn as_mut_ptr(&mut self) -> *mut T { + self.0.get() as _ + } + + /// Obtain the number of elements in this user slice. + pub fn len(&self) -> usize { + unsafe { (*self.0.get()).len() } + } + + /// Copies the value from user memory and place it into `dest`. Afterwards, + /// `dest` will contain exactly `self.len()` elements. + /// + /// # Panics + /// This function panics if the destination doesn't have the same size as + /// the source. This can happen for dynamically-sized types such as slices. + pub fn copy_to_enclave_vec(&self, dest: &mut Vec<T>) { + if let Some(missing) = self.len().checked_sub(dest.capacity()) { + dest.reserve(missing) + } + // SAFETY: We reserve enough space above. + unsafe { dest.set_len(self.len()) }; + self.copy_to_enclave(&mut dest[..]); + } + + /// Copies the value from user memory into a vector in enclave memory. + pub fn to_enclave(&self) -> Vec<T> { + let mut ret = Vec::with_capacity(self.len()); + self.copy_to_enclave_vec(&mut ret); + ret + } + + /// Returns an iterator over the slice. + pub fn iter(&self) -> Iter<'_, T> + where + T: UserSafe, // FIXME: should be implied by [T]: UserSafe? + { + unsafe { Iter((&*self.as_raw_ptr()).iter()) } + } + + /// Returns an iterator that allows modifying each value. + pub fn iter_mut(&mut self) -> IterMut<'_, T> + where + T: UserSafe, // FIXME: should be implied by [T]: UserSafe? + { + unsafe { IterMut((&mut *self.as_raw_mut_ptr()).iter_mut()) } + } +} + +/// Immutable user slice iterator +/// +/// This struct is created by the `iter` method on `UserRef<[T]>`. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub struct Iter<'a, T: 'a + UserSafe>(slice::Iter<'a, T>); + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<'a, T: UserSafe> Iterator for Iter<'a, T> { + type Item = &'a UserRef<T>; + + #[inline] + fn next(&mut self) -> Option<Self::Item> { + unsafe { self.0.next().map(|e| UserRef::from_ptr(e)) } + } +} + +/// Mutable user slice iterator +/// +/// This struct is created by the `iter_mut` method on `UserRef<[T]>`. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub struct IterMut<'a, T: 'a + UserSafe>(slice::IterMut<'a, T>); + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<'a, T: UserSafe> Iterator for IterMut<'a, T> { + type Item = &'a mut UserRef<T>; + + #[inline] + fn next(&mut self) -> Option<Self::Item> { + unsafe { self.0.next().map(|e| UserRef::from_mut_ptr(e)) } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T: ?Sized> Deref for User<T> +where + T: UserSafe, +{ + type Target = UserRef<T>; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.0.as_ptr() } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T: ?Sized> DerefMut for User<T> +where + T: UserSafe, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.0.as_ptr() } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T: ?Sized> Drop for User<T> +where + T: UserSafe, +{ + fn drop(&mut self) { + unsafe { + let ptr = (*self.0.as_ptr()).0.get(); + super::free(ptr as _, mem::size_of_val(&mut *ptr), T::align_of()); + } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T: CoerceUnsized<U>, U> CoerceUnsized<UserRef<U>> for UserRef<T> {} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T, I> Index<I> for UserRef<[T]> +where + [T]: UserSafe, + I: SliceIndex<[T]>, + I::Output: UserSafe, +{ + type Output = UserRef<I::Output>; + + #[inline] + fn index(&self, index: I) -> &UserRef<I::Output> { + unsafe { + if let Some(slice) = index.get(&*self.as_raw_ptr()) { + UserRef::from_ptr(slice) + } else { + rtabort!("index out of range for user slice"); + } + } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl<T, I> IndexMut<I> for UserRef<[T]> +where + [T]: UserSafe, + I: SliceIndex<[T]>, + I::Output: UserSafe, +{ + #[inline] + fn index_mut(&mut self, index: I) -> &mut UserRef<I::Output> { + unsafe { + if let Some(slice) = index.get_mut(&mut *self.as_raw_mut_ptr()) { + UserRef::from_mut_ptr(slice) + } else { + rtabort!("index out of range for user slice"); + } + } + } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +impl UserRef<super::raw::ByteBuffer> { + /// Copies the user memory range pointed to by the user `ByteBuffer` to + /// enclave memory. + /// + /// # Panics + /// This function panics if, in the user `ByteBuffer`: + /// + /// * The pointer is null + /// * The pointed-to range does not fit in the address space + /// * The pointed-to range is not in user memory + pub fn copy_user_buffer(&self) -> Vec<u8> { + unsafe { + let buf = self.to_enclave(); + if buf.len > 0 { + User::from_raw_parts(buf.data as _, buf.len).to_enclave() + } else { + // Mustn't look at `data` or call `free` if `len` is `0`. + Vec::with_capacity(0) + } + } + } +} diff --git a/library/std/src/sys/sgx/abi/usercalls/mod.rs b/library/std/src/sys/sgx/abi/usercalls/mod.rs new file mode 100644 index 000000000..79d1db5e1 --- /dev/null +++ b/library/std/src/sys/sgx/abi/usercalls/mod.rs @@ -0,0 +1,323 @@ +use crate::cmp; +use crate::io::{Error as IoError, ErrorKind, IoSlice, IoSliceMut, Result as IoResult}; +use crate::sys::rand::rdrand64; +use crate::time::{Duration, Instant}; + +pub(crate) mod alloc; +#[macro_use] +pub(crate) mod raw; +#[cfg(test)] +mod tests; + +use self::raw::*; + +/// Usercall `read`. See the ABI documentation for more information. +/// +/// This will do a single `read` usercall and scatter the read data among +/// `bufs`. To read to a single buffer, just pass a slice of length one. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn read(fd: Fd, bufs: &mut [IoSliceMut<'_>]) -> IoResult<usize> { + unsafe { + let total_len = bufs.iter().fold(0usize, |sum, buf| sum.saturating_add(buf.len())); + let mut userbuf = alloc::User::<[u8]>::uninitialized(total_len); + let ret_len = raw::read(fd, userbuf.as_mut_ptr(), userbuf.len()).from_sgx_result()?; + let userbuf = &userbuf[..ret_len]; + let mut index = 0; + for buf in bufs { + let end = cmp::min(index + buf.len(), userbuf.len()); + if let Some(buflen) = end.checked_sub(index) { + userbuf[index..end].copy_to_enclave(&mut buf[..buflen]); + index += buf.len(); + } else { + break; + } + } + Ok(userbuf.len()) + } +} + +/// Usercall `read_alloc`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn read_alloc(fd: Fd) -> IoResult<Vec<u8>> { + unsafe { + let userbuf = ByteBuffer { data: crate::ptr::null_mut(), len: 0 }; + let mut userbuf = alloc::User::new_from_enclave(&userbuf); + raw::read_alloc(fd, userbuf.as_raw_mut_ptr()).from_sgx_result()?; + Ok(userbuf.copy_user_buffer()) + } +} + +/// Usercall `write`. See the ABI documentation for more information. +/// +/// This will do a single `write` usercall and gather the written data from +/// `bufs`. To write from a single buffer, just pass a slice of length one. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn write(fd: Fd, bufs: &[IoSlice<'_>]) -> IoResult<usize> { + unsafe { + let total_len = bufs.iter().fold(0usize, |sum, buf| sum.saturating_add(buf.len())); + let mut userbuf = alloc::User::<[u8]>::uninitialized(total_len); + let mut index = 0; + for buf in bufs { + let end = cmp::min(index + buf.len(), userbuf.len()); + if let Some(buflen) = end.checked_sub(index) { + userbuf[index..end].copy_from_enclave(&buf[..buflen]); + index += buf.len(); + } else { + break; + } + } + raw::write(fd, userbuf.as_ptr(), userbuf.len()).from_sgx_result() + } +} + +/// Usercall `flush`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn flush(fd: Fd) -> IoResult<()> { + unsafe { raw::flush(fd).from_sgx_result() } +} + +/// Usercall `close`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn close(fd: Fd) { + unsafe { raw::close(fd) } +} + +fn string_from_bytebuffer(buf: &alloc::UserRef<ByteBuffer>, usercall: &str, arg: &str) -> String { + String::from_utf8(buf.copy_user_buffer()) + .unwrap_or_else(|_| rtabort!("Usercall {usercall}: expected {arg} to be valid UTF-8")) +} + +/// Usercall `bind_stream`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn bind_stream(addr: &str) -> IoResult<(Fd, String)> { + unsafe { + let addr_user = alloc::User::new_from_enclave(addr.as_bytes()); + let mut local = alloc::User::<ByteBuffer>::uninitialized(); + let fd = raw::bind_stream(addr_user.as_ptr(), addr_user.len(), local.as_raw_mut_ptr()) + .from_sgx_result()?; + let local = string_from_bytebuffer(&local, "bind_stream", "local_addr"); + Ok((fd, local)) + } +} + +/// Usercall `accept_stream`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn accept_stream(fd: Fd) -> IoResult<(Fd, String, String)> { + unsafe { + let mut bufs = alloc::User::<[ByteBuffer; 2]>::uninitialized(); + let mut buf_it = alloc::UserRef::iter_mut(&mut *bufs); // FIXME: can this be done + // without forcing coercion? + let (local, peer) = (buf_it.next().unwrap(), buf_it.next().unwrap()); + let fd = raw::accept_stream(fd, local.as_raw_mut_ptr(), peer.as_raw_mut_ptr()) + .from_sgx_result()?; + let local = string_from_bytebuffer(&local, "accept_stream", "local_addr"); + let peer = string_from_bytebuffer(&peer, "accept_stream", "peer_addr"); + Ok((fd, local, peer)) + } +} + +/// Usercall `connect_stream`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn connect_stream(addr: &str) -> IoResult<(Fd, String, String)> { + unsafe { + let addr_user = alloc::User::new_from_enclave(addr.as_bytes()); + let mut bufs = alloc::User::<[ByteBuffer; 2]>::uninitialized(); + let mut buf_it = alloc::UserRef::iter_mut(&mut *bufs); // FIXME: can this be done + // without forcing coercion? + let (local, peer) = (buf_it.next().unwrap(), buf_it.next().unwrap()); + let fd = raw::connect_stream( + addr_user.as_ptr(), + addr_user.len(), + local.as_raw_mut_ptr(), + peer.as_raw_mut_ptr(), + ) + .from_sgx_result()?; + let local = string_from_bytebuffer(&local, "connect_stream", "local_addr"); + let peer = string_from_bytebuffer(&peer, "connect_stream", "peer_addr"); + Ok((fd, local, peer)) + } +} + +/// Usercall `launch_thread`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub unsafe fn launch_thread() -> IoResult<()> { + // SAFETY: The caller must uphold the safety contract for `launch_thread`. + unsafe { raw::launch_thread().from_sgx_result() } +} + +/// Usercall `exit`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn exit(panic: bool) -> ! { + unsafe { raw::exit(panic) } +} + +/// Usercall `wait`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn wait(event_mask: u64, mut timeout: u64) -> IoResult<u64> { + if timeout != WAIT_NO && timeout != WAIT_INDEFINITE { + // We don't want people to rely on accuracy of timeouts to make + // security decisions in an SGX enclave. That's why we add a random + // amount not exceeding +/- 10% to the timeout value to discourage + // people from relying on accuracy of timeouts while providing a way + // to make things work in other cases. Note that in the SGX threat + // model the enclave runner which is serving the wait usercall is not + // trusted to ensure accurate timeouts. + if let Ok(timeout_signed) = i64::try_from(timeout) { + let tenth = timeout_signed / 10; + let deviation = (rdrand64() as i64).checked_rem(tenth).unwrap_or(0); + timeout = timeout_signed.saturating_add(deviation) as _; + } + } + unsafe { raw::wait(event_mask, timeout).from_sgx_result() } +} + +/// This function makes an effort to wait for a non-spurious event at least as +/// long as `duration`. Note that in general there is no guarantee about accuracy +/// of time and timeouts in SGX model. The enclave runner serving usercalls may +/// lie about current time and/or ignore timeout values. +/// +/// Once the event is observed, `should_wake_up` will be used to determine +/// whether or not the event was spurious. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn wait_timeout<F>(event_mask: u64, duration: Duration, should_wake_up: F) +where + F: Fn() -> bool, +{ + // Calls the wait usercall and checks the result. Returns true if event was + // returned, and false if WouldBlock/TimedOut was returned. + // If duration is None, it will use WAIT_NO. + fn wait_checked(event_mask: u64, duration: Option<Duration>) -> bool { + let timeout = duration.map_or(raw::WAIT_NO, |duration| { + cmp::min((u64::MAX - 1) as u128, duration.as_nanos()) as u64 + }); + match wait(event_mask, timeout) { + Ok(eventset) => { + if event_mask == 0 { + rtabort!("expected wait() to return Err, found Ok."); + } + rtassert!(eventset != 0 && eventset & !event_mask == 0); + true + } + Err(e) => { + rtassert!(e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock); + false + } + } + } + + match wait_checked(event_mask, Some(duration)) { + false => return, // timed out + true if should_wake_up() => return, // woken up + true => {} // spurious event + } + + // Drain all cached events. + // Note that `event_mask != 0` is implied if we get here. + loop { + match wait_checked(event_mask, None) { + false => break, // no more cached events + true if should_wake_up() => return, // woken up + true => {} // spurious event + } + } + + // Continue waiting, but take note of time spent waiting so we don't wait + // forever. We intentionally don't call `Instant::now()` before this point + // to avoid the cost of the `insecure_time` usercall in case there are no + // spurious wakeups. + + let start = Instant::now(); + let mut remaining = duration; + loop { + match wait_checked(event_mask, Some(remaining)) { + false => return, // timed out + true if should_wake_up() => return, // woken up + true => {} // spurious event + } + remaining = match duration.checked_sub(start.elapsed()) { + Some(remaining) => remaining, + None => break, + } + } +} + +/// Usercall `send`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn send(event_set: u64, tcs: Option<Tcs>) -> IoResult<()> { + unsafe { raw::send(event_set, tcs).from_sgx_result() } +} + +/// Usercall `insecure_time`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn insecure_time() -> Duration { + let t = unsafe { raw::insecure_time() }; + Duration::new(t / 1_000_000_000, (t % 1_000_000_000) as _) +} + +/// Usercall `alloc`. See the ABI documentation for more information. +#[unstable(feature = "sgx_platform", issue = "56975")] +pub fn alloc(size: usize, alignment: usize) -> IoResult<*mut u8> { + unsafe { raw::alloc(size, alignment).from_sgx_result() } +} + +#[unstable(feature = "sgx_platform", issue = "56975")] +#[doc(inline)] +pub use self::raw::free; + +fn check_os_error(err: Result) -> i32 { + // FIXME: not sure how to make sure all variants of Error are covered + if err == Error::NotFound as _ + || err == Error::PermissionDenied as _ + || err == Error::ConnectionRefused as _ + || err == Error::ConnectionReset as _ + || err == Error::ConnectionAborted as _ + || err == Error::NotConnected as _ + || err == Error::AddrInUse as _ + || err == Error::AddrNotAvailable as _ + || err == Error::BrokenPipe as _ + || err == Error::AlreadyExists as _ + || err == Error::WouldBlock as _ + || err == Error::InvalidInput as _ + || err == Error::InvalidData as _ + || err == Error::TimedOut as _ + || err == Error::WriteZero as _ + || err == Error::Interrupted as _ + || err == Error::Other as _ + || err == Error::UnexpectedEof as _ + || ((Error::UserRangeStart as _)..=(Error::UserRangeEnd as _)).contains(&err) + { + err + } else { + rtabort!("Usercall: returned invalid error value {err}") + } +} + +trait FromSgxResult { + type Return; + + fn from_sgx_result(self) -> IoResult<Self::Return>; +} + +impl<T> FromSgxResult for (Result, T) { + type Return = T; + + fn from_sgx_result(self) -> IoResult<Self::Return> { + if self.0 == RESULT_SUCCESS { + Ok(self.1) + } else { + Err(IoError::from_raw_os_error(check_os_error(self.0))) + } + } +} + +impl FromSgxResult for Result { + type Return = (); + + fn from_sgx_result(self) -> IoResult<Self::Return> { + if self == RESULT_SUCCESS { + Ok(()) + } else { + Err(IoError::from_raw_os_error(check_os_error(self))) + } + } +} diff --git a/library/std/src/sys/sgx/abi/usercalls/raw.rs b/library/std/src/sys/sgx/abi/usercalls/raw.rs new file mode 100644 index 000000000..4267b96cc --- /dev/null +++ b/library/std/src/sys/sgx/abi/usercalls/raw.rs @@ -0,0 +1,251 @@ +#![allow(unused)] + +#[unstable(feature = "sgx_platform", issue = "56975")] +pub use fortanix_sgx_abi::*; + +use crate::num::NonZeroU64; +use crate::ptr::NonNull; + +#[repr(C)] +struct UsercallReturn(u64, u64); + +extern "C" { + fn usercall(nr: NonZeroU64, p1: u64, p2: u64, abort: u64, p3: u64, p4: u64) -> UsercallReturn; +} + +/// Performs the raw usercall operation as defined in the ABI calling convention. +/// +/// # Safety +/// +/// The caller must ensure to pass parameters appropriate for the usercall `nr` +/// and to observe all requirements specified in the ABI. +/// +/// # Panics +/// +/// Panics if `nr` is `0`. +#[unstable(feature = "sgx_platform", issue = "56975")] +#[inline] +pub unsafe fn do_usercall( + nr: NonZeroU64, + p1: u64, + p2: u64, + p3: u64, + p4: u64, + abort: bool, +) -> (u64, u64) { + let UsercallReturn(a, b) = unsafe { usercall(nr, p1, p2, abort as _, p3, p4) }; + (a, b) +} + +type Register = u64; + +trait RegisterArgument { + fn from_register(_: Register) -> Self; + fn into_register(self) -> Register; +} + +trait ReturnValue { + fn from_registers(call: &'static str, regs: (Register, Register)) -> Self; +} + +macro_rules! define_usercalls { + ($(fn $f:ident($($n:ident: $t:ty),*) $(-> $r:tt)*; )*) => { + /// Usercall numbers as per the ABI. + #[repr(u64)] + #[unstable(feature = "sgx_platform", issue = "56975")] + #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)] + #[allow(missing_docs, non_camel_case_types)] + #[non_exhaustive] + pub enum Usercalls { + #[doc(hidden)] + __enclave_usercalls_invalid = 0, + $($f,)* + } + + $(enclave_usercalls_internal_define_usercalls!(def fn $f($($n: $t),*) $(-> $r)*);)* + }; +} + +macro_rules! define_ra { + (< $i:ident > $t:ty) => { + impl<$i> RegisterArgument for $t { + fn from_register(a: Register) -> Self { + a as _ + } + fn into_register(self) -> Register { + self as _ + } + } + }; + ($i:ty as $t:ty) => { + impl RegisterArgument for $t { + fn from_register(a: Register) -> Self { + a as $i as _ + } + fn into_register(self) -> Register { + self as $i as _ + } + } + }; + ($t:ty) => { + impl RegisterArgument for $t { + fn from_register(a: Register) -> Self { + a as _ + } + fn into_register(self) -> Register { + self as _ + } + } + }; +} + +define_ra!(Register); +define_ra!(i64); +define_ra!(u32); +define_ra!(u32 as i32); +define_ra!(u16); +define_ra!(u16 as i16); +define_ra!(u8); +define_ra!(u8 as i8); +define_ra!(usize); +define_ra!(usize as isize); +define_ra!(<T> *const T); +define_ra!(<T> *mut T); + +impl RegisterArgument for bool { + fn from_register(a: Register) -> bool { + if a != 0 { true } else { false } + } + fn into_register(self) -> Register { + self as _ + } +} + +impl<T: RegisterArgument> RegisterArgument for Option<NonNull<T>> { + fn from_register(a: Register) -> Option<NonNull<T>> { + NonNull::new(a as _) + } + fn into_register(self) -> Register { + self.map_or(0 as _, NonNull::as_ptr) as _ + } +} + +impl ReturnValue for ! { + fn from_registers(call: &'static str, _regs: (Register, Register)) -> Self { + rtabort!("Usercall {call}: did not expect to be re-entered"); + } +} + +impl ReturnValue for () { + fn from_registers(call: &'static str, usercall_retval: (Register, Register)) -> Self { + rtassert!(usercall_retval.0 == 0); + rtassert!(usercall_retval.1 == 0); + () + } +} + +impl<T: RegisterArgument> ReturnValue for T { + fn from_registers(call: &'static str, usercall_retval: (Register, Register)) -> Self { + rtassert!(usercall_retval.1 == 0); + T::from_register(usercall_retval.0) + } +} + +impl<T: RegisterArgument, U: RegisterArgument> ReturnValue for (T, U) { + fn from_registers(_call: &'static str, regs: (Register, Register)) -> Self { + (T::from_register(regs.0), U::from_register(regs.1)) + } +} + +macro_rules! return_type_is_abort { + (!) => { + true + }; + ($r:ty) => { + false + }; +} + +// In this macro: using `$r:tt` because `$r:ty` doesn't match ! in `return_type_is_abort` +macro_rules! enclave_usercalls_internal_define_usercalls { + (def fn $f:ident($n1:ident: $t1:ty, $n2:ident: $t2:ty, + $n3:ident: $t3:ty, $n4:ident: $t4:ty) -> $r:tt) => ( + /// This is the raw function definition, see the ABI documentation for + /// more information. + #[unstable(feature = "sgx_platform", issue = "56975")] + #[inline(always)] + pub unsafe fn $f($n1: $t1, $n2: $t2, $n3: $t3, $n4: $t4) -> $r { + ReturnValue::from_registers(stringify!($f), unsafe { do_usercall( + rtunwrap!(Some, NonZeroU64::new(Usercalls::$f as Register)), + RegisterArgument::into_register($n1), + RegisterArgument::into_register($n2), + RegisterArgument::into_register($n3), + RegisterArgument::into_register($n4), + return_type_is_abort!($r) + ) }) + } + ); + (def fn $f:ident($n1:ident: $t1:ty, $n2:ident: $t2:ty, $n3:ident: $t3:ty) -> $r:tt) => ( + /// This is the raw function definition, see the ABI documentation for + /// more information. + #[unstable(feature = "sgx_platform", issue = "56975")] + #[inline(always)] + pub unsafe fn $f($n1: $t1, $n2: $t2, $n3: $t3) -> $r { + ReturnValue::from_registers(stringify!($f), unsafe { do_usercall( + rtunwrap!(Some, NonZeroU64::new(Usercalls::$f as Register)), + RegisterArgument::into_register($n1), + RegisterArgument::into_register($n2), + RegisterArgument::into_register($n3), + 0, + return_type_is_abort!($r) + ) }) + } + ); + (def fn $f:ident($n1:ident: $t1:ty, $n2:ident: $t2:ty) -> $r:tt) => ( + /// This is the raw function definition, see the ABI documentation for + /// more information. + #[unstable(feature = "sgx_platform", issue = "56975")] + #[inline(always)] + pub unsafe fn $f($n1: $t1, $n2: $t2) -> $r { + ReturnValue::from_registers(stringify!($f), unsafe { do_usercall( + rtunwrap!(Some, NonZeroU64::new(Usercalls::$f as Register)), + RegisterArgument::into_register($n1), + RegisterArgument::into_register($n2), + 0,0, + return_type_is_abort!($r) + ) }) + } + ); + (def fn $f:ident($n1:ident: $t1:ty) -> $r:tt) => ( + /// This is the raw function definition, see the ABI documentation for + /// more information. + #[unstable(feature = "sgx_platform", issue = "56975")] + #[inline(always)] + pub unsafe fn $f($n1: $t1) -> $r { + ReturnValue::from_registers(stringify!($f), unsafe { do_usercall( + rtunwrap!(Some, NonZeroU64::new(Usercalls::$f as Register)), + RegisterArgument::into_register($n1), + 0,0,0, + return_type_is_abort!($r) + ) }) + } + ); + (def fn $f:ident() -> $r:tt) => ( + /// This is the raw function definition, see the ABI documentation for + /// more information. + #[unstable(feature = "sgx_platform", issue = "56975")] + #[inline(always)] + pub unsafe fn $f() -> $r { + ReturnValue::from_registers(stringify!($f), unsafe { do_usercall( + rtunwrap!(Some, NonZeroU64::new(Usercalls::$f as Register)), + 0,0,0,0, + return_type_is_abort!($r) + ) }) + } + ); + (def fn $f:ident($($n:ident: $t:ty),*)) => ( + enclave_usercalls_internal_define_usercalls!(def fn $f($($n: $t),*) -> ()); + ); +} + +invoke_with_usercalls!(define_usercalls); diff --git a/library/std/src/sys/sgx/abi/usercalls/tests.rs b/library/std/src/sys/sgx/abi/usercalls/tests.rs new file mode 100644 index 000000000..cbf7d7d54 --- /dev/null +++ b/library/std/src/sys/sgx/abi/usercalls/tests.rs @@ -0,0 +1,30 @@ +use super::alloc::copy_to_userspace; +use super::alloc::User; + +#[test] +fn test_copy_function() { + let mut src = [0u8; 100]; + let mut dst = User::<[u8]>::uninitialized(100); + + for i in 0..src.len() { + src[i] = i as _; + } + + for size in 0..48 { + // For all possible alignment + for offset in 0..8 { + // overwrite complete dst + dst.copy_from_enclave(&[0u8; 100]); + + // Copy src[0..size] to dst + offset + unsafe { copy_to_userspace(src.as_ptr(), dst.as_mut_ptr().offset(offset), size) }; + + // Verify copy + for byte in 0..size { + unsafe { + assert_eq!(*dst.as_ptr().offset(offset + byte as isize), src[byte as usize]); + } + } + } + } +} diff --git a/library/std/src/sys/sgx/alloc.rs b/library/std/src/sys/sgx/alloc.rs new file mode 100644 index 000000000..4aea28cb8 --- /dev/null +++ b/library/std/src/sys/sgx/alloc.rs @@ -0,0 +1,98 @@ +use crate::alloc::{GlobalAlloc, Layout, System}; +use crate::ptr; +use crate::sys::sgx::abi::mem as sgx_mem; +use core::sync::atomic::{AtomicBool, Ordering}; + +use super::waitqueue::SpinMutex; + +// Using a SpinMutex because we never want to exit the enclave waiting for the +// allocator. +// +// The current allocator here is the `dlmalloc` crate which we've got included +// in the rust-lang/rust repository as a submodule. The crate is a port of +// dlmalloc.c from C to Rust. +#[cfg_attr(test, linkage = "available_externally")] +#[export_name = "_ZN16__rust_internals3std3sys3sgx5alloc8DLMALLOCE"] +static DLMALLOC: SpinMutex<dlmalloc::Dlmalloc<Sgx>> = + SpinMutex::new(dlmalloc::Dlmalloc::new_with_allocator(Sgx {})); + +struct Sgx; + +unsafe impl dlmalloc::Allocator for Sgx { + /// Allocs system resources + fn alloc(&self, _size: usize) -> (*mut u8, usize, u32) { + static INIT: AtomicBool = AtomicBool::new(false); + + // No ordering requirement since this function is protected by the global lock. + if !INIT.swap(true, Ordering::Relaxed) { + (sgx_mem::heap_base() as _, sgx_mem::heap_size(), 0) + } else { + (ptr::null_mut(), 0, 0) + } + } + + fn remap(&self, _ptr: *mut u8, _oldsize: usize, _newsize: usize, _can_move: bool) -> *mut u8 { + ptr::null_mut() + } + + fn free_part(&self, _ptr: *mut u8, _oldsize: usize, _newsize: usize) -> bool { + false + } + + fn free(&self, _ptr: *mut u8, _size: usize) -> bool { + return false; + } + + fn can_release_part(&self, _flags: u32) -> bool { + false + } + + fn allocates_zeros(&self) -> bool { + false + } + + fn page_size(&self) -> usize { + 0x1000 + } +} + +#[stable(feature = "alloc_system_type", since = "1.28.0")] +unsafe impl GlobalAlloc for System { + #[inline] + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + // SAFETY: the caller must uphold the safety contract for `malloc` + unsafe { DLMALLOC.lock().malloc(layout.size(), layout.align()) } + } + + #[inline] + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + // SAFETY: the caller must uphold the safety contract for `malloc` + unsafe { DLMALLOC.lock().calloc(layout.size(), layout.align()) } + } + + #[inline] + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + // SAFETY: the caller must uphold the safety contract for `malloc` + unsafe { DLMALLOC.lock().free(ptr, layout.size(), layout.align()) } + } + + #[inline] + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + // SAFETY: the caller must uphold the safety contract for `malloc` + unsafe { DLMALLOC.lock().realloc(ptr, layout.size(), layout.align(), new_size) } + } +} + +// The following functions are needed by libunwind. These symbols are named +// in pre-link args for the target specification, so keep that in sync. +#[cfg(not(test))] +#[no_mangle] +pub unsafe extern "C" fn __rust_c_alloc(size: usize, align: usize) -> *mut u8 { + unsafe { crate::alloc::alloc(Layout::from_size_align_unchecked(size, align)) } +} + +#[cfg(not(test))] +#[no_mangle] +pub unsafe extern "C" fn __rust_c_dealloc(ptr: *mut u8, size: usize, align: usize) { + unsafe { crate::alloc::dealloc(ptr, Layout::from_size_align_unchecked(size, align)) } +} diff --git a/library/std/src/sys/sgx/args.rs b/library/std/src/sys/sgx/args.rs new file mode 100644 index 000000000..ef4176c4a --- /dev/null +++ b/library/std/src/sys/sgx/args.rs @@ -0,0 +1,59 @@ +use super::abi::usercalls::{alloc, raw::ByteBuffer}; +use crate::ffi::OsString; +use crate::fmt; +use crate::slice; +use crate::sync::atomic::{AtomicUsize, Ordering}; +use crate::sys::os_str::Buf; +use crate::sys_common::FromInner; + +#[cfg_attr(test, linkage = "available_externally")] +#[export_name = "_ZN16__rust_internals3std3sys3sgx4args4ARGSE"] +static ARGS: AtomicUsize = AtomicUsize::new(0); +type ArgsStore = Vec<OsString>; + +#[cfg_attr(test, allow(dead_code))] +pub unsafe fn init(argc: isize, argv: *const *const u8) { + if argc != 0 { + let args = unsafe { alloc::User::<[ByteBuffer]>::from_raw_parts(argv as _, argc as _) }; + let args = args + .iter() + .map(|a| OsString::from_inner(Buf { inner: a.copy_user_buffer() })) + .collect::<ArgsStore>(); + ARGS.store(Box::into_raw(Box::new(args)) as _, Ordering::Relaxed); + } +} + +pub fn args() -> Args { + let args = unsafe { (ARGS.load(Ordering::Relaxed) as *const ArgsStore).as_ref() }; + if let Some(args) = args { Args(args.iter()) } else { Args([].iter()) } +} + +pub struct Args(slice::Iter<'static, OsString>); + +impl fmt::Debug for Args { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.as_slice().fmt(f) + } +} + +impl Iterator for Args { + type Item = OsString; + fn next(&mut self) -> Option<OsString> { + self.0.next().cloned() + } + fn size_hint(&self) -> (usize, Option<usize>) { + self.0.size_hint() + } +} + +impl ExactSizeIterator for Args { + fn len(&self) -> usize { + self.0.len() + } +} + +impl DoubleEndedIterator for Args { + fn next_back(&mut self) -> Option<OsString> { + self.0.next_back().cloned() + } +} diff --git a/library/std/src/sys/sgx/condvar.rs b/library/std/src/sys/sgx/condvar.rs new file mode 100644 index 000000000..36534e0ef --- /dev/null +++ b/library/std/src/sys/sgx/condvar.rs @@ -0,0 +1,45 @@ +use crate::sys::locks::Mutex; +use crate::sys_common::lazy_box::{LazyBox, LazyInit}; +use crate::time::Duration; + +use super::waitqueue::{SpinMutex, WaitQueue, WaitVariable}; + +pub struct Condvar { + inner: SpinMutex<WaitVariable<()>>, +} + +pub(crate) type MovableCondvar = LazyBox<Condvar>; + +impl LazyInit for Condvar { + fn init() -> Box<Self> { + Box::new(Self::new()) + } +} + +impl Condvar { + pub const fn new() -> Condvar { + Condvar { inner: SpinMutex::new(WaitVariable::new(())) } + } + + #[inline] + pub unsafe fn notify_one(&self) { + let _ = WaitQueue::notify_one(self.inner.lock()); + } + + #[inline] + pub unsafe fn notify_all(&self) { + let _ = WaitQueue::notify_all(self.inner.lock()); + } + + pub unsafe fn wait(&self, mutex: &Mutex) { + let guard = self.inner.lock(); + WaitQueue::wait(guard, || unsafe { mutex.unlock() }); + unsafe { mutex.lock() } + } + + pub unsafe fn wait_timeout(&self, mutex: &Mutex, dur: Duration) -> bool { + let success = WaitQueue::wait_timeout(&self.inner, dur, || unsafe { mutex.unlock() }); + unsafe { mutex.lock() }; + success + } +} diff --git a/library/std/src/sys/sgx/env.rs b/library/std/src/sys/sgx/env.rs new file mode 100644 index 000000000..8043b7c52 --- /dev/null +++ b/library/std/src/sys/sgx/env.rs @@ -0,0 +1,9 @@ +pub mod os { + pub const FAMILY: &str = ""; + pub const OS: &str = ""; + pub const DLL_PREFIX: &str = ""; + pub const DLL_SUFFIX: &str = ".sgxs"; + pub const DLL_EXTENSION: &str = "sgxs"; + pub const EXE_SUFFIX: &str = ".sgxs"; + pub const EXE_EXTENSION: &str = "sgxs"; +} diff --git a/library/std/src/sys/sgx/fd.rs b/library/std/src/sys/sgx/fd.rs new file mode 100644 index 000000000..e5dc5b5ad --- /dev/null +++ b/library/std/src/sys/sgx/fd.rs @@ -0,0 +1,84 @@ +use fortanix_sgx_abi::Fd; + +use super::abi::usercalls; +use crate::io::{self, IoSlice, IoSliceMut}; +use crate::mem; +use crate::sys::{AsInner, FromInner, IntoInner}; + +#[derive(Debug)] +pub struct FileDesc { + fd: Fd, +} + +impl FileDesc { + pub fn new(fd: Fd) -> FileDesc { + FileDesc { fd: fd } + } + + pub fn raw(&self) -> Fd { + self.fd + } + + /// Extracts the actual file descriptor without closing it. + pub fn into_raw(self) -> Fd { + let fd = self.fd; + mem::forget(self); + fd + } + + pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> { + usercalls::read(self.fd, &mut [IoSliceMut::new(buf)]) + } + + pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> { + usercalls::read(self.fd, bufs) + } + + #[inline] + pub fn is_read_vectored(&self) -> bool { + true + } + + pub fn write(&self, buf: &[u8]) -> io::Result<usize> { + usercalls::write(self.fd, &[IoSlice::new(buf)]) + } + + pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { + usercalls::write(self.fd, bufs) + } + + #[inline] + pub fn is_write_vectored(&self) -> bool { + true + } + + pub fn flush(&self) -> io::Result<()> { + usercalls::flush(self.fd) + } +} + +impl AsInner<Fd> for FileDesc { + fn as_inner(&self) -> &Fd { + &self.fd + } +} + +impl IntoInner<Fd> for FileDesc { + fn into_inner(self) -> Fd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl FromInner<Fd> for FileDesc { + fn from_inner(fd: Fd) -> FileDesc { + FileDesc { fd } + } +} + +impl Drop for FileDesc { + fn drop(&mut self) { + usercalls::close(self.fd) + } +} diff --git a/library/std/src/sys/sgx/memchr.rs b/library/std/src/sys/sgx/memchr.rs new file mode 100644 index 000000000..996748219 --- /dev/null +++ b/library/std/src/sys/sgx/memchr.rs @@ -0,0 +1 @@ +pub use core::slice::memchr::{memchr, memrchr}; diff --git a/library/std/src/sys/sgx/mod.rs b/library/std/src/sys/sgx/mod.rs new file mode 100644 index 000000000..696400670 --- /dev/null +++ b/library/std/src/sys/sgx/mod.rs @@ -0,0 +1,167 @@ +//! System bindings for the Fortanix SGX platform +//! +//! This module contains the facade (aka platform-specific) implementations of +//! OS level functionality for Fortanix SGX. +#![deny(unsafe_op_in_unsafe_fn)] + +use crate::io::ErrorKind; +use crate::sync::atomic::{AtomicBool, Ordering}; + +pub mod abi; +mod waitqueue; + +pub mod alloc; +pub mod args; +#[path = "../unix/cmath.rs"] +pub mod cmath; +pub mod env; +pub mod fd; +#[path = "../unsupported/fs.rs"] +pub mod fs; +#[path = "../unsupported/io.rs"] +pub mod io; +pub mod memchr; +pub mod net; +pub mod os; +#[path = "../unix/os_str.rs"] +pub mod os_str; +pub mod path; +#[path = "../unsupported/pipe.rs"] +pub mod pipe; +#[path = "../unsupported/process.rs"] +pub mod process; +pub mod stdio; +pub mod thread; +pub mod thread_local_key; +pub mod time; + +mod condvar; +mod mutex; +mod rwlock; + +pub mod locks { + pub use super::condvar::*; + pub use super::mutex::*; + pub use super::rwlock::*; +} + +// SAFETY: must be called only once during runtime initialization. +// NOTE: this is not guaranteed to run, for example when Rust code is called externally. +pub unsafe fn init(argc: isize, argv: *const *const u8) { + unsafe { + args::init(argc, argv); + } +} + +// SAFETY: must be called only once during runtime cleanup. +// NOTE: this is not guaranteed to run, for example when the program aborts. +pub unsafe fn cleanup() {} + +/// This function is used to implement functionality that simply doesn't exist. +/// Programs relying on this functionality will need to deal with the error. +pub fn unsupported<T>() -> crate::io::Result<T> { + Err(unsupported_err()) +} + +pub fn unsupported_err() -> crate::io::Error { + crate::io::const_io_error!(ErrorKind::Unsupported, "operation not supported on SGX yet") +} + +/// This function is used to implement various functions that doesn't exist, +/// but the lack of which might not be reason for error. If no error is +/// returned, the program might very well be able to function normally. This is +/// what happens when `SGX_INEFFECTIVE_ERROR` is set to `true`. If it is +/// `false`, the behavior is the same as `unsupported`. +pub fn sgx_ineffective<T>(v: T) -> crate::io::Result<T> { + static SGX_INEFFECTIVE_ERROR: AtomicBool = AtomicBool::new(false); + if SGX_INEFFECTIVE_ERROR.load(Ordering::Relaxed) { + Err(crate::io::const_io_error!( + ErrorKind::Uncategorized, + "operation can't be trusted to have any effect on SGX", + )) + } else { + Ok(v) + } +} + +pub fn decode_error_kind(code: i32) -> ErrorKind { + use fortanix_sgx_abi::Error; + + // FIXME: not sure how to make sure all variants of Error are covered + if code == Error::NotFound as _ { + ErrorKind::NotFound + } else if code == Error::PermissionDenied as _ { + ErrorKind::PermissionDenied + } else if code == Error::ConnectionRefused as _ { + ErrorKind::ConnectionRefused + } else if code == Error::ConnectionReset as _ { + ErrorKind::ConnectionReset + } else if code == Error::ConnectionAborted as _ { + ErrorKind::ConnectionAborted + } else if code == Error::NotConnected as _ { + ErrorKind::NotConnected + } else if code == Error::AddrInUse as _ { + ErrorKind::AddrInUse + } else if code == Error::AddrNotAvailable as _ { + ErrorKind::AddrNotAvailable + } else if code == Error::BrokenPipe as _ { + ErrorKind::BrokenPipe + } else if code == Error::AlreadyExists as _ { + ErrorKind::AlreadyExists + } else if code == Error::WouldBlock as _ { + ErrorKind::WouldBlock + } else if code == Error::InvalidInput as _ { + ErrorKind::InvalidInput + } else if code == Error::InvalidData as _ { + ErrorKind::InvalidData + } else if code == Error::TimedOut as _ { + ErrorKind::TimedOut + } else if code == Error::WriteZero as _ { + ErrorKind::WriteZero + } else if code == Error::Interrupted as _ { + ErrorKind::Interrupted + } else if code == Error::Other as _ { + ErrorKind::Uncategorized + } else if code == Error::UnexpectedEof as _ { + ErrorKind::UnexpectedEof + } else { + ErrorKind::Uncategorized + } +} + +pub fn abort_internal() -> ! { + abi::usercalls::exit(true) +} + +// This function is needed by the panic runtime. The symbol is named in +// pre-link args for the target specification, so keep that in sync. +#[cfg(not(test))] +#[no_mangle] +// NB. used by both libunwind and libpanic_abort +pub extern "C" fn __rust_abort() { + abort_internal(); +} + +pub mod rand { + pub fn rdrand64() -> u64 { + unsafe { + let mut ret: u64 = 0; + for _ in 0..10 { + if crate::arch::x86_64::_rdrand64_step(&mut ret) == 1 { + return ret; + } + } + rtabort!("Failed to obtain random data"); + } + } +} + +pub fn hashmap_random_keys() -> (u64, u64) { + (self::rand::rdrand64(), self::rand::rdrand64()) +} + +pub use crate::sys_common::{AsInner, FromInner, IntoInner}; + +pub trait TryIntoInner<Inner>: Sized { + fn try_into_inner(self) -> Result<Inner, Self>; +} diff --git a/library/std/src/sys/sgx/mutex.rs b/library/std/src/sys/sgx/mutex.rs new file mode 100644 index 000000000..513cd77fd --- /dev/null +++ b/library/std/src/sys/sgx/mutex.rs @@ -0,0 +1,62 @@ +use super::waitqueue::{try_lock_or_false, SpinMutex, WaitQueue, WaitVariable}; +use crate::sys_common::lazy_box::{LazyBox, LazyInit}; + +pub struct Mutex { + inner: SpinMutex<WaitVariable<bool>>, +} + +// not movable: see UnsafeList implementation +pub(crate) type MovableMutex = LazyBox<Mutex>; + +impl LazyInit for Mutex { + fn init() -> Box<Self> { + Box::new(Self::new()) + } +} + +// Implementation according to “Operating Systems: Three Easy Pieces”, chapter 28 +impl Mutex { + pub const fn new() -> Mutex { + Mutex { inner: SpinMutex::new(WaitVariable::new(false)) } + } + + #[inline] + pub unsafe fn init(&mut self) {} + + #[inline] + pub unsafe fn lock(&self) { + let mut guard = self.inner.lock(); + if *guard.lock_var() { + // Another thread has the lock, wait + WaitQueue::wait(guard, || {}) + // Another thread has passed the lock to us + } else { + // We are just now obtaining the lock + *guard.lock_var_mut() = true; + } + } + + #[inline] + pub unsafe fn unlock(&self) { + let guard = self.inner.lock(); + if let Err(mut guard) = WaitQueue::notify_one(guard) { + // No other waiters, unlock + *guard.lock_var_mut() = false; + } else { + // There was a thread waiting, just pass the lock + } + } + + #[inline] + pub unsafe fn try_lock(&self) -> bool { + let mut guard = try_lock_or_false!(self.inner); + if *guard.lock_var() { + // Another thread has the lock + false + } else { + // We are just now obtaining the lock + *guard.lock_var_mut() = true; + true + } + } +} diff --git a/library/std/src/sys/sgx/net.rs b/library/std/src/sys/sgx/net.rs new file mode 100644 index 000000000..4c4cd7d1d --- /dev/null +++ b/library/std/src/sys/sgx/net.rs @@ -0,0 +1,541 @@ +use crate::error; +use crate::fmt; +use crate::io::{self, IoSlice, IoSliceMut}; +use crate::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs}; +use crate::sync::Arc; +use crate::sys::fd::FileDesc; +use crate::sys::{sgx_ineffective, unsupported, AsInner, FromInner, IntoInner, TryIntoInner}; +use crate::time::Duration; + +use super::abi::usercalls; + +const DEFAULT_FAKE_TTL: u32 = 64; + +#[derive(Debug, Clone)] +pub struct Socket { + inner: Arc<FileDesc>, + local_addr: Option<String>, +} + +impl Socket { + fn new(fd: usercalls::raw::Fd, local_addr: String) -> Socket { + Socket { inner: Arc::new(FileDesc::new(fd)), local_addr: Some(local_addr) } + } +} + +impl AsInner<FileDesc> for Socket { + fn as_inner(&self) -> &FileDesc { + &self.inner + } +} + +impl TryIntoInner<FileDesc> for Socket { + fn try_into_inner(self) -> Result<FileDesc, Socket> { + let Socket { inner, local_addr } = self; + Arc::try_unwrap(inner).map_err(|inner| Socket { inner, local_addr }) + } +} + +impl FromInner<(FileDesc, Option<String>)> for Socket { + fn from_inner((inner, local_addr): (FileDesc, Option<String>)) -> Socket { + Socket { inner: Arc::new(inner), local_addr } + } +} + +#[derive(Clone)] +pub struct TcpStream { + inner: Socket, + peer_addr: Option<String>, +} + +impl fmt::Debug for TcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut res = f.debug_struct("TcpStream"); + + if let Some(ref addr) = self.inner.local_addr { + res.field("addr", addr); + } + + if let Some(ref peer) = self.peer_addr { + res.field("peer", peer); + } + + res.field("fd", &self.inner.inner.as_inner()).finish() + } +} + +fn io_err_to_addr(result: io::Result<&SocketAddr>) -> io::Result<String> { + match result { + Ok(saddr) => Ok(saddr.to_string()), + // need to downcast twice because io::Error::into_inner doesn't return the original + // value if the conversion fails + Err(e) => { + if e.get_ref().and_then(|e| e.downcast_ref::<NonIpSockAddr>()).is_some() { + Ok(e.into_inner().unwrap().downcast::<NonIpSockAddr>().unwrap().host) + } else { + Err(e) + } + } + } +} + +fn addr_to_sockaddr(addr: &Option<String>) -> io::Result<SocketAddr> { + addr.as_ref() + .ok_or(io::ErrorKind::AddrNotAvailable)? + .to_socket_addrs() + // unwrap OK: if an iterator is returned, we're guaranteed to get exactly one entry + .map(|mut it| it.next().unwrap()) +} + +impl TcpStream { + pub fn connect(addr: io::Result<&SocketAddr>) -> io::Result<TcpStream> { + let addr = io_err_to_addr(addr)?; + let (fd, local_addr, peer_addr) = usercalls::connect_stream(&addr)?; + Ok(TcpStream { inner: Socket::new(fd, local_addr), peer_addr: Some(peer_addr) }) + } + + pub fn connect_timeout(addr: &SocketAddr, dur: Duration) -> io::Result<TcpStream> { + if dur == Duration::default() { + return Err(io::const_io_error!( + io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout", + )); + } + Self::connect(Ok(addr)) // FIXME: ignoring timeout + } + + pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { + match dur { + Some(dur) if dur == Duration::default() => { + return Err(io::const_io_error!( + io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout", + )); + } + _ => sgx_ineffective(()), + } + } + + pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { + match dur { + Some(dur) if dur == Duration::default() => { + return Err(io::const_io_error!( + io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout", + )); + } + _ => sgx_ineffective(()), + } + } + + pub fn read_timeout(&self) -> io::Result<Option<Duration>> { + sgx_ineffective(None) + } + + pub fn write_timeout(&self) -> io::Result<Option<Duration>> { + sgx_ineffective(None) + } + + pub fn peek(&self, _: &mut [u8]) -> io::Result<usize> { + Ok(0) + } + + pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.inner.read(buf) + } + + pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> { + self.inner.inner.read_vectored(bufs) + } + + #[inline] + pub fn is_read_vectored(&self) -> bool { + self.inner.inner.is_read_vectored() + } + + pub fn write(&self, buf: &[u8]) -> io::Result<usize> { + self.inner.inner.write(buf) + } + + pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { + self.inner.inner.write_vectored(bufs) + } + + #[inline] + pub fn is_write_vectored(&self) -> bool { + self.inner.inner.is_write_vectored() + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + addr_to_sockaddr(&self.peer_addr) + } + + pub fn socket_addr(&self) -> io::Result<SocketAddr> { + addr_to_sockaddr(&self.inner.local_addr) + } + + pub fn shutdown(&self, _: Shutdown) -> io::Result<()> { + sgx_ineffective(()) + } + + pub fn duplicate(&self) -> io::Result<TcpStream> { + Ok(self.clone()) + } + + pub fn set_linger(&self, _: Option<Duration>) -> io::Result<()> { + sgx_ineffective(()) + } + + pub fn linger(&self) -> io::Result<Option<Duration>> { + sgx_ineffective(None) + } + + pub fn set_nodelay(&self, _: bool) -> io::Result<()> { + sgx_ineffective(()) + } + + pub fn nodelay(&self) -> io::Result<bool> { + sgx_ineffective(false) + } + + pub fn set_ttl(&self, _: u32) -> io::Result<()> { + sgx_ineffective(()) + } + + pub fn ttl(&self) -> io::Result<u32> { + sgx_ineffective(DEFAULT_FAKE_TTL) + } + + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + Ok(None) + } + + pub fn set_nonblocking(&self, _: bool) -> io::Result<()> { + sgx_ineffective(()) + } +} + +impl AsInner<Socket> for TcpStream { + fn as_inner(&self) -> &Socket { + &self.inner + } +} + +// `Inner` includes `peer_addr` so that a `TcpStream` maybe correctly +// reconstructed if `Socket::try_into_inner` fails. +impl IntoInner<(Socket, Option<String>)> for TcpStream { + fn into_inner(self) -> (Socket, Option<String>) { + (self.inner, self.peer_addr) + } +} + +impl FromInner<(Socket, Option<String>)> for TcpStream { + fn from_inner((inner, peer_addr): (Socket, Option<String>)) -> TcpStream { + TcpStream { inner, peer_addr } + } +} + +#[derive(Clone)] +pub struct TcpListener { + inner: Socket, +} + +impl fmt::Debug for TcpListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut res = f.debug_struct("TcpListener"); + + if let Some(ref addr) = self.inner.local_addr { + res.field("addr", addr); + } + + res.field("fd", &self.inner.inner.as_inner()).finish() + } +} + +impl TcpListener { + pub fn bind(addr: io::Result<&SocketAddr>) -> io::Result<TcpListener> { + let addr = io_err_to_addr(addr)?; + let (fd, local_addr) = usercalls::bind_stream(&addr)?; + Ok(TcpListener { inner: Socket::new(fd, local_addr) }) + } + + pub fn socket_addr(&self) -> io::Result<SocketAddr> { + addr_to_sockaddr(&self.inner.local_addr) + } + + pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let (fd, local_addr, peer_addr) = usercalls::accept_stream(self.inner.inner.raw())?; + let peer_addr = Some(peer_addr); + let ret_peer = addr_to_sockaddr(&peer_addr).unwrap_or_else(|_| ([0; 4], 0).into()); + Ok((TcpStream { inner: Socket::new(fd, local_addr), peer_addr }, ret_peer)) + } + + pub fn duplicate(&self) -> io::Result<TcpListener> { + Ok(self.clone()) + } + + pub fn set_ttl(&self, _: u32) -> io::Result<()> { + sgx_ineffective(()) + } + + pub fn ttl(&self) -> io::Result<u32> { + sgx_ineffective(DEFAULT_FAKE_TTL) + } + + pub fn set_only_v6(&self, _: bool) -> io::Result<()> { + sgx_ineffective(()) + } + + pub fn only_v6(&self) -> io::Result<bool> { + sgx_ineffective(false) + } + + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + Ok(None) + } + + pub fn set_nonblocking(&self, _: bool) -> io::Result<()> { + sgx_ineffective(()) + } +} + +impl AsInner<Socket> for TcpListener { + fn as_inner(&self) -> &Socket { + &self.inner + } +} + +impl IntoInner<Socket> for TcpListener { + fn into_inner(self) -> Socket { + self.inner + } +} + +impl FromInner<Socket> for TcpListener { + fn from_inner(inner: Socket) -> TcpListener { + TcpListener { inner } + } +} + +pub struct UdpSocket(!); + +impl UdpSocket { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result<UdpSocket> { + unsupported() + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0 + } + + pub fn socket_addr(&self) -> io::Result<SocketAddr> { + self.0 + } + + pub fn recv_from(&self, _: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.0 + } + + pub fn peek_from(&self, _: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.0 + } + + pub fn send_to(&self, _: &[u8], _: &SocketAddr) -> io::Result<usize> { + self.0 + } + + pub fn duplicate(&self) -> io::Result<UdpSocket> { + self.0 + } + + pub fn set_read_timeout(&self, _: Option<Duration>) -> io::Result<()> { + self.0 + } + + pub fn set_write_timeout(&self, _: Option<Duration>) -> io::Result<()> { + self.0 + } + + pub fn read_timeout(&self) -> io::Result<Option<Duration>> { + self.0 + } + + pub fn write_timeout(&self) -> io::Result<Option<Duration>> { + self.0 + } + + pub fn set_broadcast(&self, _: bool) -> io::Result<()> { + self.0 + } + + pub fn broadcast(&self) -> io::Result<bool> { + self.0 + } + + pub fn set_multicast_loop_v4(&self, _: bool) -> io::Result<()> { + self.0 + } + + pub fn multicast_loop_v4(&self) -> io::Result<bool> { + self.0 + } + + pub fn set_multicast_ttl_v4(&self, _: u32) -> io::Result<()> { + self.0 + } + + pub fn multicast_ttl_v4(&self) -> io::Result<u32> { + self.0 + } + + pub fn set_multicast_loop_v6(&self, _: bool) -> io::Result<()> { + self.0 + } + + pub fn multicast_loop_v6(&self) -> io::Result<bool> { + self.0 + } + + pub fn join_multicast_v4(&self, _: &Ipv4Addr, _: &Ipv4Addr) -> io::Result<()> { + self.0 + } + + pub fn join_multicast_v6(&self, _: &Ipv6Addr, _: u32) -> io::Result<()> { + self.0 + } + + pub fn leave_multicast_v4(&self, _: &Ipv4Addr, _: &Ipv4Addr) -> io::Result<()> { + self.0 + } + + pub fn leave_multicast_v6(&self, _: &Ipv6Addr, _: u32) -> io::Result<()> { + self.0 + } + + pub fn set_ttl(&self, _: u32) -> io::Result<()> { + self.0 + } + + pub fn ttl(&self) -> io::Result<u32> { + self.0 + } + + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.0 + } + + pub fn set_nonblocking(&self, _: bool) -> io::Result<()> { + self.0 + } + + pub fn recv(&self, _: &mut [u8]) -> io::Result<usize> { + self.0 + } + + pub fn peek(&self, _: &mut [u8]) -> io::Result<usize> { + self.0 + } + + pub fn send(&self, _: &[u8]) -> io::Result<usize> { + self.0 + } + + pub fn connect(&self, _: io::Result<&SocketAddr>) -> io::Result<()> { + self.0 + } +} + +impl fmt::Debug for UdpSocket { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0 + } +} + +#[derive(Debug)] +pub struct NonIpSockAddr { + host: String, +} + +impl error::Error for NonIpSockAddr { + #[allow(deprecated)] + fn description(&self) -> &str { + "Failed to convert address to SocketAddr" + } +} + +impl fmt::Display for NonIpSockAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Failed to convert address to SocketAddr: {}", self.host) + } +} + +pub struct LookupHost(!); + +impl LookupHost { + fn new(host: String) -> io::Result<LookupHost> { + Err(io::Error::new(io::ErrorKind::Uncategorized, NonIpSockAddr { host })) + } + + pub fn port(&self) -> u16 { + self.0 + } +} + +impl Iterator for LookupHost { + type Item = SocketAddr; + fn next(&mut self) -> Option<SocketAddr> { + self.0 + } +} + +impl TryFrom<&str> for LookupHost { + type Error = io::Error; + + fn try_from(v: &str) -> io::Result<LookupHost> { + LookupHost::new(v.to_owned()) + } +} + +impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from((host, port): (&'a str, u16)) -> io::Result<LookupHost> { + LookupHost::new(format!("{host}:{port}")) + } +} + +#[allow(bad_style)] +pub mod netc { + pub const AF_INET: u8 = 0; + pub const AF_INET6: u8 = 1; + pub type sa_family_t = u8; + + #[derive(Copy, Clone)] + pub struct in_addr { + pub s_addr: u32, + } + + #[derive(Copy, Clone)] + pub struct sockaddr_in { + pub sin_family: sa_family_t, + pub sin_port: u16, + pub sin_addr: in_addr, + } + + #[derive(Copy, Clone)] + pub struct in6_addr { + pub s6_addr: [u8; 16], + } + + #[derive(Copy, Clone)] + pub struct sockaddr_in6 { + pub sin6_family: sa_family_t, + pub sin6_port: u16, + pub sin6_addr: in6_addr, + pub sin6_flowinfo: u32, + pub sin6_scope_id: u32, + } + + #[derive(Copy, Clone)] + pub struct sockaddr {} +} diff --git a/library/std/src/sys/sgx/os.rs b/library/std/src/sys/sgx/os.rs new file mode 100644 index 000000000..5da0257f3 --- /dev/null +++ b/library/std/src/sys/sgx/os.rs @@ -0,0 +1,140 @@ +use fortanix_sgx_abi::{Error, RESULT_SUCCESS}; + +use crate::collections::HashMap; +use crate::error::Error as StdError; +use crate::ffi::{OsStr, OsString}; +use crate::fmt; +use crate::io; +use crate::marker::PhantomData; +use crate::path::{self, PathBuf}; +use crate::str; +use crate::sync::atomic::{AtomicUsize, Ordering}; +use crate::sync::Mutex; +use crate::sync::Once; +use crate::sys::{decode_error_kind, sgx_ineffective, unsupported}; +use crate::vec; + +pub fn errno() -> i32 { + RESULT_SUCCESS +} + +pub fn error_string(errno: i32) -> String { + if errno == RESULT_SUCCESS { + "operation successful".into() + } else if ((Error::UserRangeStart as _)..=(Error::UserRangeEnd as _)).contains(&errno) { + format!("user-specified error {errno:08x}") + } else { + decode_error_kind(errno).as_str().into() + } +} + +pub fn getcwd() -> io::Result<PathBuf> { + unsupported() +} + +pub fn chdir(_: &path::Path) -> io::Result<()> { + sgx_ineffective(()) +} + +pub struct SplitPaths<'a>(!, PhantomData<&'a ()>); + +pub fn split_paths(_unparsed: &OsStr) -> SplitPaths<'_> { + panic!("unsupported") +} + +impl<'a> Iterator for SplitPaths<'a> { + type Item = PathBuf; + fn next(&mut self) -> Option<PathBuf> { + self.0 + } +} + +#[derive(Debug)] +pub struct JoinPathsError; + +pub fn join_paths<I, T>(_paths: I) -> Result<OsString, JoinPathsError> +where + I: Iterator<Item = T>, + T: AsRef<OsStr>, +{ + Err(JoinPathsError) +} + +impl fmt::Display for JoinPathsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "not supported in SGX yet".fmt(f) + } +} + +impl StdError for JoinPathsError { + #[allow(deprecated)] + fn description(&self) -> &str { + "not supported in SGX yet" + } +} + +pub fn current_exe() -> io::Result<PathBuf> { + unsupported() +} + +#[cfg_attr(test, linkage = "available_externally")] +#[export_name = "_ZN16__rust_internals3std3sys3sgx2os3ENVE"] +static ENV: AtomicUsize = AtomicUsize::new(0); +#[cfg_attr(test, linkage = "available_externally")] +#[export_name = "_ZN16__rust_internals3std3sys3sgx2os8ENV_INITE"] +static ENV_INIT: Once = Once::new(); +type EnvStore = Mutex<HashMap<OsString, OsString>>; + +fn get_env_store() -> Option<&'static EnvStore> { + unsafe { (ENV.load(Ordering::Relaxed) as *const EnvStore).as_ref() } +} + +fn create_env_store() -> &'static EnvStore { + ENV_INIT.call_once(|| { + ENV.store(Box::into_raw(Box::new(EnvStore::default())) as _, Ordering::Relaxed) + }); + unsafe { &*(ENV.load(Ordering::Relaxed) as *const EnvStore) } +} + +pub type Env = vec::IntoIter<(OsString, OsString)>; + +pub fn env() -> Env { + let clone_to_vec = |map: &HashMap<OsString, OsString>| -> Vec<_> { + map.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + }; + + get_env_store().map(|env| clone_to_vec(&env.lock().unwrap())).unwrap_or_default().into_iter() +} + +pub fn getenv(k: &OsStr) -> Option<OsString> { + get_env_store().and_then(|s| s.lock().unwrap().get(k).cloned()) +} + +pub fn setenv(k: &OsStr, v: &OsStr) -> io::Result<()> { + let (k, v) = (k.to_owned(), v.to_owned()); + create_env_store().lock().unwrap().insert(k, v); + Ok(()) +} + +pub fn unsetenv(k: &OsStr) -> io::Result<()> { + if let Some(env) = get_env_store() { + env.lock().unwrap().remove(k); + } + Ok(()) +} + +pub fn temp_dir() -> PathBuf { + panic!("no filesystem in SGX") +} + +pub fn home_dir() -> Option<PathBuf> { + None +} + +pub fn exit(code: i32) -> ! { + super::abi::exit_with_code(code as _) +} + +pub fn getpid() -> u32 { + panic!("no pids in SGX") +} diff --git a/library/std/src/sys/sgx/path.rs b/library/std/src/sys/sgx/path.rs new file mode 100644 index 000000000..c805c15e7 --- /dev/null +++ b/library/std/src/sys/sgx/path.rs @@ -0,0 +1,25 @@ +use crate::ffi::OsStr; +use crate::io; +use crate::path::{Path, PathBuf, Prefix}; +use crate::sys::unsupported; + +#[inline] +pub fn is_sep_byte(b: u8) -> bool { + b == b'/' +} + +#[inline] +pub fn is_verbatim_sep(b: u8) -> bool { + b == b'/' +} + +pub fn parse_prefix(_: &OsStr) -> Option<Prefix<'_>> { + None +} + +pub const MAIN_SEP_STR: &str = "/"; +pub const MAIN_SEP: char = '/'; + +pub(crate) fn absolute(_path: &Path) -> io::Result<PathBuf> { + unsupported() +} diff --git a/library/std/src/sys/sgx/rwlock.rs b/library/std/src/sys/sgx/rwlock.rs new file mode 100644 index 000000000..a97fb9ab0 --- /dev/null +++ b/library/std/src/sys/sgx/rwlock.rs @@ -0,0 +1,212 @@ +#[cfg(test)] +mod tests; + +use crate::num::NonZeroUsize; +use crate::sys_common::lazy_box::{LazyBox, LazyInit}; + +use super::waitqueue::{ + try_lock_or_false, NotifiedTcs, SpinMutex, SpinMutexGuard, WaitQueue, WaitVariable, +}; +use crate::mem; + +pub struct RwLock { + readers: SpinMutex<WaitVariable<Option<NonZeroUsize>>>, + writer: SpinMutex<WaitVariable<bool>>, +} + +pub(crate) type MovableRwLock = LazyBox<RwLock>; + +impl LazyInit for RwLock { + fn init() -> Box<Self> { + Box::new(Self::new()) + } +} + +// Check at compile time that RwLock size matches C definition (see test_c_rwlock_initializer below) +// +// # Safety +// Never called, as it is a compile time check. +#[allow(dead_code)] +unsafe fn rw_lock_size_assert(r: RwLock) { + unsafe { mem::transmute::<RwLock, [u8; 144]>(r) }; +} + +impl RwLock { + pub const fn new() -> RwLock { + RwLock { + readers: SpinMutex::new(WaitVariable::new(None)), + writer: SpinMutex::new(WaitVariable::new(false)), + } + } + + #[inline] + pub unsafe fn read(&self) { + let mut rguard = self.readers.lock(); + let wguard = self.writer.lock(); + if *wguard.lock_var() || !wguard.queue_empty() { + // Another thread has or is waiting for the write lock, wait + drop(wguard); + WaitQueue::wait(rguard, || {}); + // Another thread has passed the lock to us + } else { + // No waiting writers, acquire the read lock + *rguard.lock_var_mut() = + NonZeroUsize::new(rguard.lock_var().map_or(0, |n| n.get()) + 1); + } + } + + #[inline] + pub unsafe fn try_read(&self) -> bool { + let mut rguard = try_lock_or_false!(self.readers); + let wguard = try_lock_or_false!(self.writer); + if *wguard.lock_var() || !wguard.queue_empty() { + // Another thread has or is waiting for the write lock + false + } else { + // No waiting writers, acquire the read lock + *rguard.lock_var_mut() = + NonZeroUsize::new(rguard.lock_var().map_or(0, |n| n.get()) + 1); + true + } + } + + #[inline] + pub unsafe fn write(&self) { + let rguard = self.readers.lock(); + let mut wguard = self.writer.lock(); + if *wguard.lock_var() || rguard.lock_var().is_some() { + // Another thread has the lock, wait + drop(rguard); + WaitQueue::wait(wguard, || {}); + // Another thread has passed the lock to us + } else { + // We are just now obtaining the lock + *wguard.lock_var_mut() = true; + } + } + + #[inline] + pub unsafe fn try_write(&self) -> bool { + let rguard = try_lock_or_false!(self.readers); + let mut wguard = try_lock_or_false!(self.writer); + if *wguard.lock_var() || rguard.lock_var().is_some() { + // Another thread has the lock + false + } else { + // We are just now obtaining the lock + *wguard.lock_var_mut() = true; + true + } + } + + #[inline] + unsafe fn __read_unlock( + &self, + mut rguard: SpinMutexGuard<'_, WaitVariable<Option<NonZeroUsize>>>, + wguard: SpinMutexGuard<'_, WaitVariable<bool>>, + ) { + *rguard.lock_var_mut() = NonZeroUsize::new(rguard.lock_var().unwrap().get() - 1); + if rguard.lock_var().is_some() { + // There are other active readers + } else { + if let Ok(mut wguard) = WaitQueue::notify_one(wguard) { + // A writer was waiting, pass the lock + *wguard.lock_var_mut() = true; + wguard.drop_after(rguard); + } else { + // No writers were waiting, the lock is released + rtassert!(rguard.queue_empty()); + } + } + } + + #[inline] + pub unsafe fn read_unlock(&self) { + let rguard = self.readers.lock(); + let wguard = self.writer.lock(); + unsafe { self.__read_unlock(rguard, wguard) }; + } + + #[inline] + unsafe fn __write_unlock( + &self, + rguard: SpinMutexGuard<'_, WaitVariable<Option<NonZeroUsize>>>, + wguard: SpinMutexGuard<'_, WaitVariable<bool>>, + ) { + match WaitQueue::notify_one(wguard) { + Err(mut wguard) => { + // No writers waiting, release the write lock + *wguard.lock_var_mut() = false; + if let Ok(mut rguard) = WaitQueue::notify_all(rguard) { + // One or more readers were waiting, pass the lock to them + if let NotifiedTcs::All { count } = rguard.notified_tcs() { + *rguard.lock_var_mut() = Some(count) + } else { + unreachable!() // called notify_all + } + rguard.drop_after(wguard); + } else { + // No readers waiting, the lock is released + } + } + Ok(wguard) => { + // There was a thread waiting for write, just pass the lock + wguard.drop_after(rguard); + } + } + } + + #[inline] + pub unsafe fn write_unlock(&self) { + let rguard = self.readers.lock(); + let wguard = self.writer.lock(); + unsafe { self.__write_unlock(rguard, wguard) }; + } + + // only used by __rust_rwlock_unlock below + #[inline] + #[cfg_attr(test, allow(dead_code))] + unsafe fn unlock(&self) { + let rguard = self.readers.lock(); + let wguard = self.writer.lock(); + if *wguard.lock_var() == true { + unsafe { self.__write_unlock(rguard, wguard) }; + } else { + unsafe { self.__read_unlock(rguard, wguard) }; + } + } +} + +// The following functions are needed by libunwind. These symbols are named +// in pre-link args for the target specification, so keep that in sync. +#[cfg(not(test))] +const EINVAL: i32 = 22; + +#[cfg(not(test))] +#[no_mangle] +pub unsafe extern "C" fn __rust_rwlock_rdlock(p: *mut RwLock) -> i32 { + if p.is_null() { + return EINVAL; + } + unsafe { (*p).read() }; + return 0; +} + +#[cfg(not(test))] +#[no_mangle] +pub unsafe extern "C" fn __rust_rwlock_wrlock(p: *mut RwLock) -> i32 { + if p.is_null() { + return EINVAL; + } + unsafe { (*p).write() }; + return 0; +} +#[cfg(not(test))] +#[no_mangle] +pub unsafe extern "C" fn __rust_rwlock_unlock(p: *mut RwLock) -> i32 { + if p.is_null() { + return EINVAL; + } + unsafe { (*p).unlock() }; + return 0; +} diff --git a/library/std/src/sys/sgx/rwlock/tests.rs b/library/std/src/sys/sgx/rwlock/tests.rs new file mode 100644 index 000000000..479996115 --- /dev/null +++ b/library/std/src/sys/sgx/rwlock/tests.rs @@ -0,0 +1,31 @@ +use super::*; + +// Verify that the byte pattern libunwind uses to initialize an RwLock is +// equivalent to the value of RwLock::new(). If the value changes, +// `src/UnwindRustSgx.h` in libunwind needs to be changed too. +#[test] +fn test_c_rwlock_initializer() { + #[rustfmt::skip] + const C_RWLOCK_INIT: &[u8] = &[ + /* 0x00 */ 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x10 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x20 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x30 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x40 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x50 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x60 */ 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x70 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /* 0x80 */ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + // For the test to work, we need the padding/unused bytes in RwLock to be + // initialized as 0. In practice, this is the case with statics. + static RUST_RWLOCK_INIT: RwLock = RwLock::new(); + + unsafe { + // If the assertion fails, that not necessarily an issue with the value + // of C_RWLOCK_INIT. It might just be an issue with the way padding + // bytes are initialized in the test code. + assert_eq!(&crate::mem::transmute_copy::<_, [u8; 144]>(&RUST_RWLOCK_INIT), C_RWLOCK_INIT); + }; +} diff --git a/library/std/src/sys/sgx/stdio.rs b/library/std/src/sys/sgx/stdio.rs new file mode 100644 index 000000000..2e680e740 --- /dev/null +++ b/library/std/src/sys/sgx/stdio.rs @@ -0,0 +1,88 @@ +use fortanix_sgx_abi as abi; + +use crate::io; +#[cfg(not(test))] +use crate::slice; +#[cfg(not(test))] +use crate::str; +use crate::sys::fd::FileDesc; + +pub struct Stdin(()); +pub struct Stdout(()); +pub struct Stderr(()); + +fn with_std_fd<F: FnOnce(&FileDesc) -> R, R>(fd: abi::Fd, f: F) -> R { + let fd = FileDesc::new(fd); + let ret = f(&fd); + fd.into_raw(); + ret +} + +impl Stdin { + pub const fn new() -> Stdin { + Stdin(()) + } +} + +impl io::Read for Stdin { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + with_std_fd(abi::FD_STDIN, |fd| fd.read(buf)) + } +} + +impl Stdout { + pub const fn new() -> Stdout { + Stdout(()) + } +} + +impl io::Write for Stdout { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + with_std_fd(abi::FD_STDOUT, |fd| fd.write(buf)) + } + + fn flush(&mut self) -> io::Result<()> { + with_std_fd(abi::FD_STDOUT, |fd| fd.flush()) + } +} + +impl Stderr { + pub const fn new() -> Stderr { + Stderr(()) + } +} + +impl io::Write for Stderr { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + with_std_fd(abi::FD_STDERR, |fd| fd.write(buf)) + } + + fn flush(&mut self) -> io::Result<()> { + with_std_fd(abi::FD_STDERR, |fd| fd.flush()) + } +} + +pub const STDIN_BUF_SIZE: usize = crate::sys_common::io::DEFAULT_BUF_SIZE; + +pub fn is_ebadf(err: &io::Error) -> bool { + // FIXME: Rust normally maps Unix EBADF to `Uncategorized` + err.raw_os_error() == Some(abi::Error::BrokenPipe as _) +} + +pub fn panic_output() -> Option<impl io::Write> { + super::abi::panic::SgxPanicOutput::new() +} + +// This function is needed by libunwind. The symbol is named in pre-link args +// for the target specification, so keep that in sync. +#[cfg(not(test))] +#[no_mangle] +pub unsafe extern "C" fn __rust_print_err(m: *mut u8, s: i32) { + if s < 0 { + return; + } + let buf = unsafe { slice::from_raw_parts(m as *const u8, s as _) }; + if let Ok(s) = str::from_utf8(&buf[..buf.iter().position(|&b| b == 0).unwrap_or(buf.len())]) { + eprint!("{s}"); + } +} diff --git a/library/std/src/sys/sgx/thread.rs b/library/std/src/sys/sgx/thread.rs new file mode 100644 index 000000000..d745a6196 --- /dev/null +++ b/library/std/src/sys/sgx/thread.rs @@ -0,0 +1,152 @@ +#![cfg_attr(test, allow(dead_code))] // why is this necessary? +use super::unsupported; +use crate::ffi::CStr; +use crate::io; +use crate::num::NonZeroUsize; +use crate::time::Duration; + +use super::abi::usercalls; + +pub struct Thread(task_queue::JoinHandle); + +pub const DEFAULT_MIN_STACK_SIZE: usize = 4096; + +pub use self::task_queue::JoinNotifier; + +mod task_queue { + use super::wait_notify; + use crate::sync::{Mutex, MutexGuard, Once}; + + pub type JoinHandle = wait_notify::Waiter; + + pub struct JoinNotifier(Option<wait_notify::Notifier>); + + impl Drop for JoinNotifier { + fn drop(&mut self) { + self.0.take().unwrap().notify(); + } + } + + pub(super) struct Task { + p: Box<dyn FnOnce()>, + done: JoinNotifier, + } + + impl Task { + pub(super) fn new(p: Box<dyn FnOnce()>) -> (Task, JoinHandle) { + let (done, recv) = wait_notify::new(); + let done = JoinNotifier(Some(done)); + (Task { p, done }, recv) + } + + pub(super) fn run(self) -> JoinNotifier { + (self.p)(); + self.done + } + } + + #[cfg_attr(test, linkage = "available_externally")] + #[export_name = "_ZN16__rust_internals3std3sys3sgx6thread15TASK_QUEUE_INITE"] + static TASK_QUEUE_INIT: Once = Once::new(); + #[cfg_attr(test, linkage = "available_externally")] + #[export_name = "_ZN16__rust_internals3std3sys3sgx6thread10TASK_QUEUEE"] + static mut TASK_QUEUE: Option<Mutex<Vec<Task>>> = None; + + pub(super) fn lock() -> MutexGuard<'static, Vec<Task>> { + unsafe { + TASK_QUEUE_INIT.call_once(|| TASK_QUEUE = Some(Default::default())); + TASK_QUEUE.as_ref().unwrap().lock().unwrap() + } + } +} + +/// This module provides a synchronization primitive that does not use thread +/// local variables. This is needed for signaling that a thread has finished +/// execution. The signal is sent once all TLS destructors have finished at +/// which point no new thread locals should be created. +pub mod wait_notify { + use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable}; + use crate::sync::Arc; + + pub struct Notifier(Arc<SpinMutex<WaitVariable<bool>>>); + + impl Notifier { + /// Notify the waiter. The waiter is either notified right away (if + /// currently blocked in `Waiter::wait()`) or later when it calls the + /// `Waiter::wait()` method. + pub fn notify(self) { + let mut guard = self.0.lock(); + *guard.lock_var_mut() = true; + let _ = WaitQueue::notify_one(guard); + } + } + + pub struct Waiter(Arc<SpinMutex<WaitVariable<bool>>>); + + impl Waiter { + /// Wait for a notification. If `Notifier::notify()` has already been + /// called, this will return immediately, otherwise the current thread + /// is blocked until notified. + pub fn wait(self) { + let guard = self.0.lock(); + if *guard.lock_var() { + return; + } + WaitQueue::wait(guard, || {}); + } + } + + pub fn new() -> (Notifier, Waiter) { + let inner = Arc::new(SpinMutex::new(WaitVariable::new(false))); + (Notifier(inner.clone()), Waiter(inner)) + } +} + +impl Thread { + // unsafe: see thread::Builder::spawn_unchecked for safety requirements + pub unsafe fn new(_stack: usize, p: Box<dyn FnOnce()>) -> io::Result<Thread> { + let mut queue_lock = task_queue::lock(); + unsafe { usercalls::launch_thread()? }; + let (task, handle) = task_queue::Task::new(p); + queue_lock.push(task); + Ok(Thread(handle)) + } + + pub(super) fn entry() -> JoinNotifier { + let mut pending_tasks = task_queue::lock(); + let task = rtunwrap!(Some, pending_tasks.pop()); + drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary + task.run() + } + + pub fn yield_now() { + let wait_error = rtunwrap!(Err, usercalls::wait(0, usercalls::raw::WAIT_NO)); + rtassert!(wait_error.kind() == io::ErrorKind::WouldBlock); + } + + pub fn set_name(_name: &CStr) { + // FIXME: could store this pointer in TLS somewhere + } + + pub fn sleep(dur: Duration) { + usercalls::wait_timeout(0, dur, || true); + } + + pub fn join(self) { + self.0.wait(); + } +} + +pub fn available_parallelism() -> io::Result<NonZeroUsize> { + unsupported() +} + +pub mod guard { + pub type Guard = !; + pub unsafe fn current() -> Option<Guard> { + None + } + pub unsafe fn init() -> Option<Guard> { + None + } +} diff --git a/library/std/src/sys/sgx/thread_local_key.rs b/library/std/src/sys/sgx/thread_local_key.rs new file mode 100644 index 000000000..b21784475 --- /dev/null +++ b/library/std/src/sys/sgx/thread_local_key.rs @@ -0,0 +1,28 @@ +use super::abi::tls::{Key as AbiKey, Tls}; + +pub type Key = usize; + +#[inline] +pub unsafe fn create(dtor: Option<unsafe extern "C" fn(*mut u8)>) -> Key { + Tls::create(dtor).as_usize() +} + +#[inline] +pub unsafe fn set(key: Key, value: *mut u8) { + Tls::set(AbiKey::from_usize(key), value) +} + +#[inline] +pub unsafe fn get(key: Key) -> *mut u8 { + Tls::get(AbiKey::from_usize(key)) +} + +#[inline] +pub unsafe fn destroy(key: Key) { + Tls::destroy(AbiKey::from_usize(key)) +} + +#[inline] +pub fn requires_synchronized_create() -> bool { + false +} diff --git a/library/std/src/sys/sgx/time.rs b/library/std/src/sys/sgx/time.rs new file mode 100644 index 000000000..db4cf2804 --- /dev/null +++ b/library/std/src/sys/sgx/time.rs @@ -0,0 +1,46 @@ +use super::abi::usercalls; +use crate::time::Duration; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub struct Instant(Duration); + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub struct SystemTime(Duration); + +pub const UNIX_EPOCH: SystemTime = SystemTime(Duration::from_secs(0)); + +impl Instant { + pub fn now() -> Instant { + Instant(usercalls::insecure_time()) + } + + pub fn checked_sub_instant(&self, other: &Instant) -> Option<Duration> { + self.0.checked_sub(other.0) + } + + pub fn checked_add_duration(&self, other: &Duration) -> Option<Instant> { + Some(Instant(self.0.checked_add(*other)?)) + } + + pub fn checked_sub_duration(&self, other: &Duration) -> Option<Instant> { + Some(Instant(self.0.checked_sub(*other)?)) + } +} + +impl SystemTime { + pub fn now() -> SystemTime { + SystemTime(usercalls::insecure_time()) + } + + pub fn sub_time(&self, other: &SystemTime) -> Result<Duration, Duration> { + self.0.checked_sub(other.0).ok_or_else(|| other.0 - self.0) + } + + pub fn checked_add_duration(&self, other: &Duration) -> Option<SystemTime> { + Some(SystemTime(self.0.checked_add(*other)?)) + } + + pub fn checked_sub_duration(&self, other: &Duration) -> Option<SystemTime> { + Some(SystemTime(self.0.checked_sub(*other)?)) + } +} diff --git a/library/std/src/sys/sgx/waitqueue/mod.rs b/library/std/src/sys/sgx/waitqueue/mod.rs new file mode 100644 index 000000000..61bb11d9a --- /dev/null +++ b/library/std/src/sys/sgx/waitqueue/mod.rs @@ -0,0 +1,240 @@ +//! A simple queue implementation for synchronization primitives. +//! +//! This queue is used to implement condition variable and mutexes. +//! +//! Users of this API are expected to use the `WaitVariable<T>` type. Since +//! that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to +//! allow shared access. +//! +//! Since userspace may send spurious wake-ups, the wakeup event state is +//! recorded in the enclave. The wakeup event state is protected by a spinlock. +//! The queue and associated wait state are stored in a `WaitVariable`. + +#[cfg(test)] +mod tests; + +mod spin_mutex; +mod unsafe_list; + +use crate::num::NonZeroUsize; +use crate::ops::{Deref, DerefMut}; +use crate::time::Duration; + +use super::abi::thread; +use super::abi::usercalls; +use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE}; + +pub use self::spin_mutex::{try_lock_or_false, SpinMutex, SpinMutexGuard}; +use self::unsafe_list::{UnsafeList, UnsafeListEntry}; + +/// An queue entry in a `WaitQueue`. +struct WaitEntry { + /// TCS address of the thread that is waiting + tcs: Tcs, + /// Whether this thread has been notified to be awoken + wake: bool, +} + +/// Data stored with a `WaitQueue` alongside it. This ensures accesses to the +/// queue and the data are synchronized, since the type itself is not `Sync`. +/// +/// Consumers of this API should use a synchronization primitive for shared +/// access, such as `SpinMutex`. +#[derive(Default)] +pub struct WaitVariable<T> { + queue: WaitQueue, + lock: T, +} + +impl<T> WaitVariable<T> { + pub const fn new(var: T) -> Self { + WaitVariable { queue: WaitQueue::new(), lock: var } + } + + pub fn queue_empty(&self) -> bool { + self.queue.is_empty() + } + + pub fn lock_var(&self) -> &T { + &self.lock + } + + pub fn lock_var_mut(&mut self) -> &mut T { + &mut self.lock + } +} + +#[derive(Copy, Clone)] +pub enum NotifiedTcs { + Single(Tcs), + All { count: NonZeroUsize }, +} + +/// An RAII guard that will notify a set of target threads as well as unlock +/// a mutex on drop. +pub struct WaitGuard<'a, T: 'a> { + mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>, + notified_tcs: NotifiedTcs, +} + +/// A queue of threads that are waiting on some synchronization primitive. +/// +/// `UnsafeList` entries are allocated on the waiting thread's stack. This +/// avoids any global locking that might happen in the heap allocator. This is +/// safe because the waiting thread will not return from that stack frame until +/// after it is notified. The notifying thread ensures to clean up any +/// references to the list entries before sending the wakeup event. +pub struct WaitQueue { + // We use an inner Mutex here to protect the data in the face of spurious + // wakeups. + inner: UnsafeList<SpinMutex<WaitEntry>>, +} +unsafe impl Send for WaitQueue {} + +impl Default for WaitQueue { + fn default() -> Self { + Self::new() + } +} + +impl<'a, T> WaitGuard<'a, T> { + /// Returns which TCSes will be notified when this guard drops. + pub fn notified_tcs(&self) -> NotifiedTcs { + self.notified_tcs + } + + /// Drop this `WaitGuard`, after dropping another `guard`. + pub fn drop_after<U>(self, guard: U) { + drop(guard); + drop(self); + } +} + +impl<'a, T> Deref for WaitGuard<'a, T> { + type Target = SpinMutexGuard<'a, WaitVariable<T>>; + + fn deref(&self) -> &Self::Target { + self.mutex_guard.as_ref().unwrap() + } +} + +impl<'a, T> DerefMut for WaitGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.mutex_guard.as_mut().unwrap() + } +} + +impl<'a, T> Drop for WaitGuard<'a, T> { + fn drop(&mut self) { + drop(self.mutex_guard.take()); + let target_tcs = match self.notified_tcs { + NotifiedTcs::Single(tcs) => Some(tcs), + NotifiedTcs::All { .. } => None, + }; + rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs)); + } +} + +impl WaitQueue { + pub const fn new() -> Self { + WaitQueue { inner: UnsafeList::new() } + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait + /// until a wakeup event. + /// + /// This function does not return until this thread has been awoken. + pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) { + // very unsafe: check requirements of UnsafeList::push + unsafe { + let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry { + tcs: thread::current(), + wake: false, + })); + let entry = guard.queue.inner.push(&mut entry); + drop(guard); + before_wait(); + while !entry.lock().wake { + // don't panic, this would invalidate `entry` during unwinding + let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE)); + rtassert!(eventset & EV_UNPARK == EV_UNPARK); + } + } + } + + /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait + /// until a wakeup event or timeout. If event was observed, returns true. + /// If not, it will remove the calling thread from the wait queue. + pub fn wait_timeout<T, F: FnOnce()>( + lock: &SpinMutex<WaitVariable<T>>, + timeout: Duration, + before_wait: F, + ) -> bool { + // very unsafe: check requirements of UnsafeList::push + unsafe { + let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry { + tcs: thread::current(), + wake: false, + })); + let entry_lock = lock.lock().queue.inner.push(&mut entry); + before_wait(); + usercalls::wait_timeout(EV_UNPARK, timeout, || entry_lock.lock().wake); + // acquire the wait queue's lock first to avoid deadlock. + let mut guard = lock.lock(); + let success = entry_lock.lock().wake; + if !success { + // nobody is waking us up, so remove our entry from the wait queue. + guard.queue.inner.remove(&mut entry); + } + success + } + } + + /// Either find the next waiter on the wait queue, or return the mutex + /// guard unchanged. + /// + /// If a waiter is found, a `WaitGuard` is returned which will notify the + /// waiter when it is dropped. + pub fn notify_one<T>( + mut guard: SpinMutexGuard<'_, WaitVariable<T>>, + ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> { + unsafe { + if let Some(entry) = guard.queue.inner.pop() { + let mut entry_guard = entry.lock(); + let tcs = entry_guard.tcs; + entry_guard.wake = true; + drop(entry); + Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) }) + } else { + Err(guard) + } + } + } + + /// Either find any and all waiters on the wait queue, or return the mutex + /// guard unchanged. + /// + /// If at least one waiter is found, a `WaitGuard` is returned which will + /// notify all waiters when it is dropped. + pub fn notify_all<T>( + mut guard: SpinMutexGuard<'_, WaitVariable<T>>, + ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> { + unsafe { + let mut count = 0; + while let Some(entry) = guard.queue.inner.pop() { + count += 1; + let mut entry_guard = entry.lock(); + entry_guard.wake = true; + } + if let Some(count) = NonZeroUsize::new(count) { + Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } }) + } else { + Err(guard) + } + } + } +} diff --git a/library/std/src/sys/sgx/waitqueue/spin_mutex.rs b/library/std/src/sys/sgx/waitqueue/spin_mutex.rs new file mode 100644 index 000000000..f6e851cca --- /dev/null +++ b/library/std/src/sys/sgx/waitqueue/spin_mutex.rs @@ -0,0 +1,80 @@ +//! Trivial spinlock-based implementation of `sync::Mutex`. +// FIXME: Perhaps use Intel TSX to avoid locking? + +#[cfg(test)] +mod tests; + +use crate::cell::UnsafeCell; +use crate::hint; +use crate::ops::{Deref, DerefMut}; +use crate::sync::atomic::{AtomicBool, Ordering}; + +#[derive(Default)] +pub struct SpinMutex<T> { + value: UnsafeCell<T>, + lock: AtomicBool, +} + +unsafe impl<T: Send> Send for SpinMutex<T> {} +unsafe impl<T: Send> Sync for SpinMutex<T> {} + +pub struct SpinMutexGuard<'a, T: 'a> { + mutex: &'a SpinMutex<T>, +} + +impl<'a, T> !Send for SpinMutexGuard<'a, T> {} +unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {} + +impl<T> SpinMutex<T> { + pub const fn new(value: T) -> Self { + SpinMutex { value: UnsafeCell::new(value), lock: AtomicBool::new(false) } + } + + #[inline(always)] + pub fn lock(&self) -> SpinMutexGuard<'_, T> { + loop { + match self.try_lock() { + None => { + while self.lock.load(Ordering::Relaxed) { + hint::spin_loop() + } + } + Some(guard) => return guard, + } + } + } + + #[inline(always)] + pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> { + if self.lock.compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire).is_ok() { + Some(SpinMutexGuard { mutex: self }) + } else { + None + } + } +} + +/// Lock the Mutex or return false. +pub macro try_lock_or_false($e:expr) { + if let Some(v) = $e.try_lock() { v } else { return false } +} + +impl<'a, T> Deref for SpinMutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.mutex.value.get() } + } +} + +impl<'a, T> DerefMut for SpinMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.mutex.value.get() } + } +} + +impl<'a, T> Drop for SpinMutexGuard<'a, T> { + fn drop(&mut self) { + self.mutex.lock.store(false, Ordering::Release) + } +} diff --git a/library/std/src/sys/sgx/waitqueue/spin_mutex/tests.rs b/library/std/src/sys/sgx/waitqueue/spin_mutex/tests.rs new file mode 100644 index 000000000..4c5994bea --- /dev/null +++ b/library/std/src/sys/sgx/waitqueue/spin_mutex/tests.rs @@ -0,0 +1,23 @@ +#![allow(deprecated)] + +use super::*; +use crate::sync::Arc; +use crate::thread; +use crate::time::Duration; + +#[test] +fn sleep() { + let mutex = Arc::new(SpinMutex::<i32>::default()); + let mutex2 = mutex.clone(); + let guard = mutex.lock(); + let t1 = thread::spawn(move || { + *mutex2.lock() = 1; + }); + + thread::sleep(Duration::from_millis(50)); + + assert_eq!(*guard, 0); + drop(guard); + t1.join().unwrap(); + assert_eq!(*mutex.lock(), 1); +} diff --git a/library/std/src/sys/sgx/waitqueue/tests.rs b/library/std/src/sys/sgx/waitqueue/tests.rs new file mode 100644 index 000000000..bf91fdd08 --- /dev/null +++ b/library/std/src/sys/sgx/waitqueue/tests.rs @@ -0,0 +1,20 @@ +use super::*; +use crate::sync::Arc; +use crate::thread; + +#[test] +fn queue() { + let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default()); + let wq2 = wq.clone(); + + let locked = wq.lock(); + + let t1 = thread::spawn(move || { + // if we obtain the lock, the main thread should be waiting + assert!(WaitQueue::notify_one(wq2.lock()).is_ok()); + }); + + WaitQueue::wait(locked, || {}); + + t1.join().unwrap(); +} diff --git a/library/std/src/sys/sgx/waitqueue/unsafe_list.rs b/library/std/src/sys/sgx/waitqueue/unsafe_list.rs new file mode 100644 index 000000000..c736cab57 --- /dev/null +++ b/library/std/src/sys/sgx/waitqueue/unsafe_list.rs @@ -0,0 +1,156 @@ +//! A doubly-linked list where callers are in charge of memory allocation +//! of the nodes in the list. + +#[cfg(test)] +mod tests; + +use crate::mem; +use crate::ptr::NonNull; + +pub struct UnsafeListEntry<T> { + next: NonNull<UnsafeListEntry<T>>, + prev: NonNull<UnsafeListEntry<T>>, + value: Option<T>, +} + +impl<T> UnsafeListEntry<T> { + fn dummy() -> Self { + UnsafeListEntry { next: NonNull::dangling(), prev: NonNull::dangling(), value: None } + } + + pub fn new(value: T) -> Self { + UnsafeListEntry { value: Some(value), ..Self::dummy() } + } +} + +// WARNING: self-referential struct! +pub struct UnsafeList<T> { + head_tail: NonNull<UnsafeListEntry<T>>, + head_tail_entry: Option<UnsafeListEntry<T>>, +} + +impl<T> UnsafeList<T> { + pub const fn new() -> Self { + unsafe { UnsafeList { head_tail: NonNull::new_unchecked(1 as _), head_tail_entry: None } } + } + + /// # Safety + unsafe fn init(&mut self) { + if self.head_tail_entry.is_none() { + self.head_tail_entry = Some(UnsafeListEntry::dummy()); + // SAFETY: `head_tail_entry` must be non-null, which it is because we assign it above. + self.head_tail = + unsafe { NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap()) }; + // SAFETY: `self.head_tail` must meet all requirements for a mutable reference. + unsafe { self.head_tail.as_mut() }.next = self.head_tail; + unsafe { self.head_tail.as_mut() }.prev = self.head_tail; + } + } + + pub fn is_empty(&self) -> bool { + if self.head_tail_entry.is_some() { + let first = unsafe { self.head_tail.as_ref() }.next; + if first == self.head_tail { + // ,-------> /---------\ next ---, + // | |head_tail| | + // `--- prev \---------/ <-------` + // SAFETY: `self.head_tail` must meet all requirements for a reference. + unsafe { rtassert!(self.head_tail.as_ref().prev == first) }; + true + } else { + false + } + } else { + true + } + } + + /// Pushes an entry onto the back of the list. + /// + /// # Safety + /// + /// The entry must remain allocated until the entry is removed from the + /// list AND the caller who popped is done using the entry. Special + /// care must be taken in the caller of `push` to ensure unwinding does + /// not destroy the stack frame containing the entry. + pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T { + unsafe { self.init() }; + + // BEFORE: + // /---------\ next ---> /---------\ + // ... |prev_tail| |head_tail| ... + // \---------/ <--- prev \---------/ + // + // AFTER: + // /---------\ next ---> /-----\ next ---> /---------\ + // ... |prev_tail| |entry| |head_tail| ... + // \---------/ <--- prev \-----/ <--- prev \---------/ + let mut entry = unsafe { NonNull::new_unchecked(entry) }; + let mut prev_tail = mem::replace(&mut unsafe { self.head_tail.as_mut() }.prev, entry); + // SAFETY: `entry` must meet all requirements for a mutable reference. + unsafe { entry.as_mut() }.prev = prev_tail; + unsafe { entry.as_mut() }.next = self.head_tail; + // SAFETY: `prev_tail` must meet all requirements for a mutable reference. + unsafe { prev_tail.as_mut() }.next = entry; + // unwrap ok: always `Some` on non-dummy entries + unsafe { (*entry.as_ptr()).value.as_ref() }.unwrap() + } + + /// Pops an entry from the front of the list. + /// + /// # Safety + /// + /// The caller must make sure to synchronize ending the borrow of the + /// return value and deallocation of the containing entry. + pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> { + unsafe { self.init() }; + + if self.is_empty() { + None + } else { + // BEFORE: + // /---------\ next ---> /-----\ next ---> /------\ + // ... |head_tail| |first| |second| ... + // \---------/ <--- prev \-----/ <--- prev \------/ + // + // AFTER: + // /---------\ next ---> /------\ + // ... |head_tail| |second| ... + // \---------/ <--- prev \------/ + let mut first = unsafe { self.head_tail.as_mut() }.next; + let mut second = unsafe { first.as_mut() }.next; + unsafe { self.head_tail.as_mut() }.next = second; + unsafe { second.as_mut() }.prev = self.head_tail; + unsafe { first.as_mut() }.next = NonNull::dangling(); + unsafe { first.as_mut() }.prev = NonNull::dangling(); + // unwrap ok: always `Some` on non-dummy entries + Some(unsafe { (*first.as_ptr()).value.as_ref() }.unwrap()) + } + } + + /// Removes an entry from the list. + /// + /// # Safety + /// + /// The caller must ensure that `entry` has been pushed onto `self` + /// prior to this call and has not moved since then. + pub unsafe fn remove(&mut self, entry: &mut UnsafeListEntry<T>) { + rtassert!(!self.is_empty()); + // BEFORE: + // /----\ next ---> /-----\ next ---> /----\ + // ... |prev| |entry| |next| ... + // \----/ <--- prev \-----/ <--- prev \----/ + // + // AFTER: + // /----\ next ---> /----\ + // ... |prev| |next| ... + // \----/ <--- prev \----/ + let mut prev = entry.prev; + let mut next = entry.next; + // SAFETY: `prev` and `next` must meet all requirements for a mutable reference.entry + unsafe { prev.as_mut() }.next = next; + unsafe { next.as_mut() }.prev = prev; + entry.next = NonNull::dangling(); + entry.prev = NonNull::dangling(); + } +} diff --git a/library/std/src/sys/sgx/waitqueue/unsafe_list/tests.rs b/library/std/src/sys/sgx/waitqueue/unsafe_list/tests.rs new file mode 100644 index 000000000..c653dee17 --- /dev/null +++ b/library/std/src/sys/sgx/waitqueue/unsafe_list/tests.rs @@ -0,0 +1,105 @@ +use super::*; +use crate::cell::Cell; + +/// # Safety +/// List must be valid. +unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) { + assert!(unsafe { list.pop() }.is_none(), "assertion failed: list is not empty"); +} + +#[test] +fn init_empty() { + unsafe { + assert_empty(&mut UnsafeList::<i32>::new()); + } +} + +#[test] +fn push_pop() { + unsafe { + let mut node = UnsafeListEntry::new(1234); + let mut list = UnsafeList::new(); + assert_eq!(list.push(&mut node), &1234); + assert_eq!(list.pop().unwrap(), &1234); + assert_empty(&mut list); + } +} + +#[test] +fn push_remove() { + unsafe { + let mut node = UnsafeListEntry::new(1234); + let mut list = UnsafeList::new(); + assert_eq!(list.push(&mut node), &1234); + list.remove(&mut node); + assert_empty(&mut list); + } +} + +#[test] +fn push_remove_pop() { + unsafe { + let mut node1 = UnsafeListEntry::new(11); + let mut node2 = UnsafeListEntry::new(12); + let mut node3 = UnsafeListEntry::new(13); + let mut node4 = UnsafeListEntry::new(14); + let mut node5 = UnsafeListEntry::new(15); + let mut list = UnsafeList::new(); + assert_eq!(list.push(&mut node1), &11); + assert_eq!(list.push(&mut node2), &12); + assert_eq!(list.push(&mut node3), &13); + assert_eq!(list.push(&mut node4), &14); + assert_eq!(list.push(&mut node5), &15); + + list.remove(&mut node1); + assert_eq!(list.pop().unwrap(), &12); + list.remove(&mut node3); + assert_eq!(list.pop().unwrap(), &14); + list.remove(&mut node5); + assert_empty(&mut list); + + assert_eq!(list.push(&mut node1), &11); + assert_eq!(list.pop().unwrap(), &11); + assert_empty(&mut list); + + assert_eq!(list.push(&mut node3), &13); + assert_eq!(list.push(&mut node4), &14); + list.remove(&mut node3); + list.remove(&mut node4); + assert_empty(&mut list); + } +} + +#[test] +fn complex_pushes_pops() { + unsafe { + let mut node1 = UnsafeListEntry::new(1234); + let mut node2 = UnsafeListEntry::new(4567); + let mut node3 = UnsafeListEntry::new(9999); + let mut node4 = UnsafeListEntry::new(8642); + let mut list = UnsafeList::new(); + list.push(&mut node1); + list.push(&mut node2); + assert_eq!(list.pop().unwrap(), &1234); + list.push(&mut node3); + assert_eq!(list.pop().unwrap(), &4567); + assert_eq!(list.pop().unwrap(), &9999); + assert_empty(&mut list); + list.push(&mut node4); + assert_eq!(list.pop().unwrap(), &8642); + assert_empty(&mut list); + } +} + +#[test] +fn cell() { + unsafe { + let mut node = UnsafeListEntry::new(Cell::new(0)); + let mut list = UnsafeList::new(); + let noderef = list.push(&mut node); + assert_eq!(noderef.get(), 0); + list.pop().unwrap().set(1); + assert_empty(&mut list); + assert_eq!(noderef.get(), 1); + } +} |