Vanson's Eternal Blog

Python中的多线程编程

Python thread.png
Published on
/19 mins read/---

定义

它允许程序同时运行多个线程,从而提高程序的执行效率。

然而,由于 Python 的全局解释器锁(GIL)的存在,多线程在 CPU 密集型任务中可能不会带来显著的性能提升,但在 I/O 密集型任务中效果显著。

全局解释器锁(GIL)

GIL 是 Python 解释器的一个特性,它确保同一时刻只有一个线程执行 Python 字节码。

这意味着即使在多核 CPU 上,Python 的多线程也无法实现真正的并行计算。不过,在 I/O 密集型任务中,线程可以在等待 I/O 操作时释放 GIL,从而让其他线程运行,提高程序的效率。

创建

使用 threading 模块

继承 Thread 类

 
import threading
import time
 
class MyThread(threading.Thread):
    def __init__(self, name):
        super().__init__()
        self.name = name
 
    def run(self):
        print(f"Thread {self.name} started")
        time.sleep(2)
        print(f"Thread {self.name} finished")
 
t1 = MyThread("Worker-1")
t2 = MyThread("Worker-2")
 
t1.start()
t2.start()
 
t1.join()
t2.join()
print("All threads completed")
 
 

直接实例化

 
import threading
 
# 定义一个线程任务函数
def print_numbers():
    for i in range(5):
        print(i)
 
# 创建线程
thread = threading.Thread(target=print_numbers)
 
# 启动线程
thread.start()
 
# 等待线程结束
thread.join()
 

我们定义了一个线程任务函数 print_numbers,它会打印从 0 到 4 的数字。

然后,我们使用 threading.Thread 创建了一个线程对象,并通过 start() 方法启动线程。最后,我们使用 join() 方法等待线程结束。

传递参数

如果需要向线程任务函数传递参数,可以通过 args 或 kwargs 参数来实现。

 
import threading
 
def print_message(message, times):
    for _ in range(times):
        print(message)
 
# 创建线程并传递参数
thread = threading.Thread(target=print_message, args=("Hello", 3))
 
# 启动线程
thread.start()
 
# 等待线程结束
thread.join()
 

线程的属性

  • name:线程的名称。
  • daemon:是否为守护线程。守护线程会在主线程退出时自动退出。
  • is_alive():判断线程是否在运行。
import threading
 
def print_numbers():
    for i in range(5):
        print(i)
 
thread = threading.Thread(target=print_numbers, name="MyThread")
thread.daemon = True  # 设置为守护线程
thread.start()
 
print(f"Thread name: {thread.name}")
print(f"Is thread alive: {thread.is_alive()}")
 

线程同步

在多线程程序中,多个线程可能会访问和修改共享资源,这可能会导致数据不一致的问题。为了避免这种情况,需要使用线程同步机制。

Python 提供了多种线程同步工具,如锁(Lock)、事件(Event)、条件(Condition)等。

锁(Lock)

锁是一种最简单的线程同步机制。它确保一次只有一个线程可以访问共享资源。

import threading
 
# 创建一个锁对象
lock = threading.Lock()
 
# 共享资源
counter = 0
 
def increment():
    global counter
    for _ in range(100000):
        lock.acquire()  # 获取锁
        counter += 1
        lock.release()  # 释放锁
 
# 创建两个线程
thread1 = threading.Thread(target=increment)
thread2 = threading.Thread(target=increment)
 
thread1.start()
thread2.start()
 
thread1.join()
thread2.join()
 
print(f"Counter value: {counter}")
 

定义了一个共享资源 counter,并创建了两个线程来对它进行递增操作。

为了避免多个线程同时修改 counter 导致数据不一致,我们在修改 counter 之前获取锁,修改完成后释放锁。

递归锁(RLock)

递归锁允许同一个线程多次获取锁。

import threading
 
# 创建一个递归锁对象
rlock = threading.RLock()
 
def recursive_function(n):
    if n > 0:
        rlock.acquire()
        print(f"Thread {threading.current_thread().name}: n = {n}")
        recursive_function(n - 1)
        rlock.release()
 
# 创建线程
thread = threading.Thread(target=recursive_function, args=(5,))
thread.start()
thread.join()
 

事件(Event)

事件是一种线程同步机制,用于线程之间的通信。一个线程可以设置或清除事件,其他线程可以等待事件的发生。

import threading
import time
 
# 创建一个事件对象
event = threading.Event()
 
def wait_for_event():
    print(f"Thread {threading.current_thread().name}: Waiting for event...")
    event.wait()  # 等待事件发生
    print(f"Thread {threading.current_thread().name}: Event occurred!")
 
