MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Python threading 模块的全面解析

2024-10-231.3k 阅读

Python threading 模块基础

Python 的 threading 模块是用于多线程编程的标准库,它提供了创建和管理线程的高级接口。在深入了解 threading 模块之前,我们先来回顾一下什么是线程。

线程是操作系统能够进行运算调度的最小单位。它被包含在进程之中,是进程中的实际运作单位。一个进程可以包含多个线程,这些线程共享进程的资源,如内存空间、文件描述符等。多线程编程的主要目的是提高程序的并发性能,尤其是在处理 I/O 密集型任务时,多个线程可以在等待 I/O 操作完成的同时执行其他任务,从而提高整体的运行效率。

简单线程创建

threading 模块中,创建一个线程非常简单。我们可以通过继承 threading.Thread 类或者直接使用 threading.Thread 类来创建线程。

通过继承 threading.Thread 类创建线程

import threading


class MyThread(threading.Thread):
    def run(self):
        print(f"线程 {self.name} 正在运行")


if __name__ == '__main__':
    thread = MyThread()
    thread.start()

在上述代码中,我们定义了一个 MyThread 类,它继承自 threading.Thread 类。我们重写了 run 方法,这个方法会在新线程启动时被执行。然后我们创建了 MyThread 类的实例,并调用 start 方法来启动线程。

直接使用 threading.Thread 类创建线程

import threading


def print_message():
    print(f"线程 {threading.current_thread().name} 正在运行")


if __name__ == '__main__':
    thread = threading.Thread(target=print_message)
    thread.start()

这里我们定义了一个函数 print_message,然后通过 threading.Thread 类的构造函数,将这个函数作为 target 参数传入,同样调用 start 方法启动线程。

线程属性和方法

  1. name 属性:每个线程都有一个名称,可以通过 name 属性来获取或设置线程的名称。
import threading


def print_thread_name():
    print(f"当前线程名称: {threading.current_thread().name}")


if __name__ == '__main__':
    thread = threading.Thread(target=print_thread_name, name="CustomThread")
    thread.start()
  1. is_alive() 方法:用于检查线程是否还在运行。
import threading
import time


def sleep_and_print():
    time.sleep(2)
    print("线程执行完毕")


if __name__ == '__main__':
    thread = threading.Thread(target=sleep_and_print)
    thread.start()
    print(f"线程是否存活: {thread.is_alive()}")
    time.sleep(3)
    print(f"线程是否存活: {thread.is_alive()}")
  1. join(timeout=None) 方法:调用该方法的线程会等待被调用 join 方法的线程执行完毕。timeout 参数用于指定等待的最长时间。
import threading
import time


def long_running_task():
    time.sleep(3)
    print("长时间运行任务完成")


if __name__ == '__main__':
    thread = threading.Thread(target=long_running_task)
    thread.start()
    print("主线程等待子线程完成")
    thread.join()
    print("子线程已完成,主线程继续执行")

线程同步

当多个线程共享资源时,可能会出现资源竞争的问题。例如,多个线程同时对一个全局变量进行读写操作,可能会导致数据不一致。为了解决这类问题,threading 模块提供了多种线程同步机制。

锁(Lock)

锁是最基本的线程同步工具。它只有两种状态:锁定(locked)和未锁定(unlocked)。当一个线程获取到锁时,其他线程试图获取锁就会被阻塞,直到锁被释放。

使用 Lock 示例

import threading


class Counter:
    def __init__(self):
        self.value = 0
        self.lock = threading.Lock()

    def increment(self):
        with self.lock:
            self.value += 1


def worker(counter):
    for _ in range(1000):
        counter.increment()


if __name__ == '__main__':
    counter = Counter()
    threads = []
    for _ in range(10):
        thread = threading.Thread(target=worker, args=(counter,))
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()
    print(f"最终计数器值: {counter.value}")

在上述代码中,Counter 类中有一个 lock 属性,在 increment 方法中,我们使用 with self.lock 语句来获取锁,这样在执行 self.value += 1 时,其他线程无法同时访问,避免了数据竞争。

递归锁(RLock)

递归锁(RLock)是一种特殊的锁,同一个线程可以多次获取它而不会造成死锁。每次获取锁时,锁的内部计数器会加 1,每次释放锁时,计数器会减 1,只有当计数器为 0 时,锁才会真正被释放。

