From dd0e80cba9356a8a776c60950730c00fcab13aea Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Wed, 26 Nov 2025 16:46:50 +1030 Subject: [PATCH 01/16] fix: move extra arguments to init --- map2loop/project.py | 34 ++++--- map2loop/sorter.py | 211 +++++++++++++++++++++----------------------- 2 files changed, 125 insertions(+), 120 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index cf5840d1..97ac1862 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -560,7 +560,18 @@ def calculate_stratigraphic_order(self, take_best=False): ) self.contact_extractor.extract_all_contacts() if take_best: - sorters = [SorterUseHint(), SorterAgeBased(), SorterAlpha(), SorterUseNetworkX()] + sorters = [ + SorterUseHint( + unit_relationships=self.topology.get_unit_unit_relationships(), + ), + SorterAgeBased(), + SorterAlpha( + contacts=self.contact_extractor.contacts, + ), + SorterUseNetworkX( + unit_relationships=self.topology.get_unit_unit_relationships(), + ), + ] logger.info( f"Calculating best stratigraphic column from {[sorter.sorter_label for sorter in sorters]}" ) @@ -568,11 +579,6 @@ def calculate_stratigraphic_order(self, take_best=False): columns = [ sorter.sort( self.stratigraphic_column.stratigraphicUnits, - self.topology.get_unit_unit_relationships(), - self.contact_extractor.contacts, - self.map_data.get_map_data(Datatype.GEOLOGY), - self.map_data.get_map_data(Datatype.STRUCTURE), - self.map_data.get_map_data(Datatype.DTM), ) for sorter in sorters ] @@ -600,13 +606,19 @@ def calculate_stratigraphic_order(self, take_best=False): self.stratigraphic_column.column = column else: logger.info(f'Calculating stratigraphic column using sorter {self.sorter.sorter_label}') + # Update sorter with current data based on what it needs + if hasattr(self.sorter, 'unit_relationships') and self.sorter.unit_relationships is None: + self.sorter.unit_relationships = self.topology.get_unit_unit_relationships() + if hasattr(self.sorter, 'contacts') and self.sorter.contacts is None: + self.sorter.contacts = self.contact_extractor.contacts + if hasattr(self.sorter, 'geology_data') and self.sorter.geology_data is None: + self.sorter.geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + if hasattr(self.sorter, 'structure_data') and self.sorter.structure_data is None: + self.sorter.structure_data = self.map_data.get_map_data(Datatype.STRUCTURE) + if hasattr(self.sorter, 'dtm_data') and self.sorter.dtm_data is None: + self.sorter.dtm_data = self.map_data.get_map_data(Datatype.DTM) self.stratigraphic_column.column = self.sorter.sort( self.stratigraphic_column.stratigraphicUnits, - self.topology.get_unit_unit_relationships(), - self.contact_extractor.contacts, - self.map_data.get_map_data(Datatype.GEOLOGY), - self.map_data.get_map_data(Datatype.STRUCTURE), - self.map_data.get_map_data(Datatype.DTM), ) @beartype.beartype diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 0428ace6..5a76bcf8 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,7 +3,7 @@ import pandas import numpy as np import math -from typing import Union +from typing import Union, Optional from osgeo import gdal import geopandas @@ -21,11 +21,31 @@ class Sorter(ABC): ABC (ABC): Derived from Abstract Base Class """ - def __init__(self): + def __init__( + self, + *, + unit_relationships: Optional[pandas.DataFrame] = None, + contacts: Optional[pandas.DataFrame] = None, + geology_data: Optional[geopandas.GeoDataFrame] = None, + structure_data: Optional[geopandas.GeoDataFrame] = None, + dtm_data: Optional[gdal.Dataset] = None, + ): """ Initialiser of for Sorter + + Args: + unit_relationships (pandas.DataFrame): the relationships between units (columns must contain ["Index1", "Unitname1", "Index2", "Unitname2"]) + contacts (pandas.DataFrame): unit contacts with length of the contacts in metres + geology_data (geopandas.GeoDataFrame): the geology data + structure_data (geopandas.GeoDataFrame): the structure data + dtm_data (gdal.Dataset): the dtm data """ self.sorter_label = "SorterBaseClass" + self.unit_relationships = unit_relationships + self.contacts = contacts + self.geology_data = geology_data + self.structure_data = structure_data + self.dtm_data = dtm_data def type(self): """ @@ -38,25 +58,12 @@ def type(self): @beartype.beartype @abstractmethod - def sort( - self, - units: pandas.DataFrame, - unit_relationships: pandas.DataFrame, - contacts: pandas.DataFrame, - geology_data: geopandas.GeoDataFrame = None, - structure_data: geopandas.GeoDataFrame = None, - dtm_data: gdal.Dataset = None, - ) -> list: + def sort(self, units: pandas.DataFrame) -> list: """ Execute sorter method (abstract method) Args: units (pandas.DataFrame): the data frame to sort (columns must contain ["layerId", "name", "minAge", "maxAge", "group"]) - units_relationships (pandas.DataFrame): the relationships between units (columns must contain ["Index1", "Unitname1", "Index2", "Unitname2"]) - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres - geology_data (geopandas.GeoDataFrame): the geology data - structure_data (geopandas.GeoDataFrame): the structure data - dtm_data (ggdal.Dataset): the dtm data Returns: list: sorted list of unit names @@ -69,29 +76,27 @@ class SorterUseNetworkX(Sorter): Sorter class which returns a sorted list of units based on the unit relationships using a topological graph sorting algorithm """ - def __init__(self): + def __init__( + self, + *, + unit_relationships: Optional[pandas.DataFrame] = None, + ): """ Initialiser for networkx graph sorter + + Args: + unit_relationships (pandas.DataFrame): the relationships between units """ + super().__init__(unit_relationships=unit_relationships) self.sorter_label = "SorterUseNetworkX" @beartype.beartype - def sort( - self, - units: pandas.DataFrame, - unit_relationships: pandas.DataFrame, - contacts: pandas.DataFrame, - geology_data: geopandas.GeoDataFrame = None, - structure_data: geopandas.GeoDataFrame = None, - dtm_data: gdal.Dataset = None, - ) -> list: + def sort(self, units: pandas.DataFrame) -> list: """ - Execute sorter method takes unit data, relationships and a hint and returns the sorted unit names based on this algorithm. + Execute sorter method takes unit data and returns the sorted unit names based on this algorithm. Args: units (pandas.DataFrame): the data frame to sort - units_relationships (pandas.DataFrame): the relationships between units - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres Returns: list: the sorted unit names @@ -103,7 +108,7 @@ def sort( for row in units.iterrows(): graph.add_node(int(row[1]["layerId"]), name=row[1]["name"]) name_to_index[row[1]["name"]] = int(row[1]["layerId"]) - for row in unit_relationships.iterrows(): + for row in self.unit_relationships.iterrows(): graph.add_edge(name_to_index[row[1]["UNITNAME_1"]], name_to_index[row[1]["UNITNAME_2"]]) cycles = list(nx.simple_cycles(graph)) @@ -124,12 +129,18 @@ def sort( class SorterUseHint(SorterUseNetworkX): - def __init__(self): + def __init__( + self, + *, + unit_relationships: Optional[pandas.DataFrame] = None, + ): logger.info( "SorterUseHint is deprecated in v3.2. Use SorterUseNetworkX instead" ) - super().__init__() - + super().__init__(unit_relationships=unit_relationships) + def sort(self, units: pandas.DataFrame) -> list: + raise NotImplementedError("SorterUseHint is deprecated in v3.2. Use SorterUseNetworkX instead") + class SorterAgeBased(Sorter): """ @@ -140,25 +151,15 @@ def __init__(self): """ Initialiser for age based sorter """ + super().__init__() self.sorter_label = "SorterAgeBased" - def sort( - self, - units: pandas.DataFrame, - unit_relationships: pandas.DataFrame, - contacts: pandas.DataFrame, - geology_data: geopandas.GeoDataFrame = None, - structure_data: geopandas.GeoDataFrame = None, - dtm_data: gdal.Dataset = None, - ) -> list: + def sort(self, units: pandas.DataFrame) -> list: """ - Execute sorter method takes unit data, relationships and a hint and returns the sorted unit names based on this algorithm. + Execute sorter method takes unit data and returns the sorted unit names based on this algorithm. Args: units (pandas.DataFrame): the data frame to sort - units_relationships (pandas.DataFrame): the relationships between units - stratigraphic_order_hint (list): a list of unit names to use as a hint to sorting the units - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres Returns: list: the sorted unit names @@ -189,44 +190,41 @@ class SorterAlpha(Sorter): prioritising the units with lower number of contacting units """ - def __init__(self): + def __init__( + self, + *, + contacts: Optional[pandas.DataFrame] = None, + ): """ Initialiser for adjacency based sorter + + Args: + contacts (pandas.DataFrame): unit contacts with length of the contacts in metres """ + super().__init__(contacts=contacts) self.sorter_label = "SorterAlpha" - def sort( - self, - units: pandas.DataFrame, - unit_relationships: pandas.DataFrame, - contacts: pandas.DataFrame, - geology_data: geopandas.GeoDataFrame = None, - structure_data: geopandas.GeoDataFrame = None, - dtm_data: gdal.Dataset = None, - ) -> list: + def sort(self, units: pandas.DataFrame) -> list: """ - Execute sorter method takes unit data, relationships and a hint and returns the sorted unit names based on this algorithm. + Execute sorter method takes unit data and returns the sorted unit names based on this algorithm. Args: units (pandas.DataFrame): the data frame to sort - units_relationships (pandas.DataFrame): the relationships between units - stratigraphic_order_hint (list): a list of unit names to use as a hint to sorting the units - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres Returns: list: the sorted unit names """ import networkx as nx - contacts = contacts.sort_values(by="length", ascending=False)[ + sorted_contacts = self.contacts.sort_values(by="length", ascending=False)[ ["UNITNAME_1", "UNITNAME_2", "length"] ] - units = list(units["name"].unique()) + unit_names = list(units["name"].unique()) graph = nx.Graph() - for unit in units: + for unit in unit_names: graph.add_node(unit, name=unit) - max_weight = max(list(contacts["length"])) + 1 - for _, row in contacts.iterrows(): + max_weight = max(list(sorted_contacts["length"])) + 1 + for _, row in sorted_contacts.iterrows(): graph.add_edge( row["UNITNAME_1"], row["UNITNAME_2"], weight=int(max_weight - row["length"]) ) @@ -273,34 +271,30 @@ class SorterMaximiseContacts(Sorter): prioritising the maximum length of each contact """ - def __init__(self): + def __init__( + self, + *, + contacts: Optional[pandas.DataFrame] = None, + ): """ Initialiser for adjacency based sorter - + + Args: + contacts (pandas.DataFrame): unit contacts with length of the contacts in metres """ + super().__init__(contacts=contacts) self.sorter_label = "SorterMaximiseContacts" # variables for visualising/interrogating the sorter self.graph = None self.route = None self.directed_graph = None - def sort( - self, - units: pandas.DataFrame, - unit_relationships: pandas.DataFrame, - contacts: pandas.DataFrame, - geology_data: geopandas.GeoDataFrame = None, - structure_data: geopandas.GeoDataFrame = None, - dtm_data: gdal.Dataset = None, - ) -> list: + def sort(self, units: pandas.DataFrame) -> list: """ - Execute sorter method takes unit data, relationships and a hint and returns the sorted unit names based on this algorithm. + Execute sorter method takes unit data and returns the sorted unit names based on this algorithm. Args: units (pandas.DataFrame): the data frame to sort - units_relationships (pandas.DataFrame): the relationships between units - stratigraphic_order_hint (list): a list of unit names to use as a hint to sorting the units - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres Returns: list: the sorted unit names @@ -308,15 +302,15 @@ def sort( import networkx as nx import networkx.algorithms.approximation as nx_app - sorted_contacts = contacts.sort_values(by="length", ascending=False) + sorted_contacts = self.contacts.sort_values(by="length", ascending=False) self.graph = nx.Graph() - units = list(units["name"].unique()) - for unit in units: + unit_names = list(units["name"].unique()) + for unit in unit_names: ## some units may not have any contacts e.g. if they are intrusives or sills. If we leave this then the ## sorter crashes if ( - unit not in sorted_contacts['UNITNAME_1'] - or unit not in sorted_contacts['UNITNAME_2'] + unit not in sorted_contacts['UNITNAME_1'].values + and unit not in sorted_contacts['UNITNAME_2'].values ): continue self.graph.add_node(unit, name=unit) @@ -356,37 +350,36 @@ class SorterObservationProjections(Sorter): using the direction of observations to predict which unit is adjacent to the current one """ - def __init__(self, length: Union[float, int] = 1000): + def __init__( + self, + *, + contacts: Optional[pandas.DataFrame] = None, + geology_data: Optional[geopandas.GeoDataFrame] = None, + structure_data: Optional[geopandas.GeoDataFrame] = None, + dtm_data: Optional[gdal.Dataset] = None, + length: Union[float, int] = 1000 + ): """ Initialiser for adjacency based sorter Args: + contacts (pandas.DataFrame): unit contacts with length of the contacts in metres + geology_data (geopandas.GeoDataFrame): the geology data + structure_data (geopandas.GeoDataFrame): the structure data + dtm_data (gdal.Dataset): the dtm data length (int): the length of the projection in metres """ + super().__init__(contacts=contacts, geology_data=geology_data, structure_data=structure_data, dtm_data=dtm_data) self.sorter_label = "SorterObservationProjections" self.length = length self.lines = [] - def sort( - self, - units: pandas.DataFrame, - unit_relationships: pandas.DataFrame, - contacts: pandas.DataFrame, - geology_data: geopandas.GeoDataFrame, - structure_data: geopandas.GeoDataFrame, - dtm_data: gdal.Dataset - ) -> list: + def sort(self, units: pandas.DataFrame) -> list: """ - Execute sorter method takes unit data, relationships and a hint and returns the sorted unit names based on this algorithm. + Execute sorter method takes unit data and returns the sorted unit names based on this algorithm. Args: units (pandas.DataFrame): the data frame to sort - units_relationships (pandas.DataFrame): the relationships between units - stratigraphic_order_hint (list): a list of unit names to use as a hint to sorting the units - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres - geology_data (geopandas.GeoDataFrame): the geology data - structure_data (geopandas.GeoDataFrame): the structure data - dtm_data (ggdal.Dataset): the dtm data Returns: list: the sorted unit names @@ -395,14 +388,14 @@ def sort( import networkx.algorithms.approximation as nx_app from shapely.geometry import LineString, Point - geol = geology_data.copy() + geol = self.geology_data.copy() if "INTRUSIVE" in geol.columns: geol = geol.drop(geol.index[geol["INTRUSIVE"]]) if "SILL" in geol.columns: geol = geol.drop(geol.index[geol["SILL"]]) - orientations = structure_data.copy() - inv_geotransform = gdal.InvGeoTransform(dtm_data.GetGeoTransform()) - dtm_array = np.array(dtm_data.GetRasterBand(1).ReadAsArray().T) + orientations = self.structure_data.copy() + inv_geotransform = gdal.InvGeoTransform(self.dtm_data.GetGeoTransform()) + dtm_array = np.array(self.dtm_data.GetRasterBand(1).ReadAsArray().T) # Create a map of maps to store younger/older observations ordered_unit_observations = [] @@ -504,9 +497,9 @@ def sort( g_undirected = g.to_undirected() for unit in unit_names: if len(list(g_undirected.neighbors(unit))) < 1: - mask1 = contacts["UNITNAME_1"] == unit - mask2 = contacts["UNITNAME_2"] == unit - for _, row in contacts[mask1 | mask2].iterrows(): + mask1 = self.contacts["UNITNAME_1"] == unit + mask2 = self.contacts["UNITNAME_2"] == unit + for _, row in self.contacts[mask1 | mask2].iterrows(): if unit == row["UNITNAME_1"]: g.add_edge(row["UNITNAME_2"], unit, weight=max_value * 10) else: From 9c5b2734651d2ce485a2ee7ff96a912021ca7f69 Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Wed, 26 Nov 2025 16:56:09 +1030 Subject: [PATCH 02/16] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- map2loop/sorter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 5a76bcf8..20d683e1 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -31,7 +31,7 @@ def __init__( dtm_data: Optional[gdal.Dataset] = None, ): """ - Initialiser of for Sorter + Initialiser for Sorter Args: unit_relationships (pandas.DataFrame): the relationships between units (columns must contain ["Index1", "Unitname1", "Index2", "Unitname2"]) @@ -138,6 +138,7 @@ def __init__( "SorterUseHint is deprecated in v3.2. Use SorterUseNetworkX instead" ) super().__init__(unit_relationships=unit_relationships) + @beartype.beartype def sort(self, units: pandas.DataFrame) -> list: raise NotImplementedError("SorterUseHint is deprecated in v3.2. Use SorterUseNetworkX instead") @@ -214,6 +215,8 @@ def sort(self, units: pandas.DataFrame) -> list: Returns: list: the sorted unit names """ + if self.contacts is None: + raise ValueError("contacts must be set (not None) before calling sort() in SorterAlpha.") import networkx as nx sorted_contacts = self.contacts.sort_values(by="length", ascending=False)[ @@ -393,7 +396,11 @@ def sort(self, units: pandas.DataFrame) -> list: geol = geol.drop(geol.index[geol["INTRUSIVE"]]) if "SILL" in geol.columns: geol = geol.drop(geol.index[geol["SILL"]]) + if self.structure_data is None: + raise ValueError("structure_data is required for sorting but is None.") orientations = self.structure_data.copy() + if self.dtm_data is None: + raise ValueError("DTM data (self.dtm_data) is not set. Cannot proceed with sorting.") inv_geotransform = gdal.InvGeoTransform(self.dtm_data.GetGeoTransform()) dtm_array = np.array(self.dtm_data.GetRasterBand(1).ReadAsArray().T) From 88122af80cf69b988859f8fd8820e98c55713762 Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Thu, 27 Nov 2025 07:58:17 +1030 Subject: [PATCH 03/16] fix: resolving copilot review and adding class attribute with required arguments --- map2loop/sorter.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 5a76bcf8..589dfcf1 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -75,7 +75,7 @@ class SorterUseNetworkX(Sorter): """ Sorter class which returns a sorted list of units based on the unit relationships using a topological graph sorting algorithm """ - + required_arguments = 'unit_relationships' def __init__( self, *, @@ -102,7 +102,8 @@ def sort(self, units: pandas.DataFrame) -> list: list: the sorted unit names """ import networkx as nx - + if self.unit_relationships is None: + raise ValueError("SorterUseNetworkX requires 'unit_relationships' argument") graph = nx.DiGraph() name_to_index = {} for row in units.iterrows(): @@ -129,6 +130,7 @@ def sort(self, units: pandas.DataFrame) -> list: class SorterUseHint(SorterUseNetworkX): + required_arguments = 'unit_relationships' def __init__( self, *, @@ -146,7 +148,7 @@ class SorterAgeBased(Sorter): """ Sorter class which returns a sorted list of units based on the min and max ages of the units """ - + requried_arguments = None def __init__(self): """ Initialiser for age based sorter @@ -182,7 +184,7 @@ def sort(self, units: pandas.DataFrame) -> list: logger.info(f"{row['name']} - {row['minAge']} - {row['maxAge']}") return list(sorted_units["name"]) - + class SorterAlpha(Sorter): """ @@ -215,7 +217,8 @@ def sort(self, units: pandas.DataFrame) -> list: list: the sorted unit names """ import networkx as nx - + if self.contacts is None: + raise ValueError("SorterAlpha requires 'contacts' argument") sorted_contacts = self.contacts.sort_values(by="length", ascending=False)[ ["UNITNAME_1", "UNITNAME_2", "length"] ] @@ -301,7 +304,8 @@ def sort(self, units: pandas.DataFrame) -> list: """ import networkx as nx import networkx.algorithms.approximation as nx_app - + if self.contacts is None: + raise ValueError("SorterMaximiseContacts requires 'contacts' argument") sorted_contacts = self.contacts.sort_values(by="length", ascending=False) self.graph = nx.Graph() unit_names = list(units["name"].unique()) @@ -310,7 +314,7 @@ def sort(self, units: pandas.DataFrame) -> list: ## sorter crashes if ( unit not in sorted_contacts['UNITNAME_1'].values - and unit not in sorted_contacts['UNITNAME_2'].values + or unit not in sorted_contacts['UNITNAME_2'].values ): continue self.graph.add_node(unit, name=unit) @@ -349,7 +353,7 @@ class SorterObservationProjections(Sorter): Sorter class which returns a sorted list of units based on the adjacency of units using the direction of observations to predict which unit is adjacent to the current one """ - + required_arguments = ['contacts', 'geology_data', 'structure_data', 'dtm_data'] def __init__( self, *, @@ -387,7 +391,10 @@ def sort(self, units: pandas.DataFrame) -> list: import networkx as nx import networkx.algorithms.approximation as nx_app from shapely.geometry import LineString, Point - + if self.contacts is None: + raise ValueError("SorterObservationProjections requires 'contacts' argument") + if self.geology_data is None: + raise ValueError("SorterObservationProjections requires 'geology_data' argument") geol = self.geology_data.copy() if "INTRUSIVE" in geol.columns: geol = geol.drop(geol.index[geol["INTRUSIVE"]]) From 32bfc4e2f75f0348fa5ee9774c1d7d64a3b958da Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Thu, 27 Nov 2025 08:39:48 +1030 Subject: [PATCH 04/16] fix: remove SorterUseHint call --- map2loop/project.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 97ac1862..7e05d0bf 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -561,9 +561,6 @@ def calculate_stratigraphic_order(self, take_best=False): self.contact_extractor.extract_all_contacts() if take_best: sorters = [ - SorterUseHint( - unit_relationships=self.topology.get_unit_unit_relationships(), - ), SorterAgeBased(), SorterAlpha( contacts=self.contact_extractor.contacts, From 11d01090c8147def0f7783a9a92102f701bbea19 Mon Sep 17 00:00:00 2001 From: Lachlan Grose Date: Thu, 27 Nov 2025 09:58:00 +1030 Subject: [PATCH 05/16] fix: remove arguments from ABC make all requrements a list make all arguments required --- map2loop/sorter.py | 101 +++++++++++++++++++++++---------------------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 4c187b40..4f57e6e7 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,10 +3,11 @@ import pandas import numpy as np import math -from typing import Union, Optional -from osgeo import gdal +import math +from typing import Union, Optional, List +from map2loop.topology import Topology import geopandas - +from osgeo import gdal from map2loop.utils import value_from_raster from .logging import getLogger @@ -22,14 +23,7 @@ class Sorter(ABC): """ def __init__( - self, - *, - unit_relationships: Optional[pandas.DataFrame] = None, - contacts: Optional[pandas.DataFrame] = None, - geology_data: Optional[geopandas.GeoDataFrame] = None, - structure_data: Optional[geopandas.GeoDataFrame] = None, - dtm_data: Optional[gdal.Dataset] = None, - ): + self): """ Initialiser for Sorter @@ -41,11 +35,7 @@ def __init__( dtm_data (gdal.Dataset): the dtm data """ self.sorter_label = "SorterBaseClass" - self.unit_relationships = unit_relationships - self.contacts = contacts - self.geology_data = geology_data - self.structure_data = structure_data - self.dtm_data = dtm_data + def type(self): """ @@ -75,11 +65,13 @@ class SorterUseNetworkX(Sorter): """ Sorter class which returns a sorted list of units based on the unit relationships using a topological graph sorting algorithm """ - required_arguments = 'unit_relationships' + required_arguments: List[str] = [ + 'geology_data' + ] + def __init__( self, - *, - unit_relationships: Optional[pandas.DataFrame] = None, + geology_data: geopandas.GeoDataFrame ): """ Initialiser for networkx graph sorter @@ -87,9 +79,12 @@ def __init__( Args: unit_relationships (pandas.DataFrame): the relationships between units """ - super().__init__(unit_relationships=unit_relationships) + super().__init__() self.sorter_label = "SorterUseNetworkX" - + if 'UNITNAME' not in geology_data.columns: + raise ValueError("geology_data must contain 'UNITNAME' column") + self.topology = Topology(geology_data=geology_data) + self.unit_relationships = self.topology.get_unit_unit_relationships() @beartype.beartype def sort(self, units: pandas.DataFrame) -> list: """ @@ -130,31 +125,31 @@ def sort(self, units: pandas.DataFrame) -> list: class SorterUseHint(SorterUseNetworkX): - required_arguments = 'unit_relationships' + required_arguments: List[str] = ['unit_relationships'] def __init__( self, *, - unit_relationships: Optional[pandas.DataFrame] = None, + geology_data: Optional[geopandas.GeoDataFrame] = None, ): - logger.info( - "SorterUseHint is deprecated in v3.2. Use SorterUseNetworkX instead" + logger.warning( + "SorterUseHint is deprecated in v3.2. Using SorterUseNetworkX instead" ) - super().__init__(unit_relationships=unit_relationships) - @beartype.beartype - def sort(self, units: pandas.DataFrame) -> list: - raise NotImplementedError("SorterUseHint is deprecated in v3.2. Use SorterUseNetworkX instead") + super().__init__(geology_data=geology_data) + class SorterAgeBased(Sorter): """ Sorter class which returns a sorted list of units based on the min and max ages of the units """ - requried_arguments = None - def __init__(self): + requried_arguments = ['min_age_column','max_age_column'] + def __init__(self, min_age_column:str, max_age_column:str): """ Initialiser for age based sorter """ super().__init__() + self.min_age_column = min_age_column + self.max_age_column = max_age_column self.sorter_label = "SorterAgeBased" def sort(self, units: pandas.DataFrame) -> list: @@ -169,13 +164,14 @@ def sort(self, units: pandas.DataFrame) -> list: """ logger.info("Calling age based sorter") sorted_units = units.copy() - if "minAge" in units.columns and "maxAge" in units.columns: + + if self.min_age_column in units.columns and self.max_age_column in units.columns: # print(sorted_units["minAge"], sorted_units["maxAge"]) sorted_units["meanAge"] = sorted_units.apply( - lambda row: (row["minAge"] + row["maxAge"]) / 2.0, axis=1 + lambda row: (row[self.min_age_column] + row[self.max_age_column]) / 2.0, axis=1 ) else: - sorted_units["meanAge"] = 0 + raise ValueError(f"Columns {self.min_age_column} and {self.max_age_column} must be present in units DataFrame") if "group" in units.columns: sorted_units = sorted_units.sort_values(by=["group", "meanAge"]) else: @@ -185,27 +181,29 @@ def sort(self, units: pandas.DataFrame) -> list: logger.info(f"{row['name']} - {row['minAge']} - {row['maxAge']}") return list(sorted_units["name"]) - + class SorterAlpha(Sorter): """ Sorter class which returns a sorted list of units based on the adjacency of units prioritising the units with lower number of contacting units """ - + required_arguments = ['contacts'] def __init__( self, - *, - contacts: Optional[pandas.DataFrame] = None, + contacts: geopandas.GeoDataFrame, ): """ Initialiser for adjacency based sorter Args: - contacts (pandas.DataFrame): unit contacts with length of the contacts in metres + contacts (geopandas.GeoDataFrame): unit contacts with length of the contacts in metres """ - super().__init__(contacts=contacts) + super().__init__() + self.contacts = contacts self.sorter_label = "SorterAlpha" + if 'UNITNAME_1' not in contacts.columns or 'UNITNAME_2' not in contacts.columns or 'length' not in contacts.columns: + raise ValueError("contacts GeoDataFrame must contain 'UNITNAME_1', 'UNITNAME_2' and 'length' columns") def sort(self, units: pandas.DataFrame) -> list: """ @@ -276,11 +274,10 @@ class SorterMaximiseContacts(Sorter): Sorter class which returns a sorted list of units based on the adjacency of units prioritising the maximum length of each contact """ - + required_arguments = ['contacts'] def __init__( self, - *, - contacts: Optional[pandas.DataFrame] = None, + contacts: geopandas.GeoDataFrame, ): """ Initialiser for adjacency based sorter @@ -288,12 +285,15 @@ def __init__( Args: contacts (pandas.DataFrame): unit contacts with length of the contacts in metres """ - super().__init__(contacts=contacts) + super().__init__() self.sorter_label = "SorterMaximiseContacts" # variables for visualising/interrogating the sorter self.graph = None self.route = None self.directed_graph = None + self.contacts = contacts + if 'UNITNAME_1' not in contacts.columns or 'UNITNAME_2' not in contacts.columns or 'length' not in contacts.columns: + raise ValueError("contacts GeoDataFrame must contain 'UNITNAME_1', 'UNITNAME_2' and 'length' columns") def sort(self, units: pandas.DataFrame) -> list: """ @@ -359,10 +359,9 @@ class SorterObservationProjections(Sorter): required_arguments = ['contacts', 'geology_data', 'structure_data', 'dtm_data'] def __init__( self, - *, - contacts: Optional[pandas.DataFrame] = None, - geology_data: Optional[geopandas.GeoDataFrame] = None, - structure_data: Optional[geopandas.GeoDataFrame] = None, + contacts: geopandas.GeoDataFrame, + geology_data: geopandas.GeoDataFrame, + structure_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, length: Union[float, int] = 1000 ): @@ -376,7 +375,11 @@ def __init__( dtm_data (gdal.Dataset): the dtm data length (int): the length of the projection in metres """ - super().__init__(contacts=contacts, geology_data=geology_data, structure_data=structure_data, dtm_data=dtm_data) + super().__init__() + self.contacts = contacts + self.geology_data = geology_data + self.structure_data = structure_data + self.dtm_data = dtm_data self.sorter_label = "SorterObservationProjections" self.length = length self.lines = [] From 5328fef0432ccc27788fee078d9dc761326fa36c Mon Sep 17 00:00:00 2001 From: lachlangrose <7371904+lachlangrose@users.noreply.github.com> Date: Wed, 26 Nov 2025 23:28:29 +0000 Subject: [PATCH 06/16] style: style fixes by ruff and autoformatting by black --- map2loop/sorter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 4f57e6e7..412bd671 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,7 +3,6 @@ import pandas import numpy as np import math -import math from typing import Union, Optional, List from map2loop.topology import Topology import geopandas From 1568248cf491593530cd53361bf6c120a659a803 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Thu, 27 Nov 2025 11:03:57 +0930 Subject: [PATCH 07/16] fix: add check for geology_data --- map2loop/sorter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 412bd671..f9bb8894 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -80,10 +80,16 @@ def __init__( """ super().__init__() self.sorter_label = "SorterUseNetworkX" + + if isinstance(geology_data, geopandas.GeoDataFrame) is False: + raise TypeError("geology_data must be a geopandas.GeoDataFrame") + if 'UNITNAME' not in geology_data.columns: raise ValueError("geology_data must contain 'UNITNAME' column") + self.topology = Topology(geology_data=geology_data) self.unit_relationships = self.topology.get_unit_unit_relationships() + @beartype.beartype def sort(self, units: pandas.DataFrame) -> list: """ From 3ea4a85d8cb74e453ff9a442810ff084c56ae720 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Thu, 27 Nov 2025 12:20:28 +0930 Subject: [PATCH 08/16] fix: not initiliase sorter in project.init --- map2loop/project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/map2loop/project.py b/map2loop/project.py index 7e05d0bf..ec7f10f5 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -140,7 +140,7 @@ def __init__( self.set_default_samplers() self.bounding_box = bounding_box self.contact_extractor = None - self.sorter = SorterUseHint() + self.sorter = SorterUseHint self.throw_calculator = ThrowCalculatorAlpha() self.fault_orientation = FaultOrientationNearest() self.map_data = MapData(verbose_level=verbose_level) From 4327c2a74b75fd346c0fb8d49eb69cbdf4d49057 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Thu, 27 Nov 2025 12:30:26 +0930 Subject: [PATCH 09/16] fix: try use SorterUseNetworkX --- tests/project/test_plot_hamersley.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/project/test_plot_hamersley.py b/tests/project/test_plot_hamersley.py index 504c4585..1cb9f4f6 100644 --- a/tests/project/test_plot_hamersley.py +++ b/tests/project/test_plot_hamersley.py @@ -1,5 +1,6 @@ import pytest from map2loop.project import Project +from map2loop.sorter import SorterUseNetworkX from map2loop.m2l_enums import VerboseLevel from unittest.mock import patch from pyproj.exceptions import CRSError @@ -36,6 +37,7 @@ def test_project_execution(): except Exception: pytest.skip("Skipping the project test from server data due to loading failure") try: + proj.set_sorter(SorterUseNetworkX) proj.run_all(take_best=True) except requests.exceptions.ReadTimeout: pytest.skip( From acf926992b12c27e88bb86057cffa20bfb9b06f7 Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Mon, 1 Dec 2025 14:06:07 +1100 Subject: [PATCH 10/16] fix: adding utils module and fixing tests --- map2loop/__init__.py | 6 +---- map2loop/logging.py | 18 +++++++++------ map2loop/utils/__init__.py | 18 +++++++++++++++ map2loop/utils/load_map2loop_data.py | 22 +++++++++++++++++++ .../{utils.py => utils/utility_functions.py} | 2 +- tests/project/test_plot_hamersley.py | 2 +- tests/project/test_thickness_calculations.py | 4 ++-- .../test_interpolated_structure.py | 2 +- .../test_ThicknessStructuralPoint.py | 2 +- .../test_ThicknessCalculatorAlpha.py | 2 +- 10 files changed, 59 insertions(+), 19 deletions(-) create mode 100644 map2loop/utils/__init__.py create mode 100644 map2loop/utils/load_map2loop_data.py rename map2loop/{utils.py => utils/utility_functions.py} (99%) diff --git a/map2loop/__init__.py b/map2loop/__init__.py index 8723f4ef..66a14c9a 100644 --- a/map2loop/__init__.py +++ b/map2loop/__init__.py @@ -1,10 +1,6 @@ import logging +from map2loop.logging import loggers, ch -loggers = {} -ch = logging.StreamHandler() -formatter = logging.Formatter("%(levelname)s: %(asctime)s: %(filename)s:%(lineno)d -- %(message)s") -ch.setFormatter(formatter) -ch.setLevel(logging.WARNING) from .project import Project from .version import __version__ diff --git a/map2loop/logging.py b/map2loop/logging.py index 2daaa0c2..816ac45e 100644 --- a/map2loop/logging.py +++ b/map2loop/logging.py @@ -1,7 +1,11 @@ import logging -import map2loop +loggers = {} +ch = ch = logging.StreamHandler() +formatter = logging.Formatter("%(levelname)s: %(asctime)s: %(filename)s:%(lineno)d -- %(message)s") +ch.setFormatter(formatter) +ch.setLevel(logging.WARNING) def get_levels(): """dict for converting to logger levels from string @@ -33,12 +37,12 @@ def getLogger(name: str): logging.Logger logger object """ - if name in map2loop.loggers: - return map2loop.loggers[name] + if name in loggers: + return loggers[name] logger = logging.getLogger(name) - logger.addHandler(map2loop.ch) + logger.addHandler(ch) logger.propagate = False - map2loop.loggers[name] = logger + loggers[name] = logger return logger @@ -56,9 +60,9 @@ def set_level(level: str): """ levels = get_levels() level = levels.get(level, logging.WARNING) - map2loop.ch.setLevel(level) + ch.setLevel(level) - for name in map2loop.loggers: + for name in loggers: logger = logging.getLogger(name) logger.setLevel(level) logger.info(f"Logging level set to {level}") diff --git a/map2loop/utils/__init__.py b/map2loop/utils/__init__.py new file mode 100644 index 00000000..c7fa49d6 --- /dev/null +++ b/map2loop/utils/__init__.py @@ -0,0 +1,18 @@ +from .utility_functions import ( + set_z_values_from_raster_df, + value_from_raster, + update_from_legacy_file, + preprocess_hjson_to_json, + read_hjson_with_json, + calculate_endpoints, + calculate_minimum_fault_length, + hex_to_rgb, + generate_random_hex_colors, + rebuild_sampled_basal_contacts, + multiline_to_line, + find_segment_strike_from_pt, + create_points, + normal_vector_to_dipdirection_dip, + strike_dip_vector, + generate_grid, +) diff --git a/map2loop/utils/load_map2loop_data.py b/map2loop/utils/load_map2loop_data.py new file mode 100644 index 00000000..c64cd5fc --- /dev/null +++ b/map2loop/utils/load_map2loop_data.py @@ -0,0 +1,22 @@ +import geopandas +import map2loop +import pathlib +from osgeo import gdal +gdal.UseExceptions() + +# Use the path of this file to locate the datasets directory +def map2loop_dir(folder) -> pathlib.Path: + path = pathlib.Path(__file__).parent.parent.parent / 'map2loop' / '_datasets' / 'geodata_files' / f'{folder}' + return path + +def load_hamersley_geology(): + path = map2loop_dir('hamersley') / "geology.geojson" + return geopandas.read_file(str(path)) + +def load_hamersley_structure(): + path = map2loop_dir('hamersley') / "structure.geojson" + return geopandas.read_file(str(path)) + +def load_hamersley_dtm(): + path = map2loop_dir('hamersley') / "dtm_rp.tif" + return gdal.Open(str(path)) diff --git a/map2loop/utils.py b/map2loop/utils/utility_functions.py similarity index 99% rename from map2loop/utils.py rename to map2loop/utils/utility_functions.py index 55e2e7b2..7c864fc0 100644 --- a/map2loop/utils.py +++ b/map2loop/utils/utility_functions.py @@ -9,7 +9,7 @@ import json from osgeo import gdal -from .logging import getLogger +from ..logging import getLogger logger = getLogger(__name__) diff --git a/tests/project/test_plot_hamersley.py b/tests/project/test_plot_hamersley.py index 1cb9f4f6..dadbb7d7 100644 --- a/tests/project/test_plot_hamersley.py +++ b/tests/project/test_plot_hamersley.py @@ -37,7 +37,7 @@ def test_project_execution(): except Exception: pytest.skip("Skipping the project test from server data due to loading failure") try: - proj.set_sorter(SorterUseNetworkX) + proj.set_sorter(SorterUseNetworkX()) proj.run_all(take_best=True) except requests.exceptions.ReadTimeout: pytest.skip( diff --git a/tests/project/test_thickness_calculations.py b/tests/project/test_thickness_calculations.py index 373687ed..961e1a9c 100644 --- a/tests/project/test_thickness_calculations.py +++ b/tests/project/test_thickness_calculations.py @@ -3,11 +3,10 @@ import geopandas import numpy -from map2loop._datasets.geodata_files import load_map2loop_data +from map2loop.utils import load_map2loop_data from map2loop.thickness_calculator import InterpolatedStructure, StructuralPoint from map2loop import Project - # 1. self.stratigraphic_column.stratigraphicUnits, st_units = pandas.DataFrame( { @@ -1705,6 +1704,7 @@ def test_calculate_unit_thicknesses(): ], "Default for thickness calculator not set" ## default is InterpolatedStructure # check set + print("****",StructuralPoint.__module__) project.set_thickness_calculator([StructuralPoint(dtm_data=dtm, bounding_box=bbox_3d), InterpolatedStructure(dtm_data=dtm, bounding_box=bbox_3d)]) assert project.get_thickness_calculator() == [ diff --git a/tests/thickness/InterpolatedStructure/test_interpolated_structure.py b/tests/thickness/InterpolatedStructure/test_interpolated_structure.py index 164ce2f7..bbda4e35 100644 --- a/tests/thickness/InterpolatedStructure/test_interpolated_structure.py +++ b/tests/thickness/InterpolatedStructure/test_interpolated_structure.py @@ -4,7 +4,7 @@ from map2loop.mapdata import MapData from map2loop.thickness_calculator import InterpolatedStructure -from map2loop._datasets.geodata_files.load_map2loop_data import ( +from map2loop.utils.load_map2loop_data import ( load_hamersley_geology, load_hamersley_dtm, ) diff --git a/tests/thickness/StructurePoint/test_ThicknessStructuralPoint.py b/tests/thickness/StructurePoint/test_ThicknessStructuralPoint.py index 74c7851d..0fbc1185 100644 --- a/tests/thickness/StructurePoint/test_ThicknessStructuralPoint.py +++ b/tests/thickness/StructurePoint/test_ThicknessStructuralPoint.py @@ -4,7 +4,7 @@ from map2loop.mapdata import MapData from map2loop.thickness_calculator import StructuralPoint -from map2loop._datasets.geodata_files.load_map2loop_data import load_hamersley_geology +from map2loop.utils.load_map2loop_data import load_hamersley_geology from map2loop.m2l_enums import Datatype #################################################################### diff --git a/tests/thickness/ThicknessCalculatorAlpha/test_ThicknessCalculatorAlpha.py b/tests/thickness/ThicknessCalculatorAlpha/test_ThicknessCalculatorAlpha.py index 521b633f..ba740a18 100644 --- a/tests/thickness/ThicknessCalculatorAlpha/test_ThicknessCalculatorAlpha.py +++ b/tests/thickness/ThicknessCalculatorAlpha/test_ThicknessCalculatorAlpha.py @@ -5,7 +5,7 @@ from map2loop.mapdata import MapData from map2loop.m2l_enums import Datatype from map2loop.thickness_calculator import ThicknessCalculatorAlpha -from map2loop._datasets.geodata_files.load_map2loop_data import load_hamersley_geology +from map2loop.utils.load_map2loop_data import load_hamersley_geology ######################################################### From e85954a9cc459ecbaa05a76257cc5c486cf5e6a8 Mon Sep 17 00:00:00 2001 From: lachlangrose <7371904+lachlangrose@users.noreply.github.com> Date: Mon, 1 Dec 2025 03:06:35 +0000 Subject: [PATCH 11/16] style: style fixes by ruff and autoformatting by black --- map2loop/utils/load_map2loop_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/map2loop/utils/load_map2loop_data.py b/map2loop/utils/load_map2loop_data.py index c64cd5fc..6cedb18b 100644 --- a/map2loop/utils/load_map2loop_data.py +++ b/map2loop/utils/load_map2loop_data.py @@ -1,5 +1,4 @@ import geopandas -import map2loop import pathlib from osgeo import gdal gdal.UseExceptions() From bb5d5c0dda8dbe6addecadd88072de0173ca45ca Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Mon, 1 Dec 2025 14:35:22 +1100 Subject: [PATCH 12/16] fix: sorter arguments optional. Allow project to update attributes from map data --- map2loop/sorter.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index f9bb8894..39bac5fe 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -70,7 +70,8 @@ class SorterUseNetworkX(Sorter): def __init__( self, - geology_data: geopandas.GeoDataFrame + *, + geology_data: Optional[geopandas.GeoDataFrame] = None, ): """ Initialiser for networkx graph sorter @@ -80,7 +81,22 @@ def __init__( """ super().__init__() self.sorter_label = "SorterUseNetworkX" - + if geology_data is not None: + self.set_geology_data(geology_data) + else: + self.unit_relationships = None + def set_geology_data(self, geology_data: geopandas.GeoDataFrame): + """ + Set geology data and calculate topology and unit relationships + + Args: + geology_data (geopandas.GeoDataFrame): the geology data + """ + self._calculate_topology(geology_data) + def _calculate_topology(self, geology_data: geopandas.GeoDataFrame): + if not geology_data: + raise ValueError("geology_data is required") + if isinstance(geology_data, geopandas.GeoDataFrame) is False: raise TypeError("geology_data must be a geopandas.GeoDataFrame") @@ -89,7 +105,7 @@ def __init__( self.topology = Topology(geology_data=geology_data) self.unit_relationships = self.topology.get_unit_unit_relationships() - + @beartype.beartype def sort(self, units: pandas.DataFrame) -> list: """ @@ -147,8 +163,8 @@ class SorterAgeBased(Sorter): """ Sorter class which returns a sorted list of units based on the min and max ages of the units """ - requried_arguments = ['min_age_column','max_age_column'] - def __init__(self, min_age_column:str, max_age_column:str): + required_arguments = ['min_age_column','max_age_column'] + def __init__(self,*, min_age_column:Optional[str], max_age_column:Optional[str]): """ Initialiser for age based sorter """ @@ -196,7 +212,8 @@ class SorterAlpha(Sorter): required_arguments = ['contacts'] def __init__( self, - contacts: geopandas.GeoDataFrame, + *, + contacts: Optional[geopandas.GeoDataFrame] = None, ): """ Initialiser for adjacency based sorter @@ -282,7 +299,8 @@ class SorterMaximiseContacts(Sorter): required_arguments = ['contacts'] def __init__( self, - contacts: geopandas.GeoDataFrame, + *, + contacts: Optional[geopandas.GeoDataFrame] = None, ): """ Initialiser for adjacency based sorter @@ -364,9 +382,10 @@ class SorterObservationProjections(Sorter): required_arguments = ['contacts', 'geology_data', 'structure_data', 'dtm_data'] def __init__( self, - contacts: geopandas.GeoDataFrame, - geology_data: geopandas.GeoDataFrame, - structure_data: geopandas.GeoDataFrame, + *, + contacts: Optional[geopandas.GeoDataFrame] = None, + geology_data: Optional[geopandas.GeoDataFrame] = None, + structure_data: Optional[geopandas.GeoDataFrame] = None, dtm_data: Optional[gdal.Dataset] = None, length: Union[float, int] = 1000 ): From 313e2f672c7701cfb9d0de36171872b551562535 Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Mon, 1 Dec 2025 14:40:23 +1100 Subject: [PATCH 13/16] fix: minage/maxage optional --- map2loop/sorter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 39bac5fe..397526a2 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -164,7 +164,7 @@ class SorterAgeBased(Sorter): Sorter class which returns a sorted list of units based on the min and max ages of the units """ required_arguments = ['min_age_column','max_age_column'] - def __init__(self,*, min_age_column:Optional[str], max_age_column:Optional[str]): + def __init__(self,*, min_age_column:Optional[str]=None, max_age_column:Optional[str]=None): """ Initialiser for age based sorter """ From 02bcae35a6667dbd9ae514de04e92e68c4e68ba4 Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Mon, 1 Dec 2025 14:44:59 +1100 Subject: [PATCH 14/16] fix: add unit relations as optional argument --- map2loop/sorter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 397526a2..2779a64c 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -71,6 +71,7 @@ class SorterUseNetworkX(Sorter): def __init__( self, *, + unit_relationships: Optional[pandas.DataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None, ): """ @@ -83,6 +84,8 @@ def __init__( self.sorter_label = "SorterUseNetworkX" if geology_data is not None: self.set_geology_data(geology_data) + elif unit_relationships is not None: + self.unit_relationships = unit_relationships else: self.unit_relationships = None def set_geology_data(self, geology_data: geopandas.GeoDataFrame): From d4d5f75a090ff9a178833573ca3c161643361c5e Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Mon, 1 Dec 2025 15:04:57 +1100 Subject: [PATCH 15/16] fix: default min/max age --- map2loop/sorter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 2779a64c..e01b4fc5 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -167,7 +167,7 @@ class SorterAgeBased(Sorter): Sorter class which returns a sorted list of units based on the min and max ages of the units """ required_arguments = ['min_age_column','max_age_column'] - def __init__(self,*, min_age_column:Optional[str]=None, max_age_column:Optional[str]=None): + def __init__(self,*, min_age_column:Optional[str]='MIN_AGE', max_age_column:Optional[str]='MAX_AGE'): """ Initialiser for age based sorter """ From a46bef487d47eb3fa345566e4c45afb1b9ca5194 Mon Sep 17 00:00:00 2001 From: lachlangrose Date: Mon, 1 Dec 2025 15:12:12 +1100 Subject: [PATCH 16/16] fix: wrong defaults. --- map2loop/sorter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index e01b4fc5..7d1a7672 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -167,7 +167,7 @@ class SorterAgeBased(Sorter): Sorter class which returns a sorted list of units based on the min and max ages of the units """ required_arguments = ['min_age_column','max_age_column'] - def __init__(self,*, min_age_column:Optional[str]='MIN_AGE', max_age_column:Optional[str]='MAX_AGE'): + def __init__(self,*, min_age_column:Optional[str]='minAge', max_age_column:Optional[str]='maxAge'): """ Initialiser for age based sorter """ @@ -195,6 +195,8 @@ def sort(self, units: pandas.DataFrame) -> list: lambda row: (row[self.min_age_column] + row[self.max_age_column]) / 2.0, axis=1 ) else: + logger.error(f"Columns {self.min_age_column} and {self.max_age_column} must be present in units DataFrame") + logger.error(f"Available columns are: {units.columns.tolist()}") raise ValueError(f"Columns {self.min_age_column} and {self.max_age_column} must be present in units DataFrame") if "group" in units.columns: sorted_units = sorted_units.sort_values(by=["group", "meanAge"])