Skip to content

Commit 28589ba

Browse files
committed
Allow list of str for algo categories as well
1 parent aa2898c commit 28589ba

File tree

3 files changed

+37
-13
lines changed

3 files changed

+37
-13
lines changed

graphdatascience/session/dedicated_sessions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@ def estimate(
2929
self,
3030
node_count: int,
3131
relationship_count: int,
32-
algorithm_categories: list[AlgorithmCategory] | None = None,
32+
algorithm_categories: list[AlgorithmCategory] | list[str] | None = None,
3333
node_label_count: int = 0,
3434
node_property_count: int = 0,
3535
relationship_property_count: int = 0,
3636
) -> SessionMemory:
3737
if algorithm_categories is None:
3838
algorithm_categories = []
39+
else:
40+
algorithm_categories = [
41+
AlgorithmCategory(cat) if isinstance(cat, str) else cat for cat in algorithm_categories
42+
]
3943
estimation = self._aura_api.estimate_size(
4044
node_count=node_count,
4145
node_label_count=node_label_count,

graphdatascience/session/gds_sessions.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ def __init__(self, api_credentials: AuraAPICredentials) -> None:
5555
"""
5656
Initializes a new instance of the GdsSessions class.
5757
58-
Args:
59-
api_credentials (AuraAPICredentials): The Aura API credentials used for establishing a connection.
58+
Parameters
59+
----------
60+
api_credentials
61+
The Aura API credentials used for establishing a connection.
6062
"""
6163
aura_env = os.environ.get("AURA_ENV")
6264
aura_api = AuraApi(
@@ -71,23 +73,34 @@ def estimate(
7173
self,
7274
node_count: int,
7375
relationship_count: int,
74-
algorithm_categories: Optional[list[AlgorithmCategory]] = None,
76+
algorithm_categories: list[AlgorithmCategory] | list[str] | None = None,
7577
node_label_count: int = 0,
7678
node_property_count: int = 0,
7779
relationship_property_count: int = 0,
7880
) -> SessionMemory:
7981
"""
8082
Estimates the memory required for a session with the given node and relationship counts.
8183
82-
Args:
83-
node_count (int): The number of nodes.
84-
relationship_count (int): The number of relationships.
85-
algorithm_categories (Optional[list[AlgorithmCategory]]): The algorithm categories to consider.
86-
node_label_count (int): The number of node labels.
87-
node_property_count (int): The number of node properties.
88-
relationship_property_count (int): The number of relationship properties.
89-
Returns:
90-
SessionMemory: The estimated memory required for the session.
84+
Parameters
85+
----------
86+
node_count
87+
Number of nodes.
88+
relationship_count
89+
Number of relationships.
90+
algorithm_categories
91+
The algorithm categories to consider.
92+
node_label_count
93+
Number of node labels.
94+
node_property_count
95+
Number of node properties.
96+
relationship_property_count
97+
Number of relationship properties.
98+
99+
100+
Returns
101+
-------
102+
SessionMemory
103+
The estimated memory required for the session.
91104
"""
92105
if algorithm_categories is None:
93106
algorithm_categories = []

graphdatascience/tests/unit/test_dedicated_sessions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,13 @@ def test_estimate_size() -> None:
779779
assert sessions.estimate(1, 1, [AlgorithmCategory.CENTRALITY]) == SessionMemory.m_8GB
780780

781781

782+
def test_estimate_str_categories_size() -> None:
783+
aura_api = FakeAuraApi(size_estimation=EstimationDetails("1GB", "8GB"))
784+
sessions = DedicatedSessions(aura_api)
785+
786+
assert sessions.estimate(1, 1, ["centrality"]) == SessionMemory.m_8GB
787+
788+
782789
def test_estimate_size_exceeds() -> None:
783790
aura_api = FakeAuraApi(size_estimation=EstimationDetails("16GB", "8GB"))
784791
sessions = DedicatedSessions(aura_api)

0 commit comments

Comments
 (0)