使用 RLock 示例

import threading


class Resource:
    def __init__(self):
        self.rlock = threading.RLock()

    def method1(self):
        with self.rlock:
            print("进入 method1")
            self.method2()
            print("离开 method1")

    def method2(self):
        with self.rlock:
            print("进入 method2")
            print("离开 method2")


if __name__ == '__main__':
    resource = Resource()
    thread = threading.Thread(target=resource.method1)
    thread.start()

在这个例子中,method1 调用了 method2,如果使用普通的 Lock,在 method1 中获取锁后,再在 method2 中尝试获取锁会导致死锁,而 RLock 可以避免这种情况。

信号量(Semaphore)

信号量是一个计数器,它允许一定数量的线程同时访问某个资源。当一个线程获取信号量时,计数器会减 1,当线程释放信号量时,计数器会加 1。如果计数器为 0,其他线程获取信号量就会被阻塞。

使用 Semaphore 示例

import threading
import time


class Database:
    def __init__(self):
        self.semaphore = threading.Semaphore(3)

    def query(self):
        with self.semaphore:
            print(f"{threading.current_thread().name} 正在查询数据库")
            time.sleep(2)
            print(f"{threading.current_thread().name} 完成查询")


if __name__ == '__main__':
    database = Database()
    threads = []
    for i in range(5):
        thread = threading.Thread(target=database.query, name=f"Thread-{i}")
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()

在上述代码中,Semaphore(3) 表示最多允许 3 个线程同时访问数据库查询操作。

事件(Event)

事件是一种简单的线程同步机制,它允许一个线程通知其他线程发生了某个事件。事件有一个内部标志,可以通过 set() 方法设置为 True,通过 clear() 方法设置为 False。其他线程可以通过 wait() 方法等待事件的发生,当事件标志为 True 时,wait() 方法会立即返回,否则会阻塞。

使用 Event 示例

import threading
import time


def worker(event):
    print(f"{threading.current_thread().name} 等待事件")
    event.wait()
    print(f"{threading.current_thread().name} 事件已发生")


if __name__ == '__main__':
    event = threading.Event()
    threads = []
    for i in range(3):
        thread = threading.Thread(target=worker, args=(event,), name=f"Thread-{i}")
        threads.append(thread)
        thread.start()
    time.sleep(2)
    print("设置事件")
    event.set()
    for thread in threads:
        thread.join()

在这个例子中,worker 线程通过 event.wait() 等待事件发生,主线程在等待 2 秒后通过 event.set() 设置事件,从而唤醒所有等待的线程。

线程池

在实际应用中,频繁地创建和销毁线程会带来一定的开销。线程池可以解决这个问题,它预先创建一定数量的线程,这些线程可以重复使用来执行不同的任务。

concurrent.futures 模块中的线程池

虽然 threading 模块本身没有直接提供线程池的实现,但 concurrent.futures 模块提供了 ThreadPoolExecutor 类来实现线程池。

使用 ThreadPoolExecutor 示例

import concurrent.futures
import time


def task(x):
    print(f"开始任务 {x}")
    time.sleep(2)
    print(f"完成任务 {x}")
    return x * x


if __name__ == '__main__':
    with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
        future_to_x = {executor.submit(task, x): x for x in range(5)}
        for future in concurrent.futures.as_completed(future_to_x):
            x = future_to_x[future]
            try:
                result = future.result()
            except Exception as e:
                print(f"任务 {x} 抛出异常: {e}")
            else:
                print(f"任务 {x} 的结果是: {result}")

在上述代码中,我们创建了一个最大线程数为 3 的线程池 ThreadPoolExecutor。通过 submit 方法向线程池提交任务,as_completed 函数用于迭代已完成的任务,result 方法用于获取任务的返回值。

线程与 GIL(全局解释器锁)

Python 的 threading 模块虽然提供了多线程编程的能力,但由于 GIL(Global Interpreter Lock)的存在,在同一时刻,Python 解释器只能执行一个线程的字节码。这意味着在 CPU 密集型任务中,多线程并不能充分利用多核 CPU 的优势。

GIL 的设计初衷是为了简化 Python 解释器的内存管理,因为 Python 的内存管理不是线程安全的。在执行 I/O 密集型任务时,由于线程大部分时间在等待 I/O 操作完成,GIL 的影响较小,多线程可以提高程序的整体性能。但对于 CPU 密集型任务,我们可以考虑使用 multiprocessing 模块来利用多核 CPU 的优势。

