2020
2121from .. import config , logging
2222
23- from ..interfaces .base import (BaseInterface , traits , TraitedSpec , File ,
24- InputMultiPath , BaseInterfaceInputSpec ,
25- isdefined )
23+ from ..interfaces .base import (
24+ SimpleInterface , BaseInterface , traits , TraitedSpec , File ,
25+ InputMultiPath , BaseInterfaceInputSpec ,
26+ isdefined )
2627from ..interfaces .nipy .base import NipyBaseInterface
27- from ..utils import NUMPY_MMAP
2828
2929iflogger = logging .getLogger ('interface' )
3030
@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
383383 File (exists = True ),
384384 mandatory = True ,
385385 desc = 'Test image. Requires the same dimensions as in_ref.' )
386+ in_mask = File (exists = True , desc = 'calculate overlap only within mask' )
386387 weighting = traits .Enum (
387388 'none' ,
388389 'volume' ,
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
403404class FuzzyOverlapOutputSpec (TraitedSpec ):
404405 jaccard = traits .Float (desc = 'Fuzzy Jaccard Index (fJI), all the classes' )
405406 dice = traits .Float (desc = 'Fuzzy Dice Index (fDI), all the classes' )
406- diff_file = File (
407- exists = True ,
408- desc =
409- 'resulting difference-map of all classes, using the chosen weighting' )
410407 class_fji = traits .List (
411408 traits .Float (),
412409 desc = 'Array containing the fJIs of each computed class' )
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
415412 desc = 'Array containing the fDIs of each computed class' )
416413
417414
418- class FuzzyOverlap (BaseInterface ):
415+ class FuzzyOverlap (SimpleInterface ):
419416 """Calculates various overlap measures between two maps, using the fuzzy
420417 definition proposed in: Crum et al., Generalized Overlap Measures for
421418 Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
@@ -439,77 +436,75 @@ class FuzzyOverlap(BaseInterface):
439436 output_spec = FuzzyOverlapOutputSpec
440437
441438 def _run_interface (self , runtime ):
442- ncomp = len (self .inputs .in_ref )
443- assert (ncomp == len (self .inputs .in_tst ))
444- weights = np .ones (shape = ncomp )
445-
446- img_ref = np .array ([
447- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
448- for fname in self .inputs .in_ref
449- ])
450- img_tst = np .array ([
451- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
452- for fname in self .inputs .in_tst
453- ])
454-
455- msk = np .sum (img_ref , axis = 0 )
456- msk [msk > 0 ] = 1.0
457- tst_msk = np .sum (img_tst , axis = 0 )
458- tst_msk [tst_msk > 0 ] = 1.0
459-
460- # check that volumes are normalized
461- # img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
462- # img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]
463-
464- self ._jaccards = []
465- volumes = []
466-
467- diff_im = np .zeros (img_ref .shape )
468-
469- for ref_comp , tst_comp , diff_comp in zip (img_ref , img_tst , diff_im ):
470- num = np .minimum (ref_comp , tst_comp )
471- ddr = np .maximum (ref_comp , tst_comp )
472- diff_comp [ddr > 0 ] += 1.0 - (num [ddr > 0 ] / ddr [ddr > 0 ])
473- self ._jaccards .append (np .sum (num ) / np .sum (ddr ))
474- volumes .append (np .sum (ref_comp ))
475-
476- self ._dices = 2.0 * (np .array (self ._jaccards ) /
477- (np .array (self ._jaccards ) + 1.0 ))
439+ # Load data
440+ refdata = nb .concat_images (self .inputs .in_ref ).get_data ()
441+ tstdata = nb .concat_images (self .inputs .in_tst ).get_data ()
442+
443+ # Data must have same shape
444+ if not refdata .shape == tstdata .shape :
445+ raise RuntimeError (
446+ 'Size of "in_tst" %s must match that of "in_ref" %s.' %
447+ (tstdata .shape , refdata .shape ))
448+
449+ ncomp = refdata .shape [- 1 ]
478450
451+ # Load mask
452+ mask = np .ones_like (refdata , dtype = bool )
453+ if isdefined (self .inputs .in_mask ):
454+ mask = nb .load (self .inputs .in_mask ).get_data ()
455+ mask = mask > 0
456+ mask = np .repeat (mask [..., np .newaxis ], ncomp , - 1 )
457+ assert mask .shape == refdata .shape
458+
459+ # Drop data outside mask
460+ refdata = refdata [mask ]
461+ tstdata = tstdata [mask ]
462+
463+ if np .any (refdata < 0.0 ):
464+ iflogger .warning ('Negative values encountered in "in_ref" input, '
465+ 'taking absolute values.' )
466+ refdata = np .abs (refdata )
467+
468+ if np .any (tstdata < 0.0 ):
469+ iflogger .warning ('Negative values encountered in "in_tst" input, '
470+ 'taking absolute values.' )
471+ tstdata = np .abs (tstdata )
472+
473+ if np .any (refdata > 1.0 ):
474+ iflogger .warning ('Values greater than 1.0 found in "in_ref" input, '
475+ 'scaling values.' )
476+ refdata /= refdata .max ()
477+
478+ if np .any (tstdata > 1.0 ):
479+ iflogger .warning ('Values greater than 1.0 found in "in_tst" input, '
480+ 'scaling values.' )
481+ tstdata /= tstdata .max ()
482+
483+ numerators = np .atleast_2d (
484+ np .minimum (refdata , tstdata ).reshape ((- 1 , ncomp )))
485+ denominators = np .atleast_2d (
486+ np .maximum (refdata , tstdata ).reshape ((- 1 , ncomp )))
487+
488+ jaccards = numerators .sum (axis = 0 ) / denominators .sum (axis = 0 )
489+
490+ # Calculate weights
491+ weights = np .ones_like (jaccards , dtype = float )
479492 if self .inputs .weighting != "none" :
480- weights = 1.0 / np .array (volumes )
493+ volumes = np .sum ((refdata + tstdata ) > 0 , axis = 1 ).reshape ((- 1 , ncomp ))
494+ weights = 1.0 / volumes
481495 if self .inputs .weighting == "squared_vol" :
482496 weights = weights ** 2
483497
484498 weights = weights / np .sum (weights )
499+ dices = 2.0 * jaccards / (jaccards + 1.0 )
485500
486- setattr (self , '_jaccard' , np .sum (weights * self ._jaccards ))
487- setattr (self , '_dice' , np .sum (weights * self ._dices ))
488-
489- diff = np .zeros (diff_im [0 ].shape )
490-
491- for w , ch in zip (weights , diff_im ):
492- ch [msk == 0 ] = 0
493- diff += w * ch
494-
495- nb .save (
496- nb .Nifti1Image (diff ,
497- nb .load (self .inputs .in_ref [0 ]).affine ,
498- nb .load (self .inputs .in_ref [0 ]).header ),
499- self .inputs .out_file )
500-
501+ # Fill-in the results object
502+ self ._results ['jaccard' ] = float (weights .dot (jaccards ))
503+ self ._results ['dice' ] = float (weights .dot (dices ))
504+ self ._results ['class_fji' ] = [float (v ) for v in jaccards ]
505+ self ._results ['class_fdi' ] = [float (v ) for v in dices ]
501506 return runtime
502507
503- def _list_outputs (self ):
504- outputs = self ._outputs ().get ()
505- for method in ("dice" , "jaccard" ):
506- outputs [method ] = getattr (self , '_' + method )
507- # outputs['volume_difference'] = self._volume
508- outputs ['diff_file' ] = os .path .abspath (self .inputs .out_file )
509- outputs ['class_fji' ] = np .array (self ._jaccards ).astype (float ).tolist ()
510- outputs ['class_fdi' ] = self ._dices .astype (float ).tolist ()
511- return outputs
512-
513508
514509class ErrorMapInputSpec (BaseInterfaceInputSpec ):
515510 in_ref = File (
0 commit comments