Skip to content

Commit 96ab03b

Browse files
committed
update
1 parent 1acf667 commit 96ab03b

5 files changed

Lines changed: 97 additions & 103 deletions

File tree

llvm/lib/CodeGen/GlobalISel/PatternGen.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ std::string lltToRegTypeStr(LLT Type, bool IsFloat) {
199199
return "GPR32V2";
200200
abort();
201201
} else
202-
return IsFloat ? "FPR32" : "GPR";
202+
return IsFloat ? ("FPR" + std::to_string(PatternGenArgs::Args.FLen)) : "GPR";
203203
}
204204
assert(0 && "invalid type");
205205
return "invalid";
@@ -346,6 +346,7 @@ struct TernopNode : public PatternNode {
346346
static const std::unordered_map<int, std::string> TernopStr = {
347347
{TargetOpcode::G_FSHL, "fshl"},
348348
{TargetOpcode::G_FSHR, "fshr"},
349+
{TargetOpcode::G_FMA, "any_fma"},
349350
{TargetOpcode::G_INSERT_VECTOR_ELT, "vector_insert"},
350351
{TargetOpcode::G_SELECT, "select"}};
351352

@@ -619,8 +620,8 @@ struct ConstantNode : public PatternNode {
619620
};
620621

621622
struct ConstantFPNode : public PatternNode {
622-
float Constant;
623-
ConstantFPNode(LLT Type, float Const)
623+
double Constant;
624+
ConstantFPNode(LLT Type, double Const)
624625
: PatternNode(PN_ConstantFP, Type, true), Constant(Const) {}
625626

626627
std::string patternString() override {
@@ -671,7 +672,7 @@ struct RegisterNode : public PatternNode {
671672
if ((uint64_t)Size == XLen) {
672673
std::string Str;
673674
if ((Type.isScalar() && Type.getSizeInBits() == XLen) || Type.isPointer())
674-
Str = (IsFloat ? "FPR32:$" : "GPR:$") + std::string(Name);
675+
Str = (IsFloat ? ("FPR" + std::to_string(PatternGenArgs::Args.FLen) + ":$") : "GPR:$") + std::string(Name);
675676
if (PrintType)
676677
return "(" + TypeStr + " " + Str + ")";
677678
return Str;
@@ -1009,7 +1010,7 @@ static PatternOrError traverseRegLoad(MachineRegisterInfo &MRI,
10091010
assert(Cur.getOperand(0).isReg() && "expected register");
10101011
auto Node = std::make_unique<RegisterNode>(
10111012
MRI.getType(Cur.getOperand(0).getReg()), Field->ident, Idx, false,
1012-
ReadOffset, ReadSize, false, false);
1013+
ReadOffset, ReadSize, false, Field->type & CDSLInstr::FREG);
10131014

10141015
return PPattern(std::move(Node));
10151016
}
@@ -1145,7 +1146,7 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) {
11451146
assert(Cur.getOperand(0).isReg() && "expected register");
11461147
return std::make_pair(SUCCESS, std::make_unique<ConstantFPNode>(
11471148
MRI.getType(Cur.getOperand(0).getReg()),
1148-
Imm->getValueAPF().convertToFloat()));
1149+
Imm->getValueAPF().convertToDouble()));
11491150
}
11501151
case TargetOpcode::G_IMPLICIT_DEF: {
11511152
assert(Cur.getOperand(0).isReg() && "expected register");
@@ -1219,6 +1220,7 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) {
12191220
}
12201221
case TargetOpcode::G_FSHL:
12211222
case TargetOpcode::G_FSHR:
1223+
case TargetOpcode::G_FMA:
12221224
case TargetOpcode::G_SELECT:
12231225
case TargetOpcode::G_INSERT_VECTOR_ELT: {
12241226
auto [Err, NodeFirst, NodeSecond, NodeThird] =

llvm/tools/pattern-gen/Main.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ static cl::opt<bool> NoExtend(
6262

6363
static cl::opt<int> XLen("riscv-xlen", cl::desc("RISC-V XLEN (32 or 64 bit)"),
6464
cl::init(32), cl::cat(ToolOptions));
65+
static cl::opt<int> FLen("riscv-flen", cl::desc("RISC-V FLEN (32 or 64 bit)"),
66+
cl::init(32), cl::cat(ToolOptions));
6567

6668
// Determine optimization level.
6769
static cl::opt<char>
@@ -126,7 +128,7 @@ int main(int argc, char **argv) {
126128
TokenStream Ts(InputFilename.c_str());
127129
LLVMContext Ctx;
128130
auto Mod = std::make_unique<Module>("mod", Ctx);
129-
auto Instrs = ParseCoreDSL2(Ts, (XLen == 64), Mod.get(), NoExtend);
131+
auto Instrs = ParseCoreDSL2(Ts, (XLen == 64), Mod.get(), NoExtend, FLen);
130132

131133
if (irOut) {
132134
std::string Str;
@@ -165,7 +167,8 @@ int main(int argc, char **argv) {
165167
PGArgsStruct Args{.Mattr = "",
166168
.OptLevel = Opt,
167169
.Predicates = Predicates,
168-
.Is64Bit = (XLen == 64)};
170+
.Is64Bit = (XLen == 64),
171+
.FLen = FLen};
169172

170173
optimizeBehavior(Mod.get(), Instrs, irOut, Args);
171174
if (PrintIR)

llvm/tools/pattern-gen/PatternGen.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct PGArgsStruct
1010
llvm::CodeGenOptLevel OptLevel;
1111
std::string Predicates;
1212
bool Is64Bit;
13+
int FLen;
1314
};
1415

1516
int optimizeBehavior(llvm::Module* M, std::vector<CDSLInstr> const& Instrs, std::ostream& OstreamIR, PGArgsStruct Args);

llvm/tools/pattern-gen/lib/Parser.cpp

Lines changed: 82 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,11 @@ static llvm::BasicBlock *entry;
7272
static CDSLInstr *curInstr;
7373

7474
static int xlen;
75+
static int flen;
7576
static bool NoExtend_;
7677
static llvm::Type *regT;
78+
static llvm::Type *regT2;
79+
static llvm::Type *fregT;
7780

7881
static void reset_globals() {
7982
variables.clear();
@@ -782,6 +785,73 @@ static auto find_var(uint32_t identIdx) {
782785
[identIdx](CDSLInstr::Field &f) { return f.identIdx == identIdx; });
783786
}
784787

788+
std::vector<Value> ParseFuncCallArgs(TokenStream &ts, llvm::Function *func, llvm::IRBuilder<> &build) {
789+
std::vector<Value> args;
790+
pop_cur(ts, RBrOpen);
791+
if (ts.Peek().type != RBrClose) {
792+
do {
793+
auto expr = ParseExpression(ts, func, build);
794+
promote_lvalue(build, expr);
795+
args.push_back(expr);
796+
} while (pop_cur_if(ts, Comma));
797+
}
798+
pop_cur(ts, RBrClose);
799+
return args;
800+
}
801+
802+
Value ParseLLVMFuncCall(TokenStream &ts, llvm::Function *func,
803+
llvm::IRBuilder<> &build, std::string func_name) {
804+
auto args = ParseFuncCallArgs(ts, func, build);
805+
if (func_name == "llvm_fmuladd_f32" || func_name == "llvm_fmuladd_f64") {
806+
assert(flen == std::stoi(func_name.substr(func_name.size() - 2)));
807+
assert(args.size() == 3);
808+
auto A_ = build.CreateBitCast(args[0].ll, fregT);
809+
auto B_ = build.CreateBitCast(args[1].ll, fregT);
810+
auto M_ = build.CreateBitCast(args[2].ll, fregT);
811+
auto temp = build.CreateIntrinsic(fregT, llvm::Intrinsic::fmuladd, llvm::ArrayRef<llvm::Value *>{M_, A_, B_}, nullptr);
812+
auto temp_ = build.CreateBitCast(temp, regT2);
813+
Value v = {temp_, false};
814+
v.bitWidth = xlen; // TODO
815+
return v;
816+
} else if (func_name == "llvm_fdiv_fp32" || func_name == "llvm_fdiv_fp64") {
817+
assert(flen == std::stoi(func_name.substr(func_name.size() - 2)));
818+
assert(args.size() == 2);
819+
auto A_ = build.CreateBitCast(args[0].ll, fregT);
820+
auto B_ = build.CreateBitCast(args[1].ll, fregT);
821+
auto temp = build.CreateFDiv(A_, B_);
822+
temp = build.CreateBitCast(temp, regT2);
823+
Value v = {temp, false};
824+
v.bitWidth = xlen; // TODO
825+
return v;
826+
} else if (func_name == "llvm_fadd_fp32" || func_name == "llvm_fadd_fp64") {
827+
assert(flen == std::stoi(func_name.substr(func_name.size() - 2)));
828+
assert(args.size() == 2);
829+
auto A_ = build.CreateBitCast(args[0].ll, fregT);
830+
auto B_ = build.CreateBitCast(args[1].ll, fregT);
831+
auto temp = build.CreateFAdd(A_, B_);
832+
temp = build.CreateBitCast(temp, regT2);
833+
Value v = {temp, false};
834+
v.bitWidth = xlen; // TODO
835+
return v;
836+
} else if (func_name == "llvm_uitofp_fp32" || func_name == "llvm_uitofp_fp64") {
837+
assert(flen == std::stoi(func_name.substr(func_name.size() - 2)));
838+
assert(args.size() == 1);
839+
auto A = args[0];
840+
if (A.bitWidth != xlen) {
841+
A.isSigned = false;
842+
A.bitWidth = xlen;
843+
fit_to_size(A, build);
844+
}
845+
auto A_ = A.ll;
846+
auto temp = build.CreateUIToFP(A_, fregT);
847+
temp = build.CreateBitCast(temp, regT2);
848+
Value v = {temp, false};
849+
v.bitWidth = xlen; // TODO
850+
return v;
851+
}
852+
error(("undefined llvm function: " + func_name).c_str(), ts);
853+
}
854+
785855
Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func,
786856
llvm::IRBuilder<> &build) {
787857
auto &ctx = func->getContext();
@@ -823,6 +893,8 @@ Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func,
823893
.c_str(),
824894
ts);
825895
sizeIs32 |= (match->type & CDSLInstr::IS_32_BIT);
896+
if (isFloat) // TODO: assert not REG
897+
match->type = (CDSLInstr::FieldType)(match->type | CDSLInstr::FieldType::FREG);
826898
return Value{func->getArg(match - curInstr->fields.begin()),
827899
sizeIs32 ? 32 : xlen,
828900
(bool)(match->type & CDSLInstr::SIGNED_REG)};
@@ -846,99 +918,9 @@ Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func,
846918
}
847919
// TODO: check if float reg and get flen
848920
// TODO: parse funtion call util
849-
if (t.ident.str == "llvm_fmuladd_f32") {
850-
pop_cur(ts, RBrOpen);
851-
auto A = ParseExpression(ts, func, build);
852-
promote_lvalue(build, A);
853-
pop_cur(ts, Comma);
854-
auto B = ParseExpression(ts, func, build);
855-
promote_lvalue(build, B);
856-
pop_cur(ts, Comma);
857-
auto M = ParseExpression(ts, func, build);
858-
promote_lvalue(build, M);
859-
pop_cur(ts, RBrClose);
860-
// build.CreateIntrinsic(M.ll->getType(), llvm::Intrinsic::fmuladd, llvm::ArrayRef<llvm::Value *>{M.ll, A.ll, B.ll}, nullptr, "?");
861-
// build.CreateIntrinsic(build.getFloatTy(), llvm::Intrinsic::fmuladd, llvm::ArrayRef<llvm::Value *>{M.ll, A.ll, B.ll}, nullptr, "?");
862-
llvm::Type *FloatTy = build.getFloatTy();
863-
auto A_ = build.CreateBitCast(A.ll, FloatTy);
864-
auto B_ = build.CreateBitCast(B.ll, FloatTy);
865-
auto M_ = build.CreateBitCast(M.ll, FloatTy);
866-
// auto M_ = build.CreateUIToFP(M.ll, FloatTy);;
867-
// build.CreateIntrinsic(FloatTy, llvm::Intrinsic::fmuladd, llvm::ArrayRef<llvm::Value *>{M.ll, A.ll, B.ll}, nullptr, "?");
868-
auto temp = build.CreateIntrinsic(FloatTy, llvm::Intrinsic::fmuladd, llvm::ArrayRef<llvm::Value *>{M_, A_, B_}, nullptr);
869-
// bool doBitcast = true;
870-
// if (doBitcast) {
871-
auto temp_ = build.CreateBitCast(temp, regT);
872-
Value v = {temp_, false};
873-
// v.bitWidth = xlen; // TODO
874-
// return v;
875-
// }
876-
// Value v = {temp, false, true};
877-
// error("not implemented: llvm_fmuladd_f32", ts);
878-
v.bitWidth = xlen; // TODO
879-
return v;
880-
} else if (t.ident.str == "llvm_fdiv_fp32") {
881-
pop_cur(ts, RBrOpen);
882-
auto A = ParseExpression(ts, func, build);
883-
promote_lvalue(build, A);
884-
pop_cur(ts, Comma);
885-
auto B = ParseExpression(ts, func, build);
886-
promote_lvalue(build, B);
887-
pop_cur(ts, RBrClose);
888-
llvm::Type *FloatTy = build.getFloatTy();
889-
auto A_ = build.CreateBitCast(A.ll, FloatTy);
890-
auto B_ = build.CreateBitCast(B.ll, FloatTy);
891-
auto temp = build.CreateFDiv(A_, B_);
892-
// bool doBitcast = true;
893-
// if (doBitcast) {
894-
temp = build.CreateBitCast(temp, regT);
895-
// }
896-
// Value v = {temp, false, !doBitcast};
897-
Value v = {temp, false};
898-
v.bitWidth = xlen; // TODO
899-
return v;
900-
} else if (t.ident.str == "llvm_fadd_fp32") {
901-
pop_cur(ts, RBrOpen);
902-
auto A = ParseExpression(ts, func, build);
903-
promote_lvalue(build, A);
904-
pop_cur(ts, Comma);
905-
auto B = ParseExpression(ts, func, build);
906-
promote_lvalue(build, B);
907-
pop_cur(ts, RBrClose);
908-
llvm::Type *FloatTy = build.getFloatTy();
909-
auto A_ = build.CreateBitCast(A.ll, FloatTy);
910-
auto B_ = build.CreateBitCast(B.ll, FloatTy);
911-
auto temp = build.CreateFAdd(A_, B_);
912-
// bool doBitcast = true;
913-
// if (doBitcast) {
914-
temp = build.CreateBitCast(temp, regT);
915-
// }
916-
// Value v = {temp, false, !doBitcast};
917-
Value v = {temp, false};
918-
v.bitWidth = xlen; // TODO
919-
return v;
920-
} else if (t.ident.str == "llvm_uitofp_fp32") {
921-
pop_cur(ts, RBrOpen);
922-
auto A = ParseExpression(ts, func, build);
923-
promote_lvalue(build, A);
924-
pop_cur(ts, RBrClose);
925-
if (A.bitWidth != xlen) {
926-
A.isSigned = false;
927-
A.bitWidth = xlen;
928-
fit_to_size(A, build);
929-
}
930-
llvm::Type *FloatTy = build.getFloatTy();
931-
// auto A_ = build.CreateBitCast(A.ll, FloatTy);
932-
auto A_ = A.ll;
933-
auto temp = build.CreateUIToFP(A_, FloatTy);
934-
// bool doBitcast = true;
935-
// if (doBitcast) {
936-
temp = build.CreateBitCast(temp, regT);
937-
// }
938-
// Value v = {temp, false, !doBitcast};
939-
Value v = {temp, false};
940-
v.bitWidth = xlen; // TODO
941-
return v;
921+
llvm::outs() << "t.ident.str=" << t.ident.str;
922+
if (t.ident.str.rfind("llvm_", 0) == 0) {
923+
return ParseLLVMFuncCall(ts, func, build, std::string(t.ident.str));
942924
}
943925

944926
auto iter = variables.find(t.ident.idx);
@@ -1282,6 +1264,7 @@ void ParseOperands(TokenStream &ts, CDSLInstr &instr) {
12821264
{"is_signed", {FieldType::SIGNED_REG, 0}},
12831265
{"is_imm", {FieldType::IMM, 0}},
12841266
{"is_reg", {FieldType::REG, 0}},
1267+
{"is_freg", {FieldType::FREG, 0}},
12851268
{"in", {FieldType::IN, 0}},
12861269
{"out", {FieldType::OUT, 0}},
12871270
{"inout", {(FieldType::IN | FieldType::OUT), 0}},
@@ -1539,11 +1522,14 @@ void ParseBehaviour(TokenStream &ts, CDSLInstr &instr, llvm::Module *mod,
15391522
}
15401523

15411524
std::vector<CDSLInstr> ParseCoreDSL2(TokenStream &ts, bool is64Bit,
1542-
llvm::Module *mod, bool NoExtend) {
1525+
llvm::Module *mod, bool NoExtend, int Flen) {
15431526
std::vector<CDSLInstr> instrs;
15441527
xlen = is64Bit ? 64 : 32;
1528+
flen = Flen; // TODO: allow 0?
15451529
NoExtend_ = NoExtend;
15461530
regT = llvm::Type::getIntNTy(mod->getContext(), xlen);
1531+
regT2 = llvm::Type::getIntNTy(mod->getContext(), flen);
1532+
fregT = flen == 32 ? llvm::Type::getFloatTy(mod->getContext()) : llvm::Type::getDoubleTy(mod->getContext());
15471533

15481534
while (ts.Peek().type != None) {
15491535
bool parseBoilerplate =
@@ -1565,6 +1551,8 @@ std::vector<CDSLInstr> ParseCoreDSL2(TokenStream &ts, bool is64Bit,
15651551
// add XLEN and RFS as constants for now.
15661552
add_variable(ts, ts.GetIdentIdx("XLEN"),
15671553
Value{llvm::ConstantInt::get(regT, xlen)});
1554+
add_variable(ts, ts.GetIdentIdx("FLEN"),
1555+
Value{llvm::ConstantInt::get(regT, flen)});
15681556
add_variable(ts, ts.GetIdentIdx("RFS"),
15691557
Value{llvm::ConstantInt::get(regT, 32)});
15701558
++PatternGenNumInstructionsParsed;

llvm/tools/pattern-gen/lib/Parser.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
#include <llvm/IR/Module.h>
55

66
std::vector<CDSLInstr> ParseCoreDSL2(TokenStream &ts, bool is64Bit,
7-
llvm::Module *mod, bool NoExtend);
7+
llvm::Module *mod, bool NoExtend, int FLen);

0 commit comments

Comments
 (0)