@@ -701,6 +701,33 @@ where
701701 }
702702}
703703
704+ /// Attempt to merge axes if possible, starting from the back
705+ ///
706+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
707+ /// to merge all axes one by one into Axis(3); when/if this fails,
708+ /// it attempts to merge the rest of the axes together into the next
709+ /// axis in line, for example a result could be:
710+ ///
711+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
712+ /// mean axes were merged.
713+ pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
714+ where
715+ D: Dimension,
716+ {
717+ debug_assert_eq!(dim.ndim(), strides.ndim());
718+ match dim.ndim() {
719+ 0 | 1 => {}
720+ n => {
721+ let mut last = n - 1;
722+ for i in (0..last).rev() {
723+ if !merge_axes(dim, strides, Axis(i), Axis(last)) {
724+ last = i;
725+ }
726+ }
727+ }
728+ }
729+ }
730+
704731/// Move the axis which has the smallest absolute stride and a length
705732/// greater than one to be the last axis.
706733pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
@@ -765,12 +792,40 @@ where
765792 *strides = new_strides;
766793}
767794
795+
796+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
797+ /// stride
798+ ///
799+ /// The axes are sorted according to the .abs() of their stride.
800+ pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
801+ where
802+ D: Dimension,
803+ {
804+ debug_assert!(dim.ndim() > 1);
805+ debug_assert_eq!(dim.ndim(), strides.ndim());
806+ // bubble sort axes
807+ let mut changed = true;
808+ while changed {
809+ changed = false;
810+ for i in 0..dim.ndim() - 1 {
811+ // make sure higher stride axes sort before.
812+ if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
813+ changed = true;
814+ dim.slice_mut().swap(i, i + 1);
815+ strides.slice_mut().swap(i, i + 1);
816+ }
817+ }
818+ }
819+ }
820+
821+
768822#[cfg(test)]
769823mod test {
770824 use super::{
771825 arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
772826 max_abs_offset_check_overflow, slice_min_max, slices_intersect,
773827 solve_linear_diophantine_eq, IntoDimension, squeeze,
828+ merge_axes_from_the_back,
774829 };
775830 use crate::error::{from_kind, ErrorKind};
776831 use crate::slice::Slice;
@@ -1119,4 +1174,26 @@ mod test {
11191174 assert_eq!(d, dans);
11201175 assert_eq!(s, sans);
11211176 }
1177+
1178+ #[test]
1179+ fn test_merge_axes_from_the_back() {
1180+ let dyndim = Dim::<&[usize]>;
1181+
1182+ let mut d = Dim([3, 4, 5]);
1183+ let mut s = Dim([20, 5, 1]);
1184+ merge_axes_from_the_back(&mut d, &mut s);
1185+ assert_eq!(d, Dim([1, 1, 60]));
1186+ assert_eq!(s, Dim([20, 5, 1]));
1187+
1188+ let mut d = Dim([3, 4, 5, 2]);
1189+ let mut s = Dim([80, 20, 2, 1]);
1190+ merge_axes_from_the_back(&mut d, &mut s);
1191+ assert_eq!(d, Dim([1, 12, 1, 10]));
1192+ assert_eq!(s, Dim([80, 20, 2, 1]));
1193+ let mut d = d.into_dyn();
1194+ let mut s = s.into_dyn();
1195+ squeeze(&mut d, &mut s);
1196+ assert_eq!(d, dyndim(&[12, 10]));
1197+ assert_eq!(s, dyndim(&[20, 1]));
1198+ }
11221199}
0 commit comments