6060//! // `a` and `b` have been moved, no longer valid
6161//! ```
6262
63- use ndarray:: { s, Array , Array1 , Array2 , ArrayBase , Axis , Data , DataMut , Ix1 , Ix2 } ;
63+ use ndarray:: { s, Array , Array1 , Array2 , ArrayBase , Axis , Data , DataMut , Dimension , Ix0 , Ix1 , Ix2 } ;
6464
6565use crate :: error:: * ;
6666use crate :: lapack:: least_squares:: * ;
6767use crate :: layout:: * ;
6868use crate :: types:: * ;
6969
70- pub trait Ix1OrIx2 < E : Scalar > {
71- type ScalarOrArray1 ;
72- }
73-
74- impl < E : Scalar > Ix1OrIx2 < E > for Ix1 {
75- type ScalarOrArray1 = E :: Real ;
76- }
77-
78- impl < E : Scalar > Ix1OrIx2 < E > for Ix2 {
79- type ScalarOrArray1 = Array1 < E :: Real > ;
80- }
81-
8270/// Result of a LeastSquares computation
8371///
8472/// Takes two type parameters, `E`, the element type of the matrix
@@ -88,7 +76,7 @@ impl<E: Scalar> Ix1OrIx2<E> for Ix2 {
8876/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
8977/// (which can be seen as solving `Ax = b` k times for different b) and
9078/// the solution is a `m x k` matrix.
91- pub struct LeastSquaresResult < E : Scalar , I : Ix1OrIx2 < E > > {
79+ pub struct LeastSquaresResult < E : Scalar , I : Dimension > {
9280 /// The singular values of the matrix A in `Ax = b`
9381 pub singular_values : Array1 < E :: Real > ,
9482 /// The solution vector or matrix `x` which is the best
@@ -97,16 +85,16 @@ pub struct LeastSquaresResult<E: Scalar, I: Ix1OrIx2<E>> {
9785 /// The rank of the matrix A in `Ax = b`
9886 pub rank : i32 ,
9987 /// If n < m and rank(A) == n, the sum of squares
100- /// If b is a (m x 1) vector, this is a single value
101- /// If b is a m x k matrix, this is a k x 1 column vector
102- pub residual_sum_of_squares : Option < I :: ScalarOrArray1 > ,
88+ /// If b is a (m x 1) vector, this is a 0-dimensional array ( single value)
89+ /// If b is a ( m x k) matrix, this is a ( k x 1) column vector
90+ pub residual_sum_of_squares : Option < Array < E :: Real , I :: Smaller > > ,
10391}
10492/// Solve least squares for immutable references
10593pub trait LeastSquaresSvd < D , E , I >
10694where
10795 D : Data < Elem = E > ,
10896 E : Scalar + Lapack ,
109- I : Ix1OrIx2 < E > ,
97+ I : Dimension ,
11098{
11199 /// Solve a least squares problem of the form `Ax = rhs`
112100 /// by calling `A.least_squares(&rhs)`. `A` and `rhs`
@@ -123,7 +111,7 @@ pub trait LeastSquaresSvdInto<D, E, I>
123111where
124112 D : Data < Elem = E > ,
125113 E : Scalar + Lapack ,
126- I : Ix1OrIx2 < E > ,
114+ I : Dimension ,
127115{
128116 /// Solve a least squares problem of the form `Ax = rhs`
129117 /// by calling `A.least_squares(rhs)`, consuming both `A`
@@ -142,7 +130,7 @@ pub trait LeastSquaresSvdInPlace<D, E, I>
142130where
143131 D : Data < Elem = E > ,
144132 E : Scalar + Lapack ,
145- I : Ix1OrIx2 < E > ,
133+ I : Dimension ,
146134{
147135 /// Solve a least squares problem of the form `Ax = rhs`
148136 /// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
@@ -328,11 +316,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
328316 n : usize ,
329317 rank : i32 ,
330318 b : & ArrayBase < D , Ix1 > ,
331- ) -> Option < E :: Real > {
319+ ) -> Option < Array < E :: Real , Ix0 > > {
332320 if m < n || n != rank as usize {
333321 return None ;
334322 }
335- Some ( b. slice ( s ! [ n..] ) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) )
323+ let mut arr: Array < E :: Real , Ix0 > = Array :: zeros ( ( ) ) ;
324+ arr[ ( ) ] = b. slice ( s ! [ n..] ) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
325+ Some ( arr)
336326}
337327
338328/// Solve least squares for mutable references and a matrix
@@ -429,11 +419,10 @@ mod tests {
429419 use ndarray:: { ArcArray1 , ArcArray2 , Array1 , Array2 , CowArray } ;
430420 use num_complex:: Complex ;
431421
432- ///////////////////////////////////////////////////////////////////////////
433- /// Test cases taken from the scipy test suite for the scipy lstsq function
434- /// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
435- ///////////////////////////////////////////////////////////////////////////
436-
422+ //
423+ // Test cases taken from the scipy test suite for the scipy lstsq function
424+ // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425+ //
437426 #[ test]
438427 fn scipy_test_simple_exact ( ) {
439428 let a = array ! [ [ 1. , 20. ] , [ -30. , 4. ] ] ;
@@ -463,10 +452,7 @@ mod tests {
463452 assert_eq ! ( res. rank, 2 ) ;
464453 let b_hat = a. dot ( & res. solution ) ;
465454 let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
466- assert ! ( res
467- . residual_sum_of_squares
468- . unwrap( )
469- . abs_diff_eq( & rssq, 1e-12 ) ) ;
455+ assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
470456 assert ! ( res
471457 . solution
472458 . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
@@ -480,10 +466,7 @@ mod tests {
480466 assert_eq ! ( res. rank, 2 ) ;
481467 let b_hat = a. dot ( & res. solution ) ;
482468 let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
483- assert ! ( res
484- . residual_sum_of_squares
485- . unwrap( )
486- . abs_diff_eq( & rssq, 1e-6 ) ) ;
469+ assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-6 ) ) ;
487470 assert ! ( res
488471 . solution
489472 . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-6 ) ) ;
@@ -505,10 +488,7 @@ mod tests {
505488 assert_eq ! ( res. rank, 2 ) ;
506489 let b_hat = a. dot ( & res. solution ) ;
507490 let rssq = ( & b_hat - & b) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
508- assert ! ( res
509- . residual_sum_of_squares
510- . unwrap( )
511- . abs_diff_eq( & rssq, 1e-12 ) ) ;
491+ assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
512492 assert ! ( res. solution. abs_diff_eq(
513493 & array![
514494 c( -0.4831460674157303 , 0.258426966292135 ) ,
@@ -546,18 +526,18 @@ mod tests {
546526 assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
547527 }
548528
549- ///////////////////////////////////////////////////////////////////////////
550- /// Test that the different lest squares traits work as intended on the
551- /// different array types.
552- ///
553- /// | least_squares | ls_into | ls_in_place |
554- /// --------------+---------------+---------+-------------+
555- /// Array | yes | yes | yes |
556- /// ArcArray | yes | no | no |
557- /// CowArray | yes | yes | yes |
558- /// ArrayView | yes | no | no |
559- /// ArrayViewMut | yes | no | yes |
560- ///////////////////////////////////////////////////////////////////////////
529+ //
530+ // Test that the different lest squares traits work as intended on the
531+ // different array types.
532+ //
533+ // | least_squares | ls_into | ls_in_place |
534+ // --------------+---------------+---------+-------------+
535+ // Array | yes | yes | yes |
536+ // ArcArray | yes | no | no |
537+ // CowArray | yes | yes | yes |
538+ // ArrayView | yes | no | no |
539+ // ArrayViewMut | yes | no | yes |
540+ //
561541
562542 fn assert_result < D : Data < Elem = f64 > > (
563543 a : & ArrayBase < D , Ix2 > ,
@@ -567,10 +547,7 @@ mod tests {
567547 assert_eq ! ( res. rank, 2 ) ;
568548 let b_hat = a. dot ( & res. solution ) ;
569549 let rssq = ( b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
570- assert ! ( res
571- . residual_sum_of_squares
572- . unwrap( )
573- . abs_diff_eq( & rssq, 1e-12 ) ) ;
550+ assert ! ( res. residual_sum_of_squares. as_ref( ) . unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
574551 assert ! ( res
575552 . solution
576553 . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
@@ -674,10 +651,10 @@ mod tests {
674651 assert_result ( & a, & b, & res) ;
675652 }
676653
677- ///////////////////////////////////////////////////////////////////////////
678- /// Test cases taken from the netlib documentation at
679- /// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
680- ///////////////////////////////////////////////////////////////////////////
654+ //
655+ // Test cases taken from the netlib documentation at
656+ // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657+ //
681658 #[ test]
682659 fn netlib_lapack_example_for_dgels_1 ( ) {
683660 let a: Array2 < f64 > = array ! [
@@ -694,7 +671,7 @@ mod tests {
694671
695672 let residual = b - a. dot ( & result. solution ) ;
696673 let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
697- assert ! ( ( resid_ssq - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
674+ assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
698675 }
699676
700677 #[ test]
@@ -713,7 +690,7 @@ mod tests {
713690
714691 let residual = b - a. dot ( & result. solution ) ;
715692 let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
716- assert ! ( ( resid_ssq - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
693+ assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
717694 }
718695
719696 #[ test]
@@ -738,9 +715,9 @@ mod tests {
738715 . abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
739716 }
740717
741- ///////////////////////////////////////////////////////////////////////////
742- /// Testing error cases
743- ///////////////////////////////////////////////////////////////////////////
718+ //
719+ // Testing error cases
720+ //
744721 use crate :: layout:: MatrixLayout ;
745722 use ndarray:: ErrorKind ;
746723
0 commit comments