11"""Module for the Condition interface."""
22
3- from abc import ABCMeta
4- from torch_geometric .data import Data
5- from pina ._src .core .label_tensor import LabelTensor
6- from pina ._src .core .graph import Graph
3+ from abc import ABCMeta , abstractmethod
74
85
96class ConditionInterface (metaclass = ABCMeta ):
@@ -15,112 +12,46 @@ class ConditionInterface(metaclass=ABCMeta):
1512 description of all available conditions and how to instantiate them.
1613 """
1714
18- def __init__ (self ):
15+ @abstractmethod
16+ def __init__ (self , ** kwargs ):
1917 """
2018 Initialization of the :class:`ConditionInterface` class.
2119 """
22- self ._problem = None
2320
2421 @property
22+ @abstractmethod
2523 def problem (self ):
2624 """
2725 Return the problem associated with this condition.
2826
2927 :return: Problem associated with this condition.
3028 :rtype: ~pina.problem.abstract_problem.AbstractProblem
3129 """
32- return self ._problem
3330
3431 @problem .setter
32+ @abstractmethod
3533 def problem (self , value ):
3634 """
3735 Set the problem associated with this condition.
3836
3937 :param pina.problem.abstract_problem.AbstractProblem value: The problem
4038 to associate with this condition
4139 """
42- self ._problem = value
4340
44- @staticmethod
45- def _check_graph_list_consistency ( data_list ):
41+ @abstractmethod
42+ def __len__ ( self ):
4643 """
47- Check the consistency of the list of Data | Graph objects.
48- The following checks are performed:
44+ Return the number of data points in the condition.
4945
50- - All elements in the list must be of the same type (either
51- :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
52-
53- - All elements in the list must have the same keys.
54-
55- - The data type of each tensor must be consistent across all elements.
56-
57- - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
58- must also be consistent across all elements.
59-
60- :param data_list: The list of Data | Graph objects to check.
61- :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
62- :raises ValueError: If the input types are invalid.
63- :raises ValueError: If all elements in the list do not have the same
64- keys.
65- :raises ValueError: If the type of each tensor is not consistent across
66- all elements in the list.
67- :raises ValueError: If the labels of the LabelTensors are not consistent
68- across all elements in the list.
46+ :return: Number of data points.
47+ :rtype: int
6948 """
70- # If the data is a Graph or Data object, perform no checks
71- if isinstance (data_list , (Graph , Data )):
72- return
73-
74- # Check all elements in the list are of the same type
75- if not all (isinstance (i , (Graph , Data )) for i in data_list ):
76- raise ValueError (
77- "Invalid input. Please, provide either Data or Graph objects."
78- )
79-
80- # Store the keys, data types and labels of the first element
81- data = data_list [0 ]
82- keys = sorted (list (data .keys ()))
83- data_types = {name : tensor .__class__ for name , tensor in data .items ()}
84- labels = {
85- name : tensor .labels
86- for name , tensor in data .items ()
87- if isinstance (tensor , LabelTensor )
88- }
89-
90- # Iterate over the list of Data | Graph objects
91- for data in data_list [1 :]:
92-
93- # Check that all elements in the list have the same keys
94- if sorted (list (data .keys ())) != keys :
95- raise ValueError (
96- "All elements in the list must have the same keys."
97- )
98-
99- # Iterate over the tensors in the current element
100- for name , tensor in data .items ():
101- # Check that the type of each tensor is consistent
102- if tensor .__class__ is not data_types [name ]:
103- raise ValueError (
104- f"Data { name } must be a { data_types [name ]} , got "
105- f"{ tensor .__class__ } "
106- )
107-
108- # Check that the labels of each LabelTensor are consistent
109- if isinstance (tensor , LabelTensor ):
110- if tensor .labels != labels [name ]:
111- raise ValueError (
112- "LabelTensor must have the same labels"
113- )
11449
115- def __getattribute__ (self , name ):
50+ @abstractmethod
51+ def __getitem__ (self , idx ):
11652 """
117- Get an attribute from the object .
53+ Return the data point(s) at the specified index .
11854
119- :param str name: The name of the attribute to get.
120- :return: The requested attribute.
121- :rtype: Any
55+ :param int idx: Index of the data point(s) to retrieve.
56+ :return: Data point(s) at the specified index.
12257 """
123- to_return = super ().__getattribute__ (name )
124- if isinstance (to_return , (Graph , Data )):
125- to_return = [to_return ]
126- return to_return
0 commit comments