diff --git a/pyproject.toml b/pyproject.toml index 5209ff3d..c091ff94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ mlir-loop = "xtc.cli.mlir_loop:main" mlir-backend = "xtc.cli.mlir_backend:main" loop-display = "xtc.cli.display_results:main" loop-explore = "xtc.cli.explore:main" +db-results = "xtc.cli.db_results:main" [build-system] requires = ["setuptools>=64", "wheel"] diff --git a/scripts/explore/run-explore-variants.sh b/scripts/explore/run-explore-variants.sh index 39d0a1a2..5f65aed3 100755 --- a/scripts/explore/run-explore-variants.sh +++ b/scripts/explore/run-explore-variants.sh @@ -13,11 +13,16 @@ STRATEGIES="${STRATEGIES:-" \ mkdir -p "$outdir" rm -f "$outdir/*.csv" +op="matmul" +dims="512 1024 128" + t="$TRIALS" for s in $STRATEGIES; do for b in $BACKENDS; do echo "Testing backend $b with tiling strategy $s for $t trials..." - (set -x && loop-explore --backends "$b" --trials "$t" --jobs 1 --strategy "$s" --output "$outdir/results.b$b.s$s.t$t.csv") + (set -x && loop-explore --backends "$b" --operator $op --dims $dims --trials "$t" --jobs 1 --strategy "$s" --output "$outdir/results.b$b.s$s.t$t.csv" --db-file "$outdir/results.db") done done - +result="$(set -x && db-results --operator $op --dims $dims --db-file "$outdir/results.db")" +[ -n "$result" ] || { echo "ERROR: unexpected empty db" >&2 ; exit 1; } +echo "$result" diff --git a/src/xtc/cli/db_results.py b/src/xtc/cli/db_results.py new file mode 100644 index 00000000..d6e9bea9 --- /dev/null +++ b/src/xtc/cli/db_results.py @@ -0,0 +1,244 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from abc import ABC +from typing import TypeAlias, Any +import importlib.metadata +import json +import argparse +import sys +import logging + + +from xtc.graphs.xtc.operators import XTCOperator + +logger = logging.getLogger(__name__) + +VERSION = "v0.1" + +DBEntry: TypeAlias = dict[str, Any] + + +def get_native_platform(): + import platform + + node = platform.node().split(".")[0] + return [node, platform.system(), platform.machine()] + + +def get_node_platform(node: str, target: str): + return [node, "Linux", target] + + +class ResultsDB(ABC): + def __init__( + self, + db_file: str, + db_version: str = VERSION, + target: str = "native", + node: str = "", + ): + self._db_file = db_file + self._version = [db_version] + if target == "native": + self._platform = get_native_platform() + else: + assert node != "", f"node must be specified for non native target" + self._platform = get_node_platform(node, target) + self._xtc_version = "v" + importlib.metadata.version("xtc") + self._results = [] + self._reload() + + def _reload(self): + self._results = [] + with open(self._db_file) as inf: + for jsonl in inf.readlines(): + log = json.loads(jsonl) + self._results.append(log) + self._results + + def _default_match( + self, + log: DBEntry, + operator: list[Any], + target: str = "native", + threads: int = 1, + backend: str | None = None, + ) -> bool: + compiler = ["xtc", self._xtc_version, target, threads] + if backend is not None: + compiler.append(backend) + return ( + self._version == log["version"][: len(self._version)] + and operator == log["operator"][: len(operator)] + and self._platform == log["platform"][: len(self._platform)] + and compiler == log["compiler"][: len(compiler)] + ) + + def get_operation_results( + self, + xtc_op_signature: list[Any], + target: str = "native", + threads: int = 1, + backend: str | None = None, + errors: bool = False, + ) -> list[DBEntry]: + operator = ["xtc.operator", *xtc_op_signature] + results = [] + for log in self._results: + if not self._default_match( + log, + operator=operator, + target=target, + threads=threads, + backend=backend, + ): + continue + if not errors and log["results"][0] != 0: + continue + results.append(log) + return results + + def get_operation_best( + self, + xtc_op_signature: list[Any], + target: str = "native", + threads: int = 1, + backend: str | None = None, + ) -> DBEntry | None: + best_time = float("+inf") + best_log = None + logs = self.get_operation_results( + xtc_op_signature, + target, + threads, + backend, + errors=False, + ) + for log in logs: + time = min(log["results"][1]) + if time < best_time: + best_log = log + best_time = time + return best_log + + +def get_signature_from_args(op_type: str, *spec: Any) -> list[Any]: + match op_type: + case "conv2d": + n, h, w, f, r, s, c, SH, SW, dtype = spec + signature = XTCOperator.get_op_signature( + "conv2d", + n, + h, + w, + f, + r, + s, + c, + dtype, + stride=(SH, SW), + ) + case _: + signature = XTCOperator.get_op_signature(op_type, *spec) + logger.debug("matching for signature: %s", signature) + return signature + + +def main(): + default_dtype = "float32" + default_db_file = "xtc-operators-db.json" + default_target = "native" + default_threads = 1 + + parser = argparse.ArgumentParser( + description="Report DB results", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--operator", required=True, type=str, help="operator to query") + parser.add_argument("--dims", nargs="+", type=int, required=True, help="dimensions") + parser.add_argument("--dtype", type=str, default=default_dtype, help="data type") + parser.add_argument("--target", type=str, default=default_target, help="target") + parser.add_argument( + "--threads", type=int, default=default_threads, help="threads for target" + ) + parser.add_argument("--backend", type=str, help="optional backend for filter") + parser.add_argument( + "--dump", + action=argparse.BooleanOptionalAction, + default=False, + help="dump all matched results", + ) + parser.add_argument( + "--best", + action=argparse.BooleanOptionalAction, + default=True, + help="get best result", + ) + parser.add_argument( + "--quiet", + action=argparse.BooleanOptionalAction, + default=False, + help="quiet mode, only results on output", + ) + parser.add_argument( + "--db-file", type=str, default=default_db_file, help="results json db" + ) + parser.add_argument( + "--debug", + action=argparse.BooleanOptionalAction, + default=False, + help="debug mode", + ) + args = parser.parse_args() + + logging.basicConfig() + if args.debug: + logger.setLevel(logging.DEBUG) + + spec = [*args.dims, args.dtype] + signature = get_signature_from_args(args.operator, *spec) + + db = ResultsDB(db_file=args.db_file) + + if args.dump: + results = db.get_operation_results( + xtc_op_signature=signature, + target=args.target, + threads=args.threads, + ) + num = len(results) + for idx, entry in enumerate(results): + print( + f"result {idx + 1}/{num}: compiler:", + entry["compiler"], + "results:", + entry["results"][1], + ) + if args.best: + log = db.get_operation_best( + xtc_op_signature=signature, + target=args.target, + threads=args.threads, + ) + if log is not None: + time = min(log["results"][1]) + print( + "best:", + time, + "strategy:", + log["strategy"], + "schedule:", + log["schedule"], + "compiler:", + log["compiler"], + ) + elif not args.quiet: + print("warning: no match found", file=sys.stderr) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/xtc/cli/explore.py b/src/xtc/cli/explore.py index b411480d..4e18f33c 100644 --- a/src/xtc/cli/explore.py +++ b/src/xtc/cli/explore.py @@ -15,6 +15,7 @@ """ +from abc import abstractmethod import sys import os import argparse @@ -33,7 +34,10 @@ from collections.abc import Sequence, Mapping from typing import Any, TypeAlias, cast from typing_extensions import override +import json +import platform +from xtc.graphs.xtc.graph import XTCGraph from xtc.itf.back import Backend from xtc.itf.graph import Graph from xtc.itf.comp import Module @@ -262,7 +266,8 @@ def load_and_evaluate_sample( result = (in_x, code, time, backend) if callbacks and "result" in callbacks: - callbacks["result"](result) + for callback in callbacks["result"]: + callback(result) return result @@ -530,6 +535,57 @@ def peak_time(args: NS) -> float: return time +class ResultCallBack: + @abstractmethod + def __call__(self, result: Sequence) -> None: ... + + +class DBCallback(ResultCallBack): + def __init__( + self, + dbfile: str, + target: str, + threads: int, + strategy: str, + ) -> None: + self._dbfile = dbfile + self._target = target + self._threads = threads + self._version = ["v0.1"] + self._platform = [platform.node(), platform.system(), platform.machine()] + self._strategy: list[Any] | None = None + self._operator: list[Any] | None = None + strategy_signature = [strategy_name(strategy)] # TODO: move to strategy + self._strategy = ["xtc.strategy", *strategy_signature] + + def set_graph(self, graph: XTCGraph): + assert len(graph.nodes) == 1, f"Only support recording of single node graph" + signature = graph.outputs_nodes[0].operation.signature + self._operator = ["xtc.operator", *signature] + + def _write_result(self, result: Sequence) -> None: + x, code, time, backend = result + if code != 0: + time = 0 + compiler = ["xtc", "v0.2.dev1", self._target, self._threads, backend] + log = dict( + version=self._version, + platform=self._platform, + compiler=compiler, + operator=self._operator, + strategy=self._strategy, + schedule=list(x), + results=[int(code), [float(time)]], + ) + log_json = json.dumps(log) + with open(self._dbfile, "a") as outf: + print(log_json, flush=True, file=outf) + + @override + def __call__(self, result: Sequence) -> None: + self._write_result(result) + + class CSVCallback: def __init__(self, fname: str, peak_time: float, sample_names: list[str]) -> None: self._fname = fname @@ -564,6 +620,7 @@ def _write_result(self, result: Sequence) -> None: logger.debug(f"Record row: {row}") self._write_row(row) + @override def __call__(self, result: Sequence) -> None: self._write_result(result) @@ -581,11 +638,8 @@ def search_some(strategy: Strategy, graph: Graph, args: NS): args.quiet, args.operator, ) - ptime = peak_time(args) - sample_names = strategy.sample_names - result_callback = CSVCallback(args.output, ptime, sample_names) callbacks = { - "result": result_callback, + "result": args.result_callbacks, "search": search_callback, } if args.search in ["exhaustive", "random"]: @@ -613,6 +667,8 @@ def optimize(args: NS): op_args = (*dims, dtype) graph = OPERATORS[args.operator]["operation"](*op_args, name=args.func_name) strategy = get_strategy(graph, args) + if args.db_file: + args.db_callback.set_graph(graph) # TODO: fix, not really clean if args.test or args.opt_level in [0, 1, 2, 3]: schedule = args.test if not schedule: @@ -625,18 +681,15 @@ def optimize(args: NS): args.quiet, args.operator, ) - ptime = peak_time(args) - sample_names = strategy.sample_names - result_callback = CSVCallback(args.output, ptime, sample_names) callbacks = { - "result": result_callback, + "result": args.result_callbacks, "search": search_callback, } evaluate_sample(strategy, schedule, graph, args, callbacks=callbacks) - for row in result_callback._rows: + for row in args.csv_callback._rows: in_x, time, peak, backend = row[-4:] tqdm.write( - f"Schedule: {backend}: {in_x}: time: {time * 1000:.2f} msecs, peak perf: {peak * 100:.2f}%" + f"Schedule: {backend}: {in_x}: time: {time * 1000:.3f} msecs, peak perf: {peak * 100:.2f}%" ) else: search_some(strategy, graph, args) @@ -787,6 +840,21 @@ def setup_args(args: NS): # otherwise the import of tvm breaks the MLIR python bindings args.backends = sorted(args.backends) + # Setup Callbacks + args.peak_time = peak_time(args) + sample_names = [] # TODO + args.csv_callback = CSVCallback(args.output, args.peak_time, sample_names) + result_callbacks: list[ResultCallBack] = [args.csv_callback] + if args.db_file: + args.db_callback = DBCallback( + args.db_file, + "native", + args.threads, + args.strategy, + ) + result_callbacks.append(args.db_callback) + args.result_callbacks = result_callbacks + def launch_child(argv: Sequence[str], args: NS): env = {} @@ -901,6 +969,11 @@ def main(): parser.add_argument( "--output", type=str, default="results.csv", help="output csv file for search" ) + parser.add_argument( + "--db-file", + type=str, + help="output json db, for instance: xtc-operators-db.json", + ) parser.add_argument( "--eval", type=str, choices=["eval"], default="eval", help="evaluation method" )