Skip to content

[WIP Branch] Forward Push Exploration#468

Draft
mkolodner-sc wants to merge 19 commits intomainfrom
mkolodner-sc/tracer_bullet_cns
Draft

[WIP Branch] Forward Push Exploration#468
mkolodner-sc wants to merge 19 commits intomainfrom
mkolodner-sc/tracer_bullet_cns

Conversation

@mkolodner-sc
Copy link
Collaborator

@mkolodner-sc mkolodner-sc commented Jan 28, 2026

Scope of work done

The bulk of the code for the Forward-push algorithm lives in updates to dist_neighbor_sampler.py, so I'd direct towards there for initial reviews.

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: NO

@mkolodner-sc mkolodner-sc marked this pull request as draft February 6, 2026 00:27
@mkolodner-sc mkolodner-sc changed the title [WIP Branch] CNS Exploration [WIP Branch] Forward Push Exploration Feb 6, 2026
Comment on lines +580 to +591
ntype_neighbor_ids = torch.full(
(batch_size, self._max_ppr_nodes),
self._default_node_id,
dtype=torch.long,
device=device,
)
ntype_weights = torch.full(
(batch_size, self._max_ppr_nodes),
self._default_weight,
dtype=torch.float,
device=device,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This output data structure is a quick placeholder for now and may be updated later.

Copy link
Collaborator

@kmontemayor2-sc kmontemayor2-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the exploration Matt! Since performance is important here, it may be useful to try and comeup with some way to test performance locally.

I'd expect that testing in a local settings is not going to be exactly equivalent to distributed setting but it should be able to give us some baseline and what knobs we can look into tuning.

Additionally, I feel like since there are a lot of operations here which are going to in the hot-loop for our sampler, we should consider migrating this code over to CPP eventually if possible. Given that we are still doing explorations rn I don't think we need to immediately, but if this turns out to be a useful tool for us I expect that we'd be able to get a large speedup with c++.

Comment on lines 343 to 355
# For homogeneous graphs, node_type is always None
p: list[dict[Tuple[int, Optional[NodeType]], float]] = [
defaultdict(float) for _ in range(batch_size)
]
# Residuals: r[i][(node_id, node_type)] = residual
r: list[dict[Tuple[int, Optional[NodeType]], float]] = [
defaultdict(float) for _ in range(batch_size)
]

# Queue stores (node_id, node_type) tuples
q: list[Set[Tuple[int, Optional[NodeType]]]] = [
set() for _ in range(batch_size)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we can't turn these into tensors? im imagine we'd be able to speed up our operations a lot if we do so.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The queue here is only used as a storage of the nodes at each iteration -- there isn't really a vectorized operation that occurs here IIUC. We use a set since this has O(1) lookup time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use a set since this has O(1) lookup time

FWIW this is what I mean by "theory" - sets are 0(1) """in theory""" but due to collisions/etc they may perform worse in practice. For instance if we look at the (I think) set source we see that lookups may in fact be linear in the worst cast

Not to mention that python is just, really, really slow. So we're already "behind" by writing this in Python.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, thanks!

offset = 0
for node_id, count in zip(nodes_list, counts_list):
cache_key = (node_id, etype)
neighbor_cache[cache_key] = neighbors_list[offset : offset + count]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw in practice how big does this dict get? Have we considered giving it some bounded size and booting out old entries?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neighbor_cache doesn't get relatively large from my experiments, we only add nodes to the neighbor_cache if the residual is large enough. The degree_cache gets large (about 20x so), but that is fine since it only contains an int of the count.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"only" an int is still 3x the memory footprint of a "proper" int64

>>> import sys
>>> sys.getsizeof(1)
28 # bytes, int64 is 8 bytes

I realize I'm harping on this stuff a lot - I think this is all fine for the tracer bullet, I just feel like if this sampling logic is the performance bottle neck then sticking with Python is not the correct approach here. And also if we can get some 3x memory improvement maybe we can afford to be less approximate (which tbd how "close" to true PPR we need to be)

Copy link
Collaborator Author

@mkolodner-sc mkolodner-sc Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the suggestions provided so far :)

Generally, from my experiments run so far, even when increasing the fanout to 10_000 per step, it seems like memory is not the bottleneck anymore here and it stays relatively small assuming we are able to converge. I'll let you know if this changes as I experiment further though.

I just feel like if this sampling logic is the performance bottle neck then sticking with Python is not the correct approach here

I agree, I think it largely depends on what sort of performance we are seeing in the settings we would care about and where the bottleneck lies (i.e whether it is network bound or not). If we can achieve reasonable training and inference using a pythonic forward-push implementation, we should consider the ROI of further improvements with C++. Otherwise, if we are seeing significant overhead here due to Python, 100% C++ is the way to go.

Comment on lines +439 to +441
for i in range(batch_size):
for u_node, u_type in nodes_to_process[i]:
ppr_total_nodes_processed += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not possible to vectorize this at all? I imagine we should be able to do at least some of this with torch operations.

Copy link
Collaborator Author

@mkolodner-sc mkolodner-sc Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've taken a pass here and the interplay between the different edge/node types, prevalence of dictionaries for caching, and need to collect some neighbors_needing_degree in this setting make it non-trivial to vectorize at the moment. I'll keep on thinking on this though and see if there are opportunities for improvement here.

Acknowledging that there are still a lot of sequential loops in play, and it'd be great if we could try to optimize this where possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I've had ok luck with throwing blocks like these at the robots and having them take a shot at vectorizing it. If we have some "correctness tests" then we can be more sure of their solutions here.

something to try

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done this as well, but haven't come up with a good vectorized approach through that strategy either.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a first pass, thanks for the work :)

Generally I think we should be trying as hard as possible to vectorize these operations, I think that'd be the quickest way for us to get perf gains.

Though really I think it'd be best if we can figure out some sort of "local" solution to benchmark the code, it's quite complex and until we can see where we get stuck I think doing optimizations purely from a theory standpoint is not going to be fruitful.

The local solution wont include network calls (sadly) but if we can benchmark those at scale and see if it's even worth doing other optimizations here (e.g. if network takes up 90% of time then we should be trying to prefetch as much as possible and not do the await bit).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I think we should be trying as hard as possible to vectorize these operations, I think that'd be the quickest way for us to get perf gains.

Agree that we should vectorize operations when possible. I am open to suggestions if you see any immediate place to do so. As of now, I believe the current number of loops is on-par with the implementations in other libraries, but agree that we should strive to continue to improve even on those.

Though really I think it'd be best if we can figure out some sort of "local" solution to benchmark the code, it's quite complex and until we can see where we get stuck I think doing optimizations purely from a theory standpoint is not going to be fruitful.

Can you elaborate by what you mean with "local"? The investigation here has not been pure theory, as I have been running jobs to profile how different changes affect the runtime and memory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate by what you mean with "local"? The investigation here has not been pure theory, as I have been running jobs to profile how different changes affect the runtime and memory.

By "local" I mean we just run this python code on one machine, so we can use cprofile or similar to identify hotspots (and equivalent tools to identify where we can optimize on memory usage).

Running the end to end jobs is good! And ultimately closer to what we care about, but it's quite time intensive to launch the jobs for every potential performance optimization so it'd be easier to run some tests on one machine (aka "locally") and then push our experiments there to the distributed setup to see if we are making progress.

@mkolodner-sc
Copy link
Collaborator Author

Additionally, I feel like since there are a lot of operations here which are going to in the hot-loop for our sampler, we should consider migrating this code over to CPP eventually if possible. Given that we are still doing explorations rn I don't think we need to immediately, but if this turns out to be a useful tool for us I expect that we'd be able to get a large speedup with c++.

Agreed, C++ will always be more performant than python and will certainly be useful. I think a large part of it will come down to how much more useful it will be and what the generally runtime/costs look like in a practical setting we care about.

After I do more profiling, if it turns out the majority of the runtime is from the network calls (which I currently believe to be the case), I don't imagine the C++ optimizations will be very large, but otherwise there is room for improvement there for sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants