Skip to content

Commit 27be9e4

Browse files
committed
feat: Introduce merge_small_segments to collapse empty graph segments and update pseudotime warnings.
1 parent 0277419 commit 27be9e4

5 files changed

Lines changed: 290 additions & 2 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__pycache__/
33
*.py[cod]
44
*$py.class
5+
.DS_Store
56

67
# C extensions
78
*.so
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Test for merge_small_segments function."""
2+
import warnings
3+
import numpy as np
4+
import scFates as scf
5+
6+
scf.settings.verbosity = 0
7+
8+
9+
def test_merge_small_segments():
10+
"""Test that merge_small_segments handles empty segments correctly."""
11+
# Load test data and build a tree
12+
adata = scf.datasets.test_adata()
13+
14+
# Build a tree with many nodes to potentially create short segments
15+
scf.tl.tree(
16+
adata,
17+
Nodes=100,
18+
use_rep="pca",
19+
method="ppt",
20+
device="cpu",
21+
ppt_sigma=1,
22+
ppt_lambda=200,
23+
seed=42,
24+
)
25+
scf.tl.cleanup(adata)
26+
scf.tl.root(adata, adata.uns["graph"]["tips"][0])
27+
28+
# Run pseudotime and check for warning about empty segments
29+
with warnings.catch_warnings(record=True) as w:
30+
warnings.simplefilter("always")
31+
scf.tl.pseudotime(adata)
32+
empty_seg_warnings = [
33+
warning for warning in w
34+
if "Some segs have no cell assigned" in str(warning.message)
35+
]
36+
37+
# If there are empty segments, test the merge function
38+
if len(empty_seg_warnings) > 0:
39+
# Get initial state
40+
initial_n_nodes = adata.uns["graph"]["B"].shape[0]
41+
42+
# Run merge_small_segments
43+
scf.tl.merge_small_segments(adata)
44+
45+
# Verify nodes were removed
46+
final_n_nodes = adata.uns["graph"]["B"].shape[0]
47+
assert final_n_nodes <= initial_n_nodes, "Nodes should be merged"
48+
49+
# Verify no empty segments remain after re-running pseudotime
50+
with warnings.catch_warnings(record=True) as w2:
51+
warnings.simplefilter("always")
52+
scf.tl.pseudotime(adata)
53+
new_warnings = [
54+
warning for warning in w2
55+
if "Some segs have no cell assigned" in str(warning.message)
56+
]
57+
assert len(new_warnings) == 0, "No empty segments should remain"
58+
else:
59+
# No empty segments - just verify the function runs without error
60+
scf.tl.merge_small_segments(adata)
61+
62+
# Verify basic graph integrity
63+
assert adata.uns["graph"]["B"].shape[0] == adata.uns["graph"]["B"].shape[1]
64+
assert adata.uns["graph"]["F"].shape[1] == adata.uns["graph"]["B"].shape[0]
65+
assert adata.obsm["X_R"].shape[1] == adata.uns["graph"]["B"].shape[0]
66+
assert "t" in adata.obs
67+
assert "seg" in adata.obs
68+
69+
70+
def test_merge_small_segments_synthetic_empty():
71+
"""Test merge with synthetically created empty segment."""
72+
adata = scf.datasets.test_adata()
73+
74+
scf.tl.tree(
75+
adata,
76+
Nodes=50,
77+
use_rep="pca",
78+
method="ppt",
79+
device="cpu",
80+
ppt_sigma=1,
81+
ppt_lambda=100,
82+
seed=1,
83+
)
84+
scf.tl.cleanup(adata)
85+
scf.tl.root(adata, adata.uns["graph"]["tips"][0])
86+
scf.tl.pseudotime(adata)
87+
88+
# Force an empty segment by reassigning cells
89+
pp_seg = adata.uns["graph"]["pp_seg"]
90+
if len(pp_seg) > 1:
91+
# Find smallest segment and remove its cells
92+
seg_counts = adata.obs.seg.value_counts()
93+
smallest_seg = seg_counts.idxmin()
94+
95+
# Move cells to another segment
96+
other_segs = seg_counts.index[seg_counts.index != smallest_seg]
97+
if len(other_segs) > 0:
98+
adata.obs.loc[adata.obs.seg == smallest_seg, "seg"] = other_segs[0]
99+
100+
# Now run merge
101+
initial_nodes = adata.uns["graph"]["B"].shape[0]
102+
scf.tl.merge_small_segments(adata)
103+
104+
# Verify merge happened
105+
assert adata.uns["graph"]["B"].shape[0] <= initial_nodes
106+
107+
108+
def test_merge_no_empty_segments():
109+
"""Test that function handles case with no empty segments gracefully."""
110+
adata = scf.datasets.test_adata()
111+
112+
scf.tl.tree(
113+
adata,
114+
Nodes=30,
115+
use_rep="pca",
116+
method="ppt",
117+
device="cpu",
118+
ppt_sigma=1,
119+
ppt_lambda=100,
120+
seed=1,
121+
)
122+
scf.tl.cleanup(adata)
123+
scf.tl.root(adata, adata.uns["graph"]["tips"][0])
124+
scf.tl.pseudotime(adata)
125+
126+
initial_nodes = adata.uns["graph"]["B"].shape[0]
127+
128+
# Should complete without error even if no empty segments
129+
scf.tl.merge_small_segments(adata)
130+
131+
# Verify graph is still valid
132+
assert adata.uns["graph"]["B"].shape[0] <= initial_nodes
133+
assert "t" in adata.obs

scFates/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
attach_tree,
66
simplify,
77
convert_to_soft,
8+
merge_small_segments,
89
)
910
from .graph_fitting import tree, curve, circle, explore_sigma
1011
from .root import root, roots

scFates/tools/graph_operations.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,157 @@ def simplify(adata: AnnData, n_nodes: int = 10, copy: bool = False):
589589
return adata if copy else None
590590

591591

592+
def merge_small_segments(adata: AnnData, copy: bool = False):
593+
"""\
594+
Merge segments that have no cells assigned into their neighboring nodes.
595+
596+
Small segments (typically between two high-degree nodes like forks) can end up
597+
with no cells assigned during pseudotime calculation. This function merges
598+
such segments by collapsing intermediate nodes into the earlier milestone,
599+
resolving downstream analysis failures.
600+
601+
This function should be called after pseudotime calculation if warnings about
602+
segments having no cells assigned are encountered.
603+
604+
Parameters
605+
----------
606+
adata
607+
Annotated data matrix with computed pseudotime.
608+
copy
609+
Return a copy instead of writing to adata.
610+
611+
Returns
612+
-------
613+
adata : anndata.AnnData
614+
if `copy=True` it returns or else updates fields in `adata`:
615+
616+
`.uns['graph']['B']`
617+
updated adjacency matrix with merged nodes.
618+
`.uns['graph']['F']`
619+
updated coordinates with merged nodes.
620+
`.obsm['X_R']`
621+
updated soft assignment matrix with merged nodes.
622+
`.uns['graph']['pp_seg']`
623+
updated segment definitions.
624+
`.uns['graph']['pp_info']`
625+
updated node information.
626+
"""
627+
628+
logg.info("merging small segments without cells", reset=True)
629+
630+
adata = adata.copy() if copy else adata
631+
632+
if "t" not in adata.obs:
633+
raise ValueError(
634+
"You need to run `tl.pseudotime` before merging small segments."
635+
)
636+
637+
graph = adata.uns["graph"]
638+
pp_seg = graph["pp_seg"].copy()
639+
pp_info = graph["pp_info"].copy()
640+
B = graph["B"].copy()
641+
F = graph["F"].copy()
642+
R = adata.obsm["X_R"].copy()
643+
644+
# Identify segments with no cells assigned
645+
cell_segs = adata.obs.seg.value_counts()
646+
all_segs = pp_seg.n.values
647+
empty_segs = [s for s in all_segs if s not in cell_segs.index]
648+
649+
if len(empty_segs) == 0:
650+
logg.info(" no empty segments found", time=False)
651+
return adata if copy else None
652+
653+
logg.info(f" found {len(empty_segs)} empty segment(s): {empty_segs}", time=False)
654+
655+
# Build graph for path finding
656+
g = igraph.Graph.Adjacency((B > 0).tolist(), mode="undirected")
657+
658+
# Track nodes to remove
659+
nodes_to_remove = set()
660+
661+
for seg_id in empty_segs:
662+
seg_row = pp_seg.loc[pp_seg.n == seg_id].iloc[0]
663+
from_node = int(seg_row["from"])
664+
to_node = int(seg_row["to"])
665+
666+
# Get all nodes in this segment
667+
seg_nodes = pp_info.index[pp_info.seg == seg_id].tolist()
668+
669+
# Identify intermediate nodes (not the from/to milestones)
670+
intermediate_nodes = [n for n in seg_nodes if n != from_node and n != to_node]
671+
672+
# The from_node is earlier in pseudotime, we keep it
673+
# Merge intermediate nodes and to_node into from_node
674+
675+
# Combine soft assignments: add R columns of removed nodes to from_node
676+
for node in intermediate_nodes + [to_node]:
677+
if node < R.shape[1]:
678+
R[:, from_node] = R[:, from_node] + R[:, node]
679+
nodes_to_remove.add(node)
680+
681+
# Update adjacency: connect from_node to to_node's neighbors
682+
if to_node < B.shape[0]:
683+
to_neighbors = np.where(B[to_node, :] > 0)[0]
684+
for neighbor in to_neighbors:
685+
if neighbor != from_node and neighbor not in nodes_to_remove:
686+
B[from_node, neighbor] = 1
687+
B[neighbor, from_node] = 1
688+
689+
if len(nodes_to_remove) == 0:
690+
logg.info(" no nodes to merge", time=False)
691+
return adata if copy else None
692+
693+
# Remove nodes from matrices
694+
nodes_to_keep = [i for i in range(B.shape[0]) if i not in nodes_to_remove]
695+
nodes_to_keep = np.array(nodes_to_keep)
696+
697+
B = B[np.ix_(nodes_to_keep, nodes_to_keep)]
698+
F = F[:, nodes_to_keep]
699+
R = R[:, nodes_to_keep]
700+
701+
# Normalize R
702+
Rsum = R.sum(axis=1)
703+
Rsum[Rsum == 0] = 1 # Avoid division by zero
704+
R = R / Rsum.reshape(-1, 1)
705+
706+
# Update graph structures
707+
g = igraph.Graph.Adjacency((B > 0).tolist(), mode="undirected")
708+
tips = np.argwhere(np.array(g.degree()) == 1).flatten()
709+
forks = np.argwhere(np.array(g.degree()) > 2).flatten()
710+
711+
# Update adata
712+
adata.uns["graph"]["B"] = B
713+
adata.uns["graph"]["F"] = F
714+
adata.uns["graph"]["tips"] = tips
715+
adata.uns["graph"]["forks"] = forks
716+
adata.obsm["X_R"] = R
717+
718+
# Create mapping from old to new indices
719+
old_to_new = {old: new for new, old in enumerate(nodes_to_keep)}
720+
721+
# Update milestones
722+
if "milestones" in graph:
723+
new_milestones = {}
724+
for name, old_idx in graph["milestones"].items():
725+
if old_idx in old_to_new:
726+
new_milestones[name] = old_to_new[old_idx]
727+
adata.uns["graph"]["milestones"] = new_milestones
728+
729+
# Update root
730+
if "root" in graph and graph["root"] in old_to_new:
731+
adata.uns["graph"]["root"] = old_to_new[graph["root"]]
732+
733+
# Recalculate pseudotime with new graph structure
734+
root(adata, adata.uns["graph"]["root"])
735+
pseudotime(adata)
736+
737+
n_removed = len(nodes_to_remove)
738+
logg.info(" finished", time=True, end=" " if settings.verbosity > 2 else "\n")
739+
logg.hint(f"merged {n_removed} nodes from {len(empty_segs)} empty segment(s)")
740+
741+
return adata if copy else None
742+
592743
def getpath(adata, root_milestone, milestones, include_root=False):
593744
"""\
594745
Obtain dataframe of cell of a given path.

scFates/tools/pseudotime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def pseudotime(
157157

158158
if allsegs.shape[1]!=allsegs_complete.shape[1]:
159159
missing = allsegs_complete.columns[~allsegs_complete.columns.isin(allsegs.columns)]
160-
message=f"Some segs have no cell assigned: {missing.tolist()}"
160+
message=f"Some segs have no cell assigned: {missing.tolist()}. " \
161+
"This may cause downstream failures. Consider running `scf.tl.merge_small_segments(adata)`."
161162
warnings.warn(message)
162163

163164
for c in allsegs.columns:
@@ -248,7 +249,8 @@ def pseudotime(
248249
adata.obs["milestones"] = milestones_str.astype("category")
249250

250251
if adata.obs["milestones"].isna().sum()>0:
251-
message = "Some cells have no milestones assigned. This is likely due to the fact that these uniquely compose a segment."
252+
message = "Some cells have no milestones assigned. This is likely due to the fact that these uniquely compose a segment. " \
253+
"Consider running `scf.tl.merge_small_segments(adata)`."
252254
warnings.warn(message)
253255

254256
adata.uns["graph"]["milestones"] = dict(

0 commit comments

Comments
 (0)