Skip to content

Commit e655a5a

Browse files
committed
[ENH] Refactoring of nipype.interfaces.utility
Split utility.py into a module
1 parent 7d96581 commit e655a5a

File tree

4 files changed

+312
-256
lines changed

4 files changed

+312
-256
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
3+
# vi: set ft=python sts=4 ts=4 sw=4 et:
4+
"""
5+
Package contains interfaces for using existing functionality in other packages
6+
7+
Requires Packages to be installed
8+
"""
9+
from __future__ import print_function, division, unicode_literals, absolute_import
10+
__docformat__ = 'restructuredtext'
11+
12+
from .base import (IdentityInterface, Rename, Select, Split, Merge,
13+
AssertEqual)
14+
from .csv import CSVReader
15+
from .wrappers import Function
Lines changed: 7 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
>>> os.chdir(datadir)
1111
"""
1212
from __future__ import print_function, division, unicode_literals, absolute_import
13-
from builtins import zip, range, str, open
13+
from builtins import range
1414

1515
from future import standard_library
1616
standard_library.install_aliases()
@@ -20,22 +20,12 @@
2020
import numpy as np
2121
import nibabel as nb
2222

23-
from nipype import logging
24-
from .base import (traits, TraitedSpec, DynamicTraitedSpec, File,
25-
Undefined, isdefined, OutputMultiPath, runtime_profile,
26-
InputMultiPath, BaseInterface, BaseInterfaceInputSpec)
27-
from .io import IOBase, add_traits
28-
from ..utils.filemanip import (filename_to_list, copyfile, split_filename)
29-
from ..utils.misc import getsource, create_function_from_source
30-
31-
logger = logging.getLogger('interface')
32-
if runtime_profile:
33-
try:
34-
import psutil
35-
except ImportError as exc:
36-
logger.info('Unable to import packages needed for runtime profiling. '\
37-
'Turning off runtime profiler. Reason: %s' % exc)
38-
runtime_profile = False
23+
from ..base import (traits, TraitedSpec, DynamicTraitedSpec, File,
24+
Undefined, isdefined, OutputMultiPath, InputMultiPath,
25+
BaseInterface, BaseInterfaceInputSpec)
26+
from ..io import IOBase, add_traits
27+
from ...utils.filemanip import filename_to_list, copyfile, split_filename
28+
3929

4030
class IdentityInterface(IOBase):
4131
"""Basic interface class generates identity mappings
@@ -357,165 +347,6 @@ def _list_outputs(self):
357347
return outputs
358348

359349