def set_event():
    time.sleep(2)  # 模拟一些操作
    print(f"Thread {threading.current_thread().name}: Setting event...")
    event.set()  # 设置事件
 
# 创建线程
thread1 = threading.Thread(target=wait_for_event, name="Thread1")
thread2 = threading.Thread(target=set_event, name="Thread2")
 
thread1.start()
thread2.start()
 
thread1.join()
thread2.join()
 

Thread1 等待事件的发生,而 Thread2 在 2 秒后设置事件。当事件发生时,Thread1 会继续执行。

条件(Condition)

条件是一种更高级的线程同步机制,它允许线程在某些条件下等待或通知其他线程。

import threading
 
# 创建一个条件对象
condition = threading.Condition()
 
# 共享资源
item = None
 
def producer():
    global item
    with condition:
        print(f"Producer: Producing item...")
        item = "Product"
        condition.notify()  # 通知等待的线程
 
def consumer():
    global item
    with condition:
        print(f"Consumer: Waiting for item...")
        condition.wait()  # 等待条件满足
        print(f"Consumer: Consumed item: {item}")
 
# 创建线程
thread1 = threading.Thread(target=producer, name="Producer")
thread2 = threading.Thread(target=consumer, name="Consumer")
 
thread2.start()
thread1.start()
 
thread1.join()
thread2.join()
 

生产者线程 Thread1 生产一个产品并通知消费者线程 Thread2,消费者线程在等待产品可用时会阻塞,直到生产者线程通知它。

线程安全的队列

在多线程程序中,队列是一种常见的数据结构,用于在生产者和消费者之间传递数据。

Python 的 queue.Queue 是一个线程安全的队列实现,它提供了线程安全的 put() 和 get() 方法,非常适合用于多线程的生产者-消费者模型。

生产者-消费者模型

import threading
import queue
import time
import random
 
# 创建一个线程安全的队列
q = queue.Queue()
 
# 生产者线程
def producer():
    for i in range(10):
        item = random.randint(1, 100)
        q.put(item)  # 将数据放入队列
        print(f"Producer: Produced {item}")
        time.sleep(random.random())  # 模拟生产时间
 
# 消费者线程
def consumer():
    while True:
        item = q.get()  # 从队列中获取数据
        print(f"Consumer: Consumed {item}")
        q.task_done()  # 标记任务完成
        time.sleep(random.random())  # 模拟消费时间
 
# 创建生产者和消费者线程
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer, daemon=True)  # 设置为守护线程
 
# 启动线程
producer_thread.start()
consumer_thread.start()
 
# 等待生产者线程结束
producer_thread.join()
 
# 等待队列中的所有任务完成
q.join()
print("All tasks are consumed.")
 

生产者线程会生成随机数并将其放入队列,消费者线程会从队列中取出数据并消费。

queue.Queue 的 put() 和 get() 方法是线程安全的,因此我们不需要额外的锁来保护队列的访问。

线程的中断与异常处理

在多线程程序中,线程可能会因为各种原因抛出异常。为了确保程序的健壮性,我们需要正确地处理线程中的异常。

此外,我们还可以通过中断线程来优雅地关闭线程。

线程中的异常处理

import threading
 
def risky_task():
    try:
        print("Task is running...")
        raise ValueError("Something went wrong!")
    except Exception as e:
        print(f"Exception caught in thread: {e}")
 
# 创建线程
thread = threading.Thread(target=risky_task)
 
# 启动线程
thread.start()
 
# 等待线程结束
thread.join()
 

线程任务函数 risky_task 会抛出一个异常,我们在函数内部捕获并处理了这个异常。

线程的中断

在某些情况下,我们可能需要中断正在运行的线程。可以通过设置一个标志变量来实现线程的中断。

import threading
import time
 
# 中断标志
stop_event = threading.Event()
 
def long_running_task():
    while not stop_event.is_set():
        print("Task is running...")
        time.sleep(1)
    print("Task is interrupted.")
 
# 创建线程
thread = threading.Thread(target=long_running_task)
 
# 启动线程
thread.start()
 
# 模拟运行一段时间后中断线程
time.sleep(3)
stop_event.set()  # 设置中断标志
 
# 等待线程结束
thread.join()
 

定义了一个 stop_event 标志变量,线程在运行过程中会检查这个标志。当主线程调用 stop_event.set() 时,线程会检测到标志被设置并退出。

应用

线程池

线程池是一种管理线程的机制,它可以预先创建一定数量的线程,并在需要时分配给任务。

使用线程池可以避免频繁创建和销毁线程的开销。Python 的 concurrent.futures 模块提供了线程池的实现。

import concurrent.futures
 
def task(n):
    print(f"Task {n} is running")
    return n * n
 
