Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/chop/nn/quantized/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,12 @@ def linearMXInt(
)
x_skip_first_dim = config.get("data_in_skip_first_dim", True)

out_width, out_exponent_width, out_block_size = (
config.get("data_in_width", x_width),
config.get("data_in_exponent_width", x_exponent_width),
config.get("data_in_block_size", x_block_size),
)

b_width, b_exponent_width, b_block_size = (
config["bias_width"],
config["bias_exponent_width"],
Expand All @@ -556,6 +562,13 @@ def linearMXInt(
block_size=x_block_size,
skip_first_dim=x_skip_first_dim,
)
out_quantizer = partial(
mxint_quantizer,
width=out_width,
exponent_width=out_exponent_width,
block_size=out_block_size,
skip_first_dim=x_skip_first_dim,
)
b_quantizer = partial(
mxint_quantizer,
width=b_width,
Expand All @@ -568,4 +581,4 @@ def linearMXInt(
weight = w_quantizer(weight)
bias = b_quantizer(bias) if bias is not None else None

return F.linear(x, weight, bias)
return out_quantizer(F.linear(x, weight, bias))
26 changes: 24 additions & 2 deletions src/chop/nn/quantized/functional/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def relu_block_log(x, inplace=False, config=None):

def relu_mxint(x, inplace=False, config=None):
bypass = config.get("bypass", False)
if bypass:
if bypass or isinstance(x, torch.fx.proxy.Proxy):
return F.relu(x, inplace=inplace)
else:
x_width, x_exponent_width, x_block_size = (
Expand All @@ -213,6 +213,12 @@ def relu_mxint(x, inplace=False, config=None):
config["data_in_block_size"],
)

out_width, out_exponent_width, out_block_size = (
config.get("data_in_width", x_width),
config.get("data_in_exponent_width", x_exponent_width),
config.get("data_in_block_size", x_block_size),
)

x_more_than_2_dims = x.ndim > 2
x_quantizer = partial(
mxint_quantizer,
Expand All @@ -222,9 +228,25 @@ def relu_mxint(x, inplace=False, config=None):
skip_first_dim=x_more_than_2_dims,
)

out_quantizer = partial(
mxint_quantizer,
width=out_width,
exponent_width=out_exponent_width,
block_size=out_block_size,
skip_first_dim=x_more_than_2_dims,
)

x_shape = [i for i in x.shape]
if x_more_than_2_dims:
x = torch.flatten(x, start_dim=0, end_dim=-3)
x = x_quantizer(x)
x = torch.reshape(x, x_shape)
return F.relu(x, inplace=inplace)
relu_out = F.relu(x, inplace=inplace)
relu_out = (
torch.flatten(relu_out, start_dim=0, end_dim=-3)
if x_more_than_2_dims
else relu_out
)
relu_out_q = out_quantizer(relu_out)
relu_out_q = torch.reshape(relu_out_q, x_shape)
return relu_out_q
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ class IpDescType(TypedDict):
"activation_layers/rtl/fixed_relu.sv",
],
},
"mxint": {
"name": "mxint_relu",
"dependence_files": [
"linear_layers/mxint_operators/rtl/mxint_relu.sv",
"linear_layers/mxint_operators/rtl/mxint_cast.sv",
"common/rtl/split2.sv",
"linear_layers/mxint_operators/rtl/log2_max_abs.sv",
"linear_layers/mxint_operators/rtl/or_tree.sv",
"linear_layers/mxint_operators/rtl/or_tree_layer.sv",
"common/rtl/register_slice.sv",
"linear_layers/mxint_operators/rtl/unpacked_mx_fifo.sv",
"memory/rtl/fifo.sv",
"memory/rtl/skid_buffer.sv",
"memory/rtl/simple_dual_port_ram.sv",
"common/rtl/join2.sv",
],
},
},
"hardshrink": {
"fixed": {
Expand Down
4 changes: 2 additions & 2 deletions src/chop/passes/graph/transforms/verilog/emit_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self, dut, fail_on_checks=True):
raise NotImplementedError(
f"Unsupported type format {t} for {node} {arg}"
)
self.input_drivers[arg].log.setLevel(logging.DEBUG)
self.input_drivers[arg].log.setLevel(logging.INFO)

for node in graph.nodes_out:
for result, result_info in node.meta["mase"]["common"][
Expand Down Expand Up @@ -320,7 +320,7 @@ def __init__(self, dut, fail_on_checks=True):
raise NotImplementedError(
f"Unsupported type format {t} for {node} {result}"
)
self.output_monitors[result].log.setLevel(logging.DEBUG)
self.output_monitors[result].log.setLevel(logging.INFO)

self.model = graph.model
self.input_precision = graph.meta["mase"]["common"]["args"]["data_in_0"][
Expand Down
66 changes: 39 additions & 27 deletions src/chop/passes/graph/transforms/verilog/emit_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,47 +146,47 @@ def module_interface_template(
)


def wiring_top_template(
def wiring_template(
type: str | None,
interface_signal: str,
internal_signal: str,
from_signal: str,
to_signal: str,
direction: Literal["input", "output"],
node_name: str,
):
match type:
case "fixed":
if direction == "input":
out = f"""
assign {internal_signal} = {interface_signal};"""
assign {to_signal} = {from_signal};"""
else:
out = f"""
assign {interface_signal} = {internal_signal};"""
assign {from_signal} = {to_signal};"""
case "mxint":
if direction == "input":
out = f"""
assign m_{internal_signal} = m_{interface_signal};
assign e_{internal_signal} = e_{interface_signal};"""
assign m_{to_signal} = m_{from_signal};
assign e_{to_signal} = e_{from_signal};"""
else:
out = f"""
assign m_{interface_signal} = m_{internal_signal};
assign e_{interface_signal} = e_{internal_signal};"""
assign m_{from_signal} = m_{to_signal};
assign e_{from_signal} = e_{to_signal};"""
case None:
raise ValueError(
f"Missing type information for {node_name} {interface_signal} {internal_signal}"
f"Missing type information for {node_name} {from_signal} {to_signal}"
)
case t:
raise NotImplementedError(
f"Unsupported type format {t} for {node_name} {interface_signal} {internal_signal}"
f"Unsupported type format {t} for {node_name} {from_signal} {to_signal}"
)
if direction == "input":
out += f"""
assign {interface_signal}_ready = {internal_signal}_ready;
assign {internal_signal}_valid = {interface_signal}_valid;
assign {from_signal}_ready = {to_signal}_ready;
assign {to_signal}_valid = {from_signal}_valid;
"""
else:
out += f"""
assign {internal_signal}_ready = {interface_signal}_ready;
assign {interface_signal}_valid = {internal_signal}_valid;
assign {to_signal}_ready = {from_signal}_ready;
assign {from_signal}_valid = {to_signal}_valid;
"""
return out

Expand Down Expand Up @@ -735,10 +735,10 @@ def _emit_top_wires(self):
node.meta["mase"].parameters["common"]["args"].items()
):
if is_real_input_arg(node, arg_idx):
wires += wiring_top_template(
wires += wiring_template(
arg_info.get("type", None),
interface_signal=f"data_in_{i}",
internal_signal=f"{node_name}_{arg}",
from_signal=f"data_in_{i}",
to_signal=f"{node_name}_{arg}",
node_name=node_name,
direction="input",
)
Expand All @@ -750,10 +750,10 @@ def _emit_top_wires(self):
node.meta["mase"].parameters["common"]["results"].items()
):
if "data_out" in result:
wires += wiring_top_template(
wires += wiring_template(
result_info.get("type", None),
interface_signal=f"data_out_{i}",
internal_signal=f"{node_name}_{result}",
from_signal=f"data_out_{i}",
to_signal=f"{node_name}_{result}",
node_name=node_name,
direction="output",
)
Expand Down Expand Up @@ -799,14 +799,26 @@ def _emit_node2node_wires(self):
continue

to_name = vf(node.name)

for i, node_in in enumerate(node.all_input_nodes):
to_type = node.meta["mase"]["common"]["args"][f"data_in_{i}"].get(
"type", None
)
from_type = (
node_in.meta["mase"]
.parameters["common"]["results"]["data_out_0"]
.get("type", None)
)
assert (
to_type == from_type
), f"Incongruent types {to_type=} {from_type=}"
from_name = vf(node_in.name)
wires += f"""
assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready;
assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid;
assign {to_name}_data_in_{i} = {from_name}_data_out_0;
"""
wires += wiring_template(
to_type,
from_signal=f"{from_name}_data_out_0",
to_signal=f"{to_name}_data_in_{i}",
node_name=node.name,
direction="input",
)
return wires

def emit(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ module mxint_cast #(
assert (IN_EXP_WIDTH > 2)
else $fatal("IN_EXP_WIDTH must be greater than 2");
assert (IN_MAN_WIDTH > 3)
else $fatal("IN_EXP_WIDTH must be greater than 2");
else $fatal("IN_MAN_WIDTH must be greater than 3");
assert (OUT_EXP_WIDTH > 2)
else $fatal("OUT_EXP_WIDTH must be greater than 2");
assert (OUT_MAN_WIDTH > 3)
else $fatal("IN_EXP_WIDTH must be greater than 2");
else $fatal("OUT_MAN_WIDTH must be greater than 3");
end

// =============================
Expand Down
Loading