@@ -38,13 +38,35 @@ class WarpPointsOutputSpec(TraitedSpec):
3838
3939class WarpPoints (BaseInterface ):
4040 """
41- Applies a warp field to a point set in vtk
41+ Applies a displacement field to a point set in vtk
42+
43+ Example
44+ -------
45+
46+ >>> from nipype.algorithms.mesh import WarpPoints
47+ >>> wp = mesh.P2PDistance()
48+ >>> wp.inputs.points = 'surf1.vtk'
49+ >>> wp.inputs.warp = 'warpfield.nii'
50+ >>> res = wp.run() # doctest: +SKIP
4251 """
4352 input_spec = WarpPointsInputSpec
4453 output_spec = WarpPointsOutputSpec
4554
46- def _overload_extension (self , value , name ):
47- return value + '.vtk'
55+ def _gen_fname (self , in_file , suffix = 'generated' , ext = None ):
56+ import os .path as op
57+
58+ fname , fext = op .splitext (op .basename (in_file ))
59+
60+ if fext == '.gz' :
61+ fname , fext2 = op .splitext (fname )
62+ fext = fext2 + fext
63+
64+ if ext is None :
65+ ext = fext
66+
67+ if ext [0 ] == '.' :
68+ ext = ext [1 :]
69+ return op .abspath ('%s_%s.%s' % (fname , suffix , ext ))
4870
4971 def _run_interface (self , runtime ):
5072 from tvtk .api import tvtk
@@ -60,30 +82,41 @@ def _run_interface(self, runtime):
6082
6183 affine = warp_dims [0 ].get_affine ()
6284 voxsize = warp_dims [0 ].get_header ().get_zooms ()
63- R = np .linalg .inv (affine [0 :3 ,0 :3 ])
85+ vox2ras = affine [0 :3 ,0 :3 ]
86+ ras2vox = np .linalg .inv (vox2ras )
6487 origin = affine [0 :3 ,3 ]
65- points = points - origin [np .newaxis ,:]
66- points = np . array ([ np . dot ( R , p ) for p in points ])
88+ voxpoints = np . array ( [np .dot ( ras2vox ,
89+ ( p - origin ) ) for p in points ])
6790
6891 warps = []
6992 for axis in warp_dims :
7093 wdata = axis .get_data ()
7194 if np .any (wdata != 0 ):
72- warps .append ([ndimage .map_coordinates (wdata , points .transpose ())])
95+
96+ warp = ndimage .map_coordinates (wdata ,
97+ voxpoints .transpose ())
7398 else :
74- warps .append ([np .zeros ((points .shape [0 ],))])
75- warps = np .squeeze (np .array (warps )).reshape (- 1 ,3 )
76- print warps .shape
77- print points .shape
99+ warp = np .zeros ((points .shape [0 ],))
100+
101+ warps .append (warp )
78102
79- newpoints = [ p + d for p ,d in zip (points , warps )]
103+ disps = np .squeeze (np .dstack (warps ))
104+ newpoints = [p + d for p ,d in zip (points , disps )]
80105 mesh .points = newpoints
81106 w = tvtk .PolyDataWriter (input = mesh )
82- w .file_name = self ._filename_from_source ('out_points' )
107+ w .file_name = self ._gen_fname (self .inputs .points ,
108+ suffix = 'warped' ,
109+ ext = '.vtk' )
83110 w .write ()
84-
85111 return runtime
86112
113+ def _list_outputs (self ):
114+ outputs = self ._outputs ().get ()
115+ outputs ['out_points' ] = self ._gen_fname (self .inputs .points ,
116+ suffix = 'warped' ,
117+ ext = '.vtk' )
118+ return outputs
119+
87120class P2PDistanceInputSpec (BaseInterfaceInputSpec ):
88121 surface1 = File (exists = True , mandatory = True ,
89122 desc = ("Reference surface (vtk format) to which compute "
0 commit comments