完整代码:https://github.com/chiehw/hello_rust/blob/main/crates/counter/src/lib.rs |
定义 Trait
Trait 可以看作是一种能力的抽象,和接口有点类似。Trait 还能作为泛型约束条件,作为参数的限制条件。
pub trait AtomicCounter: Send + Sync {type PrimitiveType;fn get(&self) -> Self::PrimitiveType; // 获取当前计数器的值。fn increase(&self) -> Self::PrimitiveType; // 自增,并返回上一次的值fn add(&self, count: Self::PrimitiveType) -> Self::PrimitiveType; // 添加一个数,并返回上一次的值fn reset(&self) -> Self::PrimitiveType; // 重置计数器fn into_inner(self) -> Self::PrimitiveType; // 获取内部值
}
简单的测试用例TDD
使用测试驱动开发可以让目标更明确,这里先写个简单的测试案例。
#[cfg(test)]
mod tests {use super::*;fn test_simple<Counter>(counter: Counter)whereCounter: AtomicCounter<PrimitiveType = usize>, // 使用 Trait 作为泛型约束条件{counter.reset();assert_eq!(0, counter.add(5));assert_eq!(5, counter.increase());assert_eq!(6, counter.get())}#[test]fn it_works() {test_simple(RelaxedCounter::new(10));}
}
亿点细节
直接封装 AtomicUsize
#[derive(Default, Debug)]
pub struct ConsistentCounter(AtomicUsize);impl ConsistentCounter {pub fn new(init_num: usize) -> ConsistentCounter {ConsistentCounter(AtomicUsize::new(init_num))}
}impl AtomicCounter for ConsistentCounter {type PrimitiveType = usize;fn get(&self) -> Self::PrimitiveType {self.0.load(Ordering::SeqCst)}fn increase(&self) -> Self::PrimitiveType {self.add(1)}fn add(&self, count: Self::PrimitiveType) -> Self::PrimitiveType {self.0.fetch_add(count, Ordering::SeqCst)}fn reset(&self) -> Self::PrimitiveType {self.0.swap(0, Ordering::SeqCst)}fn into_inner(self) -> Self::PrimitiveType {self.0.into_inner()}
}
增加测试用例
使用多线程同时对计数器进行操作,然后判断计数的结果是否正确。更多的测试案例请查看【完整代码】
fn test_increase<Counter>(counter: Arc<Counter>)whereCounter: AtomicCounter<PrimitiveType = usize> + Debug + 'static,{println!("[+] test_increase: Spawning {} thread, each with {}", NUM_THREADS, NUM_ITERATIONS);let mut join_handles = Vec::new();// 创建 NUM_THREADS 个线程,同时使用 increase 函数for _ in 0..NUM_THREADS {let counter_ref = counter.clone();join_handles.push(thread::spawn(move || {let counter: &Counter = counter_ref.deref();for _ in 0..NUM_ITERATIONS {counter.increase();}}));}// 等待线程完成for handle in join_handles {handle.join().unwrap();}let count = Arc::try_unwrap(counter).unwrap().into_inner();let excepted_num = NUM_ITERATIONS * NUM_THREADS;println!("[+] test_increase: get count {}, excepted num is {}", count, excepted_num);// 确定 count 正确assert_eq!(count, excepted_num)}
参考教程:
- 谈谈 C++ 中的内存顺序 (Memory Order):https://luyuhuang.tech/2022/06/25/cpp-memory-order.html#happens-before