MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Rust线程的使用技巧

2021-08-237.9k 阅读

Rust线程基础

在Rust中,线程是一种轻量级的并发执行单元。Rust标准库提供了std::thread模块来支持线程的创建和管理。创建一个新线程非常简单,通过thread::spawn函数就可以启动一个新线程。例如:

use std::thread;

fn main() {
    thread::spawn(|| {
        println!("This is a new thread!");
    });
    println!("This is the main thread.");
}

在这个例子中,thread::spawn接受一个闭包作为参数,闭包中的代码会在新线程中执行。但是运行这段代码,你可能会发现新线程中的println!语句并没有输出。这是因为主线程在新线程执行完之前就结束了,整个程序随之结束。为了避免这种情况,我们可以使用join方法来等待新线程完成。

use std::thread;

fn main() {
    let handle = thread::spawn(|| {
        println!("This is a new thread!");
    });
    handle.join().unwrap();
    println!("This is the main thread.");
}

join方法会阻塞主线程,直到被调用的线程结束。unwrap方法用于处理join可能返回的错误,这里我们假设线程不会发生错误直接unwrap。

线程间的数据共享

不可变数据共享

在Rust中,共享不可变数据是相对简单的。因为不可变数据不存在数据竞争的问题,多个线程可以安全地读取相同的数据。例如:

use std::thread;

fn main() {
    let data = String::from("Hello, Rust!");
    let handle = thread::spawn(move || {
        println!("Data in new thread: {}", data);
    });
    handle.join().unwrap();
}

这里我们将data通过move关键字转移到新线程中,新线程可以读取这个字符串。因为data是不可变的,所以不存在数据竞争。

可变数据共享

共享可变数据则要复杂一些,因为Rust的所有权和借用规则不允许在多个线程间同时存在可变引用。为了实现线程间可变数据的共享,Rust提供了一些同步原语,比如Mutex(互斥锁)。Mutex可以保证在任何时刻只有一个线程能够访问被它保护的数据。

use std::sync::{Mutex, Arc};
use std::thread;

