diff --git a/compiler-rt/lib/tsan/rtl/tsan_interceptors.cc b/compiler-rt/lib/tsan/rtl/tsan_interceptors.cc index 9a184c79798..733decfe52c 100644 --- a/compiler-rt/lib/tsan/rtl/tsan_interceptors.cc +++ b/compiler-rt/lib/tsan/rtl/tsan_interceptors.cc @@ -1021,7 +1021,7 @@ TSAN_INTERCEPTOR(int, pthread_create, TSAN_INTERCEPTOR(int, pthread_join, void *th, void **ret) { SCOPED_INTERCEPTOR_RAW(pthread_join, th, ret); - int tid = ThreadTid(thr, pc, (uptr)th); + int tid = ThreadConsumeTid(thr, pc, (uptr)th); ThreadIgnoreBegin(thr, pc); int res = BLOCK_REAL(pthread_join)(th, ret); ThreadIgnoreEnd(thr, pc); @@ -1034,8 +1034,8 @@ TSAN_INTERCEPTOR(int, pthread_join, void *th, void **ret) { DEFINE_REAL_PTHREAD_FUNCTIONS TSAN_INTERCEPTOR(int, pthread_detach, void *th) { - SCOPED_TSAN_INTERCEPTOR(pthread_detach, th); - int tid = ThreadTid(thr, pc, (uptr)th); + SCOPED_INTERCEPTOR_RAW(pthread_detach, th); + int tid = ThreadConsumeTid(thr, pc, (uptr)th); int res = REAL(pthread_detach)(th); if (res == 0) { ThreadDetach(thr, pc, tid); @@ -1055,8 +1055,8 @@ TSAN_INTERCEPTOR(void, pthread_exit, void *retval) { #if SANITIZER_LINUX TSAN_INTERCEPTOR(int, pthread_tryjoin_np, void *th, void **ret) { - SCOPED_TSAN_INTERCEPTOR(pthread_tryjoin_np, th, ret); - int tid = ThreadTid(thr, pc, (uptr)th); + SCOPED_INTERCEPTOR_RAW(pthread_tryjoin_np, th, ret); + int tid = ThreadConsumeTid(thr, pc, (uptr)th); ThreadIgnoreBegin(thr, pc); int res = REAL(pthread_tryjoin_np)(th, ret); ThreadIgnoreEnd(thr, pc); @@ -1069,8 +1069,8 @@ TSAN_INTERCEPTOR(int, pthread_tryjoin_np, void *th, void **ret) { TSAN_INTERCEPTOR(int, pthread_timedjoin_np, void *th, void **ret, const struct timespec *abstime) { - SCOPED_TSAN_INTERCEPTOR(pthread_timedjoin_np, th, ret, abstime); - int tid = ThreadTid(thr, pc, (uptr)th); + SCOPED_INTERCEPTOR_RAW(pthread_timedjoin_np, th, ret, abstime); + int tid = ThreadConsumeTid(thr, pc, (uptr)th); ThreadIgnoreBegin(thr, pc); int res = BLOCK_REAL(pthread_timedjoin_np)(th, ret, abstime); ThreadIgnoreEnd(thr, pc); diff --git a/compiler-rt/lib/tsan/rtl/tsan_rtl.h b/compiler-rt/lib/tsan/rtl/tsan_rtl.h index 3a8231bda9a..30e144fbd00 100644 --- a/compiler-rt/lib/tsan/rtl/tsan_rtl.h +++ b/compiler-rt/lib/tsan/rtl/tsan_rtl.h @@ -772,7 +772,7 @@ int ThreadCreate(ThreadState *thr, uptr pc, uptr uid, bool detached); void ThreadStart(ThreadState *thr, int tid, tid_t os_id, ThreadType thread_type); void ThreadFinish(ThreadState *thr); -int ThreadTid(ThreadState *thr, uptr pc, uptr uid); +int ThreadConsumeTid(ThreadState *thr, uptr pc, uptr uid); void ThreadJoin(ThreadState *thr, uptr pc, int tid); void ThreadDetach(ThreadState *thr, uptr pc, int tid); void ThreadFinalize(ThreadState *thr); diff --git a/compiler-rt/lib/tsan/rtl/tsan_rtl_thread.cc b/compiler-rt/lib/tsan/rtl/tsan_rtl_thread.cc index fd95cfed4f5..13e457bd770 100644 --- a/compiler-rt/lib/tsan/rtl/tsan_rtl_thread.cc +++ b/compiler-rt/lib/tsan/rtl/tsan_rtl_thread.cc @@ -285,19 +285,34 @@ void ThreadFinish(ThreadState *thr) { ctx->thread_registry->FinishThread(thr->tid); } -static bool FindThreadByUid(ThreadContextBase *tctx, void *arg) { - uptr uid = (uptr)arg; - if (tctx->user_id == uid && tctx->status != ThreadStatusInvalid) { +struct ConsumeThreadContext { + uptr uid; + ThreadContextBase* tctx; +}; + +static bool ConsumeThreadByUid(ThreadContextBase *tctx, void *arg) { + ConsumeThreadContext *findCtx = (ConsumeThreadContext*)arg; + if (tctx->user_id == findCtx->uid && tctx->status != ThreadStatusInvalid) { + if (findCtx->tctx) { + // Ensure that user_id is unique. If it's not the case we are screwed. + // Something went wrong before, but now there is no way to recover. + // Returning a wrong thread is not an option, it may lead to very hard + // to debug false positives (e.g. if we join a wrong thread). + Report("ThreadSanitizer: dup thread with used id 0x%zx\n", findCtx->uid); + Die(); + } + findCtx->tctx = tctx; tctx->user_id = 0; - return true; } return false; } -int ThreadTid(ThreadState *thr, uptr pc, uptr uid) { - int res = ctx->thread_registry->FindThread(FindThreadByUid, (void*)uid); - DPrintf("#%d: ThreadTid uid=%zu tid=%d\n", thr->tid, uid, res); - return res; +int ThreadConsumeTid(ThreadState *thr, uptr pc, uptr uid) { + ConsumeThreadContext findCtx = {uid, nullptr}; + ctx->thread_registry->FindThread(ConsumeThreadByUid, &findCtx); + int tid = findCtx.tctx ? findCtx.tctx->tid : ThreadRegistry::kUnknownTid; + DPrintf("#%d: ThreadTid uid=%zu tid=%d\n", thr->tid, uid, tid); + return tid; } void ThreadJoin(ThreadState *thr, uptr pc, int tid) {