@@ -99,10 +99,13 @@ def __init__(
9999 # Move domain discretisation into conditions subsets
100100 self .problem .move_discretisation_into_conditions ()
101101
102- # Verify which splits are zero
103- self ._has_train = train_size > 0
104- self ._has_val = val_size > 0
105- self ._has_test = test_size > 0
102+ # If no splits are defined, use the default dataloaders
103+ if train_size == 0 :
104+ self .train_dataloader = super ().train_dataloader
105+ if val_size == 0 :
106+ self .val_dataloader = super ().val_dataloader
107+ if test_size == 0 :
108+ self .test_dataloader = super ().test_dataloader
106109
107110 # Otherwise, create the condition splits and initialize the creator
108111 self ._create_condition_splits (train_size , test_size )
@@ -244,14 +247,6 @@ def train_dataloader(self):
244247 dataloaders.
245248 :rtype: _Aggregator
246249 """
247- # If no training split is defined, return the default dataloader
248- if not self ._has_train :
249- return super ().train_dataloader ()
250-
251- # If the training dataloaders have not been created yet, call setup
252- if not hasattr (self , "train_datasets" ):
253- self .setup ("fit" )
254-
255250 return _Aggregator (
256251 self .creator (self .train_datasets ),
257252 batching_mode = self .batching_mode ,
@@ -265,14 +260,6 @@ def val_dataloader(self):
265260 dataloaders.
266261 :rtype: _Aggregator
267262 """
268- # If no validation split is defined, return the default dataloader
269- if not self ._has_val :
270- return super ().val_dataloader ()
271-
272- # If the validation dataloaders have not been created yet, call setup
273- if not hasattr (self , "val_datasets" ):
274- self .setup ("fit" )
275-
276263 return _Aggregator (
277264 self .creator (self .val_datasets ), batching_mode = self .batching_mode
278265 )
@@ -285,14 +272,6 @@ def test_dataloader(self):
285272 dataloaders.
286273 :rtype: _Aggregator
287274 """
288- # If no test split is defined, return the default dataloader
289- if not self ._has_test :
290- return super ().test_dataloader ()
291-
292- # If the test dataloaders have not been created yet, call setup
293- if not hasattr (self , "test_datasets" ):
294- self .setup ("test" )
295-
296275 return _Aggregator (
297276 self .creator (self .test_datasets ),
298277 batching_mode = self .batching_mode ,
0 commit comments