Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions compiler/rustc_abi/src/callconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
Ok(HomogeneousAggregate::Homogeneous(Reg { kind, size: self.size }))
}

BackendRepr::SimdVector { .. } => {
BackendRepr::SimdVector { element, count: _ } => {
assert!(!self.is_zst());

Ok(HomogeneousAggregate::Homogeneous(Reg {
kind: RegKind::Vector,
kind: RegKind::Vector { hint_vector_elem: element.primitive() },
size: self.size,
}))
}
Expand Down
19 changes: 16 additions & 3 deletions compiler/rustc_abi/src/callconv/reg.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
#[cfg(feature = "nightly")]
use rustc_macros::HashStable_Generic;

use crate::{Align, HasDataLayout, Size};
use crate::{Align, HasDataLayout, Integer, Primitive, Size};

#[cfg_attr(feature = "nightly", derive(HashStable_Generic))]
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum RegKind {
Integer,
Float,
Vector,
Vector {
/// The `hint_vector_elem` is strictly for optimization purposes and can be safely ignored (e.g.
/// by always picking i8) by codegen backends.
///
/// The element kind is used to provide more accurate type information to the backend, which
/// helps with optimization (e.g. because it prevents extra bitcasts that obscure a pattern).
hint_vector_elem: Primitive,
},
}

#[cfg_attr(feature = "nightly", derive(HashStable_Generic))]
Expand Down Expand Up @@ -36,6 +43,12 @@ impl Reg {
reg_ctor!(f32, Float, 32);
reg_ctor!(f64, Float, 64);
reg_ctor!(f128, Float, 128);

/// A vector of the given size with an unknown (and irrelevant) element type.
pub fn opaque_vector(size: Size) -> Reg {
// Default to an i8 vector of the given size.
Reg { kind: RegKind::Vector { hint_vector_elem: Primitive::Int(Integer::I8, true) }, size }
}
}

impl Reg {
Expand All @@ -58,7 +71,7 @@ impl Reg {
128 => dl.f128_align,
_ => panic!("unsupported float: {self:?}"),
},
RegKind::Vector => dl.llvmlike_vector_align(self.size),
RegKind::Vector { .. } => dl.llvmlike_vector_align(self.size),
}
}
}
4 changes: 3 additions & 1 deletion compiler/rustc_codegen_cranelift/src/abi/pass_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ fn reg_to_abi_param(reg: Reg) -> AbiParam {
(RegKind::Float, 4) => types::F32,
(RegKind::Float, 8) => types::F64,
(RegKind::Float, 16) => types::F128,
(RegKind::Vector, size) => types::I8.by(u32::try_from(size).unwrap()).unwrap(),
(RegKind::Vector { hint_vector_elem: _ }, size) => {
types::I8.by(u32::try_from(size).unwrap()).unwrap()
}
_ => unreachable!("{:?}", reg),
};
AbiParam::new(clif_ty)
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_codegen_gcc/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ impl GccType for Reg {
64 => cx.type_f64(),
_ => bug!("unsupported float: {:?}", self),
},
RegKind::Vector => cx.type_vector(cx.type_i8(), self.size.bytes()),
RegKind::Vector { hint_vector_elem: _ } => {
cx.type_vector(cx.type_i8(), self.size.bytes())
}
}
}
}
Expand Down
23 changes: 22 additions & 1 deletion compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,28 @@ impl LlvmType for Reg {
128 => cx.type_f128(),
_ => bug!("unsupported float: {:?}", self),
},
RegKind::Vector => cx.type_vector(cx.type_i8(), self.size.bytes()),
RegKind::Vector { hint_vector_elem } => {
// NOTE: it is valid to ignore the element type hint (and always pick i8).
// But providing a more accurate type means fewer casts in LLVM IR,
// which helps with optimization.
let ty = match hint_vector_elem {
Primitive::Int(integer, _) => match integer.size().bits() {
bits @ (8 | 16 | 32 | 64 | 128) => cx.type_ix(bits),
bits => panic!("unsupported vector integer element size: {bits}"),
},
Primitive::Float(float) => match float.size().bits() {
16 => cx.type_f16(),
32 => cx.type_f32(),
64 => cx.type_f64(),
128 => cx.type_f128(),
bits => panic!("unsupported vector float element size: {bits}"),
},
Primitive::Pointer(_) => cx.type_ptr(),
};

let len = self.size.bytes() / hint_vector_elem.size(cx).bytes();
cx.type_vector(ty, len)
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/naked_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ fn wasm_type<'tcx>(signature: &mut String, arg_abi: &ArgAbi<'_, Ty<'tcx>>, ptr_t
..=8 => "f64",
_ => ptr_type,
},
RegKind::Vector => "v128",
RegKind::Vector { .. } => "v128",
};