# 创建线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    # 提交任务
    futures = [executor.submit(task, i) for i in range(5)]
 
    # 获取任务结果
    for future in concurrent.futures.as_completed(futures):
        print(f"Task result: {future.result()}")
 

创建了一个最大工作线程数为 3 的线程池,并提交了 5 个任务。

as_completed() 方法会返回一个迭代器,当任务完成时,它会生成一个 Future 对象,我们可以通过调用 result() 方法获取任务的结果。

Semaphore

信号量(Semaphore)是一种高级同步原语,用于控制同时访问共享资源的线程数量。它通过维护一个内部计数器来实现这一功能。

信号量的计数器在初始化时设置一个最大值,每次线程获取信号量时,计数器减一;每次线程释放信号量时,计数器加一。

当计数器为零时,后续尝试获取信号量的线程将被阻塞,直到其他线程释放信号量。

import threading
import time
 
# 创建一个信号量,允许最多3个线程同时访问
semaphore = threading.Semaphore(3)
 
def access_resource(worker_id):
    with semaphore:  # 使用上下文管理器自动管理信号量的获取和释放
        print(f"Worker {worker_id} accessing resource")
        time.sleep(2)  # 模拟资源访问时间
        print(f"Worker {worker_id} released resource")
 
# 创建并启动5个线程
for i in range(5):
    threading.Thread(target=access_resource, args=(i,)).start()
 

信号量的特点

  • 限制并发数量:信号量可以限制同时访问共享资源的线程数量,从而避免资源过载。
  • 线程阻塞与唤醒:当信号量的计数器为零时,后续尝试获取信号量的线程将被阻塞;当其他线程释放信号量时,阻塞的线程将被唤醒。
  • 灵活性:信号量的计数器值可以根据需要设置,允许灵活控制并发数量。

使用场景

  • 限制资源访问:当资源有限(如数据库连接池、文件句柄等)时,可以使用信号量限制同时访问资源的线程数量。
  • 控制并发数量:在多线程程序中,需要限制同时运行的线程数量,以避免系统过载。

线程局部变量

线程局部数据(Thread Local)是一种在多线程程序中非常有用的机制,它允许为每个线程分配独立的变量副本,从而避免线程之间的数据共享和竞争条件。

在 Python 中,可以通过 threading.local() 来创建线程局部数据。

import threading
 
# 创建线程局部数据
thread_local = threading.local()
 
def show_data():
    # 打印当前线程的线程局部数据
    print(f"Thread {threading.current_thread().name}: {thread_local.data}")
 
def worker(value):
    # 为当前线程设置线程局部数据
    thread_local.data = value
    # 显示当前线程的线程局部数据
    show_data()
 
# 创建并启动两个线程
threading.Thread(target=worker, args=("A",), name="Thread-1").start()
threading.Thread(target=worker, args=("B",), name="Thread-2").start()
 

线程局部数据的特点

  • 线程独立:每个线程都有自己的线程局部数据副本,线程之间不会相互影响。
  • 动态属性:线程局部数据对象可以动态地添加属性,而无需在创建时预先定义。
  • 生命周期:线程局部数据的生命周期与线程的生命周期一致。当线程结束时,线程局部数据也会被自动清理。

使用场景

  • 避免线程间的数据共享:当需要为每个线程分配独立的变量副本时,线程局部数据是一个很好的选择。例如,存储线程的用户身份信息、数据库连接对象等。
  • 简化线程同步:通过使用线程局部数据,可以避免复杂的线程同步机制,从而简化代码逻辑。例如,在多线程的 Web 服务器中,每个线程可以独立地处理请求,而无需担心数据冲突。

多线程的异步编程

Python 的 asyncio 模块提供了异步编程的支持,但它主要用于单线程的异步 I/O 操作

然而,在某些情况下,我们可能需要将多线程与异步编程结合起来,以充分利用多线程的优势。可以通过 asyncio 的 run_in_executor 方法来实现。

 
import asyncio
import concurrent.futures
import time
 
# 模拟一个耗时的 I/O 操作
def blocking_io():
    print("Start blocking_io")
    time.sleep(2)
    print("End blocking_io")
    return "Blocking IO result"
 
# 异步主函数
async def main():
    # 创建线程池
    with concurrent.futures.ThreadPoolExecutor() as pool:
        # 将耗时的 I/O 操作提交到线程池
        loop = asyncio.get_running_loop()
        result = await loop.run_in_executor(pool, blocking_io)
        print(f"Result: {result}")
 
# 运行异步主函数
asyncio.run(main())
 

定义了一个耗时的 I/O 操作函数 blocking_io,它会阻塞 2 秒。在异步主函数 main 中,我们使用 asyncio.get_running_loop().run_in_executor 方法将 blocking_io 提交给线程池执行,从而避免了阻塞主线程。

