Skip to content

Commit e144b48

Browse files
committed
Updating predictor and cli to work with baseline based iteration
1 parent f3d8205 commit e144b48

2 files changed

Lines changed: 227 additions & 30 deletions

File tree

src/samrfi/cli.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -419,37 +419,68 @@ def predict_command(args):
419419
# Determine if iterative
420420
num_iterations = args.iterations if args.iterations else 1
421421
is_iterative = num_iterations > 1
422-
423-
if is_iterative:
424-
print(f"\nMode: Iterative flagging ({num_iterations} passes)")
425-
flags = predictor.predict_iterative(
426-
ms_path=args.input,
427-
num_iterations=num_iterations,
428-
num_antennas=args.num_antennas,
429-
patch_size=args.patch_size,
430-
stretch=stretch,
431-
save_flags=not args.no_save,
432-
apply_existing_flags=args.apply_existing,
433-
threshold=threshold,
434-
)
422+
per_baseline = args.per_baseline
423+
424+
if per_baseline:
425+
# Per-baseline mode (low memory)
426+
if is_iterative:
427+
print(f"\nMode: Iterative per-baseline ({num_iterations} passes, low memory)")
428+
predictor.predict_iterative_per_baseline(
429+
ms_path=args.input,
430+
num_iterations=num_iterations,
431+
num_antennas=args.num_antennas,
432+
patch_size=args.patch_size,
433+
stretch=stretch,
434+
save_flags=not args.no_save,
435+
threshold=threshold,
436+
)
437+
else:
438+
print("\nMode: Per-baseline flagging (low memory)")
439+
predictor.predict_ms_per_baseline(
440+
ms_path=args.input,
441+
num_antennas=args.num_antennas,
442+
patch_size=args.patch_size,
443+
stretch=stretch,
444+
save_flags=not args.no_save,
445+
threshold=threshold,
446+
)
447+
print("\n" + "=" * 60)
448+
print("Prediction Complete!")
449+
print("=" * 60)
450+
if not args.no_save:
451+
print(f"Flags saved to: {args.input}")
435452
else:
436-
print("\nMode: Single-pass flagging")
437-
flags = predictor.predict_ms(
438-
ms_path=args.input,
439-
num_antennas=args.num_antennas,
440-
patch_size=args.patch_size,
441-
stretch=stretch,
442-
apply_existing_flags=args.apply_existing,
443-
save_flags=not args.no_save,
444-
threshold=threshold,
445-
)
446-
447-
print("\n" + "=" * 60)
448-
print("Prediction Complete!")
449-
print("=" * 60)
450-
print(f"Total flagged: {flags.sum()/flags.size*100:.2f}%")
451-
if not args.no_save:
452-
print(f"Flags saved to: {args.input}")
453+
# Original mode (greedy)
454+
if is_iterative:
455+
print(f"\nMode: Iterative flagging ({num_iterations} passes)")
456+
flags = predictor.predict_iterative(
457+
ms_path=args.input,
458+
num_iterations=num_iterations,
459+
num_antennas=args.num_antennas,
460+
patch_size=args.patch_size,
461+
stretch=stretch,
462+
save_flags=not args.no_save,
463+
apply_existing_flags=args.apply_existing,
464+
threshold=threshold,
465+
)
466+
else:
467+
print("\nMode: Single-pass flagging")
468+
flags = predictor.predict_ms(
469+
ms_path=args.input,
470+
num_antennas=args.num_antennas,
471+
patch_size=args.patch_size,
472+
stretch=stretch,
473+
apply_existing_flags=args.apply_existing,
474+
save_flags=not args.no_save,
475+
threshold=threshold,
476+
)
477+
478+
print("\n" + "=" * 60)
479+
print("Prediction Complete!")
480+
print("=" * 60)
481+
print(f"Total flagged: {flags.sum()/flags.size*100:.2f}%")
482+
if not args.no_save:
483+
print(f"Flags saved to: {args.input}")
453484

454485

455486
def evaluate_command(args):
@@ -676,6 +707,11 @@ def main():
676707
predict_parser.add_argument(
677708
"--no-save", action="store_true", help="Do not save flags to MS (prediction only)"
678709
)
710+
predict_parser.add_argument(
711+
"--per-baseline",
712+
action="store_true",
713+
help="Process one baseline at a time (low memory usage)",
714+
)
679715