signature.push_str(wrapped_wasm_type);
Expand Down
7 changes: 5 additions & 2 deletions compiler/rustc_monomorphize/src/mono_checks/abi_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ fn passes_vectors_by_value(mode: &PassMode, repr: &BackendRepr) -> UsesVectorReg
match mode {
PassMode::Ignore | PassMode::Indirect { .. } => UsesVectorRegisters::No,
PassMode::Cast { pad_i32: _, cast }
if cast.prefix.iter().any(|r| r.is_some_and(|x| x.kind == RegKind::Vector))
|| cast.rest.unit.kind == RegKind::Vector =>
if cast
.prefix
.iter()
.any(|r| r.is_some_and(|x| matches!(x.kind, RegKind::Vector { .. })))
|| matches!(cast.rest.unit.kind, RegKind::Vector { .. }) =>
{
UsesVectorRegisters::FixedVector
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_target/src/callconv/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
// The softfloat ABI treats floats like integers, so they
// do not get homogeneous aggregate treatment.
RegKind::Float => cx.target_spec().rustc_abi != Some(RustcAbi::Softfloat),
RegKind::Vector => size.bits() == 64 || size.bits() == 128,
RegKind::Vector { .. } => size.bits() == 64 || size.bits() == 128,
};

valid_unit.then_some(Uniform::consecutive(unit, size))
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_target/src/callconv/arm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
let valid_unit = match unit.kind {
RegKind::Integer => false,
RegKind::Float => true,
RegKind::Vector => size.bits() == 64 || size.bits() == 128,
RegKind::Vector { .. } => size.bits() == 64 || size.bits() == 128,
};

valid_unit.then_some(Uniform::consecutive(unit, size))
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_target/src/callconv/powerpc64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ where
let valid_unit = match unit.kind {
RegKind::Integer => false,
RegKind::Float => true,
RegKind::Vector => arg.layout.size.bits() == 128,
RegKind::Vector { .. } => arg.layout.size.bits() == 128,
};

valid_unit.then_some(Uniform::consecutive(unit, arg.layout.size))
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_target/src/callconv/s390x.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use rustc_abi::{BackendRepr, HasDataLayout, TyAbiInterface};

use crate::callconv::{ArgAbi, FnAbi, Reg, RegKind};
use crate::callconv::{ArgAbi, FnAbi, Reg};
use crate::spec::{Env, HasTargetSpec, Os};

fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
Expand Down Expand Up @@ -51,7 +51,7 @@ where

if arg.layout.is_single_vector_element(cx, size) {
// pass non-transparent wrappers around a vector as `PassMode::Cast`
arg.cast_to(Reg { kind: RegKind::Vector, size });
arg.cast_to(Reg::opaque_vector(size));
return;
}
}
Expand Down
9 changes: 4 additions & 5 deletions compiler/rustc_target/src/callconv/x86.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use rustc_abi::{
AddressSpace, Align, BackendRepr, HasDataLayout, Primitive, Reg, RegKind, TyAbiInterface,
TyAndLayout,
AddressSpace, Align, BackendRepr, HasDataLayout, Primitive, Reg, RegKind, TyAndLayout,
};

use crate::callconv::{ArgAttribute, FnAbi, PassMode};
use crate::callconv::{ArgAttribute, FnAbi, PassMode, TyAbiInterface};
use crate::spec::{HasTargetSpec, RustcAbi};

#[derive(PartialEq)]
Expand Down Expand Up @@ -175,7 +174,7 @@ pub(crate) fn fill_inregs<'a, Ty, C>(
// At this point we know this must be a primitive of sorts.
let unit = arg.layout.homogeneous_aggregate(cx).unwrap().unit().unwrap();
assert_eq!(unit.size, arg.layout.size);
if matches!(unit.kind, RegKind::Float | RegKind::Vector) {
if matches!(unit.kind, RegKind::Float | RegKind::Vector { .. }) {
continue;
}

Expand Down Expand Up @@ -226,7 +225,7 @@ where
// This is a single scalar that fits into an SSE register, and the target uses the
// SSE ABI. We prefer this over integer registers as float scalars need to be in SSE
// registers for float operations, so that's the best place to pass them around.
fn_abi.ret.cast_to(Reg { kind: RegKind::Vector, size: fn_abi.ret.layout.size });
fn_abi.ret.cast_to(Reg::opaque_vector(fn_abi.ret.layout.size));
} else if fn_abi.ret.layout.size <= Primitive::Pointer(AddressSpace::ZERO).size(cx) {
// Same size or smaller than pointer, return in an integer register.
fn_abi.ret.cast_to(Reg { kind: RegKind::Integer, size: fn_abi.ret.layout.size });
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_target/src/callconv/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn reg_component(cls: &[Option<Class>], i: &mut usize, size: Size) -> Option<Reg
_ => Reg::f64(),
}
} else {
Reg { kind: RegKind::Vector, size: Size::from_bytes(8) * (vec_len as u64) }
Reg::opaque_vector(Size::from_bytes(8) * (vec_len as u64))
})
}
Some(c) => unreachable!("reg_component: unhandled class {:?}", c),
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_target/src/callconv/x86_win64.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_abi::{BackendRepr, Float, Integer, Primitive, RegKind, Size, TyAbiInterface};
use rustc_abi::{BackendRepr, Float, Integer, Primitive, Size, TyAbiInterface};