CPU 密集型任务对比示例

import threading
import multiprocessing
import time


def cpu_bound_task():
    result = 0
    for i in range(100000000):
        result += i
    return result


if __name__ == '__main__':
    start_time = time.time()
    threads = []
    for _ in range(4):
        thread = threading.Thread(target=cpu_bound_task)
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()
    print(f"多线程 CPU 密集型任务耗时: {time.time() - start_time} 秒")

    start_time = time.time()
    processes = []
    for _ in range(4):
        process = multiprocessing.Process(target=cpu_bound_task)
        processes.append(process)
        process.start()
    for process in processes:
        process.join()
    print(f"多进程 CPU 密集型任务耗时: {time.time() - start_time} 秒")

在上述代码中,我们对比了使用多线程和多进程执行 CPU 密集型任务的耗时,可以明显看到多进程在多核 CPU 上的优势。

线程安全的数据结构

在多线程编程中,使用线程安全的数据结构可以避免数据竞争问题。Python 标准库提供了一些线程安全的数据结构,如 queue.Queue

queue.Queue

queue.Queue 是一个线程安全的队列,它提供了 putget 方法来向队列中添加和取出元素。在添加或取出元素时,会自动进行线程同步。

使用 queue.Queue 示例

import threading
import queue


def producer(queue):
    for i in range(10):
        queue.put(i)
        print(f"生产者放入元素: {i}")


def consumer(queue):
    while True:
        item = queue.get()
        if item is None:
            break
        print(f"消费者取出元素: {item}")
        queue.task_done()


if __name__ == '__main__':
    q = queue.Queue()
    producer_thread = threading.Thread(target=producer, args=(q,))
    consumer_thread = threading.Thread(target=consumer, args=(q,))
    producer_thread.start()
    consumer_thread.start()
    producer_thread.join()
    q.put(None)
    consumer_thread.join()

在这个例子中,生产者线程向队列中放入元素,消费者线程从队列中取出元素,queue.Queue 保证了操作的线程安全性。

高级线程应用

线程本地数据(Thread - Local Data)

在多线程编程中,有时我们希望每个线程都有自己独立的数据副本,而不是共享数据。threading.local 类可以帮助我们实现这一点。

使用 threading.local 示例

import threading


class ThreadLocalData:
    def __init__(self):
        self.local_data = threading.local()

    def set_value(self, value):
        self.local_data.value = value

    def get_value(self):
        return getattr(self.local_data, 'value', None)


def worker(thread_local_data, value):
    thread_local_data.set_value(value)
    print(f"线程 {threading.current_thread().name} 设置值: {value}")
    print(f"线程 {threading.current_thread().name} 获取值: {thread_local_data.get_value()}")


if __name__ == '__main__':
    thread_local_data = ThreadLocalData()
    threads = []
    for i in range(3):
        thread = threading.Thread(target=worker, args=(thread_local_data, i), name=f"Thread-{i}")
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()

在上述代码中,threading.local 类创建了一个线程本地数据对象 local_data,每个线程都可以独立地设置和获取 local_data.value,不会相互干扰。

守护线程(Daemon Threads)

守护线程是一种特殊的线程,当主线程结束时,所有守护线程会自动终止。守护线程通常用于执行一些后台任务,如垃圾回收、数据缓存管理等。

设置守护线程示例

import threading
import time


def background_task():
    while True:
        print(f"守护线程 {threading.current_thread().name} 正在运行")
        time.sleep(1)


if __name__ == '__main__':
    daemon_thread = threading.Thread(target=background_task)
    daemon_thread.daemon = True
    daemon_thread.start()
    time.sleep(3)
    print("主线程结束")

在这个例子中,我们将 daemon_thread 设置为守护线程,当主线程等待 3 秒结束后,守护线程也会随之终止。

通过对 threading 模块的全面解析,我们深入了解了 Python 多线程编程的各个方面,包括线程的创建、同步、线程池、GIL 相关问题以及一些高级应用。在实际开发中,合理运用多线程技术可以显著提高程序的性能和响应能力,尤其是在处理 I/O 密集型任务时。但同时也要注意线程同步和数据竞争等问题,以确保程序的正确性和稳定性。