@@ -600,12 +600,13 @@ def symmetrize_wavefunctions(self,m):
600600 raise ImportError ("numpy is not available." )
601601 csize = len (self ._basis_functions )
602602 (d1 ,d2 ) = m .shape
603- if not (d1 == csize and d2 == csize and m .dtype == np .float64 ):
604- raise ValueError ("Must provide a " + str (csize ) + "x" + str (csize ) + " numpy.float64 array" )
603+ if not (d1 == csize and d2 == csize ):
604+ raise ValueError ("Must provide a " + str (csize ) + "x" + str (csize ))
605+ wf = np .ascontiguousarray (m , dtype = np .float64 )
605606 partner_functions = (PartnerFunction * csize )()
606607 species = np .zeros ((csize ),dtype = np .int32 )
607- self ._assert_success (_lib .msymSymmetrizeWavefunctions (self ._ctx ,csize ,m ,species ,partner_functions ))
608- return (m , species , partner_functions [0 :csize ])
608+ self ._assert_success (_lib .msymSymmetrizeWavefunctions (self ._ctx ,csize ,wf ,species ,partner_functions ))
609+ return (wf , species , partner_functions [0 :csize ])
609610
610611 def generate_elements (self , elements ):
611612 if not self ._ctx :
@@ -625,10 +626,11 @@ def generate_elements(self, elements):
625626 def symmetry_species_components (self , wf ):
626627
627628 wf_size = len (wf )
628- if not ( wf_size == len (self .basis_functions ) and wf . dtype == np . float64 ):
629- raise ValueError ("Must provide a numpy.float64 array of length " + str (len (self .basis_functions )))
629+ if not wf_size == len (self .basis_functions ):
630+ raise ValueError ("Must provide an array of length " + str (len (self .basis_functions )))
630631 species_size = self .character_table ._d
631632 species = np .zeros ((species_size ),dtype = np .float64 )
633+ wf = np .ascontiguousarray (wf , dtype = np .float64 )
632634 self ._assert_success (_lib .msymSymmetrySpeciesComponents (self ._ctx , wf_size , wf , species_size , species ))
633635 return species
634636
0 commit comments