Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,110 @@ def PTO_SaturationModeAttr : EnumAttr<PTO_Dialect, PTO_SaturationModeEnum, "satu
let summary = "saturation mode attribute";
}
//===----------------------------------------------------------------------===//
// TDIV template controls
//===----------------------------------------------------------------------===//

def PTO_DivPrecisionEnum : PTO_I32Enum<
"DivPrecision", "PTO TDIV precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_DivPrecisionAttr : EnumAttr<PTO_Dialect, PTO_DivPrecisionEnum, "div_precision"> {
let summary = "TDIV precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TEXP template controls
//===----------------------------------------------------------------------===//

def PTO_ExpPrecisionEnum : PTO_I32Enum<
"ExpPrecision", "PTO TEXP precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_ExpPrecisionAttr : EnumAttr<PTO_Dialect, PTO_ExpPrecisionEnum, "exp_precision"> {
let summary = "TEXP precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TLOG template controls
//===----------------------------------------------------------------------===//

def PTO_LogPrecisionEnum : PTO_I32Enum<
"LogPrecision", "PTO TLOG precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_LogPrecisionAttr : EnumAttr<PTO_Dialect, PTO_LogPrecisionEnum, "log_precision"> {
let summary = "TLOG precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TRECIP template controls
//===----------------------------------------------------------------------===//

def PTO_RecipPrecisionEnum : PTO_I32Enum<
"RecipPrecision", "PTO TRECIP precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_RecipPrecisionAttr : EnumAttr<PTO_Dialect, PTO_RecipPrecisionEnum, "recip_precision"> {
let summary = "TRECIP precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TREM template controls
//===----------------------------------------------------------------------===//

def PTO_RemPrecisionEnum : PTO_I32Enum<
"RemPrecision", "PTO TREM precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_RemPrecisionAttr : EnumAttr<PTO_Dialect, PTO_RemPrecisionEnum, "rem_precision"> {
let summary = "TREM precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TRSQRT template controls
//===----------------------------------------------------------------------===//

def PTO_RsqrtPrecisionEnum : PTO_I32Enum<
"RsqrtPrecision", "PTO TRSQRT precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_RsqrtPrecisionAttr : EnumAttr<PTO_Dialect, PTO_RsqrtPrecisionEnum, "rsqrt_precision"> {
let summary = "TRSQRT precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TSQRT template controls
//===----------------------------------------------------------------------===//

def PTO_SqrtPrecisionEnum : PTO_I32Enum<
"SqrtPrecision", "PTO TSQRT precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_SqrtPrecisionAttr : EnumAttr<PTO_Dialect, PTO_SqrtPrecisionEnum, "sqrt_precision"> {
let summary = "TSQRT precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TFMOD template controls
//===----------------------------------------------------------------------===//

def PTO_FmodPrecisionEnum : PTO_I32Enum<
"FmodPrecision", "PTO TFMOD precision mode", [
I32EnumAttrCase<"Default", 0, "default">,
I32EnumAttrCase<"HighPrecision", 1, "high_precision">
]>;

def PTO_FmodPrecisionAttr : EnumAttr<PTO_Dialect, PTO_FmodPrecisionEnum, "fmod_precision"> {
let summary = "TFMOD precision mode attribute";
}
//===----------------------------------------------------------------------===//
// TStore template controls
//===----------------------------------------------------------------------===//

Expand Down
54 changes: 36 additions & 18 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,8 @@ def TMatmulBiasOp : PTO_TOp<"tmatmul.bias", [
PTODpsType:$a,
PTODpsType:$b,
PTODpsType:$bias,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase
);

let results = (outs Optional<AnyRankedTensor>:$result);
Expand Down Expand Up @@ -856,7 +857,8 @@ def TMatmulMxOp : PTO_TOp<"tmatmul.mx", [
PTODpsType:$a_scale,
PTODpsType:$b,
PTODpsType:$b_scale,
PTODpsType:$dst);
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase);