使用 C 扩展模块

C 扩展模块(如 NumPy、Pandas 等)可以在内部释放 GIL,从而允许多线程在 CPU 密集型任务中获得更好的性能。

import numpy as np
import threading
 
def compute_sum(arr):
    return np.sum(arr)
 
# 创建一个大数组
data = np.random.rand(10000000)
 
# 分割数组
split_data = np.array_split(data, 4)
 
# 创建线程
threads = []
results = [None] * 4
 
def worker(index, data):
    results[index] = compute_sum(data)
 
for i, chunk in enumerate(split_data):
    thread = threading.Thread(target=worker, args=(i, chunk))
    threads.append(thread)
    thread.start()
 
# 等待线程结束
for thread in threads:
    thread.join()
 
# 合并结果
total_sum = np.sum(results)
print(f"Total sum: {total_sum}")
 

我们使用 NumPy 的 np.sum 函数来计算数组的和。由于 NumPy 是一个 C 扩展模块,它可以在内部释放 GIL,因此多线程可以提高计算性能。

Numpy实现:

NumPy 在执行底层计算时,会暂时释放 GIL,以允许其他线程在 Python 解释器中运行。

NumPy 的核心数组操作(如 np.dot)在底层 C 代码中会释放 GIL,从而允许并行执行多个数值计算任务‌。

  • 释放 GIL:在 C 扩展函数中,使用 Py_BEGIN_ALLOW_THREADS 宏来释放 GIL。
  • 执行计算:在释放 GIL 的期间,执行计算密集型任务,这些任务通常不涉及 Python C API 的调用。
  • 重新获取 GIL:计算完成后,使用 Py_END_ALLOW_THREADS 宏重新获取 GIL。

框架中的多线程

Flask

Flask 是一个轻量级的 Web 框架,它支持多线程来处理并发请求。在 Flask 中,可以通过 app.run() 方法的 threaded 参数来启用多线程。

在 Flask 的 app.run() 方法中,threaded 参数会传递到 werkzeug.serving.run_simple() 函数中,最终影响服务器的行为

 
def run(self, host=None, port=None, debug=None, **options):
    from werkzeug.serving import run_simple
    if host is None:
        host = '127.0.0.1'
    if port is None:
        server_name = self.config['SERVER_NAME']
        if server_name and ':' in server_name:
            port = int(server_name.rsplit(':', 1)[1])
        else:
            port = 5000
    if debug is not None:
        self.debug = bool(debug)
    options.setdefault('use_reloader', self.debug)
    options.setdefault('use_debugger', self.debug)
    options.setdefault("threaded", True)  # 默认启用多线程
    try:
        run_simple(host, port, self, **options)
    finally:
        self._got_first_request = False
 

在 werkzeug.serving.run_simple() 函数中,根据 threaded 参数选择不同的服务器类。

def make_server(host=None, port=None, app=None, threaded=False, processes=1,
                request_handler=None, passthrough_errors=False,
                ssl_context=None, fd=None):
    """Create a new server instance that is either threaded, or forks
    or just processes one request after another.
    """
    if threaded and processes > 1:
        raise ValueError("cannot have a multithreaded and multi process server.")
    elif threaded:
        return ThreadedWSGIServer(host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd)
    elif processes > 1:
        return ForkingWSGIServer(host, port, app, processes, request_handler, passthrough_errors, ssl_context, fd=fd)
    else:
        return BaseWSGIServer(host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd)
 

ThreadedWSGIServer 是 Flask 用于多线程的服务器类,它继承自 socketserver.ThreadingMixIn 和 BaseWSGIServer。

import socketserver
 
class ThreadedWSGIServer(socketserver.ThreadingMixIn, BaseWSGIServer):
    """A WSGI server that does threading."""
    multithread = True
    daemon_threads = True
 

socketserver.ThreadingMixIn 是 Python 标准库中的一个混入类,它会为每个请求创建一个新的线程。

import threading
 
class ThreadingMixIn:
    """Mix-in class to handle each request in a new thread."""
    daemon_threads = False
 
    def process_request_thread(self, request, client_address):
        """Same as in BaseServer but as a thread.
        In addition, exception handling is done here.
        """
        try:
            self.finish_request(request, client_address)
            self.shutdown_request(request)
        except:
            self.handle_error(request, client_address)
            self.shutdown_request(request)
 
    def process_request(self, request, client_address):
        """Start a new thread to process the request."""
        t = threading.Thread(target=self.process_request_thread,
                             args=(request, client_address))
        t.daemon = self.daemon_threads
        t.start()