360-
class FunctionInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
361-
function_str = traits.Str(mandatory=True, desc='code for function')
362-
363-
364-
class Function(IOBase):
365-
"""Runs arbitrary function as an interface
366-
367-
Examples
368-
--------
369-
370-
>>> func = 'def func(arg1, arg2=5): return arg1 + arg2'
371-
>>> fi = Function(input_names=['arg1', 'arg2'], output_names=['out'])
372-
>>> fi.inputs.function_str = func
373-
>>> res = fi.run(arg1=1)
374-
>>> res.outputs.out
375-
6
376-
377-
"""
378-
379-
input_spec = FunctionInputSpec
380-
output_spec = DynamicTraitedSpec
381-
382-
def __init__(self, input_names, output_names, function=None, imports=None,
383-
**inputs):
384-
"""
385-
386-
Parameters
387-
----------
388-
389-
input_names: single str or list
390-
names corresponding to function inputs
391-
output_names: single str or list
392-
names corresponding to function outputs.
393-
has to match the number of outputs
394-
function : callable
395-
callable python object. must be able to execute in an
396-
isolated namespace (possibly in concert with the ``imports``
397-
parameter)
398-
imports : list of strings
399-
list of import statements that allow the function to execute
400-
in an otherwise empty namespace
401-
"""
402-
403-
super(Function, self).__init__(**inputs)
404-
if function:
405-
if hasattr(function, '__call__'):
406-
try:
407-
self.inputs.function_str = getsource(function)
408-
except IOError:
409-
raise Exception('Interface Function does not accept '
410-
'function objects defined interactively '
411-
'in a python session')
412-
elif isinstance(function, (str, bytes)):
413-
self.inputs.function_str = function
414-
else:
415-
raise Exception('Unknown type of function')
416-
self.inputs.on_trait_change(self._set_function_string,
417-
'function_str')
418-
self._input_names = filename_to_list(input_names)
419-
self._output_names = filename_to_list(output_names)
420-
add_traits(self.inputs, [name for name in self._input_names])
421-
self.imports = imports
422-
self._out = {}
423-
for name in self._output_names:
424-
self._out[name] = None
425-
426-
def _set_function_string(self, obj, name, old, new):
427-
if name == 'function_str':
428-
if hasattr(new, '__call__'):
429-
function_source = getsource(new)
430-
elif isinstance(new, (str, bytes)):
431-
function_source = new
432-
self.inputs.trait_set(trait_change_notify=False,
433-
**{'%s' % name: function_source})
434-
435-
def _add_output_traits(self, base):
436-
undefined_traits = {}
437-
for key in self._output_names:
438-
base.add_trait(key, traits.Any)
439-
undefined_traits[key] = Undefined
440-
base.trait_set(trait_change_notify=False, **undefined_traits)
441-
return base
442-
443-
def _run_interface(self, runtime):
444-
# Get workflow logger for runtime profile error reporting
445-
from nipype import logging
446-
logger = logging.getLogger('workflow')
447-
448-
# Create function handle
449-
function_handle = create_function_from_source(self.inputs.function_str,
450-
self.imports)
451-
452-
# Wrapper for running function handle in multiprocessing.Process
453-
# Can catch exceptions and report output via multiprocessing.Queue
454-
def _function_handle_wrapper(queue, **kwargs):
455-
try:
456-
out = function_handle(**kwargs)
457-
queue.put(out)
458-
except Exception as exc:
459-
queue.put(exc)
460-
461-
# Get function args
462-
args = {}
463-
for name in self._input_names:
464-
value = getattr(self.inputs, name)
465-
if isdefined(value):
466-
args[name] = value
467-
468-
# Profile resources if set
469-
if runtime_profile:
470-
from nipype.interfaces.base import get_max_resources_used
471-
import multiprocessing
472-
# Init communication queue and proc objs
473-
queue = multiprocessing.Queue()
474-
proc = multiprocessing.Process(target=_function_handle_wrapper,
475-
args=(queue,), kwargs=args)
476-
477-
# Init memory and threads before profiling
478-
mem_mb = 0
479-
num_threads = 0
480-
481-
# Start process and profile while it's alive
482-
proc.start()
483-
while proc.is_alive():
484-
mem_mb, num_threads = \
485-
get_max_resources_used(proc.pid, mem_mb, num_threads,
486-
pyfunc=True)
487-
488-
# Get result from process queue
489-
out = queue.get()
490-
# If it is an exception, raise it
491-
if isinstance(out, Exception):
492-
raise out
493-
494-
# Function ran successfully, populate runtime stats
495-
setattr(runtime, 'runtime_memory_gb', mem_mb / 1024.0)
496-
setattr(runtime, 'runtime_threads', num_threads)
497-
else:
498-
out = function_handle(**args)
499-
500-
if len(self._output_names) == 1:
501-
self._out[self._output_names[0]] = out
502-
else:
503-
if isinstance(out, tuple) and (len(out) != len(self._output_names)):
504-
raise RuntimeError('Mismatch in number of expected outputs')
505-
506-
else:
507-
for idx, name in enumerate(self._output_names):
508-
self._out[name] = out[idx]
509-
510-
return runtime
511-
512-
def _list_outputs(self):
513-
outputs = self._outputs().get()
514-
for key in self._output_names:
515-
outputs[key] = self._out[key]
516-
return outputs
517-
518-
519350
class AssertEqualInputSpec(BaseInterfaceInputSpec):
520351
volume1 = File(exists=True, mandatory=True)
521352
volume2 = File(exists=True, mandatory=True)
@@ -532,83 +363,3 @@ def _run_interface(self, runtime):
532363
if not np.all(data1 == data2):
533364
raise RuntimeError('Input images are not exactly equal')
534365
return runtime
535-
536-
537-
class CSVReaderInputSpec(DynamicTraitedSpec, TraitedSpec):
538-
in_file = File(exists=True, mandatory=True, desc='Input comma-seperated value (CSV) file')
539-
header = traits.Bool(False, usedefault=True, desc='True if the first line is a column header')
540-
541-
542-
class CSVReader(BaseInterface):
543-
"""
544-
Examples
545-
--------
546-
547-
>>> reader = CSVReader() # doctest: +SKIP
548-
>>> reader.inputs.in_file = 'noHeader.csv' # doctest: +SKIP
549-
>>> out = reader.run() # doctest: +SKIP
550-
>>> out.outputs.column_0 == ['foo', 'bar', 'baz'] # doctest: +SKIP
551-
True
552-
>>> out.outputs.column_1 == ['hello', 'world', 'goodbye'] # doctest: +SKIP
553-
True
554-
>>> out.outputs.column_2 == ['300.1', '5', '0.3'] # doctest: +SKIP
555-
True
556-
557-
>>> reader = CSVReader() # doctest: +SKIP
558-
>>> reader.inputs.in_file = 'header.csv' # doctest: +SKIP
559-
>>> reader.inputs.header = True # doctest: +SKIP
560-
>>> out = reader.run() # doctest: +SKIP
561-
>>> out.outputs.files == ['foo', 'bar', 'baz'] # doctest: +SKIP
562-
True
563-
>>> out.outputs.labels == ['hello', 'world', 'goodbye'] # doctest: +SKIP
564-
True
565-
>>> out.outputs.erosion == ['300.1', '5', '0.3'] # doctest: +SKIP
566-
True
567-
568-
"""
569-
input_spec = CSVReaderInputSpec
570-
output_spec = DynamicTraitedSpec
571-
_always_run = True
572-
573-
def _append_entry(self, outputs, entry):
574-
for key, value in zip(self._outfields, entry):
575-
outputs[key].append(value)
576-
return outputs
577-
578-
def _parse_line(self, line):
579-
line = line.replace('\n', '')
580-
entry = [x.strip() for x in line.split(',')]
581-
return entry
582-
583-
def _get_outfields(self):
584-
with open(self.inputs.in_file, 'r') as fid:
585-
entry = self._parse_line(fid.readline())
586-
if self.inputs.header:
587-
self._outfields = tuple(entry)
588-
else:
589-
self._outfields = tuple(['column_' + str(x) for x in range(len(entry))])
590-
return self._outfields
591-
592-
def _run_interface(self, runtime):
593-
self._get_outfields()
594-
return runtime
595-
596-
def _outputs(self):
597-
return self._add_output_traits(super(CSVReader, self)._outputs())
598-
599-
def _add_output_traits(self, base):
600-
return add_traits(base, self._get_outfields())
601-
602-
def _list_outputs(self):
603-
outputs = self.output_spec().get()
604-
isHeader = True
605-
for key in self._outfields:
606-
outputs[key] = [] # initialize outfields
607-
with open(self.inputs.in_file, 'r') as fid:
608-
for line in fid.readlines():
609-
if self.inputs.header and isHeader: # skip header line
610-
isHeader = False
611-
continue
612-
entry = self._parse_line(line)
613-
outputs = self._append_entry(outputs, entry)
614-
return outputs

0 commit comments

Comments
 (0)