Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 256 additions & 0 deletions src/cpu/x64/jit_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {

private:
const size_t xmm_len = 16;
const size_t ymm_len = 32;
const size_t zmm_len = 64;
#ifdef _WIN32
const size_t xmm_to_preserve_start = 6;
const size_t xmm_to_preserve = 10;
Expand All @@ -162,6 +164,26 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {
= num_abi_save_gpr_regs * rax.getBit() / 8
+ xmm_to_preserve * xmm_len;

template<typename TOutReg, typename TInReg>
inline TOutReg get_free_reg(std::vector<int>& reg_idxs,
std::vector<TInReg>& not_available) {
std::vector<int> not_available_idx(not_available.size());
std::transform(not_available.begin(), not_available.end(), not_available_idx.begin(),
[](const Xbyak::Reg& reg) {
return reg.getIdx();
});
auto removed = std::remove_if(reg_idxs.begin(), reg_idxs.end(),
[&not_available_idx](const int& reg_idx) {
return not_available_idx.end() != std::find(not_available_idx.begin(),
not_available_idx.end(),
reg_idx);
});
reg_idxs.erase(removed, reg_idxs.end());
TOutReg alloc_reg{reg_idxs.front()};
not_available.push_back(alloc_reg);
return alloc_reg;
}

public:
enum {
_cmp_eq_oq = 0u,
Expand All @@ -182,6 +204,81 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {

inline size_t get_size_of_abi_save_regs() { return size_of_abi_save_regs; }

using Xbyak::CodeGenerator::push;
using Xbyak::CodeGenerator::pop;

inline void push(const Xbyak::Xmm &xmm) {
sub(rsp, xmm_len);
uni_vmovdqu(ptr[rsp], xmm);
}

inline void push(const std::vector<Xbyak::Xmm> &xmms) {
sub(rsp, xmms.size() * xmm_len);
for (size_t i = 0; i < xmms.size(); ++i) {
uni_vmovdqu(ptr[rsp + i * xmm_len], xmms[i]);
}
}

inline void push(const Xbyak::Ymm &ymm) {
sub(rsp, ymm_len);
uni_vmovdqu(ptr[rsp], ymm);
}

inline void push(const std::vector<Xbyak::Ymm> &ymms) {
sub(rsp, ymms.size() * ymm_len);
for (size_t i = 0; i < ymms.size(); ++i) {
uni_vmovdqu(ptr[rsp + i * ymm_len], ymms[i]);
}
}

inline void push(const Xbyak::Zmm &zmm) {
sub(rsp, zmm_len);
uni_vmovdqu(ptr[rsp], zmm);
}

inline void push(const std::vector<Xbyak::Zmm> &zmms) {
sub(rsp, zmms.size() * zmm_len);
for (size_t i = 0; i < zmms.size(); ++i) {
uni_vmovdqu(ptr[rsp + i * zmm_len], zmms[i]);
}
}

inline void pop(const Xbyak::Xmm &xmm) {
uni_vmovdqu(xmm, ptr[rsp]);
add(rsp, xmm_len);
}

inline void pop(const std::vector<Xbyak::Xmm> &xmms) {
for (size_t i = 0; i < xmms.size(); ++i) {
uni_vmovdqu(xmms[i], ptr[rsp + i * xmm_len]);
}
sub(rsp, xmms.size() * xmm_len);
}

inline void pop(const Xbyak::Ymm &ymm) {
uni_vmovdqu(ymm, ptr[rsp]);
add(rsp, ymm_len);
}

inline void pop(const std::vector<Xbyak::Ymm> &ymms) {
for (size_t i = 0; i < ymms.size(); ++i) {
uni_vmovdqu(ymms[i], ptr[rsp + i * ymm_len]);
}
sub(rsp, ymms.size() * ymm_len);
}

inline void pop(const Xbyak::Zmm &zmm) {
uni_vmovdqu(zmm, ptr[rsp]);
add(rsp, zmm_len);
}

inline void pop(const std::vector<Xbyak::Zmm> &zmms) {
for (size_t i = 0; i < zmms.size(); ++i) {
uni_vmovdqu(zmms[i], ptr[rsp + i * zmm_len]);
}
sub(rsp, zmms.size() * zmm_len);
}

void preamble() {
if (xmm_to_preserve) {
sub(rsp, xmm_to_preserve * xmm_len);
Expand Down Expand Up @@ -339,6 +436,165 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {
vpxord(x1, x2, op);
}

template<typename TReg>
inline TReg get_free_reg(std::vector<Xbyak::Reg>& not_available) {
static_assert(std::is_base_of<Xbyak::Reg, TReg>::value, "Xbyak::Reg should be base of Tmm");
const size_t regsNumber = 16;
std::vector<int> reg_idxs;
reg_idxs.reserve(regsNumber);
for (int i = 0; i < static_cast<int>(regsNumber); ++i) {
// NOTE: We should avoid allocation rsp, otherwise we could write in
// wrong stack and crash application
if (rsp.getIdx() != i) {
reg_idxs.push_back(i);
}
}
return get_free_reg<TReg>(reg_idxs, not_available);
}

template<typename TVmm>
inline TVmm get_free_reg(std::vector<Xbyak::Xmm>& not_available) {
static_assert(std::is_base_of<Xbyak::Xmm, TVmm>::value, "Xbyak::Xmm should be base of TVmm");
std::vector<int> xmm_idxs(8);
#ifdef XBYAK64
size_t simdNumber = 0;
if (is_valid_isa(cpu_isa_t::avx512_core)) {
simdNumber = x64::cpu_isa_traits<cpu_isa_t::avx512_core>::vlen;
} else if (is_valid_isa(cpu_isa_t::avx2)) {
simdNumber = x64::cpu_isa_traits<cpu_isa_t::avx2>::vlen;
} else {
simdNumber = x64::cpu_isa_traits<cpu_isa_t::sse41>::vlen;
}
xmm_idxs.reserve(simdNumber);
for (int i = 0; i < static_cast<int>(simdNumber); ++i) {
xmm_idxs.push_back(i);
}
#endif
return get_free_reg<TVmm>(xmm_idxs, not_available);
}

inline void uni_vgatherdps(const Xbyak::Xmm &xmm_val,
const Xbyak::Reg64 &reg_addr,
const Xbyak::Xmm &xmm_index,
const int &scale,
const int &disp,
const Xbyak::Reg &reg_mask) {
const size_t kDataTypeSize = sizeof(float);
if (is_valid_isa(cpu_isa_t::avx512_core)) {
assert(reg_mask.isOPMASK());
vgatherdps(xmm_val, ptr[reg_addr + xmm_index * scale + disp]);
Copy link
Copy Markdown

@avoskoboinyk-lohika avoskoboinyk-lohika Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we apply reg_mask here, in addition to xmm_val in vgatherdps()?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, according the documentation the reg_mask will be applied implicitly, see section:
https://www.felixcloutier.com/x86/vgatherdps:vgatherdpd#instruction-operand-encoding

} else if (is_valid_isa(cpu_isa_t::avx2)) {
assert(reg_mask.isYMM());
Xbyak::Ymm ymm_mask{reg_mask.getIdx()};
vgatherdps(xmm_val, ptr[reg_addr + xmm_index * scale + disp], ymm_mask);
} else {
const size_t kSimdWidth = x64::cpu_isa_traits<cpu_isa_t::sse41>::vlen / kDataTypeSize;
assert(reg_mask.isXMM());
Xbyak::Xmm xmm_mask{reg_mask.getIdx()};
assert(xmm_val.getKind() == xmm_index.getKind());
assert(xmm_index.getKind() == xmm_mask.getKind());

std::vector<Xbyak::Reg> not_available_reg{reg_addr};
const Xbyak::Reg64 idx = this->get_free_reg<Xbyak::Reg64>(not_available_reg);
const Xbyak::Reg64 mask = this->get_free_reg<Xbyak::Reg64>(not_available_reg);

push(idx);
push(mask);
xor_(idx, idx);
xor_(mask, mask);

for (int i = 0; i < static_cast<int>(kSimdWidth); i++) {
Xbyak::Label gather_end;
uni_vpextrd(mask.cvt32(), xmm_mask, i);
cmp(mask.cvt32(), 0xFFFFFFFF);
jne(gather_end, T_NEAR);
uni_vpextrd(idx.cvt32(), xmm_index, i);
Xbyak::Address addr = ptr[reg_addr + idx * scale + disp];
uni_vpinsrd(xmm_val, xmm_val, addr, i);
L(gather_end);
}
pop(mask);
pop(idx);
}
}

inline void uni_vscatterdps(const Xbyak::Reg64& reg_addr,
const Xbyak::Xmm& xmm_index,
const int scale,
const int disp,
const Xbyak::Xmm& xmm_val,
const Xbyak::Reg& reg_mask) {
const size_t kDataTypeSize = sizeof(float);
if (is_valid_isa(cpu_isa_t::avx512_core)) {
assert(reg_mask.isOPMASK());
vscatterdps(ptr[reg_addr + xmm_index * scale + disp], xmm_val);
} else {
assert(reg_mask.isXMM() || reg_mask.isYMM());
const size_t kXmmSimdWidth = x64::cpu_isa_traits<cpu_isa_t::sse41>::vlen / kDataTypeSize;
const size_t kYmmSimdWidth = x64::cpu_isa_traits<cpu_isa_t::avx2>::vlen / kDataTypeSize;
Xbyak::Xmm xmm_mask{reg_mask.getIdx(), reg_mask.getKind(), static_cast<int>(reg_mask.getBit())};
assert(xmm_val.getKind() == xmm_index.getKind());
assert(xmm_index.getKind() == xmm_mask.getKind());

std::vector<Xbyak::Reg> not_available_reg{reg_addr};
std::vector<Xbyak::Xmm> not_available_xmm{xmm_index, xmm_val, xmm_mask};
const Xbyak::Reg64 idx = this->get_free_reg<Xbyak::Reg64>(not_available_reg);
const Xbyak::Reg64 mask = this->get_free_reg<Xbyak::Reg64>(not_available_reg);
const Xbyak::Reg64 val = this->get_free_reg<Xbyak::Reg64>(not_available_reg);
const Xbyak::Xmm xmm_mask_temp = this->get_free_reg<Xbyak::Xmm>(not_available_xmm);
const Xbyak::Xmm xmm_index_temp = this->get_free_reg<Xbyak::Xmm>(not_available_xmm);
const Xbyak::Xmm xmm_val_temp = this->get_free_reg<Xbyak::Xmm>(not_available_xmm);

push(idx);
push(mask);
push(val);
if (is_valid_isa(cpu_isa_t::avx2)) {
push(Xbyak::Ymm{xmm_mask_temp.getIdx()});
push(Xbyak::Ymm{xmm_index_temp.getIdx()});
push(Xbyak::Ymm{xmm_val_temp.getIdx()});
}
xor_(idx, idx);
xor_(mask, mask);
xor_(val, val);

auto store_xmm = [&](const Xbyak::Xmm& xmm_mask,
const Xbyak::Xmm& xmm_index,
const Xbyak::Xmm& xmm_val) {
for (int i = 0; i < static_cast<int>(kXmmSimdWidth); i++) {
Xbyak::Label scatter_end;
uni_vpextrd(mask.cvt32(), xmm_mask, i);
cmp(mask.cvt32(), 0xFFFFFFFF);
jne(scatter_end, T_NEAR);
uni_vpextrd(idx.cvt32(), xmm_index, i);
Xbyak::Address addr = ptr[reg_addr + idx * scale];
uni_vpextrd(val.cvt32(), xmm_val, i);
mov(addr, val.cvt32());
L(scatter_end);
}
};

if (is_valid_isa(cpu_isa_t::avx2)) {
for (int i = 0; i < static_cast<int>(kYmmSimdWidth / kXmmSimdWidth); i++) {
vextracti128(xmm_mask_temp, Xbyak::Ymm{xmm_mask.getIdx()}, i);
vextracti128(xmm_index_temp, Xbyak::Ymm{xmm_index.getIdx()}, i);
vextracti128(xmm_val_temp, Xbyak::Ymm{xmm_val.getIdx()}, i);
store_xmm(xmm_mask_temp, xmm_index_temp, xmm_val_temp);
}
} else {
store_xmm(xmm_mask, xmm_index, xmm_val);
}

if (is_valid_isa(cpu_isa_t::avx2)) {
pop(Xbyak::Ymm{xmm_val_temp.getIdx()});
pop(Xbyak::Ymm{xmm_index_temp.getIdx()});
pop(Xbyak::Ymm{xmm_mask_temp.getIdx()});
}
pop(val);
pop(mask);
pop(idx);
}
}

void uni_vmovss(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
if (is_valid_isa(avx))
vmovss(addr, x);
Expand Down