diff --git a/README.md b/README.md index 54c1efa..2350e9b 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,8 @@ risc-v simulator homework for compiler class. Implementation supports rv32imf_zb - C++ compiler (with C++20 support) - CMake (3.21+) +- Python (3.12+) +- Graphviz (optional) ## Build @@ -18,3 +20,16 @@ cmake --build build ```bash ./build/riscv-sim.x ``` + +## Notes + +This project use generator to produce **instruction decoder** from the ISA description +(`src/isa/include/isa/isa_ext.inc`, the `MNEMONIC(name, mask, match)` table). +During the CMake build, the Python generator (`decode/pygen/gen_decoder.py`) +builds a **decision tree** over instruction bits/slices and emits `decode.cpp` as nested `if`/`switch` code. +You can select the decoder backend at configure time: `-DDECODE_BACKEND=generated` (default) or `-DDECODE_BACKEND=linear`. +Optionally, add `-DDECODE_EMIT_DOT=ON` to produce .dot representation of decision tree. + +Example of generated decision tree for rv32i: + +![generated decode tree](images/decode_tree.svg) diff --git a/images/decode_tree.svg b/images/decode_tree.svg new file mode 100644 index 0000000..12e9a43 --- /dev/null +++ b/images/decode_tree.svg @@ -0,0 +1,682 @@ + + + + + + +decode + + + +node0 + +0_6 + + + +node1 + +LUI + + + +node0->node1 + + +55 + + + +node2 + +AUIPC + + + +node0->node2 + + +23 + + + +node3 + +JAL + + + +node0->node3 + + +111 + + + +node4 + +JALR + + + +node0->node4 + + +103 + + + +node5 + +12_14 + + + +node0->node5 + + +99 + + + +node12 + +12_14 + + + +node0->node12 + + +3 + + + +node18 + +12_14 + + + +node0->node18 + + +35 + + + +node22 + +12_14 + + + +node0->node22 + + +19 + + + +node33 + +12_14 + + + +node0->node33 + + +51 + + + +node46 + +12 + + + +node0->node46 + + +15 + + + +node49 + +20 + + + +node0->node49 + + +115 + + + +node6 + +BEQ + + + +node5->node6 + + +0 + + + +node7 + +BNE + + + +node5->node7 + + +1 + + + +node8 + +BLT + + + +node5->node8 + + +4 + + + +node9 + +BGE + + + +node5->node9 + + +5 + + + +node10 + +BLTU + + + +node5->node10 + + +6 + + + +node11 + +BGEU + + + +node5->node11 + + +7 + + + +node13 + +LB + + + +node12->node13 + + +0 + + + +node14 + +LH + + + +node12->node14 + + +1 + + + +node15 + +LW + + + +node12->node15 + + +2 + + + +node16 + +LBU + + + +node12->node16 + + +4 + + + +node17 + +LHU + + + +node12->node17 + + +5 + + + +node19 + +SB + + + +node18->node19 + + +0 + + + +node20 + +SH + + + +node18->node20 + + +1 + + + +node21 + +SW + + + +node18->node21 + + +2 + + + +node23 + +ADDI + + + +node22->node23 + + +0 + + + +node24 + +SLTI + + + +node22->node24 + + +2 + + + +node25 + +SLTIU + + + +node22->node25 + + +3 + + + +node26 + +XORI + + + +node22->node26 + + +4 + + + +node27 + +ORI + + + +node22->node27 + + +6 + + + +node28 + +ANDI + + + +node22->node28 + + +7 + + + +node29 + +SLLI + + + +node22->node29 + + +1 + + + +node30 + +30 + + + +node22->node30 + + +5 + + + +node31 + +SRLI + + + +node30->node31 + + +0 + + + +node32 + +SRAI + + + +node30->node32 + + +1 + + + +node34 + +30 + + + +node33->node34 + + +0 + + + +node37 + +SLL + + + +node33->node37 + + +1 + + + +node38 + +SLT + + + +node33->node38 + + +2 + + + +node39 + +SLTU + + + +node33->node39 + + +3 + + + +node40 + +XOR + + + +node33->node40 + + +4 + + + +node41 + +30 + + + +node33->node41 + + +5 + + + +node44 + +OR + + + +node33->node44 + + +6 + + + +node45 + +AND + + + +node33->node45 + + +7 + + + +node35 + +ADD + + + +node34->node35 + + +0 + + + +node36 + +SUB + + + +node34->node36 + + +1 + + + +node42 + +SRL + + + +node41->node42 + + +0 + + + +node43 + +SRA + + + +node41->node43 + + +1 + + + +node47 + +FENCE + + + +node46->node47 + + +0 + + + +node48 + +FENCE_I + + + +node46->node48 + + +1 + + + +node50 + +ECALL + + + +node49->node50 + + +0 + + + +node51 + +EBREAK + + + +node49->node51 + + +1 + + + diff --git a/src/isa/include/isa/isa_hlp.hpp b/src/isa/include/isa/isa_hlp.hpp index 902c35d..16de072 100644 --- a/src/isa/include/isa/isa_hlp.hpp +++ b/src/isa/include/isa/isa_hlp.hpp @@ -46,6 +46,16 @@ constexpr bool IsBitSet(T val, size_t index) { return (val & one_hot) == one_hot; } +template +constexpr T GetBitSlice(T val, size_t low, size_t high) { + static_assert(std::is_unsigned_v, "T must be an unsigned type"); + assert(high < sizeof(T) * CHAR_BIT); + + auto width = high - low + 1; + T one = T{1}; + return (val >> low) & ((one << width) - 1); +} + // TruncHigh means save low bits template constexpr ToT TruncHigh(FromT x) { diff --git a/src/sim/decode/pygen/emit_cpp.py b/src/sim/decode/pygen/emit_cpp.py index 08f0bf0..eec8b6d 100644 --- a/src/sim/decode/pygen/emit_cpp.py +++ b/src/sim/decode/pygen/emit_cpp.py @@ -6,32 +6,48 @@ def _node_gen(node: DTreeNode, depth: int = 0) -> str: if isinstance(node, DTreeLeaf): return f"{indent}return isa::InsnMnemonic::{node.insn_name};\n" - - cond_if = f"{indent}if (isa::IsBitSet(insn, {node.bit_idx})) {{\n" - cond_else = f"{indent}}} else {{\n" - cond_end = f"{indent}}}\n" - - return cond_if + _node_gen(node.one, depth=depth + 1) + cond_else + _node_gen(node.zero, depth=depth + 1) + cond_end + elif isinstance(node, DTreeTestBit): + cond_if = f"{indent}if (isa::IsBitSet(insn, {node.bit_idx})) {{\n" + cond_else = f"{indent}}} else {{ // else bit {node.bit_idx}\n" + cond_end = f"{indent}}} // end if bit {node.bit_idx}\n" + + return cond_if + _node_gen(node.one, depth=depth + 1) + cond_else + _node_gen(node.zero, depth=depth + 1) + cond_end + else: # DTreeTestSlice + switch_begin = f"{indent}switch (isa::GetBitSlice(insn, {node.low_bit}, {node.high_bit})) {{\n" + switch_end = f"{indent} default:\n" + switch_end += f"{indent} return isa::InsnMnemonic::kInvalid;\n" + switch_end += f"{indent}}}\n" + switch_cases = "" + for backet in node.cases.items(): + case_begin = f"{indent} case {backet[0]}: {{\n" + case_end = f"{indent} break;\n" + case_end += f"{indent} }} // end case {backet[0]}\n" + + switch_cases += case_begin + _node_gen(backet[1], depth + 2) + case_end + + return switch_begin + switch_cases + switch_end def gen(op_dtree: DTree, namespace: str = "sim::decode") -> str: - prologue = f""" - // THIS CODE IS GENERATED - // by decode/gen_decoder.py - #include "decode/decode.hpp" - - #include "isa/isa_hlp.hpp" - #include "isa/mnemonics.hpp" - - namespace {namespace} {{ - - isa::InsnMnemonic Decode(isa::UndecodedInsn insn) {{ - """ - - epilogue = f""" - }} - - }} // namespace {namespace} - """ + # list to avoid indentations + prologue_list = [ + "// THIS CODE IS GENERATED", + "// by decode/gen_decoder.py", + "#include \"decode/decode.hpp\"", + "#include \"isa/isa_hlp.hpp\"", + "#include \"isa/mnemonics.hpp\"", + "", + f"namespace {namespace} {{", + "", + "isa::InsnMnemonic Decode(isa::UndecodedInsn insn) {\n", + ] + prologue = "\n".join(prologue_list) + + epilogue_list = [ + "}", + "", + f"}} // namespace {namespace}" + ] + epilogue = "\n".join(epilogue_list) decode_body = _node_gen(op_dtree.get_root(), depth=1) diff --git a/src/sim/decode/pygen/insn_mnem.py b/src/sim/decode/pygen/insn_mnem.py index e62f46b..f8224ff 100644 --- a/src/sim/decode/pygen/insn_mnem.py +++ b/src/sim/decode/pygen/insn_mnem.py @@ -1,5 +1,8 @@ from dataclasses import dataclass +INSN_BIT_SIZE = 32 +INSN_FULL_MASK = (1 << INSN_BIT_SIZE) - 1 + @dataclass() class Insn: name: str diff --git a/src/sim/decode/pygen/opcode_dtree.py b/src/sim/decode/pygen/opcode_dtree.py index 9df0f39..1f6de33 100644 --- a/src/sim/decode/pygen/opcode_dtree.py +++ b/src/sim/decode/pygen/opcode_dtree.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import List, Tuple, Optional, Union, Set, Dict -from insn_mnem import Insn +from insn_mnem import INSN_BIT_SIZE, INSN_FULL_MASK, Insn @dataclass class DTreeLeaf: @@ -15,7 +15,24 @@ class DTreeTestBit: zero: "DTreeNode" one: "DTreeNode" -type DTreeNode = Union[DTreeLeaf, DTreeTestBit] +@dataclass +class DTreeTestSlice: + low_bit: int + high_bit: int + cases: Dict[int, "DTreeNode"] + +type DTreeNode = Union[DTreeLeaf, DTreeTestBit, DTreeTestSlice] + +@dataclass +class ChosenBit: + bit: int + +@dataclass +class ChosenSlice: + low: int + high: int + +ChosenTest = Union[ChosenBit, ChosenSlice] @dataclass class DTree: @@ -47,19 +64,26 @@ def _emit_node(node: DTreeNode): if isinstance(node, DTreeLeaf): lines += f' node{node_id} [shape=box, label={node.insn_name}];\n' - return - - lines += f' node{node_id} [shape=box, label={node.bit_idx}];\n' + elif isinstance(node, DTreeTestBit): + lines += f' node{node_id} [shape=box, label={node.bit_idx}];\n' + + # test bit node + _emit_node(node.zero) + _emit_node(node.one) - # test bit node - _emit_node(node.zero) - _emit_node(node.one) + zero_id = _get_id(node.zero) + one_id = _get_id(node.one) - zero_id = _get_id(node.zero) - one_id = _get_id(node.one) + lines += f' node{node_id} -> node{zero_id} [label="0"];\n' + lines += f' node{node_id} -> node{one_id} [label="1"];\n' + else: # DTreeTestSlice + lines += f' node{node_id} [shape=box, label="{node.low_bit}_{node.high_bit}"];\n' + + for bucket in node.cases.items(): + _emit_node(bucket[1]) + bucket_id = _get_id(bucket[1]) - lines += f' node{node_id} -> node{zero_id} [label="0"];\n' - lines += f' node{node_id} -> node{one_id} [label="1"];\n' + lines += f' node{node_id} -> node{bucket_id} [label="{bucket[0]}"];\n' _emit_node(self.root) @@ -73,10 +97,130 @@ def write_dot(self, dot_path: Path, *, name: str = "decode"): def _bit_by_idx(x: int, idx: int) -> int: return (x >> idx) & 1 -type Branch = List[Insn] +def _bits_to_mask(bits: List[int]) -> int: + mask = 0 + for bit_idx in bits: + mask |= (1 << bit_idx) + + return mask + +type Cands = List[Insn] +type Branch = Cands type Decision = Tuple[Branch, Branch, bool] -def _split_on_bit(cands: List[Insn], bit_idx: int) -> Decision: +# slice test + +def _common_mask(cands: Cands) -> int: + if not cands: + return 0 + common = INSN_FULL_MASK + for insn in cands: + common &= insn.mask & INSN_FULL_MASK + + return common + +def _do_slice(value: int, low: int, high: int) -> int: + width = high - low + 1 + return (value >> low) & ((1 << width) - 1) + +def _get_all_slices( + cands: Cands, + bits: List[int] +) -> List[Tuple[int, int]]: # lo hi + cands_mask = _common_mask(cands) + bits_mask = _bits_to_mask(bits) + allowed = cands_mask & bits_mask + + slices: List[Tuple[int, int]] = [] + + for low in range(INSN_BIT_SIZE): + if not _bit_by_idx(allowed, low): + continue + high = low + for _ in range(low, INSN_BIT_SIZE): + if not _bit_by_idx(allowed, high): + break + # NOTE maybe append all lenghts + if (high - low + 1) > 1: + slices.append((low, high)) + high += 1 + + return slices + +type SliceHeuristic = Tuple[int, int, int] # tree_size depth -width +type InsnBuckets = Dict[int, Cands] # slice mask +type Slice = Tuple[int, int] # low high + +def _slice_width(slice: Slice) -> int: + return slice[1] - slice[0] + 1 + +def _slice_heuristic(slice: Slice, insn_buckets: InsnBuckets) -> SliceHeuristic: + tree_size = sum( + [len(bucket) for bucket in insn_buckets.values()] + ) + depth = max( + [len(bucket) for bucket in insn_buckets.values()] + ) + width = _slice_width(slice) + heuristic = (tree_size, depth, -width) + # heuristic = (depth, tree_size, -width) + return heuristic + +def _split_on_slice( + cands: Cands, + low: int, + high: int +) -> Tuple[InsnBuckets, bool]: # buckets and differ + insn_buckets: InsnBuckets = {} + for insn in cands: + match_slice = _do_slice(insn.match, low, high) + # if match_slice in insn_buckets append insn else insert (match_slice, []) and then append + insn_buckets.setdefault(match_slice, []).append(insn) + + differ = len(insn_buckets) > 1 + return (insn_buckets, differ) + +def _choose_slice(cands: Cands, bits: List[int]) -> Optional[Tuple[Slice, SliceHeuristic]]: + best_slice: Optional[Slice] = None + best_heuristic: Optional[SliceHeuristic] = None + + for (low, high) in _get_all_slices(cands, bits): + insn_buckets, differ = _split_on_slice(cands, low, high) + if not differ: + continue + heuristic = _slice_heuristic((low, high), insn_buckets) + if best_heuristic is None or best_heuristic > heuristic: + best_slice = (low, high) + best_heuristic = heuristic + + if best_heuristic is None or best_slice is None: + return None + + return (best_slice, best_heuristic) + +# bit Test ------------------------------ + +def _significant_bits(cands: List[Insn]) -> List[int]: + union_mask = 0 + for insn in cands: + union_mask |= insn.mask & INSN_FULL_MASK + + return [b for b in range(INSN_BIT_SIZE) if _bit_by_idx(union_mask, b)] + +type BitHeuristic = Tuple[int, int, int] + +def _bit_heuristic( + branch0: Branch, + branch1: Branch, + bit: int +) -> BitHeuristic: + depth = max(len(branch0), len(branch1)) + tree_size = len(branch0) + len(branch1) + # heuristic = (depth, tree_size, bit) + heuristic = (tree_size, depth, bit) + return heuristic + +def _split_on_bit(cands: Cands, bit_idx: int) -> Decision: branch0: Branch = [] branch1: Branch = [] has_req0: bool = False @@ -99,31 +243,12 @@ def _split_on_bit(cands: List[Insn], bit_idx: int) -> Decision: return (branch0, branch1, has_req0 and has_req1) -def _significant_bits(cands: List[Insn]) -> List[int]: - union_mask = 0 - for insn in cands: - union_mask |= insn.mask & 0xFFFFFFFF - return [b for b in range(32) if (union_mask >> b) & 1] - -type Heuristic = Tuple[int, int, int] - -def _bit_heuristic( - branch0: Branch, - branch1: Branch, - bit: int -) -> Heuristic: - b_depth = max(len(branch0), len(branch1)) - summ = len(branch0) + len(branch1) - # heuristic = (b_depth, summ, bit) - heuristic = (summ, b_depth, bit) - return heuristic - def _choose_bit( cands: List[Insn], bits: List[int] -) -> Optional[int]: +) -> Optional[Tuple[int, BitHeuristic]]: # bit heuristic # b_depth, summ, bit - best_bit: Optional[Heuristic] = None + best_bit: Optional[BitHeuristic] = None for bit in bits: branch0, branch1, differ = _split_on_bit(cands, bit) @@ -133,18 +258,43 @@ def _choose_bit( heuristic = _bit_heuristic(branch0, branch1, bit) if best_bit is None or heuristic < best_bit: best_bit = heuristic + + if best_bit is None: + return None + + return (best_bit[-1], best_bit) + +# build tree ------------------------------------------------ + +def _choose_test( + bit_test: Optional[Tuple[int, BitHeuristic]], + slice_test: Optional[Tuple[Slice, SliceHeuristic]] +) -> Optional[ChosenTest]: + if bit_test is None and slice_test is None: + return None + if bit_test is not None and slice_test is None: + return ChosenBit(bit=bit_test[0]) + if bit_test is None and slice_test is not None: + return ChosenSlice(low=slice_test[0][0], high=slice_test[0][1]) - return best_bit[-1] if best_bit is not None else None + assert bit_test is not None + assert slice_test is not None + + bit_heuristic = bit_test[1] + slice_heuristic = slice_test[1] + + bit_heuristic_main = bit_heuristic[0:2] + slice_heuristic_main = slice_heuristic[0:2] + if bit_heuristic_main > slice_heuristic_main: # slice is better + return ChosenSlice(low=slice_test[0][0], high=slice_test[0][1]) + else: # bit <= slice because bit is prefered + return ChosenBit(bit=bit_test[0]) def _build_dtree_impl( cands: List[Insn], *, - bits: Optional[List[int]] = None, - used_bits: Optional[Set[int]] = None, + bits: List[int], ) -> DTreeNode: - if used_bits is None: - used_bits = set() - # check for depth if not cands: @@ -153,33 +303,58 @@ def _build_dtree_impl( if len(cands) == 1: return DTreeLeaf(cands[0].name) - if bits is None: - bits = _significant_bits(cands) + chosen_bit = _choose_bit(cands, bits) + chosen_slice = _choose_slice(cands, bits) - remained_bits = [b for b in bits if b not in used_bits] - - best_bit = _choose_bit(cands, remained_bits) - if best_bit is None: + chosen_test = _choose_test(chosen_bit, chosen_slice) + if chosen_test is None: raise RuntimeError(f"Cannot decode, ambiguity in isa: cands: {cands}") - - branch0, branch1, differ = _split_on_bit(cands, best_bit) - assert differ, "best_bit should always be differ one" + + if isinstance(chosen_test, ChosenBit): + assert chosen_bit is not None - updated_used_bits = used_bits | {best_bit} - - zero = _build_dtree_impl( - branch0, - bits=bits, - used_bits=updated_used_bits, - ) - one = _build_dtree_impl( - branch1, - bits=bits, - used_bits=updated_used_bits, - ) + best_bit = chosen_test.bit + + branch0, branch1, differ = _split_on_bit(cands, best_bit) + assert differ + + remained_bits = bits.copy() + remained_bits.remove(best_bit) + + zero = _build_dtree_impl( + branch0, + bits=remained_bits, + ) + one = _build_dtree_impl( + branch1, + bits=remained_bits, + ) - return DTreeTestBit(best_bit, zero, one) + return DTreeTestBit(best_bit, zero, one) + else: + assert chosen_slice is not None + + (low, high) = (chosen_test.low, chosen_test.high) + insn_buckets, differ = _split_on_slice(cands, low, high) + assert differ + + remained_bits = bits.copy() + for i in range(low, high + 1): + remained_bits.remove(i) + + slice_node = DTreeTestSlice(low, high, {}) + for bucket in insn_buckets.items(): + bucket_tree = _build_dtree_impl( + bucket[1], + bits=remained_bits + ) + slice_node.cases[bucket[0]] = bucket_tree + + return slice_node def build_dtree(cands: List[Insn]) -> DTree: - root = _build_dtree_impl(cands) + root = _build_dtree_impl( + cands, + bits=_significant_bits(cands) + ) return DTree(root=root)