@@ -48,10 +48,10 @@ inline int64_t check_boxing(int64_t a) {
4848 return nan_box (0x7fc00000 ); // NaN
4949}
5050
51- void Emulator::fetch_registers (std::vector<reg_data_t >& out, uint32_t wid, uint32_t src_index, const RegOpd& reg) {
51+ void Emulator::fetch_registers (std::vector<reg_data_t >& out, uint32_t wid, uint32_t src_index, const RegOpd& reg, const ThreadMask& tmask ) {
5252 __unused (src_index);
5353 auto & warp = warps_.at (wid);
54- uint32_t num_threads = warp. tmask .size ();
54+ uint32_t num_threads = tmask.size ();
5555 out.resize (num_threads);
5656 switch (reg.type ) {
5757 case RegType::None:
@@ -60,7 +60,7 @@ void Emulator::fetch_registers(std::vector<reg_data_t>& out, uint32_t wid, uint3
6060 DPH (2 , " Src" << src_index << " Reg: " << reg << " ={" );
6161 for (uint32_t t = 0 ; t < num_threads; ++t) {
6262 if (t) DPN (2 , " , " );
63- if (!warp. tmask .test (t)) {
63+ if (!tmask.test (t)) {
6464 DPN (2 , " -" );
6565 continue ;
6666 }
@@ -74,7 +74,7 @@ void Emulator::fetch_registers(std::vector<reg_data_t>& out, uint32_t wid, uint3
7474 auto & reg_data = warp.ireg_file .at (reg.idx );
7575 for (uint32_t t = 0 ; t < num_threads; ++t) {
7676 if (t) DPN (2 , " , " );
77- if (!warp. tmask .test (t)) {
77+ if (!tmask.test (t)) {
7878 DPN (2 , " -" );
7979 continue ;
8080 }
@@ -89,7 +89,7 @@ void Emulator::fetch_registers(std::vector<reg_data_t>& out, uint32_t wid, uint3
8989 auto & reg_data = warp.freg_file .at (reg.idx );
9090 for (uint32_t t = 0 ; t < num_threads; ++t) {
9191 if (t) DPN (2 , " , " );
92- if (!warp. tmask .test (t)) {
92+ if (!tmask.test (t)) {
9393 DPN (2 , " -" );
9494 continue ;
9595 }
@@ -124,6 +124,8 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
124124 auto rsrc2 = instr.getSrcReg (2 );
125125
126126 auto num_threads = arch_.num_threads ();
127+ auto exec_tmask = instr.hasTmask () ? (warp.tmask & instr.getTmask ()) : warp.tmask ;
128+ auto operand_tmask = warp.tmask ;
127129
128130 // create instruction trace
129131 auto trace_alloc = core_->trace_pool ().allocate (1 );
@@ -133,7 +135,7 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
133135 trace->cid = core_->id ();
134136 trace->wid = wid;
135137 trace->PC = warp.PC ;
136- trace->tmask = warp. tmask ;
138+ trace->tmask = exec_tmask ;
137139 trace->dst_reg = rdest;
138140 trace->src_regs = {rsrc0, rsrc1, rsrc2};
139141
@@ -143,27 +145,27 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
143145 std::vector<reg_data_t > rs3_data;
144146
145147 if (instr.is_uop ()) {
146- DP (1 , " Instr: " << instr << " , cid=" << core_->id () << " , wid=" << wid << " , tmask=" << warp. tmask
148+ DP (1 , " Instr: " << instr << " , cid=" << core_->id () << " , wid=" << wid << " , tmask=" << exec_tmask
147149 << " , PC=0x" << std::hex << warp.PC << std::dec << " , parent=#" << instr.getParentUUID () << " (#" << instr.getUUID () << " )" );
148150 } else {
149- DP (1 , " Instr: " << instr << " , cid=" << core_->id () << " , wid=" << wid << " , tmask=" << warp. tmask
151+ DP (1 , " Instr: " << instr << " , cid=" << core_->id () << " , wid=" << wid << " , tmask=" << exec_tmask
150152 << " , PC=0x" << std::hex << warp.PC << std::dec << " (#" << instr.getUUID () << " )" );
151153 }
152154
153155 // fetch register values
154- if (rsrc0.type != RegType::None) fetch_registers (rs1_data, wid, 0 , rsrc0);
155- if (rsrc1.type != RegType::None) fetch_registers (rs2_data, wid, 1 , rsrc1);
156- if (rsrc2.type != RegType::None) fetch_registers (rs3_data, wid, 2 , rsrc2);
156+ if (rsrc0.type != RegType::None) fetch_registers (rs1_data, wid, 0 , rsrc0, operand_tmask );
157+ if (rsrc1.type != RegType::None) fetch_registers (rs2_data, wid, 1 , rsrc1, operand_tmask );
158+ if (rsrc2.type != RegType::None) fetch_registers (rs3_data, wid, 2 , rsrc2, operand_tmask );
157159
158160 uint32_t thread_start = 0 ;
159161 for (; thread_start < num_threads; ++thread_start) {
160- if (warp. tmask .test (thread_start))
162+ if (exec_tmask .test (thread_start))
161163 break ;
162164 }
163165
164166 int32_t thread_last = num_threads - 1 ;
165167 for (; thread_last >= 0 ; --thread_last) {
166- if (warp. tmask .test (thread_last))
168+ if (exec_tmask .test (thread_last))
167169 break ;
168170 }
169171
@@ -1601,21 +1603,22 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
16011603 case TcuType::WMMA: {
16021604 auto trace_data = std::make_shared<TensorUnit::ExeTraceData>();
16031605 trace->data = trace_data;
1604- assert (warp. tmask .count () == num_threads);
1606+ assert (operand_tmask .count () == num_threads);
16051607 core_->tensor_unit ()->wmma (wid, tpuArgs.fmt_s , tpuArgs.fmt_d , tpuArgs.step_m , tpuArgs.step_n , tpuArgs.step_k , rs1_data, rs2_data, rs3_data, rd_data, trace_data.get ());
16061608 rd_write = true ;
16071609 } break ;
16081610 case TcuType::WMMA_SP: {
16091611 auto trace_data = std::make_shared<TensorUnit::ExeTraceData>();
16101612 trace->data = trace_data;
1611- assert (warp.tmask .count () == num_threads);
1613+ assert (operand_tmask.count () == num_threads);
1614+ assert (exec_tmask.any ());
16121615 core_->tensor_unit ()->wmma_sp (wid, tpuArgs.fmt_s , tpuArgs.fmt_d , tpuArgs.step_m , tpuArgs.step_n , tpuArgs.step_k , rs1_data, rs2_data, rs3_data, rd_data, trace_data.get ());
16131616 rd_write = true ;
16141617 } break ;
16151618 case TcuType::META_STORE: {
16161619 auto trace_data = std::make_shared<TensorUnit::ExeTraceData>();
16171620 trace->data = trace_data;
1618- assert (warp. tmask .count () == num_threads);
1621+ assert (operand_tmask .count () == num_threads);
16191622 core_->tensor_unit ()->meta_store (wid, tpuArgs.fmt_s , tpuArgs.fmt_d , rs1_data, trace_data.get ());
16201623 } break ;
16211624 default :
@@ -1635,7 +1638,7 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
16351638 DPH (2 , " Dest Reg: " << rdest << " ={" );
16361639 for (uint32_t t = 0 ; t < num_threads; ++t) {
16371640 if (t) DPN (2 , " , " );
1638- if (!warp. tmask .test (t)) {
1641+ if (!exec_tmask .test (t)) {
16391642 DPN (2 , " -" );
16401643 continue ;
16411644 }
@@ -1652,7 +1655,7 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
16521655 DPH (2 , " Dest Reg: " << rdest << " ={" );
16531656 for (uint32_t t = 0 ; t < num_threads; ++t) {
16541657 if (t) DPN (2 , " , " );
1655- if (!warp. tmask .test (t)) {
1658+ if (!exec_tmask .test (t)) {
16561659 DPN (2 , " -" );
16571660 continue ;
16581661 }
0 commit comments