summaryrefslogtreecommitdiffstats
path: root/library/std/src/sys/windows/thread_local_key.rs
blob: 17628b7579b8db3ac2a8eaa89e4fc20b1ce81af7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
use crate::cell::UnsafeCell;
use crate::ptr;
use crate::sync::atomic::{
    AtomicPtr, AtomicU32,
    Ordering::{AcqRel, Acquire, Relaxed, Release},
};
use crate::sys::c;

#[cfg(test)]
mod tests;

type Key = c::DWORD;
type Dtor = unsafe extern "C" fn(*mut u8);

// Turns out, like pretty much everything, Windows is pretty close the
// functionality that Unix provides, but slightly different! In the case of
// TLS, Windows does not provide an API to provide a destructor for a TLS
// variable. This ends up being pretty crucial to this implementation, so we
// need a way around this.
//
// The solution here ended up being a little obscure, but fear not, the
// internet has informed me [1][2] that this solution is not unique (no way
// I could have thought of it as well!). The key idea is to insert some hook
// somewhere to run arbitrary code on thread termination. With this in place
// we'll be able to run anything we like, including all TLS destructors!
//
// To accomplish this feat, we perform a number of threads, all contained
// within this module:
//
// * All TLS destructors are tracked by *us*, not the Windows runtime. This
//   means that we have a global list of destructors for each TLS key that
//   we know about.
// * When a thread exits, we run over the entire list and run dtors for all
//   non-null keys. This attempts to match Unix semantics in this regard.
//
// For more details and nitty-gritty, see the code sections below!
//
// [1]: https://www.codeproject.com/Articles/8113/Thread-Local-Storage-The-C-Way
// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base/threading/thread_local_storage_win.cc#L42

pub struct StaticKey {
    /// The key value shifted up by one. Since TLS_OUT_OF_INDEXES == DWORD::MAX
    /// is not a valid key value, this allows us to use zero as sentinel value
    /// without risking overflow.
    key: AtomicU32,
    dtor: Option<Dtor>,
    next: AtomicPtr<StaticKey>,
    /// Currently, destructors cannot be unregistered, so we cannot use racy
    /// initialization for keys. Instead, we need synchronize initialization.
    /// Use the Windows-provided `Once` since it does not require TLS.
    once: UnsafeCell<c::INIT_ONCE>,
}

impl StaticKey {
    #[inline]
    pub const fn new(dtor: Option<Dtor>) -> StaticKey {
        StaticKey {
            key: AtomicU32::new(0),
            dtor,
            next: AtomicPtr::new(ptr::null_mut()),
            once: UnsafeCell::new(c::INIT_ONCE_STATIC_INIT),
        }
    }

    #[inline]
    pub unsafe fn set(&'static self, val: *mut u8) {
        let r = c::TlsSetValue(self.key(), val.cast());
        debug_assert_eq!(r, c::TRUE);
    }

    #[inline]
    pub unsafe fn get(&'static self) -> *mut u8 {
        c::TlsGetValue(self.key()).cast()
    }

    #[inline]
    unsafe fn key(&'static self) -> Key {
        match self.key.load(Acquire) {
            0 => self.init(),
            key => key - 1,
        }
    }

    #[cold]
    unsafe fn init(&'static self) -> Key {
        if self.dtor.is_some() {
            let mut pending = c::FALSE;
            let r = c::InitOnceBeginInitialize(self.once.get(), 0, &mut pending, ptr::null_mut());
            assert_eq!(r, c::TRUE);

            if pending == c::FALSE {
                // Some other thread initialized the key, load it.
                self.key.load(Relaxed) - 1
            } else {
                let key = c::TlsAlloc();
                if key == c::TLS_OUT_OF_INDEXES {
                    // Wakeup the waiting threads before panicking to avoid deadlock.
                    c::InitOnceComplete(self.once.get(), c::INIT_ONCE_INIT_FAILED, ptr::null_mut());
                    panic!("out of TLS indexes");
                }

                self.key.store(key + 1, Release);
                register_dtor(self);

                let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut());
                debug_assert_eq!(r, c::TRUE);

                key
            }
        } else {
            // If there is no destructor to clean up, we can use racy initialization.

            let key = c::TlsAlloc();
            assert_ne!(key, c::TLS_OUT_OF_INDEXES, "out of TLS indexes");

            match self.key.compare_exchange(0, key + 1, AcqRel, Acquire) {
                Ok(_) => key,
                Err(new) => {
                    // Some other thread completed initialization first, so destroy
                    // our key and use theirs.
                    let r = c::TlsFree(key);
                    debug_assert_eq!(r, c::TRUE);
                    new - 1
                }
            }
        }
    }
}

unsafe impl Send for StaticKey {}
unsafe impl Sync for StaticKey {}

// -------------------------------------------------------------------------
// Dtor registration
//
// Windows has no native support for running destructors so we manage our own
// list of destructors to keep track of how to destroy keys. We then install a
// callback later to get invoked whenever a thread exits, running all
// appropriate destructors.
//
// Currently unregistration from this list is not supported. A destructor can be
// registered but cannot be unregistered. There's various simplifying reasons
// for doing this, the big ones being:
//
// 1. Currently we don't even support deallocating TLS keys, so normal operation
//    doesn't need to deallocate a destructor.
// 2. There is no point in time where we know we can unregister a destructor
//    because it could always be getting run by some remote thread.
//
// Typically processes have a statically known set of TLS keys which is pretty
// small, and we'd want to keep this memory alive for the whole process anyway
// really.