use crate::callconv::{ArgAbi, FnAbi, Reg};
use crate::spec::{HasTargetSpec, RustcAbi};
Expand Down Expand Up @@ -33,8 +33,7 @@ where
} else {
// `i128` is returned in xmm0 by Clang and GCC
// FIXME(#134288): This may change for the `-msvc` targets in the future.
let reg = Reg { kind: RegKind::Vector, size: Size::from_bits(128) };
a.cast_to(reg);
a.cast_to(Reg::opaque_vector(Size::from_bits(128)));
}
} else if a.layout.size.bytes() > 8
&& !matches!(scalar.primitive(), Primitive::Float(Float::F128))
Expand Down
48 changes: 48 additions & 0 deletions tests/assembly-llvm/aarch64-vld2-s16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//@ assembly-output: emit-asm
//@ compile-flags: -Copt-level=3
//@ only-aarch64-unknown-linux-gnu
#![feature(repr_simd, portable_simd, core_intrinsics, f16, f128)]
#![crate_type = "lib"]
#![allow(non_camel_case_types)]

// Test `vld_s16` can be implemented in a portable way (i.e. without using LLVM neon intrinsics).
// This relies on rust preserving the SIMD vector element type and using it to construct the
// LLVM type. Without this information, additional casts are needed that defeat the LLVM pattern
// matcher, see https://github.com/llvm/llvm-project/issues/181514.

use std::mem::transmute;
use std::simd::Simd;

#[unsafe(no_mangle)]
#[target_feature(enable = "neon")]
unsafe extern "C" fn vld2_s16_old(ptr: *const i16) -> std::arch::aarch64::int16x4x2_t {
// CHECK-LABEL: vld2_s16_old
// CHECK: .cfi_startproc
// CHECK-NEXT: ld2 { v0.4h, v1.4h }, [x0]
// CHECK-NEXT: ret
std::arch::aarch64::vld2_s16(ptr)
}

#[unsafe(no_mangle)]
#[target_feature(enable = "neon")]
unsafe extern "C" fn vld2_s16_new(a: *const i16) -> std::arch::aarch64::int16x4x2_t {
// CHECK-LABEL: vld2_s16_new
// CHECK: .cfi_startproc
// CHECK-NEXT: ld2 { v0.4h, v1.4h }, [x0]
// CHECK-NEXT: ret

type V = Simd<i16, 4>;
type W = Simd<i16, 8>;

let w: W = std::ptr::read_unaligned(a as *const W);

#[repr(simd)]
pub(crate) struct SimdShuffleIdx<const LEN: usize>([u32; LEN]);

let v0: V =
std::intrinsics::simd::simd_shuffle(w, w, const { SimdShuffleIdx([0u32, 2, 4, 6]) });
let v1: V =
std::intrinsics::simd::simd_shuffle(w, w, const { SimdShuffleIdx([1u32, 3, 5, 7]) });

transmute((v0, v1))
}
98 changes: 98 additions & 0 deletions tests/codegen-llvm/preserve-vec-element-types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// ignore-tidy-linelength
//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled --target=aarch64-unknown-linux-gnu
//@ needs-llvm-components: aarch64
//@ add-minicore
#![feature(no_core, repr_simd, f16, f128)]
#![crate_type = "lib"]
#![no_std]
#![no_core]
#![allow(non_camel_case_types)]

// Test that the SIMD vector element type is preserved. This is not required for correctness, but
// useful for optimization. It prevents additional bitcasts that make LLVM patterns fail.

extern crate minicore;
use minicore::*;

#[repr(simd)]
pub struct Simd<T, const N: usize>([T; N]);

#[repr(C)]
struct Pair<T>(T, T);

#[repr(C)]
struct Triple<T>(T, T, T);

#[repr(C)]
struct Quad<T>(T, T, T, T);

