33import logging , math , random , time
44import multiprocessing as mp
55from collections import defaultdict
6+ from concurrent .futures import ThreadPoolExecutor , as_completed
67
78import fastremap
89import numpy as np
1516from pychunkedgraph .graph .types import empty_2d
1617from pychunkedgraph .utils .general import chunked
1718
18- from .utils import exists_as_parent , get_parent_timestamps
19+ from .utils import exists_as_parent , get_end_timestamps , get_parent_timestamps
1920
2021
2122CHILDREN = {}
@@ -51,7 +52,7 @@ def _get_cx_edges_at_timestamp(node, response, ts):
5152
5253
5354def _populate_cx_edges_with_timestamps (
54- cg : ChunkedGraph , layer : int , nodes : list , nodes_ts : list , earliest_ts
55+ cg : ChunkedGraph , layer : int , nodes : list , nodes_ts : list
5556):
5657 """
5758 Collect timestamps of edits from children, since we use the same timestamp
@@ -63,7 +64,8 @@ def _populate_cx_edges_with_timestamps(
6364 all_children = np .concatenate (list (CHILDREN .values ()))
6465 response = cg .client .read_nodes (node_ids = all_children , properties = attrs )
6566 timestamps_d = get_parent_timestamps (cg , nodes )
66- for node , node_ts in zip (nodes , nodes_ts ):
67+ end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , CHILDREN )
68+ for node , node_ts , node_end_ts in zip (nodes , nodes_ts , end_timestamps ):
6769 CX_EDGES [node ] = {}
6870 timestamps = timestamps_d [node ]
6971 cx_edges_d_node_ts = _get_cx_edges_at_timestamp (node , response , node_ts )
@@ -75,8 +77,8 @@ def _populate_cx_edges_with_timestamps(
7577 CX_EDGES [node ][node_ts ] = cx_edges_d_node_ts
7678
7779 for ts in sorted (timestamps ):
78- if ts < earliest_ts :
79- ts = earliest_ts
80+ if ts > node_end_ts :
81+ break
8082 CX_EDGES [node ][ts ] = _get_cx_edges_at_timestamp (node , response , ts )
8183
8284
@@ -107,7 +109,7 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
107109
108110 row_id = serializers .serialize_uint64 (node )
109111 for ts , cx_edges_d in CX_EDGES [node ].items ():
110- if node_ts > ts :
112+ if ts < node_ts :
111113 continue
112114 edges = get_latest_edges_wrapper (cg , cx_edges_d , parent_ts = ts )
113115 if edges .size == 0 :
@@ -129,17 +131,29 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
129131 return rows
130132
131133
134+ def _update_cross_edges_helper_thread (args ):
135+ cg , layer , node , node_ts , earliest_ts = args
136+ return update_cross_edges (cg , layer , node , node_ts , earliest_ts )
137+
138+
132139def _update_cross_edges_helper (args ):
133140 cg_info , layer , nodes , nodes_ts , earliest_ts = args
134141 rows = []
135142 cg = ChunkedGraph (** cg_info )
136143 parents = cg .get_parents (nodes , fail_to_zero = True )
144+
145+ tasks = []
137146 for node , parent , node_ts in zip (nodes , parents , nodes_ts ):
138147 if parent == 0 :
139- # invalid id caused by failed ingest task
148+ # invalid id caused by failed ingest task / edits
140149 continue
141- _rows = update_cross_edges (cg , layer , node , node_ts , earliest_ts )
142- rows .extend (_rows )
150+ tasks .append ((cg , layer , node , node_ts , earliest_ts ))
151+
152+ with ThreadPoolExecutor (max_workers = 4 ) as executor :
153+ futures = [executor .submit (_update_cross_edges_helper_thread , task ) for task in tasks ]
154+ for future in tqdm (as_completed (futures ), total = len (futures )):
155+ rows .extend (future .result ())
156+
143157 cg .client .write (rows )
144158
145159
@@ -159,7 +173,7 @@ def update_chunk(
159173 nodes = list (CHILDREN .keys ())
160174 random .shuffle (nodes )
161175 nodes_ts = cg .get_node_timestamps (nodes , return_numpy = False , normalize = True )
162- _populate_cx_edges_with_timestamps (cg , layer , nodes , nodes_ts , earliest_ts )
176+ _populate_cx_edges_with_timestamps (cg , layer , nodes , nodes_ts )
163177
164178 task_size = int (math .ceil (len (nodes ) / mp .cpu_count () / 2 ))
165179 chunked_nodes = chunked (nodes , task_size )
@@ -171,8 +185,9 @@ def update_chunk(
171185 args = (cg_info , layer , chunk , ts_chunk , earliest_ts )
172186 tasks .append (args )
173187
174- logging .info (f"Processing { len (nodes )} nodes." )
175- with mp .Pool (min (mp .cpu_count (), len (tasks ))) as pool :
188+ processes = min (mp .cpu_count () * 2 , len (tasks ))
189+ logging .info (f"Processing { len (nodes )} nodes with { processes } workers." )
190+ with mp .Pool (processes ) as pool :
176191 _ = list (
177192 tqdm (
178193 pool .imap_unordered (_update_cross_edges_helper , tasks ),
0 commit comments