@@ -632,6 +632,81 @@ def predict_ms(
632632
633633 return predicted_flags
634634
635+ def predict_ms_per_baseline (
636+ self ,
637+ ms_path ,
638+ num_antennas = None ,
639+ patch_size = 128 ,
640+ stretch = "SQRT" ,
641+ save_flags = True ,
642+ normalize_before_stretch = False ,
643+ normalize_after_stretch = False ,
644+ threshold = None ,
645+ field_id = None ,
646+ ):
647+ """
648+ Predict on MS one baseline at a time - low memory usage.
649+
650+ Args:
651+ ms_path: Path to measurement set
652+ num_antennas: Number of antennas (None = all)
653+ patch_size: Patch size for prediction
654+ stretch: Stretch function ('SQRT', 'LOG10', or None)
655+ save_flags: If True, save flags back to MS
656+ normalize_before_stretch: Normalize before stretch
657+ normalize_after_stretch: Normalize after stretch
658+ threshold: Probability threshold for RFI detection
659+ field_id: Optional FIELD_ID to load
660+
661+ Returns:
662+ None (flags saved to MS if save_flags=True)
663+ """
664+ logger .info (f"\n { '=' * 60 } " )
665+ logger .info ("RFI Prediction - Per Baseline" )
666+ logger .info (f"{ '=' * 60 } " )
667+
668+ # Validate preprocessing parameters
669+ self ._validate_preprocessing_params (
670+ patch_size , stretch , normalize_before_stretch , normalize_after_stretch
671+ )
672+
673+ # Open MS
674+ loader = MSLoader (ms_path , field_id = field_id )
675+ baseline_pairs = loader .get_baseline_pairs (num_antennas )
676+
677+ logger .info (f"\n Processing { len (baseline_pairs )} baselines" )
678+ logger .info (f" Patch size: { patch_size } " )
679+ logger .info (" Memory: ~1 baseline in RAM at a time" )
680+
681+ # Process each baseline
682+ for ant1 , ant2 in tqdm (baseline_pairs , desc = "Baselines" ):
683+ # Load one baseline
684+ baseline_data = loader .load_baseline (ant1 , ant2 , mode = "DATA" , field_id = field_id )
685+
686+ # Add baseline dimension
687+ baseline_data = baseline_data [np .newaxis , ...] # (1, pols, channels, times)
688+
689+ # Predict
690+ baseline_flags = self .predict_array (
691+ baseline_data ,
692+ patch_size = patch_size ,
693+ stretch = stretch ,
694+ normalize_before_stretch = normalize_before_stretch ,
695+ normalize_after_stretch = normalize_after_stretch ,
696+ return_probabilities = False ,
697+ threshold = threshold ,
698+ )[
699+ 0
700+ ] # Remove baseline dimension
701+
702+ # Write flags
703+ if save_flags :
704+ loader .save_baseline_flags (ant1 , ant2 , baseline_flags , field_id = field_id )
705+
706+ logger .info (f"\n { '=' * 60 } " )
707+ logger .info ("✓ Prediction complete" )
708+ logger .info (f"{ '=' * 60 } " )
709+
635710 def predict_iterative (
636711 self ,
637712 ms_path ,
@@ -761,6 +836,92 @@ def predict_iterative(
761836
762837 return cumulative_flags
763838
839+ def predict_iterative_per_baseline (
840+ self ,
841+ ms_path ,
842+ num_iterations = 3 ,
843+ num_antennas = None ,
844+ patch_size = 128 ,
845+ stretch = "SQRT" ,
846+ save_flags = True ,
847+ normalize_before_stretch = False ,
848+ normalize_after_stretch = False ,
849+ threshold = None ,
850+ field_id = None ,
851+ ):
852+ """
853+ Iterative prediction per baseline - low memory usage.
854+
855+ Args:
856+ ms_path: Path to measurement set
857+ num_iterations: Number of flagging passes
858+ num_antennas: Number of antennas (None = all)
859+ patch_size: Patch size for prediction
860+ stretch: Stretch function ('SQRT', 'LOG10', or None)
861+ save_flags: If True, save final flags to MS
862+ normalize_before_stretch: Normalize before stretch
863+ normalize_after_stretch: Normalize after stretch
864+ threshold: Probability threshold for RFI detection
865+ field_id: Optional FIELD_ID to load
866+
867+ Returns:
868+ None (flags saved to MS if save_flags=True)
869+ """
870+ logger .info (f"\n { '=' * 60 } " )
871+ logger .info (f"RFI Prediction - Iterative Per Baseline ({ num_iterations } passes)" )
872+ logger .info (f"{ '=' * 60 } " )
873+
874+ # Validate preprocessing parameters
875+ self ._validate_preprocessing_params (
876+ patch_size , stretch , normalize_before_stretch , normalize_after_stretch
877+ )
878+
879+ # Open MS
880+ loader = MSLoader (ms_path , field_id = field_id )
881+ baseline_pairs = loader .get_baseline_pairs (num_antennas )
882+
883+ logger .info (f"\n Processing { len (baseline_pairs )} baselines" )
884+ logger .info (f" Iterations: { num_iterations } " )
885+ logger .info (" Memory: ~1 baseline in RAM at a time" )
886+
887+ # Process each baseline
888+ for ant1 , ant2 in tqdm (baseline_pairs , desc = "Baselines" ):
889+ # Load original data
890+ original_data = loader .load_baseline (ant1 , ant2 , mode = "DATA" , field_id = field_id )
891+
892+ # Initialize cumulative flags for this baseline
893+ cumulative_flags = np .zeros (original_data .shape , dtype = bool )
894+
895+ # Iterative flagging for this baseline
896+ for _iteration in range (num_iterations ):
897+ # Mask previously flagged data
898+ masked_data = np .where (cumulative_flags , np .nan , original_data )
899+
900+ # Add baseline dimension
901+ masked_data = masked_data [np .newaxis , ...]
902+
903+ # Predict
904+ iteration_flags = self .predict_array (
905+ masked_data ,
906+ patch_size = patch_size ,
907+ stretch = stretch ,
908+ normalize_before_stretch = normalize_before_stretch ,
909+ normalize_after_stretch = normalize_after_stretch ,
910+ return_probabilities = False ,
911+ threshold = threshold ,
912+ )[0 ]
913+
914+ # Combine flags
915+ cumulative_flags = cumulative_flags | iteration_flags
916+
917+ # Write final flags for this baseline
918+ if save_flags :
919+ loader .save_baseline_flags (ant1 , ant2 , cumulative_flags , field_id = field_id )
920+
921+ logger .info (f"\n { '=' * 60 } " )
922+ logger .info ("✓ Iterative prediction complete" )
923+ logger .info (f"{ '=' * 60 } " )
924+
764925 def _predict_dataset (
765926 self , dataset , target_size = None , return_probabilities = False , threshold = None
766927 ):
0 commit comments