fn main() {
    let data = Arc::new(Mutex::new(0));
    let mut handles = vec![];

    for _ in 0..10 {
        let data_clone = Arc::clone(&data);
        let handle = thread::spawn(move || {
            let mut num = data_clone.lock().unwrap();
            *num += 1;
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("Final value: {}", data.lock().unwrap());
}

在这个例子中,我们使用Arc(原子引用计数)来在多个线程间共享Mutex,因为Mutex本身不具备SendSync特性,不能直接在线程间传递,而Arc<Mutex<T>>是可以的。lock方法会尝试获取锁,如果锁可用则返回一个MutexGuard,这是一个智能指针,在其生命周期结束时会自动释放锁。通过MutexGuard我们可以安全地修改被保护的数据。

线程通信

使用通道(Channel)

通道是线程间通信的常用方式。Rust标准库提供了std::sync::mpsc模块来创建多生产者 - 单消费者(MPSC)通道。例如:

use std::sync::mpsc;
use std::thread;

fn main() {
    let (sender, receiver) = mpsc::channel();

    thread::spawn(move || {
        let message = String::from("Hello from another thread!");
        sender.send(message).unwrap();
    });

    let received = receiver.recv().unwrap();
    println!("Received: {}", received);
}

在这个例子中,mpsc::channel创建了一个通道,返回一个发送者(sender)和一个接收者(receiver)。发送者线程通过send方法向通道发送数据,接收者线程通过recv方法从通道接收数据。recv方法是阻塞的,直到有数据可用。

无缓冲通道和有缓冲通道

上述例子中的通道是无缓冲的,即send方法会阻塞,直到有线程调用recv方法接收数据。我们也可以创建有缓冲的通道,例如:

use std::sync::mpsc;
use std::thread;

fn main() {
    let (sender, receiver) = mpsc::channel::<i32>(10); // 创建一个缓冲大小为10的通道

    for i in 0..5 {
        sender.send(i).unwrap();
    }

    for _ in 0..5 {
        let received = receiver.recv().unwrap();
        println!("Received: {}", received);
    }
}

在这个例子中,我们创建了一个缓冲大小为10的通道,所以在接收者接收数据之前,发送者可以发送最多10个数据而不会阻塞。

线程安全与数据竞争

数据竞争的定义与危害

数据竞争是指多个线程同时访问共享的可变数据,并且至少有一个线程是在进行写操作,同时没有适当的同步机制来协调这些访问。数据竞争会导致未定义行为,程序可能出现各种奇怪的错误,包括崩溃、数据损坏等。在Rust中,通过所有权系统和同步原语,我们可以有效地避免数据竞争。

Rust如何保证线程安全

Rust通过类型系统和所有权规则来保证内存安全,同时提供了一系列同步原语来保证线程安全。例如MutexRwLockRwLock(读写锁)允许多个线程同时进行读操作,但只允许一个线程进行写操作。

use std::sync::{RwLock, Arc};
use std::thread;

fn main() {
    let data = Arc::new(RwLock::new(0));
    let mut handles = vec![];

    for _ in 0..5 {
        let data_clone = Arc::clone(&data);
        handles.push(thread::spawn(move || {
            let read_data = data_clone.read().unwrap();
            println!("Read value: {}", read_data);
        }));
    }

    for _ in 0..5 {
        let data_clone = Arc::clone(&data);
        handles.push(thread::spawn(move || {
            let mut write_data = data_clone.write().unwrap();
            *write_data += 1;
        }));
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("Final value: {}", data.read().unwrap());
}

在这个例子中,读操作可以同时进行,因为read方法返回的RwLockReadGuard允许多个同时存在。而写操作通过write方法获取RwLockWriteGuard,此时其他线程无法进行读写操作,从而保证了数据的一致性。

线程池

为什么需要线程池

在一些应用场景中,频繁地创建和销毁线程会带来较大的开销。线程池可以复用线程,减少这种开销。线程池预先创建一定数量的线程,当有任务到达时,线程池中的线程可以执行这些任务,任务完成后线程不会被销毁,而是等待下一个任务。

实现简单的线程池

下面是一个简单的线程池实现示例:

use std::sync::{Arc, Mutex};
use std::thread;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::time::Duration;

struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<Sender<Job>>,
}

type Job = Box<dyn FnOnce() + Send + 'static>;

struct Worker {
    id: usize,
    thread: thread::JoinHandle<()>,
}

impl Worker {
    fn new(id: usize, receiver: Receiver<Job>) -> Worker {
        let thread = thread::spawn(move || {
            loop {
                match receiver.recv() {
                    Ok(job) => {
                        println!("Worker {} got a job; executing.", id);
                        job();
                    }
                    Err(_) => {
                        println!("Worker {} shutting down.", id);
                        break;
                    }
                }
            }
        });

        Worker {
            id,
            thread,
        }
    }
}

impl ThreadPool {
    fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        let (sender, receiver) = channel();
        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);

        for id in 0..size {
            let receiver_clone = Arc::clone(&receiver);
            workers.push(Worker::new(id, receiver_clone));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.as_ref().unwrap().send(job).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.take());

        for worker in &mut self.workers {
            println!("Shutting down worker {}", worker.id);
        }

        for worker in &mut self.workers {
            if let Err(e) = worker.thread.join() {
                println!("Error joining thread: {}", e);
            }
        }
    }
}

你可以这样使用这个线程池:

fn main() {
    let pool = ThreadPool::new(4);

    for i in 0..8 {
        let i = i;
        pool.execute(move || {
            println!("Task {} is running on a thread from the pool.", i);
            thread::sleep(Duration::from_secs(1));
            println!("Task {} has finished.", i);
        });
    }
    println!("All tasks have been submitted.");
    thread::sleep(Duration::from_secs(5));
    println!("Main thread is done.");
}

在这个线程池实现中,ThreadPool结构体包含一个workers向量,用于存储线程,以及一个Sender用于发送任务。Worker结构体包含线程的ID和线程的JoinHandleexecute方法用于向线程池提交任务,Drop实现用于在ThreadPool被销毁时正确地关闭所有线程。

线程本地化存储(TLS)

TLS的概念

线程本地化存储(TLS)允许每个线程拥有自己独立的数据副本。在Rust中,thread::local模块提供了对TLS的支持。这在一些场景下非常有用,比如每个线程需要有自己独立的日志记录器或者数据库连接等。

使用TLS

下面是一个简单的使用TLS的示例:

use std::thread;
use std::thread::local;

fn main() {
    let local_num = local!(static LOCAL_NUM: u32 = 0);

    local_num.with(|num| {
        *num.borrow_mut() = 10;
    });

    let handle = thread::spawn(move || {
        local_num.with(|num| {
            println!("Local num in new thread: {}", *num.borrow());
        });
    });

    local_num.with(|num| {
        println!("Local num in main thread: {}", *num.borrow());
    });

    handle.join().unwrap();
}

在这个例子中,我们通过local!(static ...)语法创建了一个线程本地变量LOCAL_NUMwith方法用于在每个线程中访问和修改这个本地变量。每个线程中的LOCAL_NUM是独立的,互不影响。

线程错误处理

线程恐慌(Panic)

当线程中的代码发生恐慌(panic)时,默认情况下整个程序会终止。但是我们可以通过一些方法来捕获线程中的恐慌,避免程序崩溃。例如:

use std::thread;

fn main() {
    let handle = thread::spawn(|| {
        panic!("This thread is panicking!");
    });

    match handle.join() {
        Ok(_) => println!("Thread finished successfully."),
        Err(panic) => println!("Thread panicked: {:?}", panic),
    }
}

在这个例子中,join方法返回一个Result,如果线程正常结束则为Ok,如果线程恐慌则为Err,通过这种方式我们可以捕获并处理线程中的恐慌。

自定义错误处理

除了捕获恐慌,我们还可以在创建线程时设置自定义的错误处理函数。例如:

use std::thread;
use std::panic;

fn main() {
    let handle = thread::Builder::new()
        .name("my_thread".to_string())
        .panic_handler(Box::new(|panic_info| {
            println!("Thread panicked: {:?}", panic_info);
        }))
        .spawn(|| {
            panic!("This thread is panicking!");
        })
       .unwrap();

    handle.join().unwrap();
}

在这个例子中,我们通过thread::Builder创建线程,并设置了一个自定义的恐慌处理函数。当线程发生恐慌时,会调用这个自定义的处理函数,而不是默认的终止程序行为。

高级线程技巧

条件变量(Condvar)

条件变量用于线程间的同步,它通常与Mutex一起使用。条件变量允许一个线程等待某个条件满足,而其他线程可以通知这个条件变量,唤醒等待的线程。例如:

use std::sync::{Arc, Condvar, Mutex};
use std::thread;

fn main() {
    let pair = Arc::new((Mutex::new(false), Condvar::new()));
    let pair2 = Arc::clone(&pair);

    thread::spawn(move || {
        let (lock, cvar) = &*pair;
        let mut started = lock.lock().unwrap();
        *started = true;
        println!("Notifying the other thread...");
        cvar.notify_one();
    });

    let (lock, cvar) = &*pair2;
    let mut started = lock.lock().unwrap();
    while!*started {
        started = cvar.wait(started).unwrap();
    }
    println!("Condition has been met.");
}

在这个例子中,主线程等待条件变量被通知,子线程修改条件并通知条件变量。wait方法会释放锁并阻塞线程,直到收到通知。收到通知后,wait方法会重新获取锁并返回修改后的锁。

自旋锁(Spinlock)

自旋锁是一种特殊的锁,当线程尝试获取自旋锁时,如果锁不可用,线程不会进入睡眠状态,而是在循环中不断尝试获取锁,直到锁可用。自旋锁适用于锁被持有时间较短的场景。Rust标准库没有直接提供自旋锁,但我们可以使用第三方库,比如spin库来实现自旋锁。

use spin::Mutex;
use std::thread;

fn main() {
    let data = Mutex::new(0);
    let mut handles = vec![];

    for _ in 0..10 {
        let data_clone = data.clone();
        handles.push(thread::spawn(move || {
            let mut num = data_clone.lock();
            *num += 1;
        }));
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("Final value: {}", *data.lock());
}

在这个例子中,我们使用spin::Mutex作为自旋锁,它的使用方式与标准库中的Mutex类似,但内部实现为自旋锁。

原子操作

原子操作是一种不可分割的操作,在多线程环境中可以保证操作的原子性。Rust标准库提供了std::sync::atomic模块来支持原子操作。例如,AtomicI32可以用于原子地操作32位整数。

use std::sync::atomic::{AtomicI32, Ordering};
use std::thread;

fn main() {
    let num = AtomicI32::new(0);
    let mut handles = vec![];

    for _ in 0..10 {
        let num_clone = num.clone();
        handles.push(thread::spawn(move || {
            num_clone.fetch_add(1, Ordering::SeqCst);
        }));
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("Final value: {}", num.load(Ordering::SeqCst));
}

在这个例子中,fetch_add方法是一个原子操作,它会原子地将指定的值加到AtomicI32上,并返回原来的值。Ordering参数用于指定内存顺序,这里我们使用SeqCst(顺序一致性),这是最严格的内存顺序。

通过以上对Rust线程的各种使用技巧的介绍,希望能帮助你更好地在Rust项目中利用线程实现高效的并发编程。无论是简单的线程创建,还是复杂的线程间通信、数据共享和同步,Rust都提供了丰富且安全的机制来支持。在实际应用中,根据具体的需求和场景选择合适的线程使用方式,是实现高性能和可靠的并发程序的关键。