@@ -21,6 +21,59 @@ namespace opt {
2121namespace {
2222constexpr uint32_t kExtractCompositeIdInIdx = 0 ;
2323
24+ // Returns the value obtained by extracting the |number_of_bits| least
25+ // significant bits from |value|, and sign-extending it to 64-bits.
26+ uint64_t SignExtendValue (uint64_t value, uint32_t number_of_bits) {
27+ if (number_of_bits == 64 ) return value;
28+
29+ uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1 );
30+ uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1 ) - 1ull ;
31+ if (value & mask_for_sign_bit) {
32+ // Set upper bits to 1
33+ value |= ~mask_for_significant_bits;
34+ } else {
35+ // Clear the upper bits
36+ value &= mask_for_significant_bits;
37+ }
38+ return value;
39+ }
40+
41+ // Returns the value obtained by extracting the |number_of_bits| least
42+ // significant bits from |value|, and zero-extending it to 64-bits.
43+ uint64_t ZeroExtendValue (uint64_t value, uint32_t number_of_bits) {
44+ if (number_of_bits == 64 ) return value;
45+
46+ uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
47+ uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1 ;
48+ value &= mask_for_bits_to_keep;
49+ return value;
50+ }
51+
52+ // Returns a constant whose value is `value` and type is `type`. This constant
53+ // will be generated by `const_mgr`. The type must be a scalar integer type.
54+ const analysis::Constant* GenerateIntegerConstant (
55+ const analysis::Integer* integer_type, uint64_t result,
56+ analysis::ConstantManager* const_mgr) {
57+ assert (integer_type != nullptr );
58+
59+ std::vector<uint32_t > words;
60+ if (integer_type->width () == 64 ) {
61+ // In the 64-bit case, two words are needed to represent the value.
62+ words = {static_cast <uint32_t >(result),
63+ static_cast <uint32_t >(result >> 32 )};
64+ } else {
65+ // In all other cases, only a single word is needed.
66+ assert (integer_type->width () <= 32 );
67+ if (integer_type->IsSigned ()) {
68+ result = SignExtendValue (result, integer_type->width ());
69+ } else {
70+ result = ZeroExtendValue (result, integer_type->width ());
71+ }
72+ words = {static_cast <uint32_t >(result)};
73+ }
74+ return const_mgr->GetConstant (integer_type, words);
75+ }
76+
2477// Returns a constants with the value NaN of the given type. Only works for
2578// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
2679const analysis::Constant* GetNan (const analysis::Type* type,
@@ -676,7 +729,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
676729 return [scalar_rule](IRContext* context, Instruction* inst,
677730 const std::vector<const analysis::Constant*>& constants)
678731 -> const analysis::Constant* {
679-
680732 analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
681733 analysis::TypeManager* type_mgr = context->get_type_mgr ();
682734 const analysis::Type* result_type = type_mgr->GetType (inst->type_id ());
@@ -716,6 +768,64 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
716768 };
717769}
718770
771+ // Returns a |ConstantFoldingRule| that folds binary scalar ops
772+ // using |scalar_rule| and binary vectors ops by applying
773+ // |scalar_rule| to the elements of the vector. The folding rule assumes that op
774+ // has two inputs. For regular instruction, those are in operands 0 and 1. For
775+ // extended instruction, they are in operands 1 and 2. If an element in
776+ // |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
777+ // or |Vector| whose element type is |Float| or |Integer|.
778+ ConstantFoldingRule FoldBinaryOp (BinaryScalarFoldingRule scalar_rule) {
779+ return [scalar_rule](IRContext* context, Instruction* inst,
780+ const std::vector<const analysis::Constant*>& constants)
781+ -> const analysis::Constant* {
782+ assert (constants.size () == inst->NumInOperands ());
783+ assert (constants.size () == (inst->opcode () == spv::Op::OpExtInst ? 3 : 2 ));
784+ analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
785+ analysis::TypeManager* type_mgr = context->get_type_mgr ();
786+ const analysis::Type* result_type = type_mgr->GetType (inst->type_id ());
787+ const analysis::Vector* vector_type = result_type->AsVector ();
788+
789+ const analysis::Constant* arg1 =
790+ (inst->opcode () == spv::Op::OpExtInst) ? constants[1 ] : constants[0 ];
791+ const analysis::Constant* arg2 =
792+ (inst->opcode () == spv::Op::OpExtInst) ? constants[2 ] : constants[1 ];
793+
794+ if (arg1 == nullptr || arg2 == nullptr ) {
795+ return nullptr ;
796+ }
797+
798+ if (vector_type == nullptr ) {
799+ return scalar_rule (result_type, arg1, arg2, const_mgr);
800+ }
801+
802+ std::vector<const analysis::Constant*> a_components;
803+ std::vector<const analysis::Constant*> b_components;
804+ std::vector<const analysis::Constant*> results_components;
805+
806+ a_components = arg1->GetVectorComponents (const_mgr);
807+ b_components = arg2->GetVectorComponents (const_mgr);
808+ assert (a_components.size () == b_components.size ());
809+
810+ // Fold each component of the vector.
811+ for (uint32_t i = 0 ; i < a_components.size (); ++i) {
812+ results_components.push_back (scalar_rule (vector_type->element_type (),
813+ a_components[i], b_components[i],
814+ const_mgr));
815+ if (results_components[i] == nullptr ) {
816+ return nullptr ;
817+ }
818+ }
819+
820+ // Build the constant object and return it.
821+ std::vector<uint32_t > ids;
822+ for (const analysis::Constant* member : results_components) {
823+ ids.push_back (const_mgr->GetDefiningInstruction (member)->result_id ());
824+ }
825+ return const_mgr->GetConstant (vector_type, ids);
826+ };
827+ }
828+
719829// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
720830// using |scalar_rule| and unary float point vectors ops by applying
721831// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
@@ -1587,6 +1697,72 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
15871697 return nullptr ;
15881698 };
15891699}
1700+
1701+ enum Sign { Signed, Unsigned };
1702+
1703+ // Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
1704+ // The `signedness` is used to determine if the operands should be interpreted
1705+ // as signed or unsigned. If the operands are signed, the value will be sign
1706+ // extended before the value is passed to `op`. Otherwise the values will be
1707+ // zero extended.
1708+ template <Sign signedness>
1709+ BinaryScalarFoldingRule FoldBinaryIntegerOperation (uint64_t (*op)(uint64_t ,
1710+ uint64_t )) {
1711+ return
1712+ [op](const analysis::Type* result_type, const analysis::Constant* a,
1713+ const analysis::Constant* b,
1714+ analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1715+ assert (result_type != nullptr && a != nullptr && b != nullptr );
1716+ const analysis::Integer* integer_type = result_type->AsInteger ();
1717+ assert (integer_type != nullptr );
1718+ assert (integer_type == a->type ()->AsInteger ());
1719+ assert (integer_type == b->type ()->AsInteger ());
1720+
1721+ // In SPIR-V, all operations support unsigned types, but the way they
1722+ // are interpreted depends on the opcode. This is why we use the
1723+ // template argument to determine how to interpret the operands.
1724+ uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue ()
1725+ : a->GetZeroExtendedValue ());
1726+ uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue ()
1727+ : b->GetZeroExtendedValue ());
1728+ uint64_t result = op (ia, ib);
1729+
1730+ const analysis::Constant* result_constant =
1731+ GenerateIntegerConstant (integer_type, result, const_mgr);
1732+ return result_constant;
1733+ };
1734+ }
1735+
1736+ // A scalar folding rule that folds OpSConvert.
1737+ const analysis::Constant* FoldScalarSConvert (
1738+ const analysis::Type* result_type, const analysis::Constant* a,
1739+ analysis::ConstantManager* const_mgr) {
1740+ assert (result_type != nullptr );
1741+ assert (a != nullptr );
1742+ assert (const_mgr != nullptr );
1743+ const analysis::Integer* integer_type = result_type->AsInteger ();
1744+ assert (integer_type && " The result type of an SConvert" );
1745+ int64_t value = a->GetSignExtendedValue ();
1746+ return GenerateIntegerConstant (integer_type, value, const_mgr);
1747+ }
1748+
1749+ // A scalar folding rule that folds OpUConvert.
1750+ const analysis::Constant* FoldScalarUConvert (
1751+ const analysis::Type* result_type, const analysis::Constant* a,
1752+ analysis::ConstantManager* const_mgr) {
1753+ assert (result_type != nullptr );
1754+ assert (a != nullptr );
1755+ assert (const_mgr != nullptr );
1756+ const analysis::Integer* integer_type = result_type->AsInteger ();
1757+ assert (integer_type && " The result type of an UConvert" );
1758+ uint64_t value = a->GetZeroExtendedValue ();
1759+
1760+ // If the operand was an unsigned value with less than 32-bit, it would have
1761+ // been sign extended earlier, and we need to clear those bits.
1762+ auto * operand_type = a->type ()->AsInteger ();
1763+ value = ZeroExtendValue (value, operand_type->width ());
1764+ return GenerateIntegerConstant (integer_type, value, const_mgr);
1765+ }
15901766} // namespace
15911767
15921768void ConstantFoldingRules::AddFoldingRules () {
@@ -1604,6 +1780,8 @@ void ConstantFoldingRules::AddFoldingRules() {
16041780 rules_[spv::Op::OpConvertFToU].push_back (FoldFToI ());
16051781 rules_[spv::Op::OpConvertSToF].push_back (FoldIToF ());
16061782 rules_[spv::Op::OpConvertUToF].push_back (FoldIToF ());
1783+ rules_[spv::Op::OpSConvert].push_back (FoldUnaryOp (FoldScalarSConvert));
1784+ rules_[spv::Op::OpUConvert].push_back (FoldUnaryOp (FoldScalarUConvert));
16071785
16081786 rules_[spv::Op::OpDot].push_back (FoldOpDotWithConstants ());
16091787 rules_[spv::Op::OpFAdd].push_back (FoldFAdd ());
@@ -1662,6 +1840,46 @@ void ConstantFoldingRules::AddFoldingRules() {
16621840 rules_[spv::Op::OpSNegate].push_back (FoldSNegate ());
16631841 rules_[spv::Op::OpQuantizeToF16].push_back (FoldQuantizeToF16 ());
16641842
1843+ rules_[spv::Op::OpIAdd].push_back (
1844+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1845+ [](uint64_t a, uint64_t b) { return a + b; })));
1846+ rules_[spv::Op::OpISub].push_back (
1847+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1848+ [](uint64_t a, uint64_t b) { return a - b; })));
1849+ rules_[spv::Op::OpIMul].push_back (
1850+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1851+ [](uint64_t a, uint64_t b) { return a * b; })));
1852+ rules_[spv::Op::OpUDiv].push_back (
1853+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1854+ [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0 ); })));
1855+ rules_[spv::Op::OpSDiv].push_back (FoldBinaryOp (
1856+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1857+ return (b != 0 ? static_cast <uint64_t >(static_cast <int64_t >(a) /
1858+ static_cast <int64_t >(b))
1859+ : 0 );
1860+ })));
1861+ rules_[spv::Op::OpUMod].push_back (
1862+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1863+ [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0 ); })));
1864+
1865+ rules_[spv::Op::OpSRem].push_back (FoldBinaryOp (
1866+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1867+ return (b != 0 ? static_cast <uint64_t >(static_cast <int64_t >(a) %
1868+ static_cast <int64_t >(b))
1869+ : 0 );
1870+ })));
1871+
1872+ rules_[spv::Op::OpSMod].push_back (FoldBinaryOp (
1873+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1874+ if (b == 0 ) return static_cast <uint64_t >(0ull );
1875+
1876+ int64_t signed_a = static_cast <int64_t >(a);
1877+ int64_t signed_b = static_cast <int64_t >(b);
1878+ int64_t result = signed_a % signed_b;
1879+ if ((signed_b < 0 ) != (result < 0 )) result += signed_b;
1880+ return static_cast <uint64_t >(result);
1881+ })));
1882+
16651883 // Add rules for GLSLstd450
16661884 FeatureManager* feature_manager = context_->get_feature_mgr ();
16671885 uint32_t ext_inst_glslstd450_id =
0 commit comments