From 0d617649577bf37cfcc4b3cfcbafd0687fa940c6 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 9 Apr 2025 15:17:44 -0700 Subject: [PATCH 1/4] added test script to handle large files --- scripts/test_large_db.py | 81 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 scripts/test_large_db.py diff --git a/scripts/test_large_db.py b/scripts/test_large_db.py new file mode 100644 index 0000000..2eb2e29 --- /dev/null +++ b/scripts/test_large_db.py @@ -0,0 +1,81 @@ +import sqlite3 +from pathlib import Path +import sqlalchemy as sqla +from sqlalchemy import inspect, text +from ultrack.config import MainConfig +from trackedit.arrays.UltrackArray import UltrackArray +from trackedit.widgets.HierarchyWidget import HierarchyLabels, HierarchyVizWidget +import napari +def check_database(db_path: Path): + """Check if database exists and print its structure""" + if not db_path.exists(): + print(f"Database file does not exist at: {db_path}") + return False + + print(f"Database file found at: {db_path}") + print(f"File size: {db_path.stat().st_size / (1024*1024):.2f} MB") + + # Try to connect and list tables + engine = sqla.create_engine(f"sqlite:///{db_path}") + inspector = inspect(engine) + tables = inspector.get_table_names() + print(f"Tables in database: {tables}") + + # Print column information for each table + for table_name in tables: + columns = inspector.get_columns(table_name) + print(f"\nColumns in {table_name} table:") + for column in columns: + print(f" - {column['name']}: {column['type']}") + + # Get a sample row from each table + with engine.connect() as conn: + for table_name in tables: + result = conn.execute(text(f"SELECT * FROM {table_name} LIMIT 1")).fetchone() + if result: + print(f"\nSample row from {table_name}:") + print(result) + + return True + +def initialize_config(): + working_directory = Path("/mnt/md0/Teun/data/Chromatrace/2024_08_14/") + db_filename = "data.db" + db_path = working_directory / db_filename + + # if not check_database(db_path): + # raise FileNotFoundError(f"Database not found or invalid at {db_path}") + + # import db filename properly into an Ultrack config + config_adjusted = MainConfig() + config_adjusted.data_config.working_dir = working_directory + config_adjusted.data_config.database_file_name = db_filename + return config_adjusted + +# def main(): +# config = initialize_config() +# ultrack_array = UltrackArray(config) + +# labels_layer = HierarchyLabels( +# data=ultrack_array, scale=(4,1,1), name="hierarchy" +# ) +# viewer = napari.Viewer() +# viewer.add_layer(labels_layer) +# labels_layer.refresh() +# labels_layer.mode = "pan_zoom" +# napari.run() + + +def main2(): + config = initialize_config() + viewer = napari.Viewer() + hier_widget = HierarchyVizWidget( + viewer=viewer, + scale=(4,1,1), + config=config, + ) + viewer.window.add_dock_widget(hier_widget, area="bottom") + napari.run() + +if __name__ == "__main__": + main2() From 594cdb6515120e440a8fee854e01ce9480e8b790 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 15 Apr 2025 10:50:29 -0700 Subject: [PATCH 2/4] db opens, but is very slow --- pixi.lock | 2 +- scripts/test_large_db.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pixi.lock b/pixi.lock index 50e74e9..9e45288 100644 --- a/pixi.lock +++ b/pixi.lock @@ -13284,7 +13284,7 @@ packages: - pypi: . name: trackedit version: 0.0.1 - sha256: e833dddddaa4f3566f2671abe7258a3c670746b7a5fd70d32c21b50fdc0fe68d + sha256: babaf637ee20ddf6cf2b6974e4c4e2d20ceeb7ab5eddd947e48dc54bef5c36d0 requires_dist: - mip>=1.16rc0 requires_python: '>=3.10' diff --git a/scripts/test_large_db.py b/scripts/test_large_db.py index 2eb2e29..feb9105 100644 --- a/scripts/test_large_db.py +++ b/scripts/test_large_db.py @@ -6,6 +6,7 @@ from trackedit.arrays.UltrackArray import UltrackArray from trackedit.widgets.HierarchyWidget import HierarchyLabels, HierarchyVizWidget import napari + def check_database(db_path: Path): """Check if database exists and print its structure""" if not db_path.exists(): From 4b9ba1cd9944acf264562fb8e941d4b42388a1d2 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 15 Apr 2025 11:35:56 -0700 Subject: [PATCH 3/4] added large db to demo --- scripts/demo_neuromast.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/demo_neuromast.py b/scripts/demo_neuromast.py index e82045c..8886789 100644 --- a/scripts/demo_neuromast.py +++ b/scripts/demo_neuromast.py @@ -14,8 +14,12 @@ # **********INPUTS********* # path to the working directory that contains the database file AND metadata.toml: +# working_directory = Path( +# "/home/teun.huijben/Documents/data/Akila/20241003/neuromast4_t851/adjusted/" +# ) + working_directory = Path( - "/home/teun.huijben/Documents/data/Akila/20241003/neuromast4_t851/adjusted/" + "/mnt/md0/Teun/data/Chromatrace/2024_08_14/" ) db_filename_start = "latest" # name of the database file to start from, or "latest" @@ -45,4 +49,5 @@ allow_overwrite=allow_overwrite, imaging_zarr_file=imaging_zarr_file, imaging_channel=imaging_channel, + flag_show_hierarchy=False, ) From 55067cfa6757915c9159d10dad1c3d80976d7354 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 16 Apr 2025 16:07:42 -0700 Subject: [PATCH 4/4] got large-db working by focussing on a single cell (id) and following this cell --- scripts/demo_neuromast.py | 25 ++- scripts/test_large_db.py | 89 ++++----- trackedit/DatabaseHandler.py | 180 +++++++++++++++--- trackedit/TrackEditClass.py | 1 + trackedit/arrays/DatabaseArray.py | 35 +++- .../{UltrackArray.py => HierarchyArray.py} | 50 +++-- trackedit/run.py | 4 + trackedit/utils/utils.py | 62 ++++++ trackedit/widgets/HierarchyWidget.py | 6 +- 9 files changed, 336 insertions(+), 116 deletions(-) rename trackedit/arrays/{UltrackArray.py => HierarchyArray.py} (81%) diff --git a/scripts/demo_neuromast.py b/scripts/demo_neuromast.py index 7854547..bdef0e4 100644 --- a/scripts/demo_neuromast.py +++ b/scripts/demo_neuromast.py @@ -14,21 +14,29 @@ # **********INPUTS********* # path to the working directory that contains the database file AND metadata.toml: +working_directory = Path( + "/home/teun.huijben/Documents/data/Akila/20241003/neuromast4_t851/adjusted/" +) + # working_directory = Path( -# "/home/teun.huijben/Documents/data/Akila/20241003/neuromast4_t851/adjusted/" +# "/mnt/md0/Teun/data/Chromatrace/2024_08_14/adjusted/" # ) -working_directory = Path( - "/mnt/md0/Teun/data/Chromatrace/2024_08_14/" -) # name of the database file to start from, or "latest" to start from the latest version, defaults to "data.db" -db_filename_start = "latest" +db_filename_start = "data_updated.db" # maximum number of frames display, defaults to None (use all frames) tmax = 600 # (Z),Y,X, defaults to (1, 1, 1) -scale = (2.31, 1, 1) +scale = (4, 1, 1) # overwrite existing database/changelog, defaults to False (not used when db_filename_start is "latest") allow_overwrite = False +# NEW: +work_in_existing_db = True +flag_show_hierarchy = True +# focus_id = 7000031 +focus_id = 3000020 +# focus_id = None +margin = 150 # OPTIONAL: imaging data # imaging_zarr_file = ( @@ -49,7 +57,10 @@ tmax=tmax, scale=scale, allow_overwrite=allow_overwrite, + work_in_existing_db=work_in_existing_db, imaging_zarr_file=imaging_zarr_file, imaging_channel=imaging_channel, - flag_show_hierarchy=False, + flag_show_hierarchy=flag_show_hierarchy, + focus_id=focus_id, + margin=margin, ) diff --git a/scripts/test_large_db.py b/scripts/test_large_db.py index feb9105..1433495 100644 --- a/scripts/test_large_db.py +++ b/scripts/test_large_db.py @@ -1,49 +1,50 @@ -import sqlite3 from pathlib import Path -import sqlalchemy as sqla -from sqlalchemy import inspect, text -from ultrack.config import MainConfig -from trackedit.arrays.UltrackArray import UltrackArray -from trackedit.widgets.HierarchyWidget import HierarchyLabels, HierarchyVizWidget + import napari +from ultrack.config import MainConfig + +from trackedit.widgets.HierarchyWidget import HierarchyVizWidget + +# def check_database(db_path: Path): +# """Check if database exists and print its structure""" +# if not db_path.exists(): +# print(f"Database file does not exist at: {db_path}") +# return False + +# print(f"Database file found at: {db_path}") +# print(f"File size: {db_path.stat().st_size / (1024*1024):.2f} MB") + +# # Try to connect and list tables +# engine = sqla.create_engine(f"sqlite:///{db_path}") +# inspector = inspect(engine) +# tables = inspector.get_table_names() +# print(f"Tables in database: {tables}") + +# # Print column information for each table +# for table_name in tables: +# columns = inspector.get_columns(table_name) +# print(f"\nColumns in {table_name} table:") +# for column in columns: +# print(f" - {column['name']}: {column['type']}") + +# # Get a sample row from each table +# with engine.connect() as conn: +# for table_name in tables: +# result = conn.execute( +# text(f"SELECT * FROM {table_name} LIMIT 1") +# ).fetchone() +# if result: +# print(f"\nSample row from {table_name}:") +# print(result) + +# return True -def check_database(db_path: Path): - """Check if database exists and print its structure""" - if not db_path.exists(): - print(f"Database file does not exist at: {db_path}") - return False - - print(f"Database file found at: {db_path}") - print(f"File size: {db_path.stat().st_size / (1024*1024):.2f} MB") - - # Try to connect and list tables - engine = sqla.create_engine(f"sqlite:///{db_path}") - inspector = inspect(engine) - tables = inspector.get_table_names() - print(f"Tables in database: {tables}") - - # Print column information for each table - for table_name in tables: - columns = inspector.get_columns(table_name) - print(f"\nColumns in {table_name} table:") - for column in columns: - print(f" - {column['name']}: {column['type']}") - - # Get a sample row from each table - with engine.connect() as conn: - for table_name in tables: - result = conn.execute(text(f"SELECT * FROM {table_name} LIMIT 1")).fetchone() - if result: - print(f"\nSample row from {table_name}:") - print(result) - - return True def initialize_config(): working_directory = Path("/mnt/md0/Teun/data/Chromatrace/2024_08_14/") db_filename = "data.db" - db_path = working_directory / db_filename - + # db_path = working_directory / db_filename + # if not check_database(db_path): # raise FileNotFoundError(f"Database not found or invalid at {db_path}") @@ -53,13 +54,14 @@ def initialize_config(): config_adjusted.data_config.database_file_name = db_filename return config_adjusted + # def main(): # config = initialize_config() -# ultrack_array = UltrackArray(config) - +# ultrack_array = HierarchyArray(config) + # labels_layer = HierarchyLabels( # data=ultrack_array, scale=(4,1,1), name="hierarchy" -# ) +# ) # viewer = napari.Viewer() # viewer.add_layer(labels_layer) # labels_layer.refresh() @@ -72,11 +74,12 @@ def main2(): viewer = napari.Viewer() hier_widget = HierarchyVizWidget( viewer=viewer, - scale=(4,1,1), + scale=(4, 1, 1), config=config, ) viewer.window.add_dock_widget(hier_widget, area="bottom") napari.run() + if __name__ == "__main__": main2() diff --git a/trackedit/DatabaseHandler.py b/trackedit/DatabaseHandler.py index f66f376..5270cdb 100644 --- a/trackedit/DatabaseHandler.py +++ b/trackedit/DatabaseHandler.py @@ -9,7 +9,7 @@ import pandas as pd import toml from motile_toolbox.candidate_graph import NodeAttr -from sqlalchemy import create_engine, inspect, text +from sqlalchemy import and_, create_engine, inspect, text from ultrack.config import MainConfig from ultrack.core.database import ( Column, @@ -19,12 +19,16 @@ set_node_values, ) from ultrack.core.export import tracks_layer_to_networkx, tracks_to_zarr -from ultrack.core.export.utils import solution_dataframe_from_sql + +# from ultrack.core.export.utils import solution_dataframe_from_sql from ultrack.tracks.graph import add_track_ids_to_tracks_df from trackedit.arrays.DatabaseArray import DatabaseArray from trackedit.arrays.ImagingArray import SimpleImageArray -from trackedit.utils.utils import annotations_to_zarr +from trackedit.utils.utils import ( + annotations_to_zarr, + solution_dataframe_from_sql_windowed, +) NodeDB.generic = Column(Integer, default=-1) @@ -44,6 +48,8 @@ def __init__( work_in_existing_db: bool = False, imaging_zarr_file: str = None, imaging_channel: str = None, + focus_id: int = None, + margin: int = 150, ): # inputs @@ -60,6 +66,8 @@ def __init__( self.imaging_zarr_file = imaging_zarr_file self.imaging_channel = imaging_channel self.imaging_flag = True if self.imaging_zarr_file is not None else False + self.focus_id = focus_id + self.margin = margin # Filenames / directories self.extension_string = "" @@ -121,18 +129,32 @@ def __init__( self.add_missing_columns_to_db() - # DatabaseArray() + # Filter cell in the database (Option) + self.db_filters = [] + self.df_full = self.db_to_df( + entire_database=True + ) # find_filters needs df_full to fetch tracks + if self.focus_id is not None: + self.db_filters = self.find_db_filters( + focus_id=self.focus_id, margin=self.margin + ) + else: + self.db_filters = [] + + # DatabaseArrays self.segments = DatabaseArray( database_path=self.db_path_new, shape=self.data_shape_chunk, time_window=self.time_window, color_by_field=NodeDB.id, + extra_filters=self.db_filters, ) self.annotArray = DatabaseArray( database_path=self.db_path_new, shape=self.data_shape_chunk, time_window=self.time_window, color_by_field=NodeDB.generic, + extra_filters=self.db_filters, ) self.check_zarr_existance() if self.imaging_flag: @@ -141,8 +163,8 @@ def __init__( channel=self.imaging_channel, time_window=self.time_window, ) + self.df_full = self.db_to_df(entire_database=True) - # ToDo: df_full might be very large for large datasets, but annotation/redflags/division need it self.nxgraph = self.df_to_nxgraph() self.red_flags = self.find_all_red_flags() self.toannotate = self.find_all_toannotate() @@ -218,7 +240,7 @@ def copy_database( # Check if the old database file exists if not old_db_path.exists(): raise FileNotFoundError( - f"Error: {db_filename_old} not found in the working directory." + f"Error: {db_filename_old} not found in the working directory. Search for: {old_db_path}" ) # Check if the new database file already exists @@ -344,6 +366,12 @@ def get_same_db_filename(self, old_filename): """ Generate the next version of a database filename. """ + + if old_filename == "latest": + raise ValueError( + "Cannot use 'latest' as db_filename, when 'working_in_existing_db' is True." + ) + name, ext = os.path.splitext(old_filename) old_filename = name + ext db_filename_new = old_filename @@ -443,6 +471,57 @@ def find_chunk_from_frame(self, frame): chunk = np.where(frame >= self.time_chunk_starts)[0][-1] return chunk + # def find_db_filters(self, focus_id: int, margin: int = 150): + # # ToDo: follow the cell if it moves (annoate "track" not individual cell) + # filters = [ + # NodeDB.t < 5, + # ] + # return filters + + def find_db_filters(self, focus_id: int, margin: int = 150): + """Create filters to follow a cell track and load nearby cells. + + Instead of using a single spatial filter based on one timepoint, we: + 1. First get the track_id for our focus cell + 2. Get all positions of this track over time + 3. Create a filter that combines time-specific spatial bounds + """ + # First get the track_id for our focus cell + track_id = self.df_full[self.df_full["id"] == focus_id]["track_id"].iloc[0] + + # Get all positions of this track over time + track_positions = self.df_full[self.df_full["track_id"] == track_id] + + # Create filters for each timepoint + filters = [] + for _, pos in track_positions.iterrows(): + t = pos["t"] + x = pos["x"] + y = pos["y"] + + # For each timepoint, create a compound filter + time_filter = [ + NodeDB.t == t, # specific timepoint + NodeDB.x.between(x - margin, x + margin), + NodeDB.y.between(y - margin, y + margin), + ] + + if self.ndim == 4: + z = pos["z"] + time_filter.append(NodeDB.z.between(z - margin, z + margin)) + + # Combine the conditions for this timepoint with OR + filters.append(and_(*time_filter)) + + # Combine all timepoint filters with OR + from sqlalchemy import or_ + + final_filter = or_(*filters) + + print(f"Following track {track_id} (started from cell {focus_id})") + + return [final_filter] # Return as list to maintain compatibility + def db_to_df( self, entire_database: bool = False, @@ -466,7 +545,7 @@ def db_to_df( Dataframe with columns: track_id, t, z, y, x """ - df = solution_dataframe_from_sql( + df = solution_dataframe_from_sql_windowed( self.db_path_new, columns=( NodeDB.id, @@ -477,10 +556,30 @@ def db_to_df( NodeDB.x, NodeDB.generic, ), + extra_filters=self.db_filters, ) + + def clean_parent_ids(df): + """ + Set parent_ids to -1 when parent cell does not exist + """ + # Get all valid IDs from the index + valid_ids = set(df.index) + + # Create a mask where parent_id is not in valid_ids + invalid_parents_mask = ~df["parent_id"].isin(valid_ids) + + # Set parent_id to -1 where mask is True + df.loc[invalid_parents_mask, "parent_id"] = -1 + + return df + + # TODO: is remove_past_parents_from_df still needed? + df = clean_parent_ids(df) df = add_track_ids_to_tracks_df(df) df.sort_values(by=["track_id", "t"], inplace=True) + # filter timepoints if not entire_database: if self.time_window is not None: min_time = self.time_window[0] @@ -488,8 +587,9 @@ def db_to_df( df = df[(df.t >= min_time) & (df.t < max_time)].copy() else: df = df[df.t < self.Tmax].copy() + df.loc[:, "t"] = df["t"] - self.time_window[0] - df = self.remove_past_parents_from_df(df) + # df = self.remove_past_parents_from_df(df) if self.ndim == 4: columns = ["track_id", "t", "z", "y", "x"] @@ -511,6 +611,10 @@ def db_to_df( columns.append("generic") df = df[columns] + + if df.empty: + df = pd.DataFrame(columns=columns) + return df def df_to_nxgraph(self) -> nx.DiGraph: @@ -523,6 +627,26 @@ def df_to_nxgraph(self) -> nx.DiGraph: """ # apply scale, only do this here to avoid scaling the original dataframe df_scaled = self.db_to_df() + + def clean_parent_ids(df): + """ + Set parent_ids to -1 when parent cell does not exist + """ + # Get all valid IDs from the index + valid_ids = set(df.index) + + # Create a mask where parent_id is not in valid_ids + invalid_parents_mask = ~df["parent_id"].isin(valid_ids) + + # Set parent_id to -1 where mask is True + df.loc[invalid_parents_mask, "parent_id"] = -1 + + return df + + df_scaled = clean_parent_ids(df_scaled) + + if df_scaled.empty: + return nx.DiGraph() if self.ndim == 4: df_scaled.loc[:, "z"] = df_scaled.z * self.z_scale # apply scale @@ -536,23 +660,6 @@ def df_to_nxgraph(self) -> nx.DiGraph: return nxgraph - def remove_past_parents_from_df(self, df2): - - df2.loc[:, "t"] = df2["t"] - df2["t"].min() - - # Set all parent_id values to -1 for the first time point - df2.loc[df2["t"] == 0, "parent_id"] = -1 - - # find the tracks with parents at the first time point - tracks_with_parents = df2.loc[ - (df2["t"] == 0) & (df2["parent_track_id"] != -1), "track_id" - ] - track_ids_to_update = set(tracks_with_parents) - - # update the parent_track_id to -1 for the tracks with parents at the first time point - df2.loc[df2["track_id"].isin(track_ids_to_update), "parent_track_id"] = -1 - return df2 - def find_all_red_flags(self) -> pd.DataFrame: """ Identify tracking red flags ('added' or 'removed') from one timepoint to the next. @@ -562,6 +669,10 @@ def find_all_red_flags(self) -> pd.DataFrame: pd.DataFrame DataFrame with columns: 't', 'track_id', 'id', 'event' """ + # Return empty DataFrame if df_full is empty + if self.df_full.empty: + return pd.DataFrame(columns=["t", "track_id", "id", "event"]) + df = self.df_full.copy() # Define a continuous range of timepoints. @@ -663,6 +774,10 @@ def find_all_divisions(self) -> pd.DataFrame: pd.DataFrame DataFrame with columns: 't', 'track_id', 'id' """ + # Return empty DataFrame if df_full is empty + if self.df_full.empty: + return pd.DataFrame(columns=["t", "track_id", "id", "daughters"]) + # Get all cells that have parents (parent_id != -1) cells_with_parents = self.df_full[self.df_full["parent_id"] != -1] @@ -708,6 +823,10 @@ def find_all_toannotate(self) -> pd.DataFrame: pd.DataFrame DataFrame with columns: track_id, first_t, first_id, sorted by mean appearance time """ + # Return empty DataFrame if df_full is empty + if self.df_full.empty: + return pd.DataFrame(columns=["track_id", "first_t", "first_id"]) + # Get rows with no annotations unannotated = self.df_full[ self.df_full["generic"] == NodeDB.generic.default.arg @@ -734,7 +853,7 @@ def find_all_toannotate(self) -> pd.DataFrame: return to_annotate def recompute_red_flags(self): - """called by update_red_flags in TrackEditClass upon tracks_updated signal in TracksViewer""" + """called by update_red_flags in red_flag_box.py upon tracks_updated signal in TracksViewer""" self.red_flags = self.find_all_red_flags() # Only filter if we have any red flags @@ -744,11 +863,11 @@ def recompute_red_flags(self): ] def recompute_divisions(self): - """called by update_divisions in TrackEditClass upon tracks_updated signal in TracksViewer""" + """called by update_divisions in division_box.py upon tracks_updated signal in TracksViewer""" self.divisions = self.find_all_divisions() def recompute_toannotate(self): - """called by update_toannotate in TrackEditClass upon tracks_updated signal in TracksViewer""" + """called by update_toannotate in toannotate_box.py upon tracks_updated signal in TracksViewer""" self.toannotate = self.find_all_toannotate() def ignore_red_flag(self, id): @@ -768,6 +887,11 @@ def ignore_red_flag(self, id): def export_tracks(self): """Export tracks to a CSV file""" + # Skip export if no data + if self.df_full.empty: + print("Warning: No tracks to export (empty dataset)") + return + print("exporting...") # tracks.csv diff --git a/trackedit/TrackEditClass.py b/trackedit/TrackEditClass.py index 30b69ff..545e33b 100644 --- a/trackedit/TrackEditClass.py +++ b/trackedit/TrackEditClass.py @@ -49,6 +49,7 @@ def __init__( viewer=viewer, scale=self.databasehandler.scale, config=self.databasehandler.config_adjusted, + extra_filters=self.databasehandler.db_filters, ) hier_shape = self.hier_widget.ultrack_array.shape tmax = self.databasehandler.data_shape_chunk[0] diff --git a/trackedit/arrays/DatabaseArray.py b/trackedit/arrays/DatabaseArray.py index 01bb026..c04d37f 100644 --- a/trackedit/arrays/DatabaseArray.py +++ b/trackedit/arrays/DatabaseArray.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from ultrack.core.database import NodeDB -# import traceback +from trackedit.utils.utils import apply_filters class DatabaseArray: @@ -19,15 +19,26 @@ def __init__( color_by_field: Column = NodeDB.id, dtype: np.dtype = np.int32, current_time: int = np.nan, + extra_filters: list[sqla.Column] = [], ): """Create an array that directly visualizes the segments in the ultrack database. Parameters ---------- - config : MainConfig - Configuration file of Ultrack. + database_path : Path + Path to the ultrack database. + shape : Tuple[int, ...] + Shape of the array, e.g. (t, z, y, x) + time_window : Tuple[int, int] + Time window of the array, e.g. (0, 100) + color_by_field : Column + Column to color the array by, e.g. NodeDB.id dtype : np.dtype - Data type of the array. + Data type of the array, e.g. np.int32 + current_time : int + Current time point of the array, e.g. 0 + extra_filters : list[sqla.Column] + Additional filters to apply to the query, e.g. [NodeDB.x < 300] """ self.database_path = database_path self.shape = shape @@ -35,6 +46,7 @@ def __init__( self.current_time = current_time self.time_window = time_window self.color_by_field = color_by_field + self.extra_filters = extra_filters self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) @@ -131,23 +143,26 @@ def fill_array( ---------- time : int Time point to fill the array + extra_filters : list[sqla.Column] + Additional filters to apply to the query, e.g. [NodeDB.x < 300] Returns ------- None """ + + filters = [NodeDB.t == time, NodeDB.selected] + self.extra_filters + engine = sqla.create_engine(self.database_path) self.array.fill(0) with Session(engine) as session: - query = list( - session.query(self.color_by_field, NodeDB.pickle).where( - NodeDB.t == time, NodeDB.selected - ) - ) + query = session.query(self.color_by_field, NodeDB.pickle) + query = apply_filters(query, filters) + query = list(query) if len(query) == 0: - print(f"query is empty for time {time}") + return for idx, q in enumerate(query): q[1].paint_buffer(self.array, value=q[0], include_time=False) diff --git a/trackedit/arrays/UltrackArray.py b/trackedit/arrays/HierarchyArray.py similarity index 81% rename from trackedit/arrays/UltrackArray.py rename to trackedit/arrays/HierarchyArray.py index e50e420..e7c3ac2 100644 --- a/trackedit/arrays/UltrackArray.py +++ b/trackedit/arrays/HierarchyArray.py @@ -6,13 +6,14 @@ from sqlalchemy.orm import Session from ultrack.config import MainConfig from ultrack.core.database import NodeDB +from trackedit.utils.utils import apply_filters - -class UltrackArray: +class HierarchyArray: def __init__( self, config: MainConfig, dtype: np.dtype = np.int32, + extra_filters: list[sqla.Column] = [], ): """Create an array that directly visualizes the segments in the ultrack database. @@ -22,6 +23,8 @@ def __init__( Configuration file of Ultrack. dtype : np.dtype Data type of the array. + extra_filters : list[sqla.Column] + Additional filters to apply to the query, e.g. [NodeDB.x < 300] """ self.config = config @@ -31,6 +34,7 @@ def __init__( self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.time_window = [0, self.shape[0]] + self.extra_filters = extra_filters self.database_path = config.data_config.database_path self.minmax = self.find_min_max_volume_entire_dataset() @@ -90,39 +94,33 @@ def fill_array( time point to paint the segments """ + volume = float(self.volume) if hasattr(self.volume, 'item') else self.volume + filters = [NodeDB.t == time, NodeDB.area < volume] + self.extra_filters + engine = sqla.create_engine(self.database_path) self.array.fill(0) with Session(engine) as session: - query = list( - session.query(NodeDB.id, NodeDB.pickle, NodeDB.hier_parent_id).where( - NodeDB.t == time - ) - ) + query = session.query(NodeDB.pickle, NodeDB.id, NodeDB.hier_parent_id, NodeDB.area) + query = apply_filters(query, filters) + query = list(query) - idx_to_plot = [] - - for idx, q in enumerate(query): - if q[1].area <= self.volume: - idx_to_plot.append(idx) - - id_to_plot = [q[0] for idx, q in enumerate(query) if idx in idx_to_plot] + if len(query) == 0: + return - to_remove = [] - for idx in idx_to_plot: - if query[idx][2] in id_to_plot: # if parent is also printed - to_remove.append(idx) + nodes, node_ids, parent_ids, areas = zip(*query) - for idx in to_remove: - idx_to_plot.remove(idx) + node_ids_set = set(node_ids) # faster lookup - if len(query) == 0: - print("query is empty!") + count = 0 + for i in range(len(nodes)): + # only paint top-most level of hierarchy + if parent_ids[i] not in node_ids_set: + nodes[i].paint_buffer( + self.array, value=node_ids[i], include_time=False + ) + count += 1 - for idx in idx_to_plot: - query[idx][1].paint_buffer( - self.array, value=query[idx][0], include_time=False - ) def get_tp_num_pixels( self, diff --git a/trackedit/run.py b/trackedit/run.py index 8c18c62..93a8ea1 100644 --- a/trackedit/run.py +++ b/trackedit/run.py @@ -37,6 +37,8 @@ def run_trackedit( imaging_channel: Optional[str] = None, viewer: Optional[napari.Viewer] = None, flag_show_hierarchy: bool = True, + focus_id: Optional[int] = None, + margin: int = 150, ) -> Tuple[napari.Viewer, TrackEditClass]: """Run TrackEdit on a database file. @@ -74,6 +76,8 @@ def run_trackedit( work_in_existing_db=work_in_existing_db, imaging_zarr_file=imaging_zarr_file, imaging_channel=imaging_channel, + focus_id=focus_id, + margin=margin, ) # overwrite some motile functions diff --git a/trackedit/utils/utils.py b/trackedit/utils/utils.py index f2b5d3b..a8b7a06 100644 --- a/trackedit/utils/utils.py +++ b/trackedit/utils/utils.py @@ -62,6 +62,68 @@ def wrap_default_widgets_in_tabs(viewer): viewer.window._dock_widgets["Layer List"] = new_dock +def solution_dataframe_from_sql_windowed( + database_path: str, + columns: Sequence[sqla.Column] = ( + NodeDB.id, + NodeDB.parent_id, + NodeDB.t, + NodeDB.z, + NodeDB.y, + NodeDB.x, + ), + extra_filters: list[sqla.Column] = [], +) -> pd.DataFrame: + """Query `columns` of nodes in current solution (NodeDB.selected == True). + + Parameters + ---------- + database_path : str + SQL database path (e.g. sqlite:///your.database.db) + + columns : Sequence[sqla.Column], optional + Queried columns, MUST include NodeDB.id. + By default (NodeDB.id, NodeDB.parent_id, NodeDB.t, NodeDB.z, NodeDB.y, NodeDB.x) + extra_filters: list[sqla.Column], optional + Additional filters to apply to the query, next to NodeDB.selected==True + e.g. [NodeDB.x < 300] + + Returns + ------- + pd.DataFrame + Solution dataframe indexed by NodeDB.id + """ + + filters = [NodeDB.selected] + extra_filters + + # query and convert tracking data to dataframe + engine = sqla.create_engine(database_path) + with Session(engine) as session: + query = session.query(*columns) + query = apply_filters(query, filters) + df = pd.read_sql(query.statement, session.bind, index_col="id") + return df + + +def apply_filters(query, filters): + """ + Apply multiple filters to an sqla query + + Parameters + ---------- + query: SQLAlchemy query object + filters: list of SQLAlchemy filter conditions + e.g. [NodeDB.x < 300, NodeDB.y > 300] + + Returns + ------- + SQLAlchemy query object + """ + for filter in filters: + query = query.where(filter) + return query + + @curry def _query_and_export_data_to_frame( time: int, diff --git a/trackedit/widgets/HierarchyWidget.py b/trackedit/widgets/HierarchyWidget.py index 61a780a..4b6e63a 100644 --- a/trackedit/widgets/HierarchyWidget.py +++ b/trackedit/widgets/HierarchyWidget.py @@ -3,12 +3,13 @@ import napari import numpy as np +import sqlalchemy as sqla from magicgui.widgets import Container, FloatSlider, Label from qtpy.QtCore import QObject, Signal from scipy import interpolate from ultrack.config import MainConfig -from trackedit.arrays.UltrackArray import UltrackArray +from trackedit.arrays.HierarchyArray import HierarchyArray logging.basicConfig() logging.getLogger("sqlachemy.engine").setLevel(logging.INFO) @@ -57,6 +58,7 @@ def __init__( viewer: napari.Viewer, scale: Sequence[float] = (1, 1, 1), config=None, + extra_filters: list[sqla.Column] = [], ) -> None: """ Initialize the HierarchyVizWidget. @@ -78,7 +80,7 @@ def __init__( else: self.config = config - self.ultrack_array = UltrackArray(self.config) + self.ultrack_array = HierarchyArray(self.config, extra_filters=extra_filters) self.mapping = self._create_mapping()