@@ -119,13 +119,11 @@ void checkRelativeError( real64 const v1, real64 const v2, real64 const relTol,
119119 EXPECT_PRED_FORMAT4 ( checkRelativeErrorFormat, v1, v2, relTol, absTol );
120120}
121121
122- template < typename ROW_INDEX, typename COL_INDEX, typename VALUE >
123- void compareMatrixRow ( ROW_INDEX const rowNumber, VALUE const relTol, VALUE const absTol,
122+ template < typename COL_INDEX, typename VALUE >
123+ void compareMatrixRow ( VALUE const relTol, VALUE const absTol,
124124 localIndex const length1, COL_INDEX const * const indices1, VALUE const * const values1,
125125 localIndex const length2, COL_INDEX const * const indices2, VALUE const * const values2 )
126126{
127- SCOPED_TRACE ( " Row " + std::to_string ( rowNumber ));
128-
129127 EXPECT_EQ ( length1, length2 );
130128
131129 for ( localIndex j1 = 0 , j2 = 0 ; j1 < length1 && j2 < length2; ++j1, ++j2 )
@@ -150,17 +148,31 @@ void compareMatrixRow( ROW_INDEX const rowNumber, VALUE const relTol, VALUE cons
150148 }
151149}
152150
153- template < typename ROW_INDEX, typename COL_INDEX, typename VALUE >
154- void compareMatrixRow ( ROW_INDEX const rowNumber, VALUE const relTol, VALUE const absTol,
155- arraySlice1d< COL_INDEX const > indices1, arraySlice1d< VALUE const > values1,
156- arraySlice1d< COL_INDEX const > indices2, arraySlice1d< VALUE const > values2 )
151+ template < typename T, typename COL_INDEX >
152+ void compareLocalMatrices ( CRSMatrixView< T const , COL_INDEX const > const & matrix1,
153+ CRSMatrixView< T const , COL_INDEX const > const & matrix2,
154+ real64 const relTol = DEFAULT_REL_TOL,
155+ real64 const absTol = DEFAULT_ABS_TOL,
156+ globalIndex const rowOffset = 0 )
157157{
158- ASSERT_EQ ( indices1.size (), values1.size () );
159- ASSERT_EQ ( indices2.size (), values2.size () );
158+ ASSERT_EQ ( matrix1.numRows (), matrix2.numRows () );
159+ ASSERT_EQ ( matrix1.numColumns (), matrix2.numColumns () );
160+
161+ matrix1.move ( LvArray::MemorySpace::host, false );
162+ matrix2.move ( LvArray::MemorySpace::host, false );
160163
161- compareMatrixRow ( rowNumber, relTol, absTol,
162- indices1.size (), indices1.dataIfContiguous (), values1.dataIfContiguous (),
163- indices2.size (), indices2.dataIfContiguous (), values2.dataIfContiguous () );
164+ // check the accuracy across local rows
165+ for ( localIndex i = 0 ; i < matrix1.numRows (); ++i )
166+ {
167+ SCOPED_TRACE ( GEOS_FMT ( " Row {}" , i + rowOffset ) );
168+ compareMatrixRow ( relTol, absTol,
169+ matrix1.numNonZeros ( i ),
170+ matrix1.getColumns ( i ).dataIfContiguous (),
171+ matrix1.getEntries ( i ).dataIfContiguous (),
172+ matrix2.numNonZeros ( i ),
173+ matrix2.getColumns ( i ).dataIfContiguous (),
174+ matrix2.getEntries ( i ).dataIfContiguous () );
175+ }
164176}
165177
166178template < typename MATRIX >
@@ -175,45 +187,10 @@ void compareMatrices( MATRIX const & matrix1,
175187 ASSERT_EQ ( matrix1.numLocalRows (), matrix2.numLocalRows () );
176188 ASSERT_EQ ( matrix1.numLocalCols (), matrix2.numLocalCols () );
177189
178- array1d< globalIndex > indices1, indices2;
179- array1d< real64 > values1, values2;
180-
181- // check the accuracy across local rows
182- for ( globalIndex i = matrix1.ilower (); i < matrix1.iupper (); ++i )
183- {
184- indices1.resize ( matrix1.rowLength ( i ) );
185- values1.resize ( matrix1.rowLength ( i ) );
186- matrix1.getRowCopy ( i, indices1, values1 );
187-
188- indices2.resize ( matrix2.rowLength ( i ) );
189- values2.resize ( matrix2.rowLength ( i ) );
190- matrix2.getRowCopy ( i, indices2, values2 );
190+ CRSMatrix< real64, globalIndex > const mat1 = matrix1.extract ();
191+ CRSMatrix< real64, globalIndex > const mat2 = matrix2.extract ();
191192
192- compareMatrixRow ( i, relTol, absTol,
193- indices1.size (), indices1.data (), values1.data (),
194- indices2.size (), indices2.data (), values2.data () );
195- }
196- }
197-
198- template < typename T, typename COL_INDEX >
199- void compareLocalMatrices ( CRSMatrixView< T const , COL_INDEX const > const & matrix1,
200- CRSMatrixView< T const , COL_INDEX const > const & matrix2,
201- real64 const relTol = DEFAULT_REL_TOL,
202- real64 const absTol = DEFAULT_ABS_TOL )
203- {
204- ASSERT_EQ ( matrix1.numRows (), matrix2.numRows () );
205- ASSERT_EQ ( matrix1.numColumns (), matrix2.numColumns () );
206-
207- matrix1.move ( LvArray::MemorySpace::host, false );
208- matrix2.move ( LvArray::MemorySpace::host, false );
209-
210- // check the accuracy across local rows
211- for ( localIndex i = 0 ; i < matrix1.numRows (); ++i )
212- {
213- compareMatrixRow ( i, relTol, absTol,
214- matrix1.getColumns ( i ), matrix1.getEntries ( i ),
215- matrix2.getColumns ( i ), matrix2.getEntries ( i ) );
216- }
193+ compareLocalMatrices ( mat1.toViewConst (), mat2.toViewConst (), relTol, absTol, matrix1.ilower () );
217194}
218195
219196} // namespace testing
0 commit comments