#[rustfmt::skip]
mod tests {
use super::*;

// CHECK: define [2 x <8 x i8>] @pair_int8x8_t([2 x <8 x i8>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int8x8_t(x: Pair<Simd<i8, 8>>) -> Pair<Simd<i8, 8>> { x }

// CHECK: define [2 x <4 x i16>] @pair_int16x4_t([2 x <4 x i16>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int16x4_t(x: Pair<Simd<i16, 4>>) -> Pair<Simd<i16, 4>> { x }

// CHECK: define [2 x <2 x i32>] @pair_int32x2_t([2 x <2 x i32>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int32x2_t(x: Pair<Simd<i32, 2>>) -> Pair<Simd<i32, 2>> { x }

// CHECK: define [2 x <1 x i64>] @pair_int64x1_t([2 x <1 x i64>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int64x1_t(x: Pair<Simd<i64, 1>>) -> Pair<Simd<i64, 1>> { x }

// CHECK: define [2 x <4 x half>] @pair_float16x4_t([2 x <4 x half>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_float16x4_t(x: Pair<Simd<f16, 4>>) -> Pair<Simd<f16, 4>> { x }

// CHECK: define [2 x <2 x float>] @pair_float32x2_t([2 x <2 x float>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_float32x2_t(x: Pair<Simd<f32, 2>>) -> Pair<Simd<f32, 2>> { x }

// CHECK: define [2 x <1 x double>] @pair_float64x1_t([2 x <1 x double>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_float64x1_t(x: Pair<Simd<f64, 1>>) -> Pair<Simd<f64, 1>> { x }

// CHECK: define [2 x <1 x ptr>] @pair_ptrx1_t([2 x <1 x ptr>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_ptrx1_t(x: Pair<Simd<*const (), 1>>) -> Pair<Simd<*const (), 1>> { x }

// When it fits in a 128-bit register, it's passed directly.
Copy link
Copy Markdown
Contributor Author

@folkertdev folkertdev Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

View changes since the review

the types above are actually used by neon intrinsics. Below are a couple that technically work but are unlikely to come up practically.

Then for any type is smaller than 128-bit padding is added which means the type information is lost (but I think that is needed for ABI reasons?). Larger types are passed indirectly, so the type information is not needed there (but we do still technically provide it, maybe it's useful elsewhere).


// CHECK: define [4 x <4 x i8>] @quad_int8x4_t([4 x <4 x i8>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_int8x4_t(x: Quad<Simd<i8, 4>>) -> Quad<Simd<i8, 4>> { x }

// CHECK: define [4 x <2 x i16>] @quad_int16x2_t([4 x <2 x i16>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_int16x2_t(x: Quad<Simd<i16, 2>>) -> Quad<Simd<i16, 2>> { x }

// CHECK: define [4 x <1 x i32>] @quad_int32x1_t([4 x <1 x i32>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_int32x1_t(x: Quad<Simd<i32, 1>>) -> Quad<Simd<i32, 1>> { x }

// CHECK: define [4 x <2 x half>] @quad_float16x2_t([4 x <2 x half>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_float16x2_t(x: Quad<Simd<f16, 2>>) -> Quad<Simd<f16, 2>> { x }

// CHECK: define [4 x <1 x float>] @quad_float32x1_t([4 x <1 x float>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_float32x1_t(x: Quad<Simd<f32, 1>>) -> Quad<Simd<f32, 1>> { x }

// When it doesn't quite fit, padding is added which does erase the type.

// CHECK: define [2 x i64] @triple_int8x4_t
#[unsafe(no_mangle)] extern "C" fn triple_int8x4_t(x: Triple<Simd<i8, 4>>) -> Triple<Simd<i8, 4>> { x }

// Other configurations are not passed by-value but indirectly.

// CHECK: define void @pair_int128x1_t
#[unsafe(no_mangle)] extern "C" fn pair_int128x1_t(x: Pair<Simd<i128, 1>>) -> Pair<Simd<i128, 1>> { x }

// CHECK: define void @pair_float128x1_t
#[unsafe(no_mangle)] extern "C" fn pair_float128x1_t(x: Pair<Simd<f128, 1>>) -> Pair<Simd<f128, 1>> { x }

// CHECK: define void @pair_int8x16_t
#[unsafe(no_mangle)] extern "C" fn pair_int8x16_t(x: Pair<Simd<i8, 16>>) -> Pair<Simd<i8, 16>> { x }

// CHECK: define void @pair_int16x8_t
#[unsafe(no_mangle)] extern "C" fn pair_int16x8_t(x: Pair<Simd<i16, 8>>) -> Pair<Simd<i16, 8>> { x }

// CHECK: define void @triple_int16x8_t
#[unsafe(no_mangle)] extern "C" fn triple_int16x8_t(x: Triple<Simd<i16, 8>>) -> Triple<Simd<i16, 8>> { x }

// CHECK: define void @quad_int16x8_t
#[unsafe(no_mangle)] extern "C" fn quad_int16x8_t(x: Quad<Simd<i16, 8>>) -> Quad<Simd<i16, 8>> { x }
}
Loading