let results = (outs Optional<AnyRankedTensor>:$result);
let hasVerifier = 1;
Expand Down Expand Up @@ -892,7 +894,8 @@ def TMatmulMxAccOp : PTO_TOp<"tmatmul.mx.acc", [
PTODpsType:$a_scale,
PTODpsType:$b,
PTODpsType:$b_scale,
PTODpsType:$dst);
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase);

let results = (outs Optional<AnyRankedTensor>:$result);
let hasVerifier = 1;
Expand Down Expand Up @@ -1031,7 +1034,8 @@ def TGemvOp : PTO_TOp<"tgemv", [
let arguments = (ins
PTODpsType:$lhs,
PTODpsType:$rhs,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase
);

let results = (outs
Expand Down Expand Up @@ -1067,7 +1071,8 @@ def TGemvAccOp : PTO_TOp<"tgemv.acc", [
PTODpsType:$acc_in,
PTODpsType:$lhs,
PTODpsType:$rhs,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase
);

let results = (outs
Expand Down Expand Up @@ -1102,7 +1107,8 @@ def TGemvBiasOp : PTO_TOp<"tgemv.bias", [
PTODpsType :$a,
PTODpsType :$b,
PTODpsType :$bias,
PTODpsType :$dst
PTODpsType :$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase
);

let results = (outs Optional<AnyRankedTensor>:$result);
Expand Down Expand Up @@ -1137,7 +1143,8 @@ def TGemvMxOp : PTO_TOp<"tgemv.mx", [
PTODpsType:$a_scale,
PTODpsType:$b,
PTODpsType:$b_scale,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase
);

let results = (outs Optional<AnyRankedTensor>:$result);
Expand Down Expand Up @@ -1171,7 +1178,8 @@ def TGemvMxAccOp : PTO_TOp<"tgemv.mx.acc", [
PTODpsType:$a_scale,
PTODpsType:$b,
PTODpsType:$b_scale,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_AccPhaseAttr, "::mlir::pto::AccPhase::Unspecified">:$accPhase
);

let results = (outs Optional<AnyRankedTensor>:$result);
Expand Down Expand Up @@ -3425,7 +3433,8 @@ def TColExpandDivOp : PTO_TOp<"tcolexpanddiv", [
let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_DivPrecisionAttr, "::mlir::pto::DivPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -3804,7 +3813,8 @@ def TDivOp : PTO_TOp<"tdiv", [
let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_DivPrecisionAttr, "::mlir::pto::DivPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -3860,7 +3870,8 @@ def TFModOp : PTO_TOp<"tfmod", [
let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_FmodPrecisionAttr, "::mlir::pto::FmodPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -3917,7 +3928,8 @@ def TExpOp : PTO_TOp<"texp", [

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_ExpPrecisionAttr, "::mlir::pto::ExpPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -4355,7 +4367,8 @@ def TLogOp : PTO_TOp<"tlog", [

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_LogPrecisionAttr, "::mlir::pto::LogPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -5118,7 +5131,8 @@ def TRecipOp: PTO_TOp<"trecip", [

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_RecipPrecisionAttr, "::mlir::pto::RecipPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -5182,7 +5196,8 @@ def TRemOp: PTO_TOp<"trem", [
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$tmp,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_RemPrecisionAttr, "::mlir::pto::RemPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -5311,7 +5326,8 @@ def TRowExpandDivOp: PTO_TOp<"trowexpanddiv", [
PTODpsType:$src0,
PTODpsType:$src1,
Optional<PTODpsType>:$tmp,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_DivPrecisionAttr, "::mlir::pto::DivPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -5702,7 +5718,8 @@ def TRsqrtOp: PTO_TOp<"trsqrt", [
let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst,
Optional<PTODpsType>:$tmp
Optional<PTODpsType>:$tmp,
DefaultValuedAttr<PTO_RsqrtPrecisionAttr, "::mlir::pto::RsqrtPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down Expand Up @@ -6012,7 +6029,8 @@ def TSqrtOp: PTO_TOp<"tsqrt", [

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst
PTODpsType:$dst,
DefaultValuedAttr<PTO_SqrtPrecisionAttr, "::mlir::pto::SqrtPrecision::Default">:$precisionType
);

let results = (outs);
Expand Down
Loading