proper per object ts

This commit is contained in:
Joel Wejdenstål 2023-08-26 14:48:23 +02:00
parent a7c38cbf91
commit 59c7e73a7a
No known key found for this signature in database
GPG key ID: DF03CEFBB1A915AA
9 changed files with 499 additions and 84 deletions

View file

@ -4,8 +4,9 @@ version = "0.1.0"
edition = "2021"
[features]
default = ["tls-accel"]
default = ["tls-accel", "per-object"]
tls-accel = []
per-object = []
[dependencies]
cfg-if = "1.0.0"

View file

@ -1,3 +1,9 @@
mod thread_id;
#[cfg(feature = "per-object")]
mod thread_local_per_object;
pub use thread_id::thread_id;
#[cfg(feature = "per-object")]
pub use thread_local_per_object::{Iter, ThreadLocal};

View file

@ -1,19 +1,28 @@
mod allocator;
mod thread;
mod tids;
use allocator::SmallestFitAllocator;
use once_cell::unsync::Lazy;
use std::num::NonZeroUsize;
use std::sync::Mutex;
use tids::ThreadIdStorage;
use tids::ThreadIdStorageImpl;
use thread::SmallestFitAllocator;
use thread::Thread;
use tids::ThreadStorage;
use tids::ThreadStorageImpl;
static ID_ALLOCATOR: Mutex<Lazy<SmallestFitAllocator>> =
Mutex::new(Lazy::new(SmallestFitAllocator::new));
pub fn thread_id() -> NonZeroUsize {
if let Some(id) = ThreadIdStorageImpl::get() {
id
if let Some(thread) = ThreadStorageImpl::get() {
thread.id
} else {
lazy_init().id
}
}
pub fn thread() -> Thread {
if let Some(thread) = ThreadStorageImpl::get() {
thread
} else {
lazy_init()
}
@ -21,13 +30,15 @@ pub fn thread_id() -> NonZeroUsize {
#[cold]
#[inline(never)]
fn lazy_init() -> NonZeroUsize {
fn lazy_init() -> Thread {
let id = ID_ALLOCATOR.lock().unwrap().allocate();
ThreadIdStorageImpl::set(id);
let thread = Thread::new(id);
ThreadStorageImpl::set(thread);
tids::set_dtor(|| {
let id = ThreadIdStorageImpl::get().unwrap();
ID_ALLOCATOR.lock().unwrap().deallocate(id);
let thread = ThreadStorageImpl::get().unwrap();
ID_ALLOCATOR.lock().unwrap().deallocate(thread.id);
});
id
thread
}

View file

@ -1,6 +1,29 @@
use std::num::NonZeroUsize;
use std::collections::BinaryHeap;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::NonZeroUsize;
#[derive(Clone, Copy)]
pub struct Thread {
pub(crate) id: NonZeroUsize,
pub(crate) bucket: usize,
pub(crate) bucket_size: usize,
pub(crate) index: usize,
}
impl Thread {
pub(crate) fn new(id: NonZeroUsize) -> Self {
let bucket = (usize::BITS as usize) - id.leading_zeros() as usize;
let bucket_size = 1 << bucket.saturating_sub(1);
let index = id.get() ^ bucket_size;
Thread {
id,
bucket,
bucket_size,
index,
}
}
}
pub struct SmallestFitAllocator {
next: NonZeroUsize,

View file

@ -1,46 +1,36 @@
use super::ThreadIdStorage;
use super::super::thread::Thread;
use super::ThreadStorage;
use std::arch::asm;
use std::num::NonZeroUsize;
use std::ptr;
#[link_section = ".tdata"]
static THREAD_ID: usize = 0;
static THREAD: Option<Thread> = None;
pub struct AARCH64LinuxThreadIdStorage;
unsafe fn tls_ptr() -> *mut Option<Thread> {
let ptr: *mut Option<Thread>;
asm!(
"mrs {hi_tmp}, TPIDR_EL0",
"add {hi_tmp}, {hi_tmp}, :tprel_hi12:{symb}",
"add {ptr}, {hi_tmp}, :tprel_lo12_nc:{symb}",
symb = sym THREAD,
ptr = out(reg) ptr,
hi_tmp = out(reg) _,
options(nostack, pure, readonly, preserves_flags),
);
impl ThreadIdStorage for AARCH64LinuxThreadIdStorage {
fn get() -> Option<NonZeroUsize> {
unsafe {
let thread_id: usize;
ptr
}
asm!(
"mrs {hi_tmp}, TPIDR_EL0",
"add {hi_tmp}, {hi_tmp}, :tprel_hi12:{symb}",
"add {lo_tmp}, {hi_tmp}, :tprel_lo12_nc:{symb}",
"ldr {thread_id}, [{lo_tmp}]",
symb = sym THREAD_ID,
thread_id = lateout(reg) thread_id,
lo_tmp = out(reg) _,
hi_tmp = out(reg) _,
options(nostack, pure, readonly, preserves_flags),
);
pub struct AARCH64LinuxThreadStorage;
NonZeroUsize::new(thread_id)
}
impl ThreadStorage for AARCH64LinuxThreadStorage {
fn get() -> Option<Thread> {
unsafe { ptr::read(tls_ptr()) }
}
fn set(value: NonZeroUsize) {
fn set(value: Thread) {
unsafe {
asm!(
"mrs {hi_tmp}, TPIDR_EL0",
"add {hi_tmp}, {hi_tmp}, :tprel_hi12:{symb}",
"add {lo_tmp}, {hi_tmp}, :tprel_lo12_nc:{symb}",
"str {thread_id}, [{lo_tmp}]",
symb = sym THREAD_ID,
thread_id = in(reg) value.get(),
lo_tmp = out(reg) _,
hi_tmp = out(reg) _,
options(nostack, preserves_flags),
);
ptr::write(tls_ptr(), Option::Some(value));
}
}
}

View file

@ -1,19 +1,19 @@
use super::ThreadIdStorage;
use super::super::thread::Thread;
use super::ThreadStorage;
use std::cell::Cell;
use std::num::NonZeroUsize;
thread_local! {
static ID: Cell<Option<NonZeroUsize>> = Cell::new(None);
static ID: Cell<Option<Thread>> = Cell::new(None);
}
pub struct FallbackThreadIdStorage;
pub struct FallbackThreadStorage;
impl ThreadIdStorage for FallbackThreadIdStorage {
fn get() -> Option<NonZeroUsize> {
impl ThreadStorage for FallbackThreadStorage {
fn get() -> Option<Thread> {
ID.with(|id| id.get())
}
fn set(value: NonZeroUsize) {
fn set(value: Thread) {
ID.with(|id| id.set(Some(value)));
}
}

View file

@ -1,36 +1,45 @@
use super::thread::Thread;
use cfg_if::cfg_if;
use std::cell::Cell;
use std::num::NonZeroUsize;
pub trait ThreadIdStorage {
fn get() -> Option<NonZeroUsize>;
fn set(value: NonZeroUsize);
pub trait ThreadStorage {
fn get() -> Option<Thread>;
fn set(value: Thread);
}
cfg_if! {
if #[cfg(all(feature = "tls-accel", target_arch = "x86_64", target_os = "linux"))] {
mod x86_64_linux;
pub use x86_64_linux::X86_64LinuxThreadIdStorage as ThreadIdStorageImpl;
pub use x86_64_linux::X86_64LinuxThreadStorage as ThreadStorageImpl;
} else if #[cfg(all(feature = "tls-accel", target_arch = "aarch64", target_os = "linux"))] {
mod aarch64_linux;
pub use aarch64_linux::AARCH64LinuxThreadIdStorage as ThreadIdStorageImpl;
pub use aarch64_linux::AARCH64LinuxThreadStorage as ThreadStorageImpl;
} else {
mod fallback;
pub use fallback::FallbackThreadIdStorage as ThreadIdStorageImpl;
pub use fallback::FallbackThreadStorage as ThreadStorageImpl;
}
}
struct Dtor(fn());
impl Drop for Dtor {
fn drop(&mut self) {
(self.0)();
}
}
thread_local! {
static DTOR: Cell<fn()> = Cell::new(|| {});
static DTOR: Cell<Dtor> = Cell::new(Dtor(|| {}));
}
pub fn set_dtor(dtor: fn()) {
DTOR.with(|cell| cell.set(dtor));
DTOR.with(|cell| cell.set(Dtor(dtor)));
}
#[cfg(test)]
mod test {
use super::{ThreadIdStorage, ThreadIdStorageImpl};
use super::super::thread::Thread;
use super::{ThreadStorage, ThreadStorageImpl};
use once_cell::sync::Lazy;
use std::collections::HashSet;
use std::hint;
@ -46,11 +55,12 @@ mod test {
static KNOWN: Mutex<Lazy<HashSet<usize>>> = Mutex::new(Lazy::new(HashSet::new));
fn thread_id() -> usize {
match ThreadIdStorageImpl::get() {
Some(id) => id.get(),
match ThreadStorageImpl::get() {
Some(id) => id.id.get(),
None => {
let id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
ThreadIdStorageImpl::set(unsafe { NonZeroUsize::new_unchecked(id) });
let thread = Thread::new(unsafe { NonZeroUsize::new_unchecked(id) });
ThreadStorageImpl::set(thread);
id
}
}

View file

@ -1,34 +1,51 @@
use super::ThreadIdStorage;
use super::super::thread::Thread;
use super::ThreadStorage;
use std::arch::asm;
use std::num::NonZeroUsize;
use std::mem::MaybeUninit;
#[link_section = ".tdata"]
static THREAD_ID: usize = 0;
static THREAD: Option<Thread> = None;
pub struct X86_64LinuxThreadIdStorage;
impl ThreadIdStorage for X86_64LinuxThreadIdStorage {
fn get() -> Option<NonZeroUsize> {
let thread_id: usize;
pub struct X86_64LinuxThreadStorage;
impl ThreadStorage for X86_64LinuxThreadStorage {
fn get() -> Option<Thread> {
unsafe {
let mut thread: MaybeUninit<Thread> = MaybeUninit::uninit();
asm!(
"mov {thread_id}, qword ptr fs:[{symb}@TPOFF]",
symb = sym THREAD_ID,
thread_id = lateout(reg) thread_id,
options(nostack, pure, readonly, preserves_flags),
"mov {tmp}, qword ptr fs:[{symb}@TPOFF]",
"mov [{thread_id_ptr}], {tmp}",
"mov {tmp}, qword ptr fs:[{symb}@TPOFF+8]",
"mov [{thread_id_ptr}+8], {tmp}",
"mov {tmp}, qword ptr fs:[{symb}@TPOFF+16]",
"mov [{thread_id_ptr}+16], {tmp}",
"mov {tmp}, qword ptr fs:[{symb}@TPOFF+24]",
"mov [{thread_id_ptr}+24], {tmp}",
symb = sym THREAD,
thread_id_ptr = in(reg) &mut thread,
tmp = out(reg) _,
options(nostack, preserves_flags),
);
NonZeroUsize::new(thread_id)
Some(thread.assume_init())
}
}
fn set(value: NonZeroUsize) {
fn set(value: Thread) {
unsafe {
asm!(
"mov qword ptr fs:[{symb}@TPOFF], {thread_id}",
symb = sym THREAD_ID,
thread_id = in(reg) value.get(),
"mov {tmp}, [{thread_id_ptr}]",
"mov qword ptr fs:[{symb}@TPOFF], {tmp}",
"mov {tmp}, [{thread_id_ptr}+8]",
"mov qword ptr fs:[{symb}@TPOFF+8], {tmp}",
"mov {tmp}, [{thread_id_ptr}+16]",
"mov qword ptr fs:[{symb}@TPOFF+16], {tmp}",
"mov {tmp}, [{thread_id_ptr}+24]",
"mov qword ptr fs:[{symb}@TPOFF+24], {tmp}",
symb = sym THREAD,
thread_id_ptr = in(reg) &value,
tmp = out(reg) _,
options(nostack, preserves_flags),
);
}

View file

@ -0,0 +1,357 @@
use super::thread_id::thread;
use std::cell::UnsafeCell;
use std::mem::{self, MaybeUninit};
use std::ptr;
use std::sync::atomic::{self, AtomicBool, AtomicPtr, AtomicUsize, Ordering};
const BUCKETS: usize = (usize::BITS + 1) as usize;
pub struct ThreadLocal<T: Send> {
buckets: [AtomicPtr<Entry<T>>; BUCKETS],
pub threads: AtomicUsize,
}
struct Entry<T> {
present: AtomicBool,
value: UnsafeCell<MaybeUninit<T>>,
}
impl<T> Drop for Entry<T> {
fn drop(&mut self) {
unsafe {
if *self.present.get_mut() {
ptr::drop_in_place((*self.value.get()).as_mut_ptr());
}
}
}
}
unsafe impl<T: Send> Sync for ThreadLocal<T> {}
impl<T> ThreadLocal<T>
where
T: Send,
{
pub fn new() -> ThreadLocal<T> {
Self::with_capacity(2)
}
pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
let allocated_buckets = capacity
.checked_sub(1)
.map(|c| (usize::BITS as usize) - (c.leading_zeros() as usize) + 1)
.unwrap_or(0);
let mut buckets = [ptr::null_mut(); BUCKETS];
let mut bucket_size = 1;
for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
*bucket = allocate_bucket::<T>(bucket_size);
if i != 0 {
bucket_size <<= 1;
}
}
ThreadLocal {
// safety: `AtomicPtr` has the same representation as a pointer
buckets: unsafe { mem::transmute(buckets) },
threads: AtomicUsize::new(0),
}
}
pub fn get_or(&self, create: impl Fn() -> T) -> &T {
let thread = thread();
let bucket = unsafe { self.buckets.get_unchecked(thread.bucket) };
let mut bucket_ptr = bucket.load(Ordering::Acquire);
if bucket_ptr.is_null() {
let new_bucket = allocate_bucket(thread.bucket_size);
match bucket.compare_exchange(
ptr::null_mut(),
new_bucket,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => bucket_ptr = new_bucket,
// if the bucket value changed (from null), that means
// another thread stored a new bucket before we could,
// and we can free our bucket and use that one instead
Err(other) => unsafe {
let _ = Box::from_raw(ptr::slice_from_raw_parts_mut(
new_bucket,
thread.bucket_size,
));
bucket_ptr = other;
},
}
}
unsafe {
let entry = &*bucket_ptr.add(thread.index);
// read without atomic operations as only this thread can set the value.
if (&entry.present as *const _ as *const bool).read() {
(*entry.value.get()).assume_init_ref()
} else {
// insert the new element into the bucket
entry.value.get().write(MaybeUninit::new(create()));
entry.present.store(true, Ordering::Release);
self.threads.fetch_add(1, Ordering::Relaxed);
// seqcst: synchronize with the fence in `retire`:
// - if this fence comes first, the thread retiring will see the new thread count
// and our entry
// - if their fence comes first, we will see the new values of any pointers being
// retired by that thread
atomic::fence(Ordering::SeqCst);
(*entry.value.get()).assume_init_ref()
}
}
}
#[cfg(test)]
fn get(&self) -> Option<&T> {
let thread = thread();
let bucket_ptr =
unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
if bucket_ptr.is_null() {
return None;
}
unsafe {
let entry = &*bucket_ptr.add(thread.index);
// read without atomic operations as only this thread can set the value.
if (&entry.present as *const _ as *const bool).read() {
Some((*entry.value.get()).assume_init_ref())
} else {
None
}
}
}
pub fn iter(&self) -> Iter<'_, T> {
Iter {
bucket: 0,
bucket_size: 1,
index: 0,
thread_local: self,
}
}
}
impl<T> Default for ThreadLocal<T>
where
T: Send,
{
fn default() -> Self {
Self::new()
}
}
impl<T> Drop for ThreadLocal<T>
where
T: Send,
{
fn drop(&mut self) {
let mut bucket_size = 1;
for (i, bucket) in self.buckets.iter_mut().enumerate() {
let bucket_ptr = *bucket.get_mut();
let this_bucket_size = bucket_size;
if i != 0 {
bucket_size <<= 1;
}
if bucket_ptr.is_null() {
continue;
}
unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket_ptr, this_bucket_size)) };
}
}
}
pub struct Iter<'a, T>
where
T: Send,
{
thread_local: &'a ThreadLocal<T>,
bucket: usize,
bucket_size: usize,
index: usize,
}
impl<'a, T> Iterator for Iter<'a, T>
where
T: Send,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
// we have to check all the buckets here because we reuse
// thread IDs. keeping track of the number of values and only
// yielding that many here wouldn't work, because a new thread
// could join and be inserted into a middle bucket, and we would
// yield that instead of an active thread that actually needs to
// participate in reference counting. yielding extra values is fine,
// but not yielding all originally active threads is not.
while self.bucket < BUCKETS {
let bucket = unsafe {
self.thread_local
.buckets
.get_unchecked(self.bucket)
.load(Ordering::Acquire)
};
if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
return Some(unsafe { &*(*entry.value.get()).as_ptr() });
}
}
}
if self.bucket != 0 {
self.bucket_size <<= 1;
}
self.bucket += 1;
self.index = 0;
}
None
}
}
fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
Box::into_raw(
(0..size)
.map(|_| Entry::<T> {
present: AtomicBool::new(false),
value: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect(),
) as *mut _
}
#[cfg(test)]
#[allow(clippy::redundant_closure)]
mod tests {
use super::ThreadLocal;
use std::cell::RefCell;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::Arc;
use std::thread;
fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
let count = AtomicUsize::new(0);
Arc::new(move || count.fetch_add(1, Relaxed))
}
#[test]
fn same_thread() {
let create = make_create();
let tls = ThreadLocal::with_capacity(1);
assert_eq!(None, tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
}
#[test]
fn different_thread() {
let create = make_create();
let tls = Arc::new(ThreadLocal::with_capacity(1));
assert_eq!(None, tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(None, tls2.get());
assert_eq!(1, *tls2.get_or(|| create2()));
assert_eq!(Some(&1), tls2.get());
})
.join()
.unwrap();
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
}
#[test]
fn iter() {
let tls = Arc::new(ThreadLocal::with_capacity(1));
tls.get_or(|| Box::new(1));
let tls2 = tls.clone();
thread::spawn(move || {
tls2.get_or(|| Box::new(2));
let tls3 = tls2.clone();
thread::spawn(move || {
tls3.get_or(|| Box::new(3));
})
.join()
.unwrap();
drop(tls2);
})
.join()
.unwrap();
let tls = Arc::try_unwrap(tls).unwrap_or_else(|_| panic!("."));
let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);
}
#[test]
fn iter_snapshot() {
let tls = Arc::new(ThreadLocal::with_capacity(1));
tls.get_or(|| Box::new(1));
let iterator = tls.iter();
tls.get_or(|| Box::new(2));
let v = iterator.map(|x| **x).collect::<Vec<i32>>();
assert_eq!(vec![1], v);
}
#[test]
fn test_drop() {
let local = ThreadLocal::with_capacity(1);
struct Dropped(Arc<AtomicUsize>);
impl Drop for Dropped {
fn drop(&mut self) {
self.0.fetch_add(1, Relaxed);
}
}
let dropped = Arc::new(AtomicUsize::new(0));
local.get_or(|| Dropped(dropped.clone()));
assert_eq!(dropped.load(Relaxed), 0);
drop(local);
assert_eq!(dropped.load(Relaxed), 1);
}
#[test]
fn is_sync() {
fn foo<T: Sync>() {}
foo::<ThreadLocal<String>>();
foo::<ThreadLocal<RefCell<String>>>();
}
}