48 lines
1.8 KiB
Python
48 lines
1.8 KiB
Python
import asyncio
|
||
from typing import Any, Callable, List, Coroutine
|
||
|
||
async def batch_processor(
|
||
queue: asyncio.Queue,
|
||
batch_size: int,
|
||
linger_sec: float,
|
||
execute_fn: Callable[[List[Any]], Coroutine[Any, Any, Any]] # 标注为返回协程的函数
|
||
):
|
||
"""
|
||
异步批量处理器
|
||
|
||
参数:
|
||
queue: 异步队列,从中获取待处理项
|
||
batch_size: 触发处理的批量大小
|
||
linger_sec: 最大等待时间(秒),超过此时间即使未达到batch_size也会触发处理
|
||
execute_fn: 协程处理函数,接收一个批量的数据并返回一个协程
|
||
|
||
返回:
|
||
无,但会持续运行直到队列被关闭
|
||
"""
|
||
batch = []
|
||
last_time = asyncio.get_event_loop().time()
|
||
while True:
|
||
try:
|
||
# 设置超时时间为剩余等待时间
|
||
now = asyncio.get_event_loop().time()
|
||
# print(now, last_time, linger_sec - (now - last_time))
|
||
remaining_time = max(0, linger_sec - (now - last_time))
|
||
# print(remaining_time)
|
||
item = await asyncio.wait_for(queue.get(), timeout=remaining_time)
|
||
batch.append(item)
|
||
|
||
# 检查是否达到批量大小
|
||
if len(batch) >= batch_size:
|
||
# print("batch_size_push", len(batch))
|
||
await execute_fn(batch) # 直接 await 协程
|
||
batch.clear()
|
||
last_time = asyncio.get_event_loop().time()
|
||
|
||
except asyncio.TimeoutError:
|
||
# 超时触发处理
|
||
if batch:
|
||
# print("timeout_push: ", len(batch))
|
||
await execute_fn(batch) # 直接 await 协程
|
||
batch.clear()
|
||
last_time = asyncio.get_event_loop().time()
|