diff --git a/compiler/rustc_abi/src/callconv.rs b/compiler/rustc_abi/src/callconv.rs index 7ad7088b30899..d6594e277f00c 100644 --- a/compiler/rustc_abi/src/callconv.rs +++ b/compiler/rustc_abi/src/callconv.rs @@ -75,10 +75,11 @@ impl<'a, Ty> TyAndLayout<'a, Ty> { Ok(HomogeneousAggregate::Homogeneous(Reg { kind, size: self.size })) } - BackendRepr::SimdVector { .. } => { + BackendRepr::SimdVector { element, count: _ } => { assert!(!self.is_zst()); + Ok(HomogeneousAggregate::Homogeneous(Reg { - kind: RegKind::Vector, + kind: RegKind::Vector { hint_vector_elem: element.primitive() }, size: self.size, })) } diff --git a/compiler/rustc_abi/src/callconv/reg.rs b/compiler/rustc_abi/src/callconv/reg.rs index 66d4dca00726f..881ebad619467 100644 --- a/compiler/rustc_abi/src/callconv/reg.rs +++ b/compiler/rustc_abi/src/callconv/reg.rs @@ -1,14 +1,21 @@ #[cfg(feature = "nightly")] use rustc_macros::HashStable_Generic; -use crate::{Align, HasDataLayout, Size}; +use crate::{Align, HasDataLayout, Integer, Primitive, Size}; #[cfg_attr(feature = "nightly", derive(HashStable_Generic))] #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum RegKind { Integer, Float, - Vector, + Vector { + /// The `hint_vector_elem` is strictly for optimization purposes and can be safely ignored (e.g. + /// by always picking i8) by codegen backends. + /// + /// The element kind is used to provide more accurate type information to the backend, which + /// helps with optimization (e.g. because it prevents extra bitcasts that obscure a pattern). + hint_vector_elem: Primitive, + }, } #[cfg_attr(feature = "nightly", derive(HashStable_Generic))] @@ -36,6 +43,12 @@ impl Reg { reg_ctor!(f32, Float, 32); reg_ctor!(f64, Float, 64); reg_ctor!(f128, Float, 128); + + /// A vector of the given size with an unknown (and irrelevant) element type. + pub fn opaque_vector(size: Size) -> Reg { + // Default to an i8 vector of the given size. + Reg { kind: RegKind::Vector { hint_vector_elem: Primitive::Int(Integer::I8, true) }, size } + } } impl Reg { @@ -58,7 +71,7 @@ impl Reg { 128 => dl.f128_align, _ => panic!("unsupported float: {self:?}"), }, - RegKind::Vector => dl.llvmlike_vector_align(self.size), + RegKind::Vector { .. } => dl.llvmlike_vector_align(self.size), } } } diff --git a/compiler/rustc_codegen_cranelift/src/abi/pass_mode.rs b/compiler/rustc_codegen_cranelift/src/abi/pass_mode.rs index 44b63aa95f83c..0283263cc6047 100644 --- a/compiler/rustc_codegen_cranelift/src/abi/pass_mode.rs +++ b/compiler/rustc_codegen_cranelift/src/abi/pass_mode.rs @@ -26,7 +26,9 @@ fn reg_to_abi_param(reg: Reg) -> AbiParam { (RegKind::Float, 4) => types::F32, (RegKind::Float, 8) => types::F64, (RegKind::Float, 16) => types::F128, - (RegKind::Vector, size) => types::I8.by(u32::try_from(size).unwrap()).unwrap(), + (RegKind::Vector { hint_vector_elem: _ }, size) => { + types::I8.by(u32::try_from(size).unwrap()).unwrap() + } _ => unreachable!("{:?}", reg), }; AbiParam::new(clif_ty) diff --git a/compiler/rustc_codegen_gcc/src/abi.rs b/compiler/rustc_codegen_gcc/src/abi.rs index 4d1274a63d1fe..8277231f16a54 100644 --- a/compiler/rustc_codegen_gcc/src/abi.rs +++ b/compiler/rustc_codegen_gcc/src/abi.rs @@ -90,7 +90,9 @@ impl GccType for Reg { 64 => cx.type_f64(), _ => bug!("unsupported float: {:?}", self), }, - RegKind::Vector => cx.type_vector(cx.type_i8(), self.size.bytes()), + RegKind::Vector { hint_vector_elem: _ } => { + cx.type_vector(cx.type_i8(), self.size.bytes()) + } } } } diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index a6e841e440a21..4afa63920bb87 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -2,8 +2,8 @@ use std::cmp; use libc::c_uint; use rustc_abi::{ - ArmCall, BackendRepr, CanonAbi, HasDataLayout, InterruptKind, Primitive, Reg, RegKind, Size, - X86Call, + ArmCall, BackendRepr, CanonAbi, Float, HasDataLayout, Integer, InterruptKind, Primitive, Reg, + RegKind, Size, X86Call, }; use rustc_codegen_ssa::MemFlags; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -137,7 +137,30 @@ impl LlvmType for Reg { 128 => cx.type_f128(), _ => bug!("unsupported float: {:?}", self), }, - RegKind::Vector => cx.type_vector(cx.type_i8(), self.size.bytes()), + RegKind::Vector { hint_vector_elem } => { + // NOTE: it is valid to ignore the element type hint (and always pick i8). + // But providing a more accurate type means fewer casts in LLVM IR, + // which helps with optimization. + let ty = match hint_vector_elem { + Primitive::Int(integer, _) => match integer { + Integer::I8 => cx.type_ix(8), + Integer::I16 => cx.type_ix(16), + Integer::I32 => cx.type_ix(32), + Integer::I64 => cx.type_ix(64), + Integer::I128 => cx.type_ix(128), + }, + Primitive::Float(float) => match float { + Float::F16 => cx.type_f16(), + Float::F32 => cx.type_f32(), + Float::F64 => cx.type_f64(), + Float::F128 => cx.type_f128(), + }, + Primitive::Pointer(_) => cx.type_ptr(), + }; + + let len = self.size.bytes() / hint_vector_elem.size(cx).bytes(); + cx.type_vector(ty, len) + } } } } diff --git a/compiler/rustc_codegen_ssa/src/mir/naked_asm.rs b/compiler/rustc_codegen_ssa/src/mir/naked_asm.rs index ccd5fbcaec8b5..7a3128c028753 100644 --- a/compiler/rustc_codegen_ssa/src/mir/naked_asm.rs +++ b/compiler/rustc_codegen_ssa/src/mir/naked_asm.rs @@ -416,7 +416,7 @@ fn wasm_type<'tcx>(signature: &mut String, arg_abi: &ArgAbi<'_, Ty<'tcx>>, ptr_t ..=8 => "f64", _ => ptr_type, }, - RegKind::Vector => "v128", + RegKind::Vector { .. } => "v128", }; signature.push_str(wrapped_wasm_type); diff --git a/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs b/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs index 0921e57844b03..41904d7905ef7 100644 --- a/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs +++ b/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs @@ -25,8 +25,11 @@ fn passes_vectors_by_value(mode: &PassMode, repr: &BackendRepr) -> UsesVectorReg match mode { PassMode::Ignore | PassMode::Indirect { .. } => UsesVectorRegisters::No, PassMode::Cast { pad_i32: _, cast } - if cast.prefix.iter().any(|r| r.is_some_and(|x| x.kind == RegKind::Vector)) - || cast.rest.unit.kind == RegKind::Vector => + if cast + .prefix + .iter() + .any(|r| r.is_some_and(|x| matches!(x.kind, RegKind::Vector { .. }))) + || matches!(cast.rest.unit.kind, RegKind::Vector { .. }) => { UsesVectorRegisters::FixedVector } diff --git a/compiler/rustc_target/src/callconv/aarch64.rs b/compiler/rustc_target/src/callconv/aarch64.rs index e9a19aa7024bb..ce69427cbdd59 100644 --- a/compiler/rustc_target/src/callconv/aarch64.rs +++ b/compiler/rustc_target/src/callconv/aarch64.rs @@ -35,7 +35,7 @@ where // The softfloat ABI treats floats like integers, so they // do not get homogeneous aggregate treatment. RegKind::Float => cx.target_spec().rustc_abi != Some(RustcAbi::Softfloat), - RegKind::Vector => size.bits() == 64 || size.bits() == 128, + RegKind::Vector { .. } => size.bits() == 64 || size.bits() == 128, }; valid_unit.then_some(Uniform::consecutive(unit, size)) diff --git a/compiler/rustc_target/src/callconv/arm.rs b/compiler/rustc_target/src/callconv/arm.rs index 4c1ff27aac509..41c3a0a0210fb 100644 --- a/compiler/rustc_target/src/callconv/arm.rs +++ b/compiler/rustc_target/src/callconv/arm.rs @@ -19,7 +19,7 @@ where let valid_unit = match unit.kind { RegKind::Integer => false, RegKind::Float => true, - RegKind::Vector => size.bits() == 64 || size.bits() == 128, + RegKind::Vector { .. } => size.bits() == 64 || size.bits() == 128, }; valid_unit.then_some(Uniform::consecutive(unit, size)) diff --git a/compiler/rustc_target/src/callconv/powerpc64.rs b/compiler/rustc_target/src/callconv/powerpc64.rs index d807617491d12..6a8a6841781c9 100644 --- a/compiler/rustc_target/src/callconv/powerpc64.rs +++ b/compiler/rustc_target/src/callconv/powerpc64.rs @@ -36,7 +36,7 @@ where let valid_unit = match unit.kind { RegKind::Integer => false, RegKind::Float => true, - RegKind::Vector => arg.layout.size.bits() == 128, + RegKind::Vector { .. } => arg.layout.size.bits() == 128, }; valid_unit.then_some(Uniform::consecutive(unit, arg.layout.size)) diff --git a/compiler/rustc_target/src/callconv/s390x.rs b/compiler/rustc_target/src/callconv/s390x.rs index a2ff6f5a3a03b..581c1e2e862c5 100644 --- a/compiler/rustc_target/src/callconv/s390x.rs +++ b/compiler/rustc_target/src/callconv/s390x.rs @@ -3,7 +3,7 @@ use rustc_abi::{BackendRepr, HasDataLayout, TyAbiInterface}; -use crate::callconv::{ArgAbi, FnAbi, Reg, RegKind}; +use crate::callconv::{ArgAbi, FnAbi, Reg}; use crate::spec::{Env, HasTargetSpec, Os}; fn classify_ret(ret: &mut ArgAbi<'_, Ty>) { @@ -51,7 +51,7 @@ where if arg.layout.is_single_vector_element(cx, size) { // pass non-transparent wrappers around a vector as `PassMode::Cast` - arg.cast_to(Reg { kind: RegKind::Vector, size }); + arg.cast_to(Reg::opaque_vector(size)); return; } } diff --git a/compiler/rustc_target/src/callconv/x86.rs b/compiler/rustc_target/src/callconv/x86.rs index 9aaa411db6c05..81ff1a2a45900 100644 --- a/compiler/rustc_target/src/callconv/x86.rs +++ b/compiler/rustc_target/src/callconv/x86.rs @@ -1,9 +1,8 @@ use rustc_abi::{ - AddressSpace, Align, BackendRepr, HasDataLayout, Primitive, Reg, RegKind, TyAbiInterface, - TyAndLayout, + AddressSpace, Align, BackendRepr, HasDataLayout, Primitive, Reg, RegKind, TyAndLayout, }; -use crate::callconv::{ArgAttribute, FnAbi, PassMode}; +use crate::callconv::{ArgAttribute, FnAbi, PassMode, TyAbiInterface}; use crate::spec::{HasTargetSpec, RustcAbi}; #[derive(PartialEq)] @@ -175,7 +174,7 @@ pub(crate) fn fill_inregs<'a, Ty, C>( // At this point we know this must be a primitive of sorts. let unit = arg.layout.homogeneous_aggregate(cx).unwrap().unit().unwrap(); assert_eq!(unit.size, arg.layout.size); - if matches!(unit.kind, RegKind::Float | RegKind::Vector) { + if matches!(unit.kind, RegKind::Float | RegKind::Vector { .. }) { continue; } @@ -226,7 +225,7 @@ where // This is a single scalar that fits into an SSE register, and the target uses the // SSE ABI. We prefer this over integer registers as float scalars need to be in SSE // registers for float operations, so that's the best place to pass them around. - fn_abi.ret.cast_to(Reg { kind: RegKind::Vector, size: fn_abi.ret.layout.size }); + fn_abi.ret.cast_to(Reg::opaque_vector(fn_abi.ret.layout.size)); } else if fn_abi.ret.layout.size <= Primitive::Pointer(AddressSpace::ZERO).size(cx) { // Same size or smaller than pointer, return in an integer register. fn_abi.ret.cast_to(Reg { kind: RegKind::Integer, size: fn_abi.ret.layout.size }); diff --git a/compiler/rustc_target/src/callconv/x86_64.rs b/compiler/rustc_target/src/callconv/x86_64.rs index dc73907c0c18a..3055d18ffa014 100644 --- a/compiler/rustc_target/src/callconv/x86_64.rs +++ b/compiler/rustc_target/src/callconv/x86_64.rs @@ -151,7 +151,7 @@ fn reg_component(cls: &[Option], i: &mut usize, size: Size) -> Option Reg::f64(), } } else { - Reg { kind: RegKind::Vector, size: Size::from_bytes(8) * (vec_len as u64) } + Reg::opaque_vector(Size::from_bytes(8) * (vec_len as u64)) }) } Some(c) => unreachable!("reg_component: unhandled class {:?}", c), diff --git a/compiler/rustc_target/src/callconv/x86_win64.rs b/compiler/rustc_target/src/callconv/x86_win64.rs index 624563c68e5b9..cece9d032b53a 100644 --- a/compiler/rustc_target/src/callconv/x86_win64.rs +++ b/compiler/rustc_target/src/callconv/x86_win64.rs @@ -1,4 +1,4 @@ -use rustc_abi::{BackendRepr, Float, Integer, Primitive, RegKind, Size, TyAbiInterface}; +use rustc_abi::{BackendRepr, Float, Integer, Primitive, Size, TyAbiInterface}; use crate::callconv::{ArgAbi, FnAbi, Reg}; use crate::spec::{HasTargetSpec, RustcAbi}; @@ -33,8 +33,7 @@ where } else { // `i128` is returned in xmm0 by Clang and GCC // FIXME(#134288): This may change for the `-msvc` targets in the future. - let reg = Reg { kind: RegKind::Vector, size: Size::from_bits(128) }; - a.cast_to(reg); + a.cast_to(Reg::opaque_vector(Size::from_bits(128))); } } else if a.layout.size.bytes() > 8 && !matches!(scalar.primitive(), Primitive::Float(Float::F128)) diff --git a/tests/assembly-llvm/aarch64-vld2-s16.rs b/tests/assembly-llvm/aarch64-vld2-s16.rs new file mode 100644 index 0000000000000..137422f300199 --- /dev/null +++ b/tests/assembly-llvm/aarch64-vld2-s16.rs @@ -0,0 +1,46 @@ +//@ assembly-output: emit-asm +//@ compile-flags: -Copt-level=3 +//@ only-aarch64-unknown-linux-gnu +#![feature(repr_simd, portable_simd, core_intrinsics, f16, f128)] +#![crate_type = "lib"] +#![allow(non_camel_case_types)] + +// Test `vld_s16` can be implemented in a portable way (i.e. without using LLVM neon intrinsics). +// This relies on rust preserving the SIMD vector element type and using it to construct the +// LLVM type. Without this information, additional casts are needed that defeat the LLVM pattern +// matcher, see https://github.com/llvm/llvm-project/issues/181514. + +use std::mem::transmute; +use std::simd::Simd; + +#[unsafe(no_mangle)] +#[target_feature(enable = "neon")] +unsafe extern "C" fn vld2_s16_old(ptr: *const i16) -> std::arch::aarch64::int16x4x2_t { + // CHECK-LABEL: vld2_s16_old + // CHECK: .cfi_startproc + // CHECK-NEXT: ld2 { v0.4h, v1.4h }, [x0] + // CHECK-NEXT: ret + std::arch::aarch64::vld2_s16(ptr) +} + +#[unsafe(no_mangle)] +#[target_feature(enable = "neon")] +unsafe extern "C" fn vld2_s16_new(a: *const i16) -> std::arch::aarch64::int16x4x2_t { + // CHECK-LABEL: vld2_s16_new + // CHECK: .cfi_startproc + // CHECK-NEXT: ld2 { v0.4h, v1.4h }, [x0] + // CHECK-NEXT: ret + + type V = Simd; + type W = Simd; + + let w: W = std::ptr::read_unaligned(a as *const W); + + #[repr(simd)] + pub(crate) struct SimdShuffleIdx([u32; LEN]); + + let v0: V = std::intrinsics::simd::simd_shuffle(w, w, const { SimdShuffleIdx([0, 2, 4, 6]) }); + let v1: V = std::intrinsics::simd::simd_shuffle(w, w, const { SimdShuffleIdx([1, 3, 5, 7]) }); + + transmute((v0, v1)) +} diff --git a/tests/codegen-llvm/preserve-vec-element-types.rs b/tests/codegen-llvm/preserve-vec-element-types.rs new file mode 100644 index 0000000000000..bfa45177a9521 --- /dev/null +++ b/tests/codegen-llvm/preserve-vec-element-types.rs @@ -0,0 +1,98 @@ +// ignore-tidy-linelength +//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled --target=aarch64-unknown-linux-gnu +//@ needs-llvm-components: aarch64 +//@ add-minicore +#![feature(no_core, repr_simd, f16, f128)] +#![crate_type = "lib"] +#![no_std] +#![no_core] +#![allow(non_camel_case_types)] + +// Test that the SIMD vector element type is preserved. This is not required for correctness, but +// useful for optimization. It prevents additional bitcasts that make LLVM patterns fail. + +extern crate minicore; +use minicore::*; + +#[repr(simd)] +pub struct Simd([T; N]); + +#[repr(C)] +struct Pair(T, T); + +#[repr(C)] +struct Triple(T, T, T); + +#[repr(C)] +struct Quad(T, T, T, T); + +#[rustfmt::skip] +mod tests { + use super::*; + + // CHECK: define [2 x <8 x i8>] @pair_int8x8_t([2 x <8 x i8>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_int8x8_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <4 x i16>] @pair_int16x4_t([2 x <4 x i16>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_int16x4_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <2 x i32>] @pair_int32x2_t([2 x <2 x i32>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_int32x2_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <1 x i64>] @pair_int64x1_t([2 x <1 x i64>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_int64x1_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <4 x half>] @pair_float16x4_t([2 x <4 x half>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_float16x4_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <2 x float>] @pair_float32x2_t([2 x <2 x float>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_float32x2_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <1 x double>] @pair_float64x1_t([2 x <1 x double>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_float64x1_t(x: Pair>) -> Pair> { x } + + // CHECK: define [2 x <1 x ptr>] @pair_ptrx1_t([2 x <1 x ptr>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn pair_ptrx1_t(x: Pair>) -> Pair> { x } + + // When it fits in a 128-bit register, it's passed directly. + + // CHECK: define [4 x <4 x i8>] @quad_int8x4_t([4 x <4 x i8>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn quad_int8x4_t(x: Quad>) -> Quad> { x } + + // CHECK: define [4 x <2 x i16>] @quad_int16x2_t([4 x <2 x i16>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn quad_int16x2_t(x: Quad>) -> Quad> { x } + + // CHECK: define [4 x <1 x i32>] @quad_int32x1_t([4 x <1 x i32>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn quad_int32x1_t(x: Quad>) -> Quad> { x } + + // CHECK: define [4 x <2 x half>] @quad_float16x2_t([4 x <2 x half>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn quad_float16x2_t(x: Quad>) -> Quad> { x } + + // CHECK: define [4 x <1 x float>] @quad_float32x1_t([4 x <1 x float>] {{.*}} %0) + #[unsafe(no_mangle)] extern "C" fn quad_float32x1_t(x: Quad>) -> Quad> { x } + + // When it doesn't quite fit, padding is added which does erase the type. + + // CHECK: define [2 x i64] @triple_int8x4_t + #[unsafe(no_mangle)] extern "C" fn triple_int8x4_t(x: Triple>) -> Triple> { x } + + // Other configurations are not passed by-value but indirectly. + + // CHECK: define void @pair_int128x1_t + #[unsafe(no_mangle)] extern "C" fn pair_int128x1_t(x: Pair>) -> Pair> { x } + + // CHECK: define void @pair_float128x1_t + #[unsafe(no_mangle)] extern "C" fn pair_float128x1_t(x: Pair>) -> Pair> { x } + + // CHECK: define void @pair_int8x16_t + #[unsafe(no_mangle)] extern "C" fn pair_int8x16_t(x: Pair>) -> Pair> { x } + + // CHECK: define void @pair_int16x8_t + #[unsafe(no_mangle)] extern "C" fn pair_int16x8_t(x: Pair>) -> Pair> { x } + + // CHECK: define void @triple_int16x8_t + #[unsafe(no_mangle)] extern "C" fn triple_int16x8_t(x: Triple>) -> Triple> { x } + + // CHECK: define void @quad_int16x8_t + #[unsafe(no_mangle)] extern "C" fn quad_int16x8_t(x: Quad>) -> Quad> { x } +}