static DTORS: AtomicPtr<StaticKey> = AtomicPtr::new(ptr::null_mut());

/// Should only be called once per key, otherwise loops or breaks may occur in
/// the linked list.
unsafe fn register_dtor(key: &'static StaticKey) {
    let this = <*const StaticKey>::cast_mut(key);
    // Use acquire ordering to pass along the changes done by the previously
    // registered keys when we store the new head with release ordering.
    let mut head = DTORS.load(Acquire);
    loop {
        key.next.store(head, Relaxed);
        match DTORS.compare_exchange_weak(head, this, Release, Acquire) {
            Ok(_) => break,
            Err(new) => head = new,
        }
    }
}

// -------------------------------------------------------------------------
// Where the Magic (TM) Happens
//
// If you're looking at this code, and wondering "what is this doing?",
// you're not alone! I'll try to break this down step by step:
//
// # What's up with CRT$XLB?
//
// For anything about TLS destructors to work on Windows, we have to be able
// to run *something* when a thread exits. To do so, we place a very special
// static in a very special location. If this is encoded in just the right
// way, the kernel's loader is apparently nice enough to run some function
// of ours whenever a thread exits! How nice of the kernel!
//
// Lots of detailed information can be found in source [1] above, but the
// gist of it is that this is leveraging a feature of Microsoft's PE format
// (executable format) which is not actually used by any compilers today.
// This apparently translates to any callbacks in the ".CRT$XLB" section
// being run on certain events.
//
// So after all that, we use the compiler's #[link_section] feature to place
// a callback pointer into the magic section so it ends up being called.
//
// # What's up with this callback?
//
// The callback specified receives a number of parameters from... someone!
// (the kernel? the runtime? I'm not quite sure!) There are a few events that
// this gets invoked for, but we're currently only interested on when a
// thread or a process "detaches" (exits). The process part happens for the
// last thread and the thread part happens for any normal thread.
//
// # Ok, what's up with running all these destructors?
//
// This will likely need to be improved over time, but this function
// attempts a "poor man's" destructor callback system. Once we've got a list
// of what to run, we iterate over all keys, check their values, and then run
// destructors if the values turn out to be non null (setting them to null just
// beforehand). We do this a few times in a loop to basically match Unix
// semantics. If we don't reach a fixed point after a short while then we just
// inevitably leak something most likely.
//
// # The article mentions weird stuff about "/INCLUDE"?
//
// It sure does! Specifically we're talking about this quote:
//
//      The Microsoft run-time library facilitates this process by defining a
//      memory image of the TLS Directory and giving it the special name
//      “__tls_used” (Intel x86 platforms) or “_tls_used” (other platforms). The
//      linker looks for this memory image and uses the data there to create the
//      TLS Directory. Other compilers that support TLS and work with the
//      Microsoft linker must use this same technique.
//
// Basically what this means is that if we want support for our TLS
// destructors/our hook being called then we need to make sure the linker does
// not omit this symbol. Otherwise it will omit it and our callback won't be
// wired up.
//
// We don't actually use the `/INCLUDE` linker flag here like the article
// mentions because the Rust compiler doesn't propagate linker flags, but
// instead we use a shim function which performs a volatile 1-byte load from
// the address of the symbol to ensure it sticks around.

#[link_section = ".CRT$XLB"]
#[allow(dead_code, unused_variables)]
#[used] // we don't want LLVM eliminating this symbol for any reason, and
// when the symbol makes it to the linker the linker will take over
pub static p_thread_callback: unsafe extern "system" fn(c::LPVOID, c::DWORD, c::LPVOID) =
    on_tls_callback;

#[allow(dead_code, unused_variables)]
unsafe extern "system" fn on_tls_callback(h: c::LPVOID, dwReason: c::DWORD, pv: c::LPVOID) {
    if dwReason == c::DLL_THREAD_DETACH || dwReason == c::DLL_PROCESS_DETACH {
        run_dtors();
        #[cfg(target_thread_local)]
        super::thread_local_dtor::run_keyless_dtors();
    }

    // See comments above for what this is doing. Note that we don't need this
    // trickery on GNU windows, just on MSVC.
    reference_tls_used();
    #[cfg(target_env = "msvc")]
    unsafe fn reference_tls_used() {
        extern "C" {
            static _tls_used: u8;
        }
        crate::intrinsics::volatile_load(&_tls_used);
    }
    #[cfg(not(target_env = "msvc"))]
    unsafe fn reference_tls_used() {}
}

#[allow(dead_code)] // actually called below
unsafe fn run_dtors() {
    for _ in 0..5 {
        let mut any_run = false;

        // Use acquire ordering to observe key initialization.
        let mut cur = DTORS.load(Acquire);
        while !cur.is_null() {
            let key = (*cur).key.load(Relaxed) - 1;
            let dtor = (*cur).dtor.unwrap();

            let ptr = c::TlsGetValue(key);
            if !ptr.is_null() {
                c::TlsSetValue(key, ptr::null_mut());
                dtor(ptr as *mut _);
                any_run = true;
            }

            cur = (*cur).next.load(Relaxed);
        }

        if !any_run {
            break;
        }
    }
}