-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmulti_tqdm.py
More file actions
104 lines (80 loc) · 3.03 KB
/
multi_tqdm.py
File metadata and controls
104 lines (80 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import multiprocessing
from itertools import cycle
from tqdm.auto import tqdm as _tqdm
def _generator_multi_tqdm(func, kwargs_list):
"""Generator that injects the target function into each kwargs dict.
Args:
func (Callable): Function to be applied.
kwargs_list (list[dict]): List of kwargs dictionaries.
Yields:
dict: Modified kwargs with the function attached.
"""
for kwargs in kwargs_list:
kwargs["function"] = func
yield kwargs
def _wrapper_function_multi_tqdm(kwargs: dict):
"""Helper wrapper to extract and apply the function from kwargs.
Args:
kwargs (dict): Dictionary containing 'function' and its arguments.
Returns:
Any: Result of calling the extracted function with remaining kwargs.
"""
func = kwargs.pop("function")
return func(**kwargs)
def multi_tqdm_kwargs_list(total: int, **kwargs) -> list[dict]:
"""Generate a list of argument dictionaries for batched processing.
This cycles non-list arguments and aligns each key's value across `total` items.
Args:
total (int): Number of argument dictionaries to generate.
**kwargs: Key-value arguments, where values can be lists or scalars.
Returns:
list[dict]: List of kwargs dictionaries for each function call.
"""
kwargs_list = []
# Cycle non-list inputs
for k, v in kwargs.items():
if not isinstance(v, list):
v = [v]
kwargs[k] = cycle(v)
# Build argument dicts
for _ in range(total):
_dict = {}
for k, v in kwargs.items():
_dict[k] = next(v)
kwargs_list.append(_dict)
return kwargs_list
def multi_tqdm(
func,
kwargs_list: list[dict],
num_workers: int = 0,
ncols: int = 100,
desc: str = "",
) -> list:
"""Run a function over multiple argument sets with progress bar and optional multiprocessing.
Args:
func (Callable): The function to apply.
kwargs_list (list[dict]): List of kwargs dicts to pass to the function.
num_workers (int, optional): Number of parallel workers. Use -1 for half of CPU cores. Defaults to 0 (no parallelism).
ncols (int, optional): Width of the tqdm progress bar. Defaults to 100.
desc (str, optional): Description prefix for tqdm. Defaults to "".
Returns:
list: List of results returned by the function.
"""
outputs = []
with _tqdm(total=len(kwargs_list), ncols=ncols, desc=desc) as t:
if num_workers:
if num_workers == -1:
num_workers = max(1, os.cpu_count() // 2)
with multiprocessing.Pool(num_workers) as p:
for results in p.imap_unordered(
_wrapper_function_multi_tqdm,
_generator_multi_tqdm(func, kwargs_list),
):
outputs.append(results)
t.update()
else:
for kwargs in kwargs_list:
outputs.append(func(**kwargs))
t.update()
return outputs