680716
# Evaluate parser
681717
evaluate_parser = subparsers.add_parser(

src/samrfi/inference/predictor.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,81 @@ def predict_ms(
632632

633633
return predicted_flags
634634

635+
def predict_ms_per_baseline(
636+
self,
637+
ms_path,
638+
num_antennas=None,
639+
patch_size=128,
640+
stretch="SQRT",
641+
save_flags=True,
642+
normalize_before_stretch=False,
643+
normalize_after_stretch=False,
644+
threshold=None,
645+
field_id=None,
646+
):
647+
"""
648+
Predict on MS one baseline at a time - low memory usage.
649+
650+
Args:
651+
ms_path: Path to measurement set
652+
num_antennas: Number of antennas (None = all)
653+
patch_size: Patch size for prediction
654+
stretch: Stretch function ('SQRT', 'LOG10', or None)
655+
save_flags: If True, save flags back to MS
656+
normalize_before_stretch: Normalize before stretch
657+
normalize_after_stretch: Normalize after stretch
658+
threshold: Probability threshold for RFI detection
659+
field_id: Optional FIELD_ID to load
660+
661+
Returns:
662+
None (flags saved to MS if save_flags=True)
663+
"""
664+
logger.info(f"\n{'='*60}")
665+
logger.info("RFI Prediction - Per Baseline")
666+
logger.info(f"{'='*60}")
667+
668+
# Validate preprocessing parameters
669+
self._validate_preprocessing_params(
670+
patch_size, stretch, normalize_before_stretch, normalize_after_stretch
671+
)
672+
673+
# Open MS
674+
loader = MSLoader(ms_path, field_id=field_id)
675+
baseline_pairs = loader.get_baseline_pairs(num_antennas)
676+
677+
logger.info(f"\nProcessing {len(baseline_pairs)} baselines")
678+
logger.info(f" Patch size: {patch_size}")
679+
logger.info(" Memory: ~1 baseline in RAM at a time")
680+
681+
# Process each baseline
682+
for ant1, ant2 in tqdm(baseline_pairs, desc="Baselines"):
683+
# Load one baseline
684+
baseline_data = loader.load_baseline(ant1, ant2, mode="DATA", field_id=field_id)
685+
686+
# Add baseline dimension
687+
baseline_data = baseline_data[np.newaxis, ...] # (1, pols, channels, times)
688+
689+
# Predict
690+
baseline_flags = self.predict_array(
691+
baseline_data,
692+
patch_size=patch_size,
693+
stretch=stretch,
694+
normalize_before_stretch=normalize_before_stretch,
695+
normalize_after_stretch=normalize_after_stretch,
696+
return_probabilities=False,
697+
threshold=threshold,
698+
)[
699+
0
700+
] # Remove baseline dimension
701+
702+
# Write flags
703+
if save_flags:
704+
loader.save_baseline_flags(ant1, ant2, baseline_flags, field_id=field_id)
705+
706+
logger.info(f"\n{'='*60}")
707+
logger.info("✓ Prediction complete")
708+
logger.info(f"{'='*60}")
709+
635710
def predict_iterative(
636711
self,
637712
ms_path,
@@ -761,6 +836,92 @@ def predict_iterative(
761836

762837
return cumulative_flags
763838

839+
def predict_iterative_per_baseline(
840+
self,
841+
ms_path,
842+
num_iterations=3,
843+
num_antennas=None,
844+
patch_size=128,
845+
stretch="SQRT",
846+
save_flags=True,
847+
normalize_before_stretch=False,
848+
normalize_after_stretch=False,
849+
threshold=None,
850+
field_id=None,
851+
):
852+
"""
853+
Iterative prediction per baseline - low memory usage.
854+
855+
Args:
856+
ms_path: Path to measurement set
857+
num_iterations: Number of flagging passes
858+
num_antennas: Number of antennas (None = all)
859+
patch_size: Patch size for prediction
860+
stretch: Stretch function ('SQRT', 'LOG10', or None)
861+
save_flags: If True, save final flags to MS
862+
normalize_before_stretch: Normalize before stretch
863+
normalize_after_stretch: Normalize after stretch
864+
threshold: Probability threshold for RFI detection
865+
field_id: Optional FIELD_ID to load
866+
867+
Returns:
868+
None (flags saved to MS if save_flags=True)
869+
"""
870+
logger.info(f"\n{'='*60}")
871+
logger.info(f"RFI Prediction - Iterative Per Baseline ({num_iterations} passes)")
872+
logger.info(f"{'='*60}")
873+
874+
# Validate preprocessing parameters
875+
self._validate_preprocessing_params(
876+
patch_size, stretch, normalize_before_stretch, normalize_after_stretch
877+
)
878+
879+
# Open MS
880+
loader = MSLoader(ms_path, field_id=field_id)
881+
baseline_pairs = loader.get_baseline_pairs(num_antennas)
882+
883+
logger.info(f"\nProcessing {len(baseline_pairs)} baselines")
884+
logger.info(f" Iterations: {num_iterations}")
885+
logger.info(" Memory: ~1 baseline in RAM at a time")
886+
887+
# Process each baseline
888+
for ant1, ant2 in tqdm(baseline_pairs, desc="Baselines"):
889+
# Load original data
890+
original_data = loader.load_baseline(ant1, ant2, mode="DATA", field_id=field_id)
891+
892+
# Initialize cumulative flags for this baseline
893+
cumulative_flags = np.zeros(original_data.shape, dtype=bool)
894+
895+
# Iterative flagging for this baseline
896+
for _iteration in range(num_iterations):
897+
# Mask previously flagged data
898+
masked_data = np.where(cumulative_flags, np.nan, original_data)
899+
900+
# Add baseline dimension
901+
masked_data = masked_data[np.newaxis, ...]
902+
903+
# Predict
904+
iteration_flags = self.predict_array(
905+
masked_data,
906+
patch_size=patch_size,
907+
stretch=stretch,
908+
normalize_before_stretch=normalize_before_stretch,
909+
normalize_after_stretch=normalize_after_stretch,
910+
return_probabilities=False,
911+
threshold=threshold,
912+
)[0]
913+
914+
# Combine flags
915+
cumulative_flags = cumulative_flags | iteration_flags
916+
917+
# Write final flags for this baseline
918+
if save_flags:
919+
loader.save_baseline_flags(ant1, ant2, cumulative_flags, field_id=field_id)
920+
921+
logger.info(f"\n{'='*60}")
922+
logger.info("✓ Iterative prediction complete")
923+
logger.info(f"{'='*60}")
924+
764925
def _predict_dataset(
765926
self, dataset, target_size=None, return_probabilities=False, threshold=None
766927
):

0 commit comments

Comments
 (0)