From 5a532fa087e5e4b6020705cf2d24a5272f6bcd85 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Wed, 8 Nov 2023 14:12:05 -0500 Subject: [PATCH 01/15] Working on dep graph simulation --- feynsum-sml/src/common/DepGraph.sml | 136 ++++++++++++ feynsum-sml/src/common/DepGraphScheduler.sml | 13 ++ .../DepGraphSchedulerGreedyBranching.sml | 25 +++ .../DepGraphSchedulerGreedyFinishQubit.sml | 146 +++++++++++++ .../DepGraphSchedulerGreedyNonBranching.sml | 26 +++ feynsum-sml/src/common/DepGraphUtil.sml | 68 ++++++ feynsum-sml/src/common/GateSchedulerOrder.sml | 20 ++ feynsum-sml/src/common/sources.mlb | 19 ++ feynsum-sml/src/main.sml | 199 +++++++++++++----- 9 files changed, 594 insertions(+), 58 deletions(-) create mode 100644 feynsum-sml/src/common/DepGraph.sml create mode 100644 feynsum-sml/src/common/DepGraphScheduler.sml create mode 100644 feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml create mode 100644 feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml create mode 100644 feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml create mode 100644 feynsum-sml/src/common/DepGraphUtil.sml create mode 100644 feynsum-sml/src/common/GateSchedulerOrder.sml diff --git a/feynsum-sml/src/common/DepGraph.sml b/feynsum-sml/src/common/DepGraph.sml new file mode 100644 index 0000000..5537cf0 --- /dev/null +++ b/feynsum-sml/src/common/DepGraph.sml @@ -0,0 +1,136 @@ +structure DepGraph: +sig + + type gate_idx = int + + type dep_graph = { + gates: GateDefn.t Seq.t, + deps: gate_idx Seq.t Seq.t, + indegree: int Seq.t, + numQubits: int + } + + type t = dep_graph + + val fromJSON: JSON.value -> dep_graph + val fromString: string -> dep_graph + val fromFile: string -> dep_graph + + (*val mkGateDefn: (string * real list option * int list) -> GateDefn.t*) + +end = +struct + + type gate_idx = int + + type dep_graph = { + gates: GateDefn.t Seq.t, + deps: gate_idx Seq.t Seq.t, + indegree: int Seq.t, + numQubits: int + } + + type t = dep_graph + + fun expect opt err = case opt of + NONE => raise Fail err + | SOME x => x + + fun mkGateDefn ("x", NONE, [q0]) = + GateDefn.X q0 + | mkGateDefn ("y", NONE, [q0]) = + GateDefn.PauliY q0 + | mkGateDefn ("z", NONE, [q0]) = + GateDefn.PauliZ q0 + | mkGateDefn ("h", NONE, [q0]) = + GateDefn.Hadamard q0 + | mkGateDefn ("sx", NONE, [q0]) = + GateDefn.SqrtX q0 + | mkGateDefn ("sxdg", NONE, [q0]) = + GateDefn.Sxdg q0 + | mkGateDefn ("s", NONE, [q0]) = + GateDefn.S q0 + | mkGateDefn ("sdg", NONE, [q0]) = + GateDefn.S q0 + | mkGateDefn ("t", NONE, [q0]) = + GateDefn.T q0 + | mkGateDefn ("tdg", NONE, [q0]) = + GateDefn.Tdg q0 + | mkGateDefn ("cx", NONE, [control, target]) = + GateDefn.CX {control = control, target = target} + | mkGateDefn ("cz", NONE, [control, target]) = + GateDefn.CZ {control = control, target = target} + | mkGateDefn ("ccx", NONE, [control1, control2, target]) = + GateDefn.CCX {control1 = control1, control2 = control2, target = target} + | mkGateDefn ("phase", SOME [rot], [target]) = + GateDefn.Phase {target = target, rot = rot} + | mkGateDefn ("cp", SOME [rot], [control, target]) = + GateDefn.CPhase {control = control, target = target, rot = rot} + | mkGateDefn ("fsim", SOME [theta, phi], [left, right]) = + GateDefn.FSim {left = left, right = right, theta = theta, phi = phi} + | mkGateDefn ("rx", SOME [rot], [target]) = + GateDefn.RX {rot = rot, target = target} + | mkGateDefn ("ry", SOME [rot], [target]) = + GateDefn.RY {rot = rot, target = target} + | mkGateDefn ("rz", SOME [rot], [target]) = + GateDefn.RZ {rot = rot, target = target} + | mkGateDefn ("swap", NONE, [target1, target2]) = + GateDefn.Swap {target1 = target1, target2 = target2} + | mkGateDefn ("cswap", NONE, [control, target1, target2]) = + GateDefn.CSwap {control = control, target1 = target1, target2 = target2} + | mkGateDefn ("u", SOME [theta, phi, lambda], [target]) = + GateDefn.U {target = target, theta = theta, phi = phi, lambda = lambda} + | mkGateDefn ("u2", SOME [phi, lambda], [target]) = + GateDefn.U {target = target, theta = Math.pi/2.0, phi = phi, lambda = lambda} + | mkGateDefn ("u1", SOME [lambda], [target]) = + GateDefn.U {target = target, theta = 0.0, phi = 0.0, lambda = lambda} + | mkGateDefn (name, params, qargs) = + raise Fail ("Unknown gate-params-qargs combination with name " ^ name) + (*fun mkGate (g) = G.fromGateDefn (mkGateDefn g)*) + + fun arrayToSeq a = Seq.tabulate (fn i => Array.sub (a, i)) (Array.length a) + + fun getDepsInDeg (edges, N) = + let val deps = Array.array (N, nil); + val indeg = Array.array (N, 0); + fun incDeg (i) = Array.update (indeg, i, 1 + Array.sub (indeg, i)) + fun go nil = () + | go (JSON.ARRAY [JSON.INT fm, JSON.INT to] :: edges) = + let val fm64 = IntInf.toInt fm + val to64 = IntInf.toInt to + val () = incDeg to64 + val () = Array.update (deps, fm64, (to64 :: Array.sub (deps, fm64))) + in + go edges + end + | go (_ :: edges) = raise Fail "Malformed edge in JSON" + val () = go edges; + in + (Seq.map Seq.fromList (arrayToSeq deps), arrayToSeq indeg) + end + + fun fromJSON (data) = + let fun to_gate g = + let val name = JSONUtil.asString (expect (JSONUtil.findField g "name") "Expected field 'name' in JSON"); + val params = Option.map (JSONUtil.arrayMap JSONUtil.asNumber) (JSONUtil.findField g "params"); + val qargs = JSONUtil.arrayMap JSONUtil.asInt (expect (JSONUtil.findField g "qargs") "Expected field 'qargs' in JSON"); + in + mkGateDefn (name, params, qargs) + end + val numqs = case JSONUtil.findField data "qubits" of + SOME (JSON.INT qs) => IntInf.toInt qs + | _ => raise Fail "Expected integer field 'qubits' in JSON" + val gates = case JSONUtil.findField data "nodes" of + SOME (JSON.ARRAY ns) => Seq.fromList (List.map to_gate ns) + | _ => raise Fail "Expected array field 'nodes' in JSON" + val edges = case JSONUtil.findField data "edges" of + SOME (JSON.ARRAY es) => es + | _ => raise Fail "Expected array field 'nodes' in JSON" + val (deps, indegree) = getDepsInDeg (edges, Seq.length gates) + in + { gates = gates, deps = deps, indegree = indegree, numQubits = numqs } + (*Seq.zipWith (fn (g, (d, i)) => {gate = g, deps = d, indegree = i}) (gates, Seq.zip (deps, indegree))*) + end + fun fromString (str) = fromJSON (JSONParser.parse (JSONParser.openString str)) + fun fromFile (file) = fromJSON (JSONParser.parseFile file) +end diff --git a/feynsum-sml/src/common/DepGraphScheduler.sml b/feynsum-sml/src/common/DepGraphScheduler.sml new file mode 100644 index 0000000..20718bb --- /dev/null +++ b/feynsum-sml/src/common/DepGraphScheduler.sml @@ -0,0 +1,13 @@ +structure DepGraphScheduler = +struct + + type gate_idx = int + + type args = + { depGraph: DepGraph.t + , gateIsBranching: gate_idx -> bool + } + + (* From a frontier, select which gate to apply next *) + type t = args -> (gate_idx Seq.t -> gate_idx) +end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml new file mode 100644 index 0000000..2ae809a --- /dev/null +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml @@ -0,0 +1,25 @@ +structure DepGraphSchedulerGreedyBranching: +sig + val scheduler: DepGraphScheduler.t +end = +struct + + type gate_idx = int + + type args = + { depGraph: DepGraph.t + , gateIsBranching: gate_idx -> bool + } + + fun pickBranching i branching gates = + if i < Seq.length gates then + (if branching i then + Seq.nth gates i + else + pickBranching (i + 1) branching gates) + else + Seq.nth gates 0 (* pick a non-branching gate *) + + (* From a frontier, select which gate to apply next *) + fun scheduler ({gateIsBranching = gib, ...} : args) gates = pickBranching 0 gib gates +end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml new file mode 100644 index 0000000..736837f --- /dev/null +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml @@ -0,0 +1,146 @@ +functor DepGraphSchedulerGreedyFinishQubit + (val maxBranchingStride: int val disableFusion: bool): +sig + val scheduler: DepGraphScheduler.t + val ordered: DepGraph.t -> int list +end = +struct + + type gate_idx = int + + type args = + { depGraph: DepGraph.t + , gateIsBranching: gate_idx -> bool + } + + fun intDo (n : int) (f: int -> 'b) = + let fun next i = if i = n then () else (f i; next (i + 1)) + in + next 0 + end + + type OrderedIntSet = { + S: IntBinarySet.set ref, + A: int list ref + } + + fun popOIS ({S = S, A = A}: OrderedIntSet) = + case !A of + nil => NONE + | x :: xs => (A := xs; S := IntBinarySet.delete (!S, x); SOME x) + + fun pushFrontOIS ({S = S, A = A}: OrderedIntSet) (x: int) = + if IntBinarySet.member (!S, x) then + () + else + (A := x :: !A; S := IntBinarySet.add (!S, x)) + + fun pushBackOIS ({S = S, A = A}: OrderedIntSet) (x: int) = + if IntBinarySet.member (!S, x) then + () + else + (A := (!A) @ (x :: nil); S := IntBinarySet.add (!S, x)) + + fun newOIS () = {S = ref IntBinarySet.empty, A = ref nil} + + fun emptyOIS ({S = S, A = A}: OrderedIntSet) = List.null (!A) + + datatype TraversalOrder = BFS | DFS + + fun revTopologicalSort (dg: DepGraph.t) (tr: TraversalOrder) = + let val N = Seq.length (#gates dg) + val L = ref nil + val ois = newOIS () + (*fun outdegree i = Seq.length (Seq.nth (#deps dg) i)*) + val ind = Array.tabulate (N, Seq.nth (#indegree dg)) + fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d end + val push = case tr of BFS => pushBackOIS ois | DFS => pushFrontOIS ois + val _ = intDo N (fn i => if Seq.nth (#indegree dg) i = 0 then + push i else ()) + fun loop () = + case popOIS ois of + NONE => () + | SOME n => + (L := n :: !L; + intDo (Seq.length (Seq.nth (#deps dg) n)) + (fn m => if decInd m = 0 then push m else ()); + loop ()) + in + loop (); !L + end + + fun topologicalSort (dg: DepGraph.t) (tr: TraversalOrder) = + List.rev (revTopologicalSort dg tr) + + fun ordered (dg: DepGraph.t) = + topologicalSort (DepGraphUtil.redirect dg) DFS + + val gateDepths: int array option ref = ref NONE + + fun popIntBinarySet (S: IntBinarySet.set ref) = case IntBinarySet.find (fn _ => true) (!S) of + NONE => raise Fail "popIntBinarySet called on empty set" + | SOME elt => (S := IntBinarySet.delete (!S, elt); elt) + + fun revTopologicalSortOld (dg: DepGraph.t) = + let val N = Seq.length (#gates dg) + val L = ref nil + val S = ref IntBinarySet.empty + val A = ref nil + val ind = Array.tabulate (N, Seq.nth (#indegree dg)) + fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d end + val _ = intDo N (fn i => if Seq.nth (#indegree dg) i = 0 then + S := IntBinarySet.add (!S, i) else ()) + fun loop () = + if IntBinarySet.isEmpty (!S) then + () + else + let val n = popIntBinarySet S in + L := n :: !L; + intDo (Seq.length (Seq.nth (#deps dg) n)) + (fn m => if decInd m = 0 then S := IntBinarySet.add (!S, m) else ()); + loop () + end + in + loop (); !L + end + + fun computeGateDepths (dg: DepGraph.t) = + let val N = Seq.length (#gates dg) + val depths = Array.array (N, ~1) + fun gdep i = + Array.update (depths, i, 1 + Seq.reduce Int.max ~1 (Seq.map (fn j => Array.sub (depths, j)) (Seq.nth (#deps dg) i))) + (*case Array.sub (depths, i) of + ~1 => 1 + Seq.reduce Int.min ~1 (Seq.map gateDepth (Seq.nth (#deps dg) i)) + | d => d*) + in + List.foldl (fn (i, ()) => gdep i) () (revTopologicalSort dg DFS); depths + end + + fun gateDepth i dg = + case !gateDepths of + NONE => let val gd = computeGateDepths dg in + print "recompouting gate depths"; + gateDepths := SOME gd; + Array.sub (gd, i) + end + | SOME gd => Array.sub (gd, i) + + fun pickLeastDepth best_idx best_depth i gates dg = + if i = Seq.length gates then + best_idx + else + let val cur_idx = Seq.nth gates i + val cur_depth = gateDepth cur_idx dg in + if cur_depth < best_depth then + pickLeastDepth cur_idx cur_depth (i + 1) gates dg + else + pickLeastDepth best_idx best_depth (i + 1) gates dg + end + + (* From a frontier, select which gate to apply next *) + fun scheduler ({depGraph = dg, ...} : args) = + (gateDepths := SOME (computeGateDepths dg); + fn gates => let val g0 = Seq.nth gates 0 in + pickLeastDepth g0 (gateDepth g0 dg) 1 gates dg + end) +end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml new file mode 100644 index 0000000..258e3f0 --- /dev/null +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml @@ -0,0 +1,26 @@ +functor DepGraphSchedulerGreedyNonBranching + (val maxBranchingStride: int val disableFusion: bool): +sig + val scheduler: DepGraphScheduler.t +end = +struct + + type gate_idx = int + + type args = + { depGraph: DepGraph.t + , gateIsBranching: gate_idx -> bool + } + + fun pickNonBranching i branching gates = + if i < Seq.length gates then + (if branching i then + pickNonBranching (i + 1) branching gates + else + Seq.nth gates i) + else + Seq.nth gates 0 (* pick a branching gate *) + + (* From a frontier, select which gate to apply next *) + fun scheduler ({gateIsBranching = gib, ...} : args) gates = pickNonBranching 0 gib gates +end diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DepGraphUtil.sml new file mode 100644 index 0000000..da62d92 --- /dev/null +++ b/feynsum-sml/src/common/DepGraphUtil.sml @@ -0,0 +1,68 @@ +structure DepGraphUtil :> +sig + + type gate_idx = int + + (* Traversal Automaton State *) + type state = { visited: bool array, indegree: int array } + + type dep_graph = DepGraph.t + + val visit: dep_graph -> gate_idx -> state -> unit + val frontier: state -> gate_idx Seq.t + val initState: dep_graph -> state + + (* Switches edge directions *) + val redirect: dep_graph -> dep_graph + + (* val gateIsBranching: dep_graph -> (gate_idx -> bool) *) +end = +struct + + type gate_idx = int + + type dep_graph = DepGraph.t + + fun redirect ({gates = gs, deps = ds, indegree = is, numQubits = qs}: dep_graph) = + let val ds2 = Array.array (qs, nil) + fun apply i = Seq.map (fn j => Array.update (ds2, j, i :: Array.sub (ds2, j))) (Seq.nth ds i) + val _ = Seq.tabulate apply qs + in + {gates = gs, + deps = Seq.tabulate (fn i => Seq.rev (Seq.fromList (Array.sub (ds2, i)))) qs, + indegree = Seq.tabulate (fn i => Seq.length (Seq.nth ds i)) qs, + numQubits = qs} + end + + type state = { visited: bool array, indegree: int array } + + fun visit {gates = _, deps = ds, indegree = _, numQubits = _} i {visited = vis, indegree = deg} = + ( + (* Set visited[i] = true *) + Array.update (vis, i, true); + (* Decrement indegree of each i dependency *) + Seq.map (fn j => Array.update (deg, j, Array.sub (deg, j) - 1)) (Seq.nth ds i); + () + ) + + fun frontier {visited = vis, indegree = deg} = + let val N = Array.length vis + fun iter i acc = + if i < 0 then + acc + else + iter (i - 1) (if (not (Array.sub (vis, i))) andalso + Array.sub (deg, i) = 0 + then i :: acc else acc) + in + Seq.fromList (iter (N - 1) nil) + end + + fun initState (graph: dep_graph) = + let val N = Seq.length (#gates graph) + val vis = Array.array (N, false) + val deg = Array.tabulate (N, Seq.nth (#indegree graph)) + in + { visited = vis, indegree = deg } + end +end diff --git a/feynsum-sml/src/common/GateSchedulerOrder.sml b/feynsum-sml/src/common/GateSchedulerOrder.sml new file mode 100644 index 0000000..1f32def --- /dev/null +++ b/feynsum-sml/src/common/GateSchedulerOrder.sml @@ -0,0 +1,20 @@ +structure GateSchedulerOrder: +sig + type gate_idx = int + val mkScheduler: gate_idx Seq.t -> GateScheduler.t +end = +struct + + type qubit_idx = int + type gate_idx = int + + fun mkScheduler (order: gate_idx Seq.t) args = + let val i = ref 0 + val N = Seq.length order + in + fn () => if !i >= N then + Seq.empty () + else + let val gi = !i in i := gi + 1; Seq.singleton gi end + end +end diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index fc6630e..d513dc1 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -5,6 +5,7 @@ local in functor RedBlackMapFn functor RedBlackSetFn + structure IntBinarySet end ../lib/github.com/mpllang/mpllib/sources.$(COMPILER).mlb @@ -20,6 +21,8 @@ local structure SMLQasmParser = Parser end + $(SML_LIB)/smlnj-lib/JSON/json-lib.mlb + HashTable.sml ApplyUntilFailure.sml @@ -61,6 +64,14 @@ local GateSchedulerGreedyFinishQubit.sml Fingerprint.sml + + DepGraph.sml + DepGraphUtil.sml + DepGraphScheduler.sml + DepGraphSchedulerGreedyBranching.sml + DepGraphSchedulerGreedyNonBranching.sml + DepGraphSchedulerGreedyFinishQubit.sml + GateSchedulerOrder.sml in structure HashTable structure ApplyUntilFailure @@ -116,4 +127,12 @@ in functor RedBlackSetFn functor Fingerprint + + structure DepGraph + structure DepGraphUtil + structure DepGraphScheduler + structure DepGraphSchedulerGreedyBranching + functor DepGraphSchedulerGreedyNonBranching + functor DepGraphSchedulerGreedyFinishQubit + structure GateSchedulerOrder end \ No newline at end of file diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index 9a0010b..963d8e9 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -29,6 +29,68 @@ val _ = print ("scheduler " ^ schedulerName ^ "\n") val inputName = CLA.parseString "input" "" val _ = print ("input " ^ inputName ^ "\n") +(* ========================================================================= + * parse input + *) + +val _ = print + ("-------------------------------\n\ + \--- input-specific specs\n\ + \-------------------------------\n") + +fun parseQasm () = + let + fun handleLexOrParseError exn = + let + val e = + case exn of + SMLQasmError.Error e => e + | other => raise other + in + TerminalColorString.print + (SMLQasmError.show + {highlighter = SOME SMLQasmSyntaxHighlighter.fuzzyHighlight} e); + OS.Process.exit OS.Process.failure + end + + val ast = SMLQasmParser.parseFromFile inputName + handle exn => handleLexOrParseError exn + + val simpleCirc = SMLQasmSimpleCircuit.fromAst ast + in + Circuit.fromSMLQasmSimpleCircuit simpleCirc + end + +val (circuit, optDepGraph) = + case inputName of + "" => Util.die ("missing: -input FILE.qasm") + + | _ => + if String.isSuffix ".qasm" inputName then + (parseQasm (), NONE) + else + let val dg = DepGraph.fromFile inputName + val circuit = {numQubits = #numQubits dg, gates = #gates dg} + in + (circuit, SOME dg) + end + +val _ = print ("-------------------------------\n") + +val _ = print ("gates " ^ Int.toString (Circuit.numGates circuit) ^ "\n") +val _ = print ("qubits " ^ Int.toString (Circuit.numQubits circuit) ^ "\n") + +val showCircuit = CLA.parseFlag "show-circuit" +val _ = print ("show-circuit? " ^ (if showCircuit then "yes" else "no") ^ "\n") +val _ = + if not showCircuit then + () + else + print + ("=========================================================\n" + ^ Circuit.toString circuit + ^ "=========================================================\n") + (* ======================================================================== * gate scheduling *) @@ -48,6 +110,16 @@ in (val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) + structure DGNB = + DepGraphSchedulerGreedyNonBranching + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure DGFQ = + DepGraphSchedulerGreedyFinishQubit + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + fun print_sched_info () = let val _ = print @@ -66,19 +138,82 @@ in end end +type gate_idx = int +type Schedule = gate_idx Seq.t + +(*fun gate_to_schedule (ags: GateScheduler.args) (sched: GateScheduler.t) = + let val next = sched ags; + fun loadNext (acc: gate_idx Seq.t list) = + let val gs = next () in + if Seq.length gs = 0 then + Seq.flatten (Seq.rev (Seq.fromList acc)) + else + loadNext (gs :: acc) + end + in + loadNext nil + end*) + +fun dep_graph_to_schedule (ags: DepGraphScheduler.args) (sched: DepGraphScheduler.t) = + let val choose = sched ags + val dg = #depGraph ags + val st = DepGraphUtil.initState dg + fun loadNext (acc: gate_idx list) = + let val frntr = DepGraphUtil.frontier st in + if Seq.length frntr = 0 then + Seq.rev (Seq.fromList acc) + else + (let val next = choose frntr in + DepGraphUtil.visit dg next st; + loadNext (next :: acc) + end) + end + in + GateSchedulerOrder.mkScheduler (loadNext nil) + end + +structure Gate_branching = Gate (structure B = BasisIdxUnlimited + structure C = Complex64) + + +fun gate_branching ({ gates = gates, ...} : DepGraph.t) = + let val gates = Seq.map Gate_branching.fromGateDefn gates + fun gate i = Seq.nth gates i + in + (fn i => + case #action (gate i) of + Gate_branching.NonBranching _ => false + | _ => true) + end + +fun greedybranching () = + case optDepGraph of + NONE => GateSchedulerGreedyBranching.scheduler + | SOME dg => dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DepGraphSchedulerGreedyBranching.scheduler + +fun greedynonbranching () = + case optDepGraph of + NONE => GNB.scheduler + | SOME dg => dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGNB.scheduler + +fun greedyfinishqubit () = + case optDepGraph of + NONE => GFQ.scheduler + | SOME dg => GateSchedulerOrder.mkScheduler (Seq.fromList (DGFQ.ordered dg)) (*dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGFQ.scheduler*) + val gateScheduler = case schedulerName of "naive" => GateSchedulerNaive.scheduler - | "greedy-branching" => GateSchedulerGreedyBranching.scheduler - | "gb" => GateSchedulerGreedyBranching.scheduler + | "greedy-branching" => (print_sched_info (); greedybranching ()) + | "gb" => (print_sched_info (); greedybranching ()) - | "greedy-nonbranching" => (print_sched_info (); GNB.scheduler) - | "gnb" => (print_sched_info (); GNB.scheduler) + | "greedy-nonbranching" => (print_sched_info (); greedynonbranching ()) + | "gnb" => (print_sched_info (); greedynonbranching ()) - | "greedy-finish-qubit" => (print_sched_info (); GFQ.scheduler) - | "gfq" => (print_sched_info (); GFQ.scheduler) + | "greedy-finish-qubit" => (print_sched_info (); greedyfinishqubit ()) + | "gfq" => (print_sched_info (); greedyfinishqubit ()) | _ => Util.die @@ -86,58 +221,6 @@ val gateScheduler = ^ "; valid options are: naive, greedy-branching (gb), greedy-nonbranching (gnb), greedy-finish-qubit (gfq)") -(* ========================================================================= - * parse input - *) - -val _ = print - ("-------------------------------\n\ - \--- input-specific specs\n\ - \-------------------------------\n") - -val circuit = - case inputName of - "" => Util.die ("missing: -input FILE.qasm") - - | _ => - let - fun handleLexOrParseError exn = - let - val e = - case exn of - SMLQasmError.Error e => e - | other => raise other - in - TerminalColorString.print - (SMLQasmError.show - {highlighter = SOME SMLQasmSyntaxHighlighter.fuzzyHighlight} e); - OS.Process.exit OS.Process.failure - end - - val ast = SMLQasmParser.parseFromFile inputName - handle exn => handleLexOrParseError exn - - val simpleCirc = SMLQasmSimpleCircuit.fromAst ast - in - Circuit.fromSMLQasmSimpleCircuit simpleCirc - end - -val _ = print ("-------------------------------\n") - -val _ = print ("gates " ^ Int.toString (Circuit.numGates circuit) ^ "\n") -val _ = print ("qubits " ^ Int.toString (Circuit.numQubits circuit) ^ "\n") - -val showCircuit = CLA.parseFlag "show-circuit" -val _ = print ("show-circuit? " ^ (if showCircuit then "yes" else "no") ^ "\n") -val _ = - if not showCircuit then - () - else - print - ("=========================================================\n" - ^ Circuit.toString circuit - ^ "=========================================================\n") - (* ======================================================================== * mains: 32-bit and 64-bit *) From b82b2eb790209d2c9b515411f17b63e47655118b Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Wed, 8 Nov 2023 14:12:42 -0500 Subject: [PATCH 02/15] Add dependency-graph generating python script --- feynsum-sml/src/common/depgraph.py | 318 +++++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 feynsum-sml/src/common/depgraph.py diff --git a/feynsum-sml/src/common/depgraph.py b/feynsum-sml/src/common/depgraph.py new file mode 100644 index 0000000..3e9a454 --- /dev/null +++ b/feynsum-sml/src/common/depgraph.py @@ -0,0 +1,318 @@ +from numpy.typing import NDArray +import qiskit as Q +import rustworkx as rx +import numpy as np +import sys +import json +from typing import Any, Iterator, Optional +import time + +#import scipy +#def cluster(dag): +# adjmtx = rx.adjacency_matrix(dag) +# adjid = np.identity(adjmtx.shape[0]) +# whitened = scipy.cluster.vq.whiten(adjmtx + adjid) +# return scipy.cluster.vq.kmeans2(whitened, adjmtx.shape[0]//2, minit='points')[1] + +EPS = 1e-10 + +def circuit_to_dag(circ: Q.circuit.QuantumCircuit): + return Q.converters.circuit_to_dag(circ) + +def circuit_to_dep_graph(circ: Q.circuit.QuantumCircuit) -> Q.dagcircuit.DAGCircuit: + dag = circuit_to_dag(circ) + dag._multi_graph = dag_to_dep_graph(dag._multi_graph, circuit_find_qubit_dict(circ)) + return dag + +def circuit_find_qubit_dict(circ: Q.circuit.QuantumCircuit) -> dict[Q.circuit.Bit, int]: + return {x: circ.find_bit(x)[0] for x in circ.qubits} + +def indirect_reachability(dag): + "Returns a mapping from nodes to sets of nodes reachable from them, via 2 or more edges" + reaches = dict() + def visit(node): + nodes = {node} + reaches[node] = nodes + outs = dag.out_edges(node) + for _, to, _ in outs: + if to not in reaches: + visit(to) + nodes |= reaches[to] | ({i for _, i, _ in dag.out_edges(to)} - set(outs)) + for node in dag.node_indices(): + visit(node) + return reaches + +def trim_edges(dag): + todo = set(dag.edge_list()) + indirect_reaches = indirect_reachability(dag) + + def get_id(node): + if '_node_id' in dir(node): + return node._node_id + elif 'node_id' in dir(node): + return node.node_id + + while todo: + edge = todo.pop() + gai, gbi = edge + if any(gbi in indirect_reaches[get_id(gci)] and gbi != get_id(gci) for gci in dag.successors(gai)): + dag.remove_edge(gai, gbi) + +def is_op_node(node: Q.dagcircuit.DAGNode) -> bool: + return isinstance(node, Q.dagcircuit.DAGOpNode) + +def dag_to_dep_graph(dag: Q.dagcircuit.DAGCircuit, + find_qubit: dict[Q.circuit.Bit, int], + trim=False) -> rx.PyDAG: + assert isinstance(dag, rx.PyDAG), \ + f"dag_to_dep_graph expected a DAG of type rustworkx.PyDAG, but got {type(dag)} instead" + graph = rx.PyDAG(multigraph=False) + graph.add_nodes_from(dag.nodes()) + reaches = {i: {i} for i in dag.node_indices()} + todo = set() + for gai, gbi, qi in dag.weighted_edge_list(): + graph.add_edge(gai, gbi, None) + if is_op_node(graph.get_node_data(gai)) and is_op_node(graph.get_node_data(gbi)): + todo.add((gai, gbi)) + visited = set() + + def maybe_push_edge(ifm, ito, gfm, gto): + """ + Pushes an edge to `todo` if we need to visit it + Args: + ifm = index of source gate + ito = index of target gate + gfm = source gate + gto = target gate + """ + + opfm = is_op_node(gfm) + opto = is_op_node(gto) + + if opfm and opto: + shared = any(find_qubit[q1] == find_qubit[q2] + for q1 in gfm.qargs for q2 in gto.qargs) + if shared and (ifm, ito) not in todo \ + and (ifm, ito) not in visited \ + and (not trim or ito not in reaches[ifm]): + graph.add_edge(ifm, ito, None) + todo.add((ifm, ito)) + elif (opfm and any(find_qubit[q] == find_qubit[gto.wire] for q in gfm.qargs)) or \ + (opto and any(find_qubit[q] == find_qubit[gfm.wire] for q in gto.qargs)) or \ + (not opfm and not opto and gfm.wire == gto.wire): + graph.add_edge(ifm, ito, None) + + while todo: + gai, gbi = todo.pop() + visited.add((gai, gbi)) + ga = graph.get_node_data(gai) + gb = graph.get_node_data(gbi) + if commutes(ga, gb, find_qubit): + graph.remove_edge(gai, gbi) + # Propagate in-edges + for pidx, _, _ in graph.in_edges(gai): + maybe_push_edge(pidx, gbi, graph.get_node_data(pidx), gb) + # Propagate out-edges + for _, cidx, _ in graph.out_edges(gbi): + maybe_push_edge(gai, cidx, ga, graph.get_node_data(cidx)) + elif trim: + reaches[gai] |= reaches[gbi] + return graph + +def as_bits(n: int, nbits: Optional[int] = None) -> list[int]: + "Converts an integer into its bit representation" + return [int(bool(n & (1 << (i - 1)))) for i in range(nbits or n.bit_length(), 0, -1)] + +def rearrange_gate(mat: NDArray, old: list[int], new: list[int]) -> NDArray: + """ + Rearranges a gate's unitary matrix for application to a new set of qubits. + Assumes set(old) = set(new). + """ + old = list(reversed(old)) + new = list(reversed(new)) + qubits = len(new) + old_idx = {o:i for i, o in enumerate(old)} + new_idx = {n:i for i, n in enumerate(new)} + old2new_idx = {o:new_idx[o] for o in old} + new2old_idx = {n:old_idx[n] for n in new} + + size = 1 << qubits + I = np.identity(size, dtype=np.dtype('int64')) + + bit_map = np.array([I[new2old_idx[n]] for n in new]) + bit_mat = np.array([as_bits(i, qubits) for i in range(size)]) + mapped = bit_mat @ bit_map + for i in range(qubits - 1): + mapped[:, i] <<= qubits - i - 1 + reordered = mapped.sum(axis=1) + + mat1 = np.ndarray(mat.shape, dtype=mat.dtype) + mat2 = np.ndarray(mat.shape, dtype=mat.dtype) + + for i in range(size): + mat1[i, :] = mat[reordered[i], :] + for i in range(size): + mat2[:, i] = mat1[:, reordered[i]] + return mat2 + +def align_gates(ga, gb, find_qubit: dict[Q.circuit.Bit, int]): + """ + Aligns gates along the same qubits, returning a tuple of their new unitaries + """ + ma = ga.op.to_matrix() + mb = gb.op.to_matrix() + qas = [find_qubit[qi] for qi in ga.qargs] + qbs = [find_qubit[qi] for qi in gb.qargs] + qas_ins = list(set(qbs) - set(qas)) + qbs_ins = list(set(qas) - set(qbs)) + qas2 = qas + qas_ins + qbs2 = qbs + qbs_ins + # Add additional qubits from gb + ma2 = np.kron(np.identity(1 << len(qas_ins), dtype=ma.dtype), ma) + # Add additional qubits from ga + mb2 = np.kron(np.identity(1 << len(qbs_ins), dtype=mb.dtype), mb) + # Rearrange mb2 to match ma2's qubit order + mb3 = rearrange_gate(mb2, qbs2, qas2) + return (ma2, mb3) + +def hardcoded_commutes(ga: Q.circuit.Gate, gb: Q.circuit.Gate, find_qubit: dict[Q.circuit.Bit, int]) -> int: + "Returns 0 if N/A, 1 if commute, 2 if dependent" + NA, COMMUTE, DEPENDENT = 0, 1, 2 + gan, gbn = ga.op.name, gb.op.name + qas, qbs = [find_qubit[q] for q in ga.qargs], [find_qubit[q] for q in gb.qargs] + + def comm(gan, gbn, qas, qbs): + if gan == 'cx' and gbn in ['x', 'sx', 'rx']: + return qas[1] == qbs[0] + elif gan == 'cx' and gbn in ['z', 'rz']: + return qas[0] == qbs[0] + elif gan == 'cx' and gbn == 'cx': + return qas[0] != qbs[1] and qas[1] != qbs[0] + elif gan == 'ccx' and gbn == 'cx': + return qas[2] != qbs[0] and qbs[1] not in qas[:2] + elif gan == 'ccx' and gbn == 'ccx': + return qas[2] not in qbs[:2] and qbs[2] not in qas[:2] + elif gan == 'ccx' and gbn in ['x', 'sx', 'rx']: + return qbs[0] not in qas[:2] + elif gan == 'ccx' and gbn in ['z', 'rz']: + return qbs[0] in qas[:2] + + if gan == gbn and qas == qbs: + return COMMUTE + else: + c = comm(gan, gbn, qas, qbs) + return NA if c is None else 2 - c + +def commutes(ga: Q.circuit.Gate, gb: Q.circuit.Gate, + find_qubit: dict[Q.circuit.Bit, int]) -> bool: + hc = hardcoded_commutes(ga, gb, find_qubit) or hardcoded_commutes(gb, ga, find_qubit) + if hc == 0: # we don't have this case in the hardcoded rules + ma, mb = align_gates(ga, gb, find_qubit) + return (np.abs((ma @ mb) - (mb @ ma)) < EPS).all() + return bool(hc % 2) # we've have this case in the hardcoded rules + +def read_qasm(fh) -> Q.circuit.QuantumCircuit: + #return Q.QuantumCircuit.from_qasm_file(fh) + acc = [] + for line in fh: + if not line.startswith('//'): + acc.append(line) + return Q.QuantumCircuit.from_qasm_str(''.join(acc)) + +def write_dep_graph(num_qubits: int, graph, find_qubit: dict[Q.circuit.Bit, int], fh): + nodemap = dict() + numnodes = 0 + nodes = [] + edges = [] + for node in graph.node_indices(): + data = graph.get_node_data(node) + if is_op_node(data): + nodemap[node] = numnodes + numnodes += 1 + nodes.append(data) + for fm, to in graph.edge_list(): + if is_op_node(graph.get_node_data(fm)) and is_op_node(graph.get_node_data(to)): + edges.append((nodemap[fm], nodemap[to])) + + # TODO: handle cargs + node_data = [] + for node in nodes: + if node.op.params: + node_data.append({ + 'name': node.op.name, + 'params': node.op.params, + 'qargs': [find_qubit[qarg] for qarg in node.qargs], + #'cargs': [f'{}' for carg in node.cargs] + }) + else: + node_data.append({ + 'name': node.op.name, + 'qargs': [find_qubit[qarg] for qarg in node.qargs], + #'cargs': [f'{}' for carg in node.cargs] + }) + data = {'qubits': num_qubits, 'nodes': node_data, 'edges': edges} + json.dump(data, fh) + + +def usage(argv): + return f""" +Usage: + {argv[0]} [input.qasm] [output.json] +If either arg is omitted, read from stdin/stdout +""" + +def main(argv): + if len(argv) == 2: + ifh = open(argv[1]) + ofh = sys.stdout + elif len(argv) == 3: + ifh = open(argv[1]) + ofh = open(argv[2], 'w') + else: + print(usage(argv).strip(), file=sys.stderr) + return 1 + circuit = read_qasm(ifh) + dag = circuit_to_dag(circuit) + find_qubit = circuit_find_qubit_dict(circuit) + dg = dag_to_dep_graph(dag._multi_graph, find_qubit, trim=True) + trim_edges(dg) + write_dep_graph(circuit.num_qubits, dg, find_qubit, ofh) + +# def glen(generator: Iterator[Any]) -> int: +# return sum(1 for _ in generator) + +# def op_edges(g: Q.dagcircuit.DAGCircuit) -> int: +# return glen(filter(lambda e: isinstance(e[0], Q.dagcircuit.DAGOpNode) and isinstance(e[1], Q.dagcircuit.DAGOpNode), g.edges())) + +# def main(argv): +# if len(argv) == 2: +# circuit = read_qasm(argv[1]) +# my_dep = circuit_to_dag(circuit) +# qk_dag = circuit_to_dag(circuit) +# orig_depth = my_dep.depth() +# orig_edges = glen(my_dep.edges()) +# print(f"Original DAG: {orig_depth - 1} depth, {orig_edges} edges") + +# TRIM = False +# my_start = time.time() +# my_dep._multi_graph = dag_to_dep_graph(my_dep._multi_graph, circuit_find_qubit_dict(circuit), TRIM) +# my_end = time.time() + +# if TRIM: +# trim_start = time.time() +# trim_edges(my_dep._multi_graph) +# trim_end = time.time() +# print(f"My dependency graph: {my_dep.depth() - 1} depth, {op_edges(my_dep)} edges, {my_end - my_start:0.4f} sec + trimming for {trim_end - trim_start:0.4f} sec") +# else: +# print(f"My dependency graph: {my_dep.depth() - 1} depth, {op_edges(my_dep)} edges, {my_end - my_start:0.4f} sec") +# qk_start = time.time() +# qk_dep = Q.converters.dag_to_dagdependency(qk_dag) +# qk_end = time.time() + +# print(f"Qiskit dependency graph: {qk_dep.depth()} depth, {glen(qk_dep.get_all_edges())} edges, {qk_end - qk_start:0.4f} sec") +# else: +# print("Pass a .qasm file as arg", file=sys.stderr) + +if __name__ == '__main__': + exitcode = main(sys.argv) + exit(exitcode) From d59493ddf5d3764dfa615f79a9e18940cc9a2474 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Wed, 8 Nov 2023 14:51:12 -0500 Subject: [PATCH 03/15] Bug gix --- .../DepGraphSchedulerGreedyFinishQubit.sml | 39 +++++++++++-------- feynsum-sml/src/common/DepGraphUtil.sml | 13 ++++--- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml index 736837f..ef697f5 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml @@ -27,19 +27,19 @@ struct fun popOIS ({S = S, A = A}: OrderedIntSet) = case !A of nil => NONE - | x :: xs => (A := xs; S := IntBinarySet.delete (!S, x); SOME x) + | x :: xs => (print ("Pop " ^ Int.toString x ^ "\n"); A := xs; S := IntBinarySet.delete (!S, x); SOME x) fun pushFrontOIS ({S = S, A = A}: OrderedIntSet) (x: int) = if IntBinarySet.member (!S, x) then () else - (A := x :: !A; S := IntBinarySet.add (!S, x)) + (print ("Push front " ^ Int.toString x ^ "\n"); A := x :: !A; S := IntBinarySet.add (!S, x)) fun pushBackOIS ({S = S, A = A}: OrderedIntSet) (x: int) = if IntBinarySet.member (!S, x) then () else - (A := (!A) @ (x :: nil); S := IntBinarySet.add (!S, x)) + (print ("Push back " ^ Int.toString x ^ "\n"); A := (!A) @ (x :: nil); S := IntBinarySet.add (!S, x)) fun newOIS () = {S = ref IntBinarySet.empty, A = ref nil} @@ -49,31 +49,36 @@ struct fun revTopologicalSort (dg: DepGraph.t) (tr: TraversalOrder) = let val N = Seq.length (#gates dg) - val L = ref nil - val ois = newOIS () - (*fun outdegree i = Seq.length (Seq.nth (#deps dg) i)*) val ind = Array.tabulate (N, Seq.nth (#indegree dg)) - fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d end + fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d - 1 end + + (* + val ois = newOIS () val push = case tr of BFS => pushBackOIS ois | DFS => pushFrontOIS ois + fun pop () = popOIS ois + *) + val queue = ref nil + val push = case tr of BFS => (fn i => queue := (!queue) @ (i :: nil)) | DFS => (fn i => queue := i :: !queue) + fun pop () = case !queue of nil => NONE | x :: xs => (queue := xs; SOME x) val _ = intDo N (fn i => if Seq.nth (#indegree dg) i = 0 then push i else ()) - fun loop () = - case popOIS ois of - NONE => () - | SOME n => - (L := n :: !L; - intDo (Seq.length (Seq.nth (#deps dg) n)) - (fn m => if decInd m = 0 then push m else ()); - loop ()) + fun loop L = + (print ("loop " ^ Int.toString (List.length L) ^ "\n"); + case pop () of + NONE => L + | SOME n => + (Seq.map (fn m => if decInd m = 0 then push m else ()) + (Seq.nth (#deps dg) n); + loop (n :: L))) in - loop (); !L + loop nil end fun topologicalSort (dg: DepGraph.t) (tr: TraversalOrder) = List.rev (revTopologicalSort dg tr) fun ordered (dg: DepGraph.t) = - topologicalSort (DepGraphUtil.redirect dg) DFS + topologicalSort (DepGraphUtil.transpose dg) DFS val gateDepths: int array option ref = ref NONE diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DepGraphUtil.sml index da62d92..7d181c5 100644 --- a/feynsum-sml/src/common/DepGraphUtil.sml +++ b/feynsum-sml/src/common/DepGraphUtil.sml @@ -13,7 +13,7 @@ sig val initState: dep_graph -> state (* Switches edge directions *) - val redirect: dep_graph -> dep_graph + val transpose: dep_graph -> dep_graph (* val gateIsBranching: dep_graph -> (gate_idx -> bool) *) end = @@ -23,14 +23,15 @@ struct type dep_graph = DepGraph.t - fun redirect ({gates = gs, deps = ds, indegree = is, numQubits = qs}: dep_graph) = - let val ds2 = Array.array (qs, nil) + fun transpose ({gates = gs, deps = ds, indegree = is, numQubits = qs}: dep_graph) = + let val N = Seq.length gs + val ds2 = Array.array (N, nil) fun apply i = Seq.map (fn j => Array.update (ds2, j, i :: Array.sub (ds2, j))) (Seq.nth ds i) - val _ = Seq.tabulate apply qs + val _ = Seq.tabulate apply N in {gates = gs, - deps = Seq.tabulate (fn i => Seq.rev (Seq.fromList (Array.sub (ds2, i)))) qs, - indegree = Seq.tabulate (fn i => Seq.length (Seq.nth ds i)) qs, + deps = Seq.tabulate (fn i => Seq.rev (Seq.fromList (Array.sub (ds2, i)))) N, + indegree = Seq.map Seq.length ds, numQubits = qs} end From 34f8e960577d4ece695656d5285e6de38ac1c572 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Mon, 13 Nov 2023 11:04:34 -0500 Subject: [PATCH 04/15] Dep graph schedulers --- feynsum-sml/run.sh | 10 +- feynsum-sml/src/common/DepGraphScheduler.sml | 2 + .../DepGraphSchedulerGreedyFinishQubit.sml | 239 ++++++++++++------ .../DepGraphSchedulerGreedyNonBranching.sml | 30 ++- feynsum-sml/src/common/DepGraphUtil.sml | 75 ++++++ feynsum-sml/src/common/GateSchedulerOrder.sml | 17 +- feynsum-sml/src/common/depgraph.py | 17 +- feynsum-sml/src/common/sources.mlb | 2 +- feynsum-sml/src/main.sml | 141 +++++++---- 9 files changed, 389 insertions(+), 144 deletions(-) diff --git a/feynsum-sml/run.sh b/feynsum-sml/run.sh index 8fa783e..6c19aa3 100755 --- a/feynsum-sml/run.sh +++ b/feynsum-sml/run.sh @@ -1,5 +1,5 @@ -./main.mpl @mpl procs 72 set-affinity megablock-threshold 14 cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 max-cc-depth 1 -- -scheduler gfq -input $@ -# cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 - - -# ./all-main.mpl @mpl procs 72 set-affinity megablock-threshold 14 cc-threshold-ratio 1.25 max-cc-depth 1 -- -sim query-bfs -input $@ +./main.mpl @mpl procs 20 set-affinity megablock-threshold 14 cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 max-cc-depth 1 -- -scheduler $1 -input $2 +# -scheduler-max-branching-stride 1 +# --scheduler-disable-fusion +# -dense-thresh 0.75 +# -pull-thresh 0.01 diff --git a/feynsum-sml/src/common/DepGraphScheduler.sml b/feynsum-sml/src/common/DepGraphScheduler.sml index 20718bb..1e0f1de 100644 --- a/feynsum-sml/src/common/DepGraphScheduler.sml +++ b/feynsum-sml/src/common/DepGraphScheduler.sml @@ -9,5 +9,7 @@ struct } (* From a frontier, select which gate to apply next *) + (* args visit gates, update frontier break fusion initial frontier gate batches *) + (*type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t*) type t = args -> (gate_idx Seq.t -> gate_idx) end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml index ef697f5..0aaba0a 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml @@ -2,7 +2,11 @@ functor DepGraphSchedulerGreedyFinishQubit (val maxBranchingStride: int val disableFusion: bool): sig val scheduler: DepGraphScheduler.t - val ordered: DepGraph.t -> int list + val scheduler2: DepGraphScheduler.t + val scheduler3: DepGraphScheduler.t + val scheduler4: DepGraphScheduler.t + val scheduler5: DepGraphScheduler.t + val schedulerRandom: int -> DepGraphScheduler.t end = struct @@ -13,102 +17,34 @@ struct , gateIsBranching: gate_idx -> bool } - fun intDo (n : int) (f: int -> 'b) = - let fun next i = if i = n then () else (f i; next (i + 1)) - in - next 0 - end - - type OrderedIntSet = { - S: IntBinarySet.set ref, - A: int list ref - } - - fun popOIS ({S = S, A = A}: OrderedIntSet) = - case !A of - nil => NONE - | x :: xs => (print ("Pop " ^ Int.toString x ^ "\n"); A := xs; S := IntBinarySet.delete (!S, x); SOME x) + fun DFS ((new, old) : (int list * int list)) = new @ old + fun BFS ((new, old) : (int list * int list)) = old @ new - fun pushFrontOIS ({S = S, A = A}: OrderedIntSet) (x: int) = - if IntBinarySet.member (!S, x) then - () - else - (print ("Push front " ^ Int.toString x ^ "\n"); A := x :: !A; S := IntBinarySet.add (!S, x)) - - fun pushBackOIS ({S = S, A = A}: OrderedIntSet) (x: int) = - if IntBinarySet.member (!S, x) then - () - else - (print ("Push back " ^ Int.toString x ^ "\n"); A := (!A) @ (x :: nil); S := IntBinarySet.add (!S, x)) - - fun newOIS () = {S = ref IntBinarySet.empty, A = ref nil} - - fun emptyOIS ({S = S, A = A}: OrderedIntSet) = List.null (!A) - - datatype TraversalOrder = BFS | DFS - - fun revTopologicalSort (dg: DepGraph.t) (tr: TraversalOrder) = + fun revTopologicalSort (dg: DepGraph.t) (push: (int list * int list) -> int list) = let val N = Seq.length (#gates dg) val ind = Array.tabulate (N, Seq.nth (#indegree dg)) fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d - 1 end - - (* - val ois = newOIS () - val push = case tr of BFS => pushBackOIS ois | DFS => pushFrontOIS ois - fun pop () = popOIS ois - *) val queue = ref nil - val push = case tr of BFS => (fn i => queue := (!queue) @ (i :: nil)) | DFS => (fn i => queue := i :: !queue) + (*val push = case tr of BFS => (fn xs => queue := (!queue) @ xs) | DFS => (fn xs => queue := xs @ (!queue))*) fun pop () = case !queue of nil => NONE | x :: xs => (queue := xs; SOME x) - val _ = intDo N (fn i => if Seq.nth (#indegree dg) i = 0 then - push i else ()) + val _ = queue := push (List.filter (fn i => Seq.nth (#indegree dg) i = 0) (List.tabulate (N, fn i => i)), !queue) fun loop L = - (print ("loop " ^ Int.toString (List.length L) ^ "\n"); case pop () of NONE => L | SOME n => - (Seq.map (fn m => if decInd m = 0 then push m else ()) - (Seq.nth (#deps dg) n); - loop (n :: L))) + (let val ndeps = Seq.nth (#deps dg) n in + queue := push (List.filter (fn m => decInd m = 0) (List.tabulate (Seq.length ndeps, Seq.nth ndeps)), !queue) + end; + loop (n :: L)) in loop nil end - fun topologicalSort (dg: DepGraph.t) (tr: TraversalOrder) = - List.rev (revTopologicalSort dg tr) - - fun ordered (dg: DepGraph.t) = - topologicalSort (DepGraphUtil.transpose dg) DFS + fun topologicalSort (dg: DepGraph.t) (push: (int list * int list) -> int list) = + List.rev (revTopologicalSort dg push) val gateDepths: int array option ref = ref NONE - fun popIntBinarySet (S: IntBinarySet.set ref) = case IntBinarySet.find (fn _ => true) (!S) of - NONE => raise Fail "popIntBinarySet called on empty set" - | SOME elt => (S := IntBinarySet.delete (!S, elt); elt) - - fun revTopologicalSortOld (dg: DepGraph.t) = - let val N = Seq.length (#gates dg) - val L = ref nil - val S = ref IntBinarySet.empty - val A = ref nil - val ind = Array.tabulate (N, Seq.nth (#indegree dg)) - fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d end - val _ = intDo N (fn i => if Seq.nth (#indegree dg) i = 0 then - S := IntBinarySet.add (!S, i) else ()) - fun loop () = - if IntBinarySet.isEmpty (!S) then - () - else - let val n = popIntBinarySet S in - L := n :: !L; - intDo (Seq.length (Seq.nth (#deps dg) n)) - (fn m => if decInd m = 0 then S := IntBinarySet.add (!S, m) else ()); - loop () - end - in - loop (); !L - end - fun computeGateDepths (dg: DepGraph.t) = let val N = Seq.length (#gates dg) val depths = Array.array (N, ~1) @@ -121,6 +57,40 @@ struct List.foldl (fn (i, ()) => gdep i) () (revTopologicalSort dg DFS); depths end + fun sortList (lt: 'a * 'a -> bool) (xs: 'a list) = + let fun insert (x, xs) = + case xs of + nil => x :: nil + | x' :: xs => if lt (x, x') then x :: x' :: xs else x' :: insert (x, xs) + in + List.foldr insert nil xs + end + + (* Choose in reverse topological order, sorted by easiest qubit to finish *) + fun scheduler3 ({depGraph = dg, gateIsBranching = gib}: args) = + let val dgt = DepGraphUtil.transpose dg + val depths = computeGateDepths dg + fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso not (gib a) andalso gib b) + fun push (new, old) = DFS (sortList lt new, old) + val xs = revTopologicalSort dgt push + val N = Seq.length (#gates dg) + val ord = Array.array (N, ~1) + fun writeOrd i xs = case xs of nil => () | x :: xs' => (Array.update (ord, x, i); writeOrd (i + 1) xs') + val _ = writeOrd 0 xs + fun pickEarliestOrd best_idx best_ord i gates = + if i = Seq.length gates then + best_idx + else let val cur_idx = Seq.nth gates i + val cur_ord = Array.sub (ord, cur_idx) + in + if cur_ord < best_ord then pickEarliestOrd cur_idx cur_ord (i + 1) gates else pickEarliestOrd best_idx best_ord (i + 1) gates + end + in + fn gates => let val g0 = Seq.nth gates 0 in + pickEarliestOrd g0 (Array.sub (ord, g0)) 1 gates + end + end + fun gateDepth i dg = case !gateDepths of NONE => let val gd = computeGateDepths dg in @@ -142,10 +112,119 @@ struct pickLeastDepth best_idx best_depth (i + 1) gates dg end + fun pickGreatestDepth best_idx best_depth i gates dg = + if i = Seq.length gates then + best_idx + else + let val cur_idx = Seq.nth gates i + val cur_depth = gateDepth cur_idx dg in + if cur_depth > best_depth then + pickGreatestDepth cur_idx cur_depth (i + 1) gates dg + else + pickGreatestDepth best_idx best_depth (i + 1) gates dg + end + (* From a frontier, select which gate to apply next *) fun scheduler ({depGraph = dg, ...} : args) = (gateDepths := SOME (computeGateDepths dg); fn gates => let val g0 = Seq.nth gates 0 in pickLeastDepth g0 (gateDepth g0 dg) 1 gates dg end) + + (* Select gate with greatest number of descendants *) + fun scheduler4 ({depGraph = dg, ...} : args) = + (gateDepths := SOME (computeGateDepths dg); + fn gates => let val g0 = Seq.nth gates 0 in + pickGreatestDepth g0 (gateDepth g0 dg) 1 gates dg + end) + + structure G = Gate (structure B = BasisIdxUnlimited + structure C = Complex64) + + (* Hybrid of scheduler2 (avoid branching on unbranched qubits) and also scheduler3 (choose in reverse topological order, sorted by easiest qubit to finish) *) + fun scheduler5 ({depGraph = dg, gateIsBranching = gib} : args) = + let val gates = Seq.map G.fromGateDefn (#gates dg) + fun touches i = #touches (Seq.nth gates i) + fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 + + val dgt = DepGraphUtil.transpose dg + val depths = computeGateDepths dg + fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso branches a < branches b) + fun push (new, old) = DFS (sortList lt new, old) + val xs = revTopologicalSort dgt push + val N = Seq.length (#gates dg) + val ord = Array.array (N, ~1) + fun writeOrd i xs = case xs of nil => () | x :: xs' => (Array.update (ord, x, i); writeOrd (i + 1) xs') + val _ = writeOrd 0 xs + val touched = Array.array (#numQubits dg, false) + fun touch i = Array.update (touched, i, true) + fun touchAll gidx = let val ts = touches gidx in List.tabulate (Seq.length ts, fn i => touch (Seq.nth ts i)); () end + fun newTouches i = + Seq.length (Seq.filter (fn j => not (Array.sub (touched, j))) (touches i)) + fun pickLeastNewTouches best_idx best_newTouches best_ord i gates = + if i = Seq.length gates then + ((* print ("Picked " ^ Int.toString best_idx ^ ", new touches " ^ Int.toString best_newTouches ^ "\n"); *) + best_idx) + else + let val cur_idx = Seq.nth gates i + val cur_newTouches = newTouches cur_idx + val cur_ord = Array.sub (ord, cur_idx) + in + if cur_newTouches < best_newTouches + orelse (cur_newTouches = best_newTouches + andalso cur_ord < best_ord) then + pickLeastNewTouches cur_idx cur_newTouches cur_ord (i + 1) gates + else + pickLeastNewTouches best_idx best_newTouches best_ord (i + 1) gates + end + in + fn gates => let val g0 = Seq.nth gates 0 + val next = pickLeastNewTouches g0 (newTouches g0) (Array.sub (ord, g0)) 1 gates + in + touchAll next; next + end + end + + (* Avoids branching on unbranched qubits *) + fun scheduler2 ({depGraph = dg, gateIsBranching = gib} : args) = + let val touched = Array.array (#numQubits dg, false) + val gates = Seq.map G.fromGateDefn (#gates dg) + fun touches i = #touches (Seq.nth gates i) + fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 + fun touch i = Array.update (touched, i, true) + fun touchAll gidx = let val ts = touches gidx in List.tabulate (Seq.length ts, fn i => touch (Seq.nth ts i)); () end + fun newTouches i = + Seq.length (Seq.filter (fn j => not (Array.sub (touched, j))) (touches i)) + fun pickLeastNewTouches best_idx best_newTouches i gates = + if i = Seq.length gates then + ((* print ("Picked " ^ Int.toString best_idx ^ ", new touches " ^ Int.toString best_newTouches ^ "\n"); *) + best_idx) + else + let val cur_idx = Seq.nth gates i + val cur_newTouches = newTouches cur_idx + in + if cur_newTouches < best_newTouches + orelse (cur_newTouches = best_newTouches + andalso branches cur_idx < branches best_idx) then + pickLeastNewTouches cur_idx cur_newTouches (i + 1) gates + else + pickLeastNewTouches best_idx best_newTouches (i + 1) gates + end + in + fn gates => let val g0 = Seq.nth gates 0 + val next = pickLeastNewTouches g0 (newTouches g0) 1 gates + in + touchAll next; next + end + end + + val seed = Random.rand (50, 14125) + + fun schedulerRandom seedNum ({depGraph = dg, gateIsBranching = gib} : args) = + (*let val seed = Random.rand (seedNum, seedNum * seedNum) in*) + fn gates => let val r = Random.randRange (0, Seq.length gates - 1) seed in + (print ("Randomly chose " ^ Int.toString r ^ " from range [0, " ^ Int.toString (Seq.length gates) ^ ")\n"); + Seq.nth gates r) + end + (* end *) end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml index 258e3f0..bbbf810 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml @@ -12,15 +12,33 @@ struct , gateIsBranching: gate_idx -> bool } - fun pickNonBranching i branching gates = - if i < Seq.length gates then + fun pickNonBranching i branching ftr = + if i < Seq.length ftr then (if branching i then - pickNonBranching (i + 1) branching gates + pickNonBranching (i + 1) branching ftr else - Seq.nth gates i) + Seq.nth ftr i) else - Seq.nth gates 0 (* pick a branching gate *) + Seq.nth ftr 0 (* From a frontier, select which gate to apply next *) - fun scheduler ({gateIsBranching = gib, ...} : args) gates = pickNonBranching 0 gib gates + fun scheduler ({gateIsBranching = branching, ...} : args) ftr = + pickNonBranching 0 branching ftr + (*fun scheduler ({gateIsBranching = branching, ...} : args) updateFrontier breakFusion initialFrontier = + let fun updateAndBreak i = let val ftr = updateFrontier i in breakFusion (); ftr end + fun pickNonBranching i ftr = + if i < Seq.length gates then + (if branching i then + pickNonBranching (i + 1) gates + else + Seq.nth gates i) + else + (breakFusion (); Seq.nth gates 0) + fun sched ftr = if Seq.null ftr then + () + else + scheduler (updateFrontier (pickNonBranching 0 gates)) + in + sched initialFrontier + end*) end diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DepGraphUtil.sml index 7d181c5..26a21b0 100644 --- a/feynsum-sml/src/common/DepGraphUtil.sml +++ b/feynsum-sml/src/common/DepGraphUtil.sml @@ -15,6 +15,11 @@ sig (* Switches edge directions *) val transpose: dep_graph -> dep_graph + val scheduleWithOracle: dep_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> int -> gate_idx Seq.t Seq.t + + val scheduleCost: gate_idx Seq.t Seq.t -> (gate_idx -> bool) -> real + val chooseSchedule: gate_idx Seq.t Seq.t Seq.t -> (gate_idx -> bool) -> gate_idx Seq.t Seq.t + (* val gateIsBranching: dep_graph -> (gate_idx -> bool) *) end = struct @@ -66,4 +71,74 @@ struct in { visited = vis, indegree = deg } end + + fun scheduleWithOracle (graph: dep_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (maxBranchingStride: int) = + let val st = initState graph + fun findNonBranching (i: int) (xs: gate_idx Seq.t) = + if i = Seq.length xs then + NONE + else if branching (Seq.nth xs i) then + findNonBranching (i + 1) xs + else + SOME (Seq.nth xs i) + fun loadNonBranching (acc: gate_idx list) = + case findNonBranching 0 (frontier st) of + NONE => acc + | SOME i => (visit graph i st; loadNonBranching (i :: acc)) + fun loadNext numBranchingSoFar thisStep (acc: gate_idx list list) = + let val ftr = frontier st in + if Seq.length ftr = 0 then + Seq.map Seq.fromList (Seq.rev (Seq.fromList (if List.null thisStep then acc else List.rev thisStep :: acc))) + else + (let val next = choose ftr in + visit graph next st; + if numBranchingSoFar + 1 >= maxBranchingStride then + loadNext 0 nil (List.rev (loadNonBranching (next :: thisStep)) :: acc) + else + loadNext (numBranchingSoFar + 1) (loadNonBranching (next :: thisStep)) acc + end) + end + in + loadNext 0 (loadNonBranching nil) nil + end + + (*fun scheduleCost2 (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = + let val gates = Seq.flatten order + val N = Seq.length gates + fun iter i cost branchedQubits = + if i = N then + cost + else + iter (i + 1) (1.0 + cost + (if branching (Seq.nth gates i) then cost else 0.0)) + in + iter 0 0.0 (Vector.tabulate (N, fn _ => false)) + end*) + + + fun scheduleCost (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = + let val gates = Seq.flatten order + val N = Seq.length gates + fun iter i cost = + if i = N then + cost + else + iter (i + 1) (1.0 + (if branching (Seq.nth gates i) then cost * 2.0 else cost)) + in + iter 0 0.0 + end + + fun chooseSchedule (orders: gate_idx Seq.t Seq.t Seq.t) (branching: gate_idx -> bool) = + let fun iter i best_i best_cost = + if i = Seq.length orders then + Seq.nth orders best_i + else + let val cost = scheduleCost (Seq.nth orders i) branching in + if cost < best_cost then + (print ("Reduced cost from " ^ Real.toString best_cost ^ " to " ^ Real.toString cost ^ "\n"); iter (i + 1) i cost) + else + (print ("Maintained cost " ^ Real.toString best_cost ^ " over " ^ Real.toString cost ^ "\n"); iter (i + 1) best_i best_cost) + end + in + iter 1 0 (scheduleCost (Seq.nth orders 0) branching) + end end diff --git a/feynsum-sml/src/common/GateSchedulerOrder.sml b/feynsum-sml/src/common/GateSchedulerOrder.sml index 1f32def..8235706 100644 --- a/feynsum-sml/src/common/GateSchedulerOrder.sml +++ b/feynsum-sml/src/common/GateSchedulerOrder.sml @@ -1,20 +1,31 @@ structure GateSchedulerOrder: sig type gate_idx = int - val mkScheduler: gate_idx Seq.t -> GateScheduler.t + val mkScheduler: gate_idx Seq.t Seq.t -> GateScheduler.t end = struct type qubit_idx = int type gate_idx = int - fun mkScheduler (order: gate_idx Seq.t) args = + fun mkScheduler (order: gate_idx Seq.t Seq.t) (args: GateScheduler.args) = let val i = ref 0 val N = Seq.length order in + print ("Schedule cost: " ^ Real.toString (DepGraphUtil.scheduleCost order (#gateIsBranching args)) ^ "\n"); fn () => if !i >= N then Seq.empty () else - let val gi = !i in i := gi + 1; Seq.singleton gi end + (i := !i + 1; Seq.nth order (!i - 1)) end + + (*fun mkSchedulerFusion (order: gate_idx Seq.t Seq.t) args = + let val i = ref 0 + val N = Seq.length order + in + fn () => if !i >= N then + Seq.empty () + else + let val gi = !i in i := gi + 1; Seq.singleton gi end + end*) end diff --git a/feynsum-sml/src/common/depgraph.py b/feynsum-sml/src/common/depgraph.py index 3e9a454..a550dbb 100644 --- a/feynsum-sml/src/common/depgraph.py +++ b/feynsum-sml/src/common/depgraph.py @@ -257,26 +257,35 @@ def write_dep_graph(num_qubits: int, graph, find_qubit: dict[Q.circuit.Bit, int] def usage(argv): return f""" Usage: - {argv[0]} [input.qasm] [output.json] + {argv[0]} [--dag] [input.qasm] [output.json] If either arg is omitted, read from stdin/stdout +If --dag, immediately output unprocessed DAG """ def main(argv): + just_dag = False if len(argv) == 2: ifh = open(argv[1]) ofh = sys.stdout elif len(argv) == 3: ifh = open(argv[1]) ofh = open(argv[2], 'w') + elif len(argv) == 4 and argv[1] == '--dag': + just_dag = True + ifh = open(argv[2]) + ofh = open(argv[3], 'w') else: print(usage(argv).strip(), file=sys.stderr) return 1 circuit = read_qasm(ifh) dag = circuit_to_dag(circuit) find_qubit = circuit_find_qubit_dict(circuit) - dg = dag_to_dep_graph(dag._multi_graph, find_qubit, trim=True) - trim_edges(dg) - write_dep_graph(circuit.num_qubits, dg, find_qubit, ofh) + if just_dag: + write_dep_graph(circuit.num_qubits, dag._multi_graph, find_qubit, ofh) + else: + dg = dag_to_dep_graph(dag._multi_graph, find_qubit, trim=True) + trim_edges(dg) + write_dep_graph(circuit.num_qubits, dg, find_qubit, ofh) # def glen(generator: Iterator[Any]) -> int: # return sum(1 for _ in generator) diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index d513dc1..4f1a503 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -5,7 +5,6 @@ local in functor RedBlackMapFn functor RedBlackSetFn - structure IntBinarySet end ../lib/github.com/mpllang/mpllib/sources.$(COMPILER).mlb @@ -22,6 +21,7 @@ local end $(SML_LIB)/smlnj-lib/JSON/json-lib.mlb + $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb HashTable.sml ApplyUntilFailure.sml diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index 963d8e9..2325009 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -96,47 +96,44 @@ val _ = *) -local - val disableFusion = CLA.parseFlag "scheduler-disable-fusion" - val maxBranchingStride = CLA.parseInt "scheduler-max-branching-stride" 2 -in - structure GNB = - GateSchedulerGreedyNonBranching - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - - structure GFQ = - GateSchedulerGreedyFinishQubit - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - - structure DGNB = - DepGraphSchedulerGreedyNonBranching - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - - structure DGFQ = - DepGraphSchedulerGreedyFinishQubit - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - - fun print_sched_info () = - let - val _ = print - ("-------------------------------------\n\ - \--- scheduler-specific args\n\ - \-------------------------------------\n") - val _ = print - ("scheduler-max-branching-stride " ^ Int.toString maxBranchingStride - ^ "\n") - val _ = print - ("scheduler-disable-fusion? " ^ (if disableFusion then "yes" else "no") - ^ "\n") - val _ = print ("-------------------------------------\n") - in - () - end -end +val disableFusion = CLA.parseFlag "scheduler-disable-fusion" +val maxBranchingStride = CLA.parseInt "scheduler-max-branching-stride" 2 +structure GNB = + GateSchedulerGreedyNonBranching + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + +structure GFQ = + GateSchedulerGreedyFinishQubit + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + +structure DGNB = + DepGraphSchedulerGreedyNonBranching + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + +structure DGFQ = + DepGraphSchedulerGreedyFinishQubit + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + +fun print_sched_info () = + let + val _ = print + ("-------------------------------------\n\ + \--- scheduler-specific args\n\ + \-------------------------------------\n") + val _ = print + ("scheduler-max-branching-stride " ^ Int.toString maxBranchingStride + ^ "\n") + val _ = print + ("scheduler-disable-fusion? " ^ (if disableFusion then "yes" else "no") + ^ "\n") + val _ = print ("-------------------------------------\n") + in + () + end type gate_idx = int type Schedule = gate_idx Seq.t @@ -169,7 +166,7 @@ fun dep_graph_to_schedule (ags: DepGraphScheduler.args) (sched: DepGraphSchedule end) end in - GateSchedulerOrder.mkScheduler (loadNext nil) + GateSchedulerOrder.mkScheduler (Seq.map Seq.singleton (loadNext nil)) end structure Gate_branching = Gate (structure B = BasisIdxUnlimited @@ -186,21 +183,61 @@ fun gate_branching ({ gates = gates, ...} : DepGraph.t) = | _ => true) end +val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride + fun greedybranching () = case optDepGraph of NONE => GateSchedulerGreedyBranching.scheduler - | SOME dg => dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DepGraphSchedulerGreedyBranching.scheduler + | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DepGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') fun greedynonbranching () = case optDepGraph of NONE => GNB.scheduler - | SOME dg => dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGNB.scheduler + | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGNB.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') fun greedyfinishqubit () = case optDepGraph of NONE => GFQ.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (Seq.fromList (DGFQ.ordered dg)) (*dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGFQ.scheduler*) + | SOME dg => + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') +fun greedyfinishqubit2 () = + case optDepGraph of + NONE => GFQ.scheduler + | SOME dg => + (*GateSchedulerOrder.mkScheduler (Seq.fromList (DGFQ.ordered { depGraph = dg, gateIsBranching = gate_branching dg }))*) + (* dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGFQ.scheduler2 *) + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler2 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + +fun greedyfinishqubit3 () = + case optDepGraph of + NONE => GFQ.scheduler + | SOME dg => + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler3 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + +fun greedyfinishqubit4 () = + case optDepGraph of + NONE => GFQ.scheduler + | SOME dg => + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler4 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + +fun greedyfinishqubit5 () = + case optDepGraph of + NONE => GFQ.scheduler + | SOME dg => + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler5 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + +fun randomsched (samples: int) = + case optDepGraph of + NONE => raise Fail "Need dep graph for random scheduler" + | SOME dg => + let val ags = { depGraph = dg, gateIsBranching = gate_branching dg } + (*val scheds = Seq.fromList (List.tabulate (samples, fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride')) *) + val scheds = Seq.tabulate (fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride') samples + val chosen = DepGraphUtil.chooseSchedule scheds (gate_branching dg) + in + GateSchedulerOrder.mkScheduler chosen + end val gateScheduler = case schedulerName of @@ -215,6 +252,20 @@ val gateScheduler = | "greedy-finish-qubit" => (print_sched_info (); greedyfinishqubit ()) | "gfq" => (print_sched_info (); greedyfinishqubit ()) + | "greedy-finish-qubit2" => (print_sched_info (); greedyfinishqubit2 ()) + | "gfq2" => (print_sched_info (); greedyfinishqubit2 ()) + + | "greedy-finish-qubit3" => (print_sched_info (); greedyfinishqubit3 ()) + | "gfq3" => (print_sched_info (); greedyfinishqubit3 ()) + + | "greedy-finish-qubit4" => (print_sched_info (); greedyfinishqubit4 ()) + | "gfq4" => (print_sched_info (); greedyfinishqubit4 ()) + + | "greedy-finish-qubit5" => (print_sched_info (); greedyfinishqubit5 ()) + | "gfq5" => (print_sched_info (); greedyfinishqubit5 ()) + + | "random" => randomsched 50 + | _ => Util.die ("unknown scheduler: " ^ schedulerName From f56db2200d8987f6e0f32a7c38e1058de453e5d9 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Tue, 21 Nov 2023 12:17:02 -0500 Subject: [PATCH 05/15] More work on dep graph schedulers --- feynsum-sml/run.sh | 2 +- feynsum-sml/src/FullSimBFS.sml | 20 +++++++++++----- .../DepGraphSchedulerGreedyFinishQubit.sml | 2 +- feynsum-sml/src/common/DepGraphUtil.sml | 24 +++++++++++++++---- feynsum-sml/src/main.sml | 16 ++++++------- 5 files changed, 43 insertions(+), 21 deletions(-) diff --git a/feynsum-sml/run.sh b/feynsum-sml/run.sh index 6c19aa3..4971d0a 100755 --- a/feynsum-sml/run.sh +++ b/feynsum-sml/run.sh @@ -1,4 +1,4 @@ -./main.mpl @mpl procs 20 set-affinity megablock-threshold 14 cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 max-cc-depth 1 -- -scheduler $1 -input $2 +./main.mpl @mpl procs 20 set-affinity megablock-threshold 14 cc-threshold-ratio 1.1 collection-threshold-ratio 2.0 max-cc-depth 1 -- -dense-thresh 1.1 --scheduler-disable-fusion -scheduler $1 -input $2 # -scheduler-max-branching-stride 1 # --scheduler-disable-fusion # -dense-thresh 0.75 diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 393a507..38ea7fd 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -134,8 +134,15 @@ struct density end - - fun loop numGateApps gatesVisitedSoFar counts prevNonZeroSize state = + fun getNumZeros state = + case state of + Expander.Sparse sst => SST.zeroSize sst + | Expander.Dense ds => raise Fail "Can't do dense stuff!" + (*DS.unsafeViewContents ds, DS.nonZeroSize ds, TODO exception*) + | Expander.DenseKnownNonZeroSize (ds, nz) => raise Fail "Can't do dense stuff!" + (*DS.unsafeViewContents ds, nz, TODO exception*) + + fun loop numGateApps gatesVisitedSoFar counts prevNonZeroSize state totalDensity = if gatesVisitedSoFar >= depth then let val (nonZeros, numNonZeros) = @@ -150,9 +157,10 @@ struct (* val _ = dumpState numQubits state *) val density = - dumpDensity (gatesVisitedSoFar, numNonZeros, NONE, NONE) + dumpDensity (gatesVisitedSoFar, numNonZeros, SOME (getNumZeros state), NONE) in print "\n"; + print ("avg-density " ^ Real.toString (totalDensity / Real.fromInt gatesVisitedSoFar) ^ "\n"); (numGateApps, nonZeros, Seq.fromRevList (numNonZeros :: counts)) end @@ -187,7 +195,7 @@ struct val throughput = millions / seconds val throughputStr = Real.fmt (StringCvt.FIX (SOME 2)) throughput val density = - dumpDensity (gatesVisitedSoFar, numNonZeros, NONE, NONE) + dumpDensity (gatesVisitedSoFar, numNonZeros, SOME (getNumZeros state), NONE) val _ = print (" hop " ^ leftPad 3 (Int.toString numGatesVisitedHere) ^ " " ^ rightPad 11 method ^ " " @@ -195,14 +203,14 @@ struct ^ throughputStr ^ "\n") in loop (numGateApps + apps) (gatesVisitedSoFar + numGatesVisitedHere) - (numNonZeros :: counts) numNonZeros result + (numNonZeros :: counts) numNonZeros result (totalDensity + (Rat.approx density) * Real.fromInt numGatesVisitedHere) end val initialState = Expander.Sparse (SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) - val (numGateApps, finalState, counts) = loop 0 0 [] 1 initialState + val (numGateApps, finalState, counts) = loop 0 0 [] 1 initialState 0.0 val _ = print ("gate app count " ^ Int.toString numGateApps ^ "\n") in {result = finalState, counts = counts} diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml index 0aaba0a..fe87cc9 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml @@ -223,7 +223,7 @@ struct fun schedulerRandom seedNum ({depGraph = dg, gateIsBranching = gib} : args) = (*let val seed = Random.rand (seedNum, seedNum * seedNum) in*) fn gates => let val r = Random.randRange (0, Seq.length gates - 1) seed in - (print ("Randomly chose " ^ Int.toString r ^ " from range [0, " ^ Int.toString (Seq.length gates) ^ ")\n"); + ((*print ("Randomly chose " ^ Int.toString r ^ " from range [0, " ^ Int.toString (Seq.length gates) ^ ")\n");*) Seq.nth gates r) end (* end *) diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DepGraphUtil.sml index 26a21b0..d0af24d 100644 --- a/feynsum-sml/src/common/DepGraphUtil.sml +++ b/feynsum-sml/src/common/DepGraphUtil.sml @@ -15,7 +15,7 @@ sig (* Switches edge directions *) val transpose: dep_graph -> dep_graph - val scheduleWithOracle: dep_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> int -> gate_idx Seq.t Seq.t + val scheduleWithOracle: dep_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> bool -> int -> gate_idx Seq.t Seq.t val scheduleCost: gate_idx Seq.t Seq.t -> (gate_idx -> bool) -> real val chooseSchedule: gate_idx Seq.t Seq.t Seq.t -> (gate_idx -> bool) -> gate_idx Seq.t Seq.t @@ -72,7 +72,7 @@ struct { visited = vis, indegree = deg } end - fun scheduleWithOracle (graph: dep_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (maxBranchingStride: int) = + fun scheduleWithOracle (graph: dep_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) = let val st = initState graph fun findNonBranching (i: int) (xs: gate_idx Seq.t) = if i = Seq.length xs then @@ -81,6 +81,7 @@ struct findNonBranching (i + 1) xs else SOME (Seq.nth xs i) + fun returnSeq thisStep acc = Seq.map Seq.fromList (Seq.rev (Seq.fromList (if List.null thisStep then acc else List.rev thisStep :: acc))) fun loadNonBranching (acc: gate_idx list) = case findNonBranching 0 (frontier st) of NONE => acc @@ -88,7 +89,7 @@ struct fun loadNext numBranchingSoFar thisStep (acc: gate_idx list list) = let val ftr = frontier st in if Seq.length ftr = 0 then - Seq.map Seq.fromList (Seq.rev (Seq.fromList (if List.null thisStep then acc else List.rev thisStep :: acc))) + returnSeq thisStep acc else (let val next = choose ftr in visit graph next st; @@ -98,8 +99,21 @@ struct loadNext (numBranchingSoFar + 1) (loadNonBranching (next :: thisStep)) acc end) end + fun loadNextNoFusion (acc: gate_idx list list) = + let val ftr = frontier st in + if Seq.length ftr = 0 then + returnSeq nil acc + else + (let val next = choose ftr in + visit graph next st; + loadNextNoFusion ((next :: nil) :: acc) + end) + end in - loadNext 0 (loadNonBranching nil) nil + if disableFusion then + loadNextNoFusion nil + else + loadNext 0 (loadNonBranching nil) nil end (*fun scheduleCost2 (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = @@ -122,7 +136,7 @@ struct if i = N then cost else - iter (i + 1) (1.0 + (if branching (Seq.nth gates i) then cost * 2.0 else cost)) + iter (i + 1) (1.0 + (if branching (Seq.nth gates i) then cost * 1.67 else cost)) in iter 0 0.0 end diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index 2325009..8fc1f82 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -188,18 +188,18 @@ val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride fun greedybranching () = case optDepGraph of NONE => GateSchedulerGreedyBranching.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DepGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DepGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedynonbranching () = case optDepGraph of NONE => GNB.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGNB.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGNB.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit () = case optDepGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit2 () = case optDepGraph of @@ -207,25 +207,25 @@ fun greedyfinishqubit2 () = | SOME dg => (*GateSchedulerOrder.mkScheduler (Seq.fromList (DGFQ.ordered { depGraph = dg, gateIsBranching = gate_branching dg }))*) (* dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGFQ.scheduler2 *) - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler2 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler2 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit3 () = case optDepGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler3 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler3 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit4 () = case optDepGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler4 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler4 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit5 () = case optDepGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler5 { depGraph = dg, gateIsBranching = gate_branching dg }) maxBranchingStride') + GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler5 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun randomsched (samples: int) = case optDepGraph of @@ -233,7 +233,7 @@ fun randomsched (samples: int) = | SOME dg => let val ags = { depGraph = dg, gateIsBranching = gate_branching dg } (*val scheds = Seq.fromList (List.tabulate (samples, fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride')) *) - val scheds = Seq.tabulate (fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride') samples + val scheds = Seq.tabulate (fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) disableFusion maxBranchingStride') samples val chosen = DepGraphUtil.chooseSchedule scheds (gate_branching dg) in GateSchedulerOrder.mkScheduler chosen From e4829dfaee8590c90d07e906031e1982d929f8b0 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Tue, 21 Nov 2023 12:38:52 -0500 Subject: [PATCH 06/15] Minor bug fix --- feynsum-sml/src/common/ExpandState.sml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/feynsum-sml/src/common/ExpandState.sml b/feynsum-sml/src/common/ExpandState.sml index 93788ab..422df0c 100644 --- a/feynsum-sml/src/common/ExpandState.sml +++ b/feynsum-sml/src/common/ExpandState.sml @@ -205,7 +205,7 @@ struct case apply widx of G.OutputOne widx' => doGates (apps + 1) (widx', gatenum + 1) | G.OutputTwo (widx1, widx) => - doTwo (apps + 1) ((widx1, widx), gatenum + 1) + doTwo (apps + 2) ((widx1, widx), gatenum + 1) and doTwo apps ((widx1, widx2), gatenum) = case doGates apps (widx1, gatenum) of From 6ad360c4f21375278e14772a17f343eb7a383075 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Mon, 27 Nov 2023 13:40:40 -0500 Subject: [PATCH 07/15] Dynamic scheduler: MVP --- feynsum-sml/src/FullSimBFS.sml | 183 +++++++++------- feynsum-sml/src/MkMain.sml | 22 +- .../src/common/DepGraphDynScheduler.sml | 205 ++++++++++++++++++ feynsum-sml/src/common/DepGraphUtil.sml | 42 ++-- feynsum-sml/src/common/ExpandState.sml | 4 +- feynsum-sml/src/common/sources.mlb | 5 + feynsum-sml/src/main.sml | 33 ++- 7 files changed, 371 insertions(+), 123 deletions(-) create mode 100644 feynsum-sml/src/common/DepGraphDynScheduler.sml diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 38ea7fd..9b169f0 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -6,14 +6,16 @@ functor FullSimBFS sharing B = SST.B = G.B sharing C = SST.C = G.C + val disableFusion: bool + val maxBranchingStride: int + val gateScheduler: string val blockSize: int val maxload: real - val gateScheduler: GateScheduler.t val doMeasureZeros: bool val denseThreshold: real val pullThreshold: real): sig - val run: Circuit.t + val run: DepGraph.t -> {result: (B.t * C.t) option DelayedSeq.t, counts: int Seq.t} end = struct @@ -32,7 +34,6 @@ struct val maxload = maxload val pullThreshold = pullThreshold) - val bits = Seq.fromList [ (*"▏",*)"▎", "▍", "▌", "▊"] fun fillBar width x = @@ -76,22 +77,48 @@ struct end) end - - fun run {numQubits, gates} = + structure DGFQ = DynSchedFinishQubitWrapper + (structure B = B + structure C = C + structure SST = SST + structure DS = DS + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure DGI = DynSchedInterference + (structure B = B + structure C = C + structure SST = SST + structure DS = DS + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + structure DGN = DynSchedNaive + (structure B = B + structure C = C + structure SST = SST + structure DS = DS + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + val gateSched = + case gateScheduler of + "naive" => DGN.choose + | "gfq" => DGFQ.choose + | "interference" => DGI.choose + | _ => raise Fail ("Unknown scheduler '" ^ gateScheduler ^ "'\n") + + fun run depgraph (*{numQubits, gates}*) = let - val gates = Seq.map G.fromGateDefn gates + val gates = Seq.map G.fromGateDefn (#gates depgraph) + val numQubits = #numQubits depgraph fun gate i = Seq.nth gates i val depth = Seq.length gates + val dgstate = DepGraphUtil.initState depgraph - val gateSchedulerPickNextGates = gateScheduler - { numQubits = numQubits - , numGates = depth - , gateTouches = #touches o gate - , gateIsBranching = (fn i => - case #action (gate i) of - G.NonBranching _ => false - | _ => true) - } + val pickNextGate = + let val f = gateSched depgraph in + fn (s, g) => f (s, g) + end (* val _ = if numQubits > 63 then raise Fail "whoops, too many qubits" else () *) @@ -142,77 +169,69 @@ struct | Expander.DenseKnownNonZeroSize (ds, nz) => raise Fail "Can't do dense stuff!" (*DS.unsafeViewContents ds, nz, TODO exception*) - fun loop numGateApps gatesVisitedSoFar counts prevNonZeroSize state totalDensity = - if gatesVisitedSoFar >= depth then - let - val (nonZeros, numNonZeros) = - case state of - Expander.Sparse sst => - (SST.unsafeViewContents sst, SST.nonZeroSize sst) - | Expander.Dense ds => - (DS.unsafeViewContents ds, DS.nonZeroSize ds) - | Expander.DenseKnownNonZeroSize (ds, nz) => - (DS.unsafeViewContents ds, nz) - - (* val _ = dumpState numQubits state *) - - val density = - dumpDensity (gatesVisitedSoFar, numNonZeros, SOME (getNumZeros state), NONE) - in - print "\n"; - print ("avg-density " ^ Real.toString (totalDensity / Real.fromInt gatesVisitedSoFar) ^ "\n"); - (numGateApps, nonZeros, Seq.fromRevList (numNonZeros :: counts)) - end - - else - let - (* val _ = dumpState numQubits state *) - - val theseGates = gateSchedulerPickNextGates () - val _ = - if Seq.length theseGates > 0 then - () - else - raise Fail "FullSimBFS: gate scheduler returned empty sequence" - - (* val _ = print - ("visiting: " ^ Seq.toString Int.toString theseGates ^ "\n") *) - - val theseGates = Seq.map (Seq.nth gates) theseGates - val numGatesVisitedHere = Seq.length theseGates - val ({result, method, numNonZeros, numGateApps = apps}, tm) = - Util.getTime (fn () => - Expander.expand - { gates = theseGates - , numQubits = numQubits - , maxNumStates = maxNumStates - , state = state - , prevNonZeroSize = prevNonZeroSize - }) - - val seconds = Time.toReal tm - val millions = Real.fromInt apps / 1e6 - val throughput = millions / seconds - val throughputStr = Real.fmt (StringCvt.FIX (SOME 2)) throughput - val density = - dumpDensity (gatesVisitedSoFar, numNonZeros, SOME (getNumZeros state), NONE) - val _ = print - (" hop " ^ leftPad 3 (Int.toString numGatesVisitedHere) ^ " " - ^ rightPad 11 method ^ " " - ^ Real.fmt (StringCvt.FIX (SOME 4)) seconds ^ "s throughput " - ^ throughputStr ^ "\n") - in - loop (numGateApps + apps) (gatesVisitedSoFar + numGatesVisitedHere) - (numNonZeros :: counts) numNonZeros result (totalDensity + (Rat.approx density) * Real.fromInt numGatesVisitedHere) - end - - val initialState = Expander.Sparse (SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) - val (numGateApps, finalState, counts) = loop 0 0 [] 1 initialState 0.0 + fun runloop () = + DepGraphUtil.scheduleWithOracle' + + (* dependency graph *) + depgraph + + (* gate is branching *) + (fn i => G.expectBranching (Seq.nth gates i)) + + (* select gate *) + (fn ((state, numGateApps, counts, gatesVisitedSoFar), gates) => + case state of + Expander.Sparse sst => pickNextGate (sst, gates) + | _ => Seq.nth gates 0) + + (* disable fusion? *) + disableFusion + + (* if fusion enabled, what's the max # of branching gates to fuse? *) + maxBranchingStride + + (* apply gate fusion seq, updating state *) + (fn ((state, numGateApps, counts, gatesVisitedSoFar), theseGates) => + let val numGatesVisitedHere = Seq.length theseGates + val ({result, method, numNonZeros, numGateApps = apps}, tm) = + Util.getTime (fn () => + Expander.expand + { gates = Seq.map (Seq.nth gates) theseGates + , numQubits = numQubits + , maxNumStates = maxNumStates + , state = state + , prevNonZeroSize = (case counts of h :: t => h | nil => 1) + }) + + val seconds = Time.toReal tm + val millions = Real.fromInt apps / 1e6 + val throughput = millions / seconds + val throughputStr = Real.fmt (StringCvt.FIX (SOME 2)) throughput + val density = + dumpDensity (gatesVisitedSoFar, numNonZeros, SOME (getNumZeros state), NONE) + val _ = print + (" hop " ^ leftPad 3 (Int.toString numGatesVisitedHere) ^ " " + ^ rightPad 11 method ^ " " + ^ Real.fmt (StringCvt.FIX (SOME 4)) seconds ^ "s throughput " + ^ throughputStr ^ "\n") + in + (result, numGateApps + apps, numNonZeros :: counts, gatesVisitedSoFar + numGatesVisitedHere) + end + ) + + (* initial state *) + (initialState, 0, [], 0) + + val (finalState, numGateApps, counts, gatesVisited) = runloop () + val nonZeros = case finalState of + Expander.Sparse sst => SST.unsafeViewContents sst + | Expander.Dense ds => DS.unsafeViewContents ds + | Expander.DenseKnownNonZeroSize (ds, nz) => DS.unsafeViewContents ds val _ = print ("gate app count " ^ Int.toString numGateApps ^ "\n") in - {result = finalState, counts = counts} + {result = nonZeros, counts = Seq.fromList counts} end end diff --git a/feynsum-sml/src/MkMain.sml b/feynsum-sml/src/MkMain.sml index 198fa13..2b2e68f 100644 --- a/feynsum-sml/src/MkMain.sml +++ b/feynsum-sml/src/MkMain.sml @@ -1,9 +1,11 @@ functor MkMain (structure C: COMPLEX structure B: BASIS_IDX + val disableFusion: bool + val maxBranchingStride: int val blockSize: int val maxload: real - val gateScheduler: GateScheduler.t + val gateScheduler: string val doMeasureZeros: bool val denseThreshold: real val pullThreshold: real) = @@ -20,6 +22,8 @@ struct structure SST = SparseStateTableLockedSlots (structure B = B structure C = C) structure G = G + val disableFusion = disableFusion + val maxBranchingStride = maxBranchingStride val blockSize = blockSize val maxload = maxload val gateScheduler = gateScheduler @@ -33,6 +37,8 @@ struct structure C = C structure SST = SparseStateTable (structure B = B structure C = C) structure G = G + val disableFusion = disableFusion + val maxBranchingStride = maxBranchingStride val blockSize = blockSize val maxload = maxload val gateScheduler = gateScheduler @@ -50,23 +56,23 @@ struct fun main (inputName, circuit) = let - val numQubits = Circuit.numQubits circuit + val numQubits = #numQubits circuit val impl = CLA.parseString "impl" "lockfree" val output = CLA.parseString "output" "" val outputDensities = CLA.parseString "output-densities" "" val _ = print ("impl " ^ impl ^ "\n") - val sim = + fun sim () = case impl of - "lockfree" => BFSLockfree.run - | "locked" => BFSLocked.run + "lockfree" => BFSLockfree.run circuit + | "locked" => BFSLocked.run circuit | _ => Util.die ("unknown impl " ^ impl ^ "; valid options are: locked, lockfree\n") - val {result, counts} = Benchmark.run "full-sim-bfs" (fn _ => sim circuit) + val {result, counts} = Benchmark.run "full-sim-bfs" (fn _ => sim ()) val counts = Seq.map IntInf.fromInt counts val maxNumStates = IntInf.pow (2, numQubits) @@ -146,8 +152,8 @@ struct print (String.concatWith "," [ name - , Int.toString (Circuit.numQubits circuit) - , Int.toString (Circuit.numGates circuit) + , Int.toString (#numQubits circuit) + , Int.toString (Seq.length (#gates circuit)) , Real.fmt (StringCvt.FIX (SOME 12)) (Rat.approx maxDensity) , Real.fmt (StringCvt.FIX (SOME 12)) (Rat.approx avgDensity) ] ^ "\n") diff --git a/feynsum-sml/src/common/DepGraphDynScheduler.sml b/feynsum-sml/src/common/DepGraphDynScheduler.sml new file mode 100644 index 0000000..9f7d115 --- /dev/null +++ b/feynsum-sml/src/common/DepGraphDynScheduler.sml @@ -0,0 +1,205 @@ +signature DEP_GRAPH_DYN_SCHEDULER = +sig + structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C + + type gate_idx = int + + type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + + val choose: t +end + +functor DynSchedFinishQubitWrapper + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C + val maxBranchingStride: int + val disableFusion: bool + ): DEP_GRAPH_DYN_SCHEDULER = +struct + structure B = B + structure C = C + structure SST = SST + structure DS = DS + + type gate_idx = int + + type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + + structure DGFQ = DepGraphSchedulerGreedyFinishQubit + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure G = Gate + (structure B = B + structure C = C) + + fun choose (depgraph: DepGraph.t) = + let val branchSeq = Seq.map (fn g => G.expectBranching (G.fromGateDefn g)) (#gates depgraph) + fun branching i = Seq.nth branchSeq i + val f = DGFQ.scheduler5 {depGraph = depgraph, gateIsBranching = branching} in + fn (_, gates) => f gates + end + +end + +functor DynSchedNaive + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C + val maxBranchingStride: int + val disableFusion: bool + ): DEP_GRAPH_DYN_SCHEDULER = +struct + structure B = B + structure C = C + structure SST = SST + structure DS = DS + + type gate_idx = int + + type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + + fun choose (depgraph: DepGraph.t) = fn (_, gates) => Seq.nth gates 0 + +end + + +functor DynSchedInterference + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C + val maxBranchingStride: int + val disableFusion: bool + ): DEP_GRAPH_DYN_SCHEDULER = +struct + structure B = B + structure C = C + structure SST = SST + structure DS = DS + + type gate_idx = int + + type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + + structure DGFQ = DepGraphSchedulerGreedyFinishQubit + (val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion) + + structure G = Gate + (structure B = B + structure C = C) + + datatype Branched = Uninitialized | Zero | One | Superposition + + fun joinBranches' (br, b01) = + case (br, b01) of + (Uninitialized, false) => Zero + | (Uninitialized, true) => One + | (Zero, false) => Zero + | (Zero, true) => Superposition + | (One, false) => Superposition + | (One, true) => One + | (Superposition, _) => Superposition + + fun joinBranches (br, br') = + case (br, br') of + (Uninitialized, b) => b + | (a, Uninitialized) => a + | (Zero, Zero) => Zero + | (One, One) => One + | (_, _) => Superposition + + fun calculateBranchedQubits (numQubits, sst) = + let val nonZeros = DelayedSeq.mapOption (fn x => x) (SST.unsafeViewContents sst) + val branchedQubits = Seq.tabulate (fn qi => DelayedSeq.reduce joinBranches Uninitialized (DelayedSeq.map (fn (b, c) => if B.get b qi then One else Zero) nonZeros)) numQubits + fun isbranched b = + case b of + Uninitialized => raise Fail "Uninitialized in isbranched! This shouldn't happen" + | Zero => false + | One => false + | Superposition => true + in + Seq.map isbranched branchedQubits + end + + fun choose (depgraph: DepGraph.t) = + let val gates = Seq.map G.fromGateDefn (#gates depgraph) + val branchSeq = Seq.map G.expectBranching gates + fun branching i = Seq.nth branchSeq i + val numQubits = #numQubits depgraph + in + fn (sst, gidxs) => + (* if dense, doesn't really matter what we pick? *) + let val branchedQubits = calculateBranchedQubits (numQubits, sst) + fun getBranching i = + Seq.reduce (fn ((br, nbr), (br', nbr')) => (br + br', nbr + nbr')) (0, 0) + (Seq.map (fn qi => if Seq.nth branchedQubits qi then + (1, 0) else (0, 1)) + (#touches (Seq.nth gates i))) + fun pick i (best_g, best_diff) = + if i >= Seq.length gidxs then + best_g + else + let val g = Seq.nth gidxs i + val (br, nbr) = getBranching g + val diff = if branching g then br - nbr else nbr - br + in + pick (i + 1) (if diff > best_diff then + (g, diff) else (best_g, best_diff)) + end + val (br0, nbr0) = getBranching (Seq.nth gidxs 0) + in + pick 1 (Seq.nth gidxs 0, if branching 0 then br0 - nbr0 else nbr0 - br0) + end + end + +end + + +(*functor DepGraphDynScheduler + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + structure G: GATE + sharing B = SST.B = DS.B = G.B + sharing C = SST.C = DS.C = G.C + val blockSize: int + val maxload: real + val denseThreshold: real + val pullThreshold: real): DEP_GRAPH_DYN_SCHEDULER = +struct + type gate_idx = int + structure Expander = + ExpandState + (structure B = B + structure C = C + structure SST = SST + structure DS = DS + structure G = G + val denseThreshold = denseThreshold + val blockSize = blockSize + val maxload = maxload + val pullThreshold = pullThreshold) + + ( * From a frontier, select which gate to apply next * ) + ( * args visit gates, update frontier break fusion initial frontier gate batches * ) + ( * type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t* ) + type t = DepGraph.t -> (Expander.state * gate_idx Seq.t -> gate_idx) +end +*) diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DepGraphUtil.sml index d0af24d..1fe759e 100644 --- a/feynsum-sml/src/common/DepGraphUtil.sml +++ b/feynsum-sml/src/common/DepGraphUtil.sml @@ -17,6 +17,8 @@ sig val scheduleWithOracle: dep_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> bool -> int -> gate_idx Seq.t Seq.t + val scheduleWithOracle': dep_graph -> (gate_idx -> bool) -> ('state * gate_idx Seq.t -> gate_idx) -> bool -> int -> ('state * gate_idx Seq.t -> 'state) -> 'state -> 'state + val scheduleCost: gate_idx Seq.t Seq.t -> (gate_idx -> bool) -> real val chooseSchedule: gate_idx Seq.t Seq.t Seq.t -> (gate_idx -> bool) -> gate_idx Seq.t Seq.t @@ -72,8 +74,8 @@ struct { visited = vis, indegree = deg } end - fun scheduleWithOracle (graph: dep_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) = - let val st = initState graph + fun scheduleWithOracle' (graph: dep_graph) (branching: gate_idx -> bool) (choose: 'state * gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) (apply: 'state * gate_idx Seq.t -> 'state) state = + let val dgst = initState graph fun findNonBranching (i: int) (xs: gate_idx Seq.t) = if i = Seq.length xs then NONE @@ -83,39 +85,41 @@ struct SOME (Seq.nth xs i) fun returnSeq thisStep acc = Seq.map Seq.fromList (Seq.rev (Seq.fromList (if List.null thisStep then acc else List.rev thisStep :: acc))) fun loadNonBranching (acc: gate_idx list) = - case findNonBranching 0 (frontier st) of + case findNonBranching 0 (frontier dgst) of NONE => acc - | SOME i => (visit graph i st; loadNonBranching (i :: acc)) - fun loadNext numBranchingSoFar thisStep (acc: gate_idx list list) = - let val ftr = frontier st in + | SOME i => (visit graph i dgst; loadNonBranching (i :: acc)) + fun loadNext numBranchingSoFar thisStep state = + let val ftr = frontier dgst in if Seq.length ftr = 0 then - returnSeq thisStep acc + state else - (let val next = choose ftr in - visit graph next st; + (let val next = choose (state, ftr) in + visit graph next dgst; if numBranchingSoFar + 1 >= maxBranchingStride then - loadNext 0 nil (List.rev (loadNonBranching (next :: thisStep)) :: acc) + loadNext 0 nil (apply (state, Seq.rev (Seq.fromList (loadNonBranching (next :: thisStep))))) else - loadNext (numBranchingSoFar + 1) (loadNonBranching (next :: thisStep)) acc + loadNext (numBranchingSoFar + 1) (loadNonBranching (next :: thisStep)) state end) end - fun loadNextNoFusion (acc: gate_idx list list) = - let val ftr = frontier st in + fun loadNextNoFusion state = + let val ftr = frontier dgst in if Seq.length ftr = 0 then - returnSeq nil acc + state else - (let val next = choose ftr in - visit graph next st; - loadNextNoFusion ((next :: nil) :: acc) + (let val next = choose (state, ftr) in + visit graph next dgst; + loadNextNoFusion (apply (state, Seq.singleton next)) end) end in if disableFusion then - loadNextNoFusion nil + loadNextNoFusion state else - loadNext 0 (loadNonBranching nil) nil + loadNext 0 (loadNonBranching nil) state end + fun scheduleWithOracle (graph: dep_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) = Seq.rev (Seq.fromList (scheduleWithOracle' graph branching (fn (_, x) => choose x) disableFusion maxBranchingStride (fn (gs, g) => g :: gs) nil)) + (*fun scheduleCost2 (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = let val gates = Seq.flatten order val N = Seq.length gates diff --git a/feynsum-sml/src/common/ExpandState.sml b/feynsum-sml/src/common/ExpandState.sml index 422df0c..f5afac7 100644 --- a/feynsum-sml/src/common/ExpandState.sml +++ b/feynsum-sml/src/common/ExpandState.sml @@ -55,7 +55,7 @@ struct val d = Char.ord (String.sub (digits, depth)) - Char.ord #"0" val _ = if 0 <= d andalso d <= 9 then () - else raise Fail "riMult: bad digit" + else raise Fail ("riMult: bad digit " ^ digits ^ ", " ^ Real.toString r ^ ", " ^ IntInf.toString i) val acc = acc + (i * IntInf.fromInt d) div (IntInf.pow (10, depth + 1)) in @@ -407,7 +407,7 @@ struct val (method, {result, numGateApps}) = if - expectedCost < riMult denseThreshold maxNumStates + denseThreshold >= 1.0 orelse expectedCost < riMult denseThreshold maxNumStates then ("push sparse", expandSparse args) diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index 4f1a503..791e890 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -71,6 +71,7 @@ local DepGraphSchedulerGreedyBranching.sml DepGraphSchedulerGreedyNonBranching.sml DepGraphSchedulerGreedyFinishQubit.sml + DepGraphDynScheduler.sml GateSchedulerOrder.sml in structure HashTable @@ -135,4 +136,8 @@ in functor DepGraphSchedulerGreedyNonBranching functor DepGraphSchedulerGreedyFinishQubit structure GateSchedulerOrder + signature DEP_GRAPH_DYN_SCHEDULER + functor DynSchedFinishQubitWrapper + functor DynSchedInterference + functor DynSchedNaive end \ No newline at end of file diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index 8fc1f82..cf567a9 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -61,18 +61,19 @@ fun parseQasm () = Circuit.fromSMLQasmSimpleCircuit simpleCirc end -val (circuit, optDepGraph) = +val (circuit, depGraph) = (*(circuit, optDepGraph)*) case inputName of "" => Util.die ("missing: -input FILE.qasm") | _ => if String.isSuffix ".qasm" inputName then - (parseQasm (), NONE) + raise Fail ".qasm no longer supported, use .json dependency graph" + (*parseQasm (), NONE*) else let val dg = DepGraph.fromFile inputName val circuit = {numQubits = #numQubits dg, gates = #gates dg} in - (circuit, SOME dg) + (circuit, dg) end val _ = print ("-------------------------------\n") @@ -185,7 +186,7 @@ fun gate_branching ({ gates = gates, ...} : DepGraph.t) = val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride -fun greedybranching () = +(*fun greedybranching () = case optDepGraph of NONE => GateSchedulerGreedyBranching.scheduler | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DepGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') @@ -237,9 +238,9 @@ fun randomsched (samples: int) = val chosen = DepGraphUtil.chooseSchedule scheds (gate_branching dg) in GateSchedulerOrder.mkScheduler chosen - end + end*) -val gateScheduler = +(*val gateScheduler = case schedulerName of "naive" => GateSchedulerNaive.scheduler @@ -270,7 +271,7 @@ val gateScheduler = Util.die ("unknown scheduler: " ^ schedulerName ^ - "; valid options are: naive, greedy-branching (gb), greedy-nonbranching (gnb), greedy-finish-qubit (gfq)") + "; valid options are: naive, greedy-branching (gb), greedy-nonbranching (gnb), greedy-finish-qubit (gfq)")*) (* ======================================================================== * mains: 32-bit and 64-bit @@ -280,9 +281,11 @@ structure M_64_32 = MkMain (structure B = BasisIdx64 structure C = Complex32 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -291,9 +294,11 @@ structure M_64_64 = MkMain (structure B = BasisIdx64 structure C = Complex64 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -302,9 +307,11 @@ structure M_U_32 = MkMain (structure B = BasisIdxUnlimited structure C = Complex32 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -313,9 +320,11 @@ structure M_U_64 = MkMain (structure B = BasisIdxUnlimited structure C = Complex64 + val maxBranchingStride = maxBranchingStride + val disableFusion = disableFusion val blockSize = blockSize val maxload = maxload - val gateScheduler = gateScheduler + val gateScheduler = schedulerName val doMeasureZeros = doMeasureZeros val denseThreshold = denseThreshold val pullThreshold = pullThreshold) @@ -333,4 +342,4 @@ val main = (* ======================================================================== *) -val _ = main (inputName, circuit) +val _ = main (inputName, depGraph) From ac728a8a8e1f2944c7883e920ec6afd8412df29c Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Mon, 4 Dec 2023 10:35:38 -0500 Subject: [PATCH 08/15] Remove barriers in dep graph script, for now (though we probably don't want to allow commutations across barriers, eventually) --- feynsum-sml/src/common/depgraph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/feynsum-sml/src/common/depgraph.py b/feynsum-sml/src/common/depgraph.py index a550dbb..6b7a052 100644 --- a/feynsum-sml/src/common/depgraph.py +++ b/feynsum-sml/src/common/depgraph.py @@ -1,5 +1,6 @@ from numpy.typing import NDArray import qiskit as Q +from qiskit.transpiler.passes import RemoveBarriers import rustworkx as rx import numpy as np import sys @@ -278,6 +279,9 @@ def main(argv): print(usage(argv).strip(), file=sys.stderr) return 1 circuit = read_qasm(ifh) + circuit.remove_final_measurements(True) + rb = RemoveBarriers() + circuit = rb(circuit) dag = circuit_to_dag(circuit) find_qubit = circuit_find_qubit_dict(circuit) if just_dag: From 76952b0eff84cc0850f03ff5d3bf591c9a928713 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Mon, 4 Dec 2023 10:36:11 -0500 Subject: [PATCH 09/15] Add todo comment --- feynsum-sml/src/common/depgraph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/feynsum-sml/src/common/depgraph.py b/feynsum-sml/src/common/depgraph.py index 6b7a052..d06e669 100644 --- a/feynsum-sml/src/common/depgraph.py +++ b/feynsum-sml/src/common/depgraph.py @@ -281,6 +281,7 @@ def main(argv): circuit = read_qasm(ifh) circuit.remove_final_measurements(True) rb = RemoveBarriers() + # TODO: prevent commuting across barriers circuit = rb(circuit) dag = circuit_to_dag(circuit) find_qubit = circuit_find_qubit_dict(circuit) From cc7bc8b4508ded1625531dfd6e42d5acf82956ed Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Thu, 7 Dec 2023 11:39:44 -0500 Subject: [PATCH 10/15] Make scheduler work for either greedy or dense state --- feynsum-sml/src/FullSimBFS.sml | 41 +++++++------- .../src/common/DepGraphDynScheduler.sml | 56 +++++++++---------- feynsum-sml/src/common/DepGraphScheduler.sml | 7 +-- .../DepGraphSchedulerGreedyBranching.sml | 10 ++-- .../DepGraphSchedulerGreedyFinishQubit.sml | 18 +++--- .../DepGraphSchedulerGreedyNonBranching.sml | 23 ++------ feynsum-sml/src/common/DepGraphUtil.sml | 13 ++++- feynsum-sml/src/common/ExpandState.sml | 56 ++++++++----------- feynsum-sml/src/common/sources.mlb | 6 ++ feynsum-sml/src/main.sml | 32 +---------- 10 files changed, 103 insertions(+), 159 deletions(-) diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 9b169f0..39a5357 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -21,13 +21,16 @@ end = struct structure DS = DenseState (structure C = C structure B = B) + structure HS = HybridState (structure C = C + structure B = B + structure DS = DS + structure SST = SST) structure Expander = ExpandState (structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS structure G = G val denseThreshold = denseThreshold val blockSize = blockSize @@ -62,9 +65,9 @@ struct let val ss = case s of - Expander.Sparse sst => SST.unsafeViewContents sst - | Expander.Dense ds => DS.unsafeViewContents ds - | Expander.DenseKnownNonZeroSize (ds, _) => DS.unsafeViewContents ds + HS.Sparse sst => SST.unsafeViewContents sst + | HS.Dense ds => DS.unsafeViewContents ds + | HS.DenseKnownNonZeroSize (ds, _) => DS.unsafeViewContents ds in Util.for (0, DelayedSeq.length ss) (fn i => case DelayedSeq.nth ss i of @@ -80,23 +83,20 @@ struct structure DGFQ = DynSchedFinishQubitWrapper (structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) structure DGI = DynSchedInterference (structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) structure DGN = DynSchedNaive (structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) @@ -163,13 +163,13 @@ struct fun getNumZeros state = case state of - Expander.Sparse sst => SST.zeroSize sst - | Expander.Dense ds => raise Fail "Can't do dense stuff!" + HS.Sparse sst => SST.zeroSize sst + | HS.Dense ds => raise Fail "Can't do dense stuff!" (*DS.unsafeViewContents ds, DS.nonZeroSize ds, TODO exception*) - | Expander.DenseKnownNonZeroSize (ds, nz) => raise Fail "Can't do dense stuff!" + | HS.DenseKnownNonZeroSize (ds, nz) => raise Fail "Can't do dense stuff!" (*DS.unsafeViewContents ds, nz, TODO exception*) - val initialState = Expander.Sparse + val initialState = HS.Sparse (SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) fun runloop () = @@ -182,10 +182,7 @@ struct (fn i => G.expectBranching (Seq.nth gates i)) (* select gate *) - (fn ((state, numGateApps, counts, gatesVisitedSoFar), gates) => - case state of - Expander.Sparse sst => pickNextGate (sst, gates) - | _ => Seq.nth gates 0) + (fn ((state, numGateApps, counts, gatesVisitedSoFar), gates) => pickNextGate (state, gates)) (* disable fusion? *) disableFusion @@ -227,9 +224,9 @@ struct val (finalState, numGateApps, counts, gatesVisited) = runloop () val nonZeros = case finalState of - Expander.Sparse sst => SST.unsafeViewContents sst - | Expander.Dense ds => DS.unsafeViewContents ds - | Expander.DenseKnownNonZeroSize (ds, nz) => DS.unsafeViewContents ds + HS.Sparse sst => SST.unsafeViewContents sst + | HS.Dense ds => DS.unsafeViewContents ds + | HS.DenseKnownNonZeroSize (ds, nz) => DS.unsafeViewContents ds val _ = print ("gate app count " ^ Int.toString numGateApps ^ "\n") in {result = nonZeros, counts = Seq.fromList counts} diff --git a/feynsum-sml/src/common/DepGraphDynScheduler.sml b/feynsum-sml/src/common/DepGraphDynScheduler.sml index 9f7d115..7ee0a37 100644 --- a/feynsum-sml/src/common/DepGraphDynScheduler.sml +++ b/feynsum-sml/src/common/DepGraphDynScheduler.sml @@ -2,14 +2,13 @@ signature DEP_GRAPH_DYN_SCHEDULER = sig structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE - structure DS: DENSE_STATE - sharing B = SST.B = DS.B - sharing C = SST.C = DS.C + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C type gate_idx = int - type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) val choose: t end @@ -17,22 +16,20 @@ end functor DynSchedFinishQubitWrapper (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE - structure DS: DENSE_STATE - sharing B = SST.B = DS.B - sharing C = SST.C = DS.C + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C val maxBranchingStride: int val disableFusion: bool ): DEP_GRAPH_DYN_SCHEDULER = struct structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS type gate_idx = int - type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) structure DGFQ = DepGraphSchedulerGreedyFinishQubit (val maxBranchingStride = maxBranchingStride @@ -43,9 +40,7 @@ struct structure C = C) fun choose (depgraph: DepGraph.t) = - let val branchSeq = Seq.map (fn g => G.expectBranching (G.fromGateDefn g)) (#gates depgraph) - fun branching i = Seq.nth branchSeq i - val f = DGFQ.scheduler5 {depGraph = depgraph, gateIsBranching = branching} in + let val f = DGFQ.scheduler5 depgraph in fn (_, gates) => f gates end @@ -54,22 +49,20 @@ end functor DynSchedNaive (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE - structure DS: DENSE_STATE - sharing B = SST.B = DS.B - sharing C = SST.C = DS.C + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C val maxBranchingStride: int val disableFusion: bool ): DEP_GRAPH_DYN_SCHEDULER = struct structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS type gate_idx = int - type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) fun choose (depgraph: DepGraph.t) = fn (_, gates) => Seq.nth gates 0 @@ -79,22 +72,20 @@ end functor DynSchedInterference (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE - structure DS: DENSE_STATE - sharing B = SST.B = DS.B - sharing C = SST.C = DS.C + structure HS: HYBRID_STATE + sharing B = HS.B + sharing C = HS.C val maxBranchingStride: int val disableFusion: bool ): DEP_GRAPH_DYN_SCHEDULER = struct structure B = B structure C = C - structure SST = SST - structure DS = DS + structure HS = HS type gate_idx = int - type t = DepGraph.t -> (SST.t * gate_idx Seq.t -> gate_idx) + type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) structure DGFQ = DepGraphSchedulerGreedyFinishQubit (val maxBranchingStride = maxBranchingStride @@ -125,7 +116,7 @@ struct | (_, _) => Superposition fun calculateBranchedQubits (numQubits, sst) = - let val nonZeros = DelayedSeq.mapOption (fn x => x) (SST.unsafeViewContents sst) + let val nonZeros = DelayedSeq.mapOption (fn x => x) (HS.SST.unsafeViewContents sst) val branchedQubits = Seq.tabulate (fn qi => DelayedSeq.reduce joinBranches Uninitialized (DelayedSeq.map (fn (b, c) => if B.get b qi then One else Zero) nonZeros)) numQubits fun isbranched b = case b of @@ -143,7 +134,10 @@ struct fun branching i = Seq.nth branchSeq i val numQubits = #numQubits depgraph in - fn (sst, gidxs) => + fn (hs, gidxs) => case hs of + HS.Dense d => Seq.nth gidxs 0 + | HS.DenseKnownNonZeroSize d => Seq.nth gidxs 0 + | HS.Sparse sst => (* if dense, doesn't really matter what we pick? *) let val branchedQubits = calculateBranchedQubits (numQubits, sst) fun getBranching i = diff --git a/feynsum-sml/src/common/DepGraphScheduler.sml b/feynsum-sml/src/common/DepGraphScheduler.sml index 1e0f1de..ccf01de 100644 --- a/feynsum-sml/src/common/DepGraphScheduler.sml +++ b/feynsum-sml/src/common/DepGraphScheduler.sml @@ -3,13 +3,8 @@ struct type gate_idx = int - type args = - { depGraph: DepGraph.t - , gateIsBranching: gate_idx -> bool - } - (* From a frontier, select which gate to apply next *) (* args visit gates, update frontier break fusion initial frontier gate batches *) (*type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t*) - type t = args -> (gate_idx Seq.t -> gate_idx) + type t = DepGraph.t -> (gate_idx Seq.t -> gate_idx) end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml index 2ae809a..0a6d360 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml @@ -6,11 +6,6 @@ struct type gate_idx = int - type args = - { depGraph: DepGraph.t - , gateIsBranching: gate_idx -> bool - } - fun pickBranching i branching gates = if i < Seq.length gates then (if branching i then @@ -21,5 +16,8 @@ struct Seq.nth gates 0 (* pick a non-branching gate *) (* From a frontier, select which gate to apply next *) - fun scheduler ({gateIsBranching = gib, ...} : args) gates = pickBranching 0 gib gates + fun scheduler dg = + let val branching = DepGraphUtil.gateIsBranching dg in + fn gates => pickBranching 0 branching gates + end end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml index fe87cc9..2282433 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml @@ -12,11 +12,6 @@ struct type gate_idx = int - type args = - { depGraph: DepGraph.t - , gateIsBranching: gate_idx -> bool - } - fun DFS ((new, old) : (int list * int list)) = new @ old fun BFS ((new, old) : (int list * int list)) = old @ new @@ -67,9 +62,10 @@ struct end (* Choose in reverse topological order, sorted by easiest qubit to finish *) - fun scheduler3 ({depGraph = dg, gateIsBranching = gib}: args) = + fun scheduler3 dg = let val dgt = DepGraphUtil.transpose dg val depths = computeGateDepths dg + val gib = DepGraphUtil.gateIsBranching dg fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso not (gib a) andalso gib b) fun push (new, old) = DFS (sortList lt new, old) val xs = revTopologicalSort dgt push @@ -125,14 +121,14 @@ struct end (* From a frontier, select which gate to apply next *) - fun scheduler ({depGraph = dg, ...} : args) = + fun scheduler dg = (gateDepths := SOME (computeGateDepths dg); fn gates => let val g0 = Seq.nth gates 0 in pickLeastDepth g0 (gateDepth g0 dg) 1 gates dg end) (* Select gate with greatest number of descendants *) - fun scheduler4 ({depGraph = dg, ...} : args) = + fun scheduler4 dg = (gateDepths := SOME (computeGateDepths dg); fn gates => let val g0 = Seq.nth gates 0 in pickGreatestDepth g0 (gateDepth g0 dg) 1 gates dg @@ -142,7 +138,7 @@ struct structure C = Complex64) (* Hybrid of scheduler2 (avoid branching on unbranched qubits) and also scheduler3 (choose in reverse topological order, sorted by easiest qubit to finish) *) - fun scheduler5 ({depGraph = dg, gateIsBranching = gib} : args) = + fun scheduler5 (dg: DepGraph.t) = let val gates = Seq.map G.fromGateDefn (#gates dg) fun touches i = #touches (Seq.nth gates i) fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 @@ -186,7 +182,7 @@ struct end (* Avoids branching on unbranched qubits *) - fun scheduler2 ({depGraph = dg, gateIsBranching = gib} : args) = + fun scheduler2 (dg: DepGraph.t) = let val touched = Array.array (#numQubits dg, false) val gates = Seq.map G.fromGateDefn (#gates dg) fun touches i = #touches (Seq.nth gates i) @@ -220,7 +216,7 @@ struct val seed = Random.rand (50, 14125) - fun schedulerRandom seedNum ({depGraph = dg, gateIsBranching = gib} : args) = + fun schedulerRandom seedNum dg = (*let val seed = Random.rand (seedNum, seedNum * seedNum) in*) fn gates => let val r = Random.randRange (0, Seq.length gates - 1) seed in ((*print ("Randomly chose " ^ Int.toString r ^ " from range [0, " ^ Int.toString (Seq.length gates) ^ ")\n");*) diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml b/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml index bbbf810..3de8f1f 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml +++ b/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml @@ -22,23 +22,8 @@ struct Seq.nth ftr 0 (* From a frontier, select which gate to apply next *) - fun scheduler ({gateIsBranching = branching, ...} : args) ftr = - pickNonBranching 0 branching ftr - (*fun scheduler ({gateIsBranching = branching, ...} : args) updateFrontier breakFusion initialFrontier = - let fun updateAndBreak i = let val ftr = updateFrontier i in breakFusion (); ftr end - fun pickNonBranching i ftr = - if i < Seq.length gates then - (if branching i then - pickNonBranching (i + 1) gates - else - Seq.nth gates i) - else - (breakFusion (); Seq.nth gates 0) - fun sched ftr = if Seq.null ftr then - () - else - scheduler (updateFrontier (pickNonBranching 0 gates)) - in - sched initialFrontier - end*) + fun scheduler dg = + let val branching = DepGraphUtil.gateIsBranching dg in + fn gates => pickNonBranching 0 branching gates + end end diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DepGraphUtil.sml index 1fe759e..115aa0f 100644 --- a/feynsum-sml/src/common/DepGraphUtil.sml +++ b/feynsum-sml/src/common/DepGraphUtil.sml @@ -22,7 +22,7 @@ sig val scheduleCost: gate_idx Seq.t Seq.t -> (gate_idx -> bool) -> real val chooseSchedule: gate_idx Seq.t Seq.t Seq.t -> (gate_idx -> bool) -> gate_idx Seq.t Seq.t - (* val gateIsBranching: dep_graph -> (gate_idx -> bool) *) + val gateIsBranching: dep_graph -> (gate_idx -> bool) end = struct @@ -159,4 +159,15 @@ struct in iter 1 0 (scheduleCost (Seq.nth orders 0) branching) end + + (* B and C don't affect gate_branching, so pick arbitrarily *) + structure Gate_branching = Gate (structure B = BasisIdxUnlimited + structure C = Complex64) + + + fun gateIsBranching ({ gates = gates, ...} : dep_graph) = + let val branchSeq = Seq.map (fn g => Gate_branching.expectBranching (Gate_branching.fromGateDefn g)) gates in + fn i => Seq.nth branchSeq i + end + end diff --git a/feynsum-sml/src/common/ExpandState.sml b/feynsum-sml/src/common/ExpandState.sml index f5afac7..92e8374 100644 --- a/feynsum-sml/src/common/ExpandState.sml +++ b/feynsum-sml/src/common/ExpandState.sml @@ -1,34 +1,31 @@ functor ExpandState (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE - structure DS: DENSE_STATE + structure HS: HYBRID_STATE structure G: GATE - sharing B = SST.B = DS.B = G.B - sharing C = SST.C = DS.C = G.C + sharing B = HS.B = G.B + sharing C = HS.C = G.C val blockSize: int val maxload: real val denseThreshold: real val pullThreshold: real) :> sig - datatype state = - Sparse of SST.t - | Dense of DS.t - | DenseKnownNonZeroSize of DS.t * int - val expand: { gates: G.t Seq.t , numQubits: int , maxNumStates: IntInf.int - , state: state + , state: HS.state , prevNonZeroSize: int } - -> {result: state, method: string, numNonZeros: int, numGateApps: int} + -> {result: HS.state, method: string, numNonZeros: int, numGateApps: int} end = struct + structure SST = HS.SST + structure DS = HS.DS + (* 0 < r < 1 * * I wish this wasn't so difficult @@ -108,13 +105,6 @@ struct AllSucceeded | SomeFailed of {widx: B.t * C.t, gatenum: int} list - - datatype state = - Sparse of SST.t - | Dense of DS.t - | DenseKnownNonZeroSize of DS.t * int - - fun expandSparse {gates: G.t Seq.t, numQubits, state, expected} = let val numGates = Seq.length gates @@ -122,9 +112,9 @@ struct val stateSeq = case state of - Sparse sst => DelayedSeq.map SOME (SST.compact sst) - | Dense state => DS.unsafeViewContents state - | DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state + HS.Sparse sst => DelayedSeq.map SOME (SST.compact sst) + | HS.Dense state => DS.unsafeViewContents state + | HS.DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state (* number of initial elements *) val n = DelayedSeq.length stateSeq @@ -285,7 +275,7 @@ struct val (apps, output) = loop 0 initialBlocks initialTable in - {result = Sparse output, numGateApps = apps} + {result = HS.Sparse output, numGateApps = apps} end @@ -296,9 +286,9 @@ struct val stateSeq = case state of - Sparse sst => DelayedSeq.map SOME (SST.compact sst) - | Dense state => DS.unsafeViewContents state - | DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state + HS.Sparse sst => DelayedSeq.map SOME (SST.compact sst) + | HS.Dense state => DS.unsafeViewContents state + | HS.DenseKnownNonZeroSize (state, _) => DS.unsafeViewContents state (* number of initial elements *) val n = DelayedSeq.length stateSeq @@ -331,7 +321,7 @@ struct SOME widx => doGates 0 (widx, 0) | NONE => 0) in - {result = Dense output, numGateApps = numGateApps} + {result = HS.Dense output, numGateApps = numGateApps} end @@ -343,9 +333,9 @@ struct val lookup = case state of - Sparse sst => (fn bidx => Option.getOpt (SST.lookup sst bidx, C.zero)) - | Dense ds => DS.lookupDirect ds - | DenseKnownNonZeroSize (ds, _) => DS.lookupDirect ds + HS.Sparse sst => (fn bidx => Option.getOpt (SST.lookup sst bidx, C.zero)) + | HS.Dense ds => DS.lookupDirect ds + | HS.DenseKnownNonZeroSize (ds, _) => DS.lookupDirect ds fun doGates (bidx, gatenum) = if gatenum < 0 then @@ -375,7 +365,7 @@ struct DS.pull {numQubits = numQubits} (fn bidx => doGates (bidx, numGates - 1)) in - { result = DenseKnownNonZeroSize (result, nonZeroSize) + { result = HS.DenseKnownNonZeroSize (result, nonZeroSize) , numGateApps = totalCount } end @@ -385,9 +375,9 @@ struct let val nonZeroSize = case state of - Sparse sst => SST.nonZeroSize sst - | Dense ds => DS.nonZeroSize ds - | DenseKnownNonZeroSize (_, nz) => nz + HS.Sparse sst => SST.nonZeroSize sst + | HS.Dense ds => DS.nonZeroSize ds + | HS.DenseKnownNonZeroSize (_, nz) => nz val rate = Real.max (1.0, Real.fromInt nonZeroSize / Real.fromInt prevNonZeroSize) diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index 791e890..cd19c08 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -54,6 +54,9 @@ local DENSE_STATE.sml DenseState.sml + + HYBRID_STATE.sml + HybridState.sml ExpandState.sml @@ -115,6 +118,9 @@ in signature DENSE_STATE functor DenseState + + signature HYBRID_STATE + functor HybridState functor ExpandState diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index cf567a9..e9ea4dd 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -139,22 +139,8 @@ fun print_sched_info () = type gate_idx = int type Schedule = gate_idx Seq.t -(*fun gate_to_schedule (ags: GateScheduler.args) (sched: GateScheduler.t) = - let val next = sched ags; - fun loadNext (acc: gate_idx Seq.t list) = - let val gs = next () in - if Seq.length gs = 0 then - Seq.flatten (Seq.rev (Seq.fromList acc)) - else - loadNext (gs :: acc) - end - in - loadNext nil - end*) - -fun dep_graph_to_schedule (ags: DepGraphScheduler.args) (sched: DepGraphScheduler.t) = - let val choose = sched ags - val dg = #depGraph ags +fun dep_graph_to_schedule (dg: DepGraph.t) (sched: DepGraphScheduler.t) = + let val choose = sched dg val st = DepGraphUtil.initState dg fun loadNext (acc: gate_idx list) = let val frntr = DepGraphUtil.frontier st in @@ -170,20 +156,6 @@ fun dep_graph_to_schedule (ags: DepGraphScheduler.args) (sched: DepGraphSchedule GateSchedulerOrder.mkScheduler (Seq.map Seq.singleton (loadNext nil)) end -structure Gate_branching = Gate (structure B = BasisIdxUnlimited - structure C = Complex64) - - -fun gate_branching ({ gates = gates, ...} : DepGraph.t) = - let val gates = Seq.map Gate_branching.fromGateDefn gates - fun gate i = Seq.nth gates i - in - (fn i => - case #action (gate i) of - Gate_branching.NonBranching _ => false - | _ => true) - end - val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride (*fun greedybranching () = From 1917b6db06fcd534fad48572a66a03e5ce16591a Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Thu, 7 Dec 2023 11:43:39 -0500 Subject: [PATCH 11/15] Reenable dense states --- feynsum-sml/src/FullSimBFS.sml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 39a5357..5114fe7 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -164,9 +164,9 @@ struct fun getNumZeros state = case state of HS.Sparse sst => SST.zeroSize sst - | HS.Dense ds => raise Fail "Can't do dense stuff!" + | HS.Dense ds => 0 (*raise Fail "Can't do dense stuff!"*) (*DS.unsafeViewContents ds, DS.nonZeroSize ds, TODO exception*) - | HS.DenseKnownNonZeroSize (ds, nz) => raise Fail "Can't do dense stuff!" + | HS.DenseKnownNonZeroSize (ds, nz) => 0 (*raise Fail "Can't do dense stuff!"*) (*DS.unsafeViewContents ds, nz, TODO exception*) val initialState = HS.Sparse From 3b2ac9c231693cfb145d84dcd5fcb6c9946f173e Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Tue, 19 Dec 2023 16:12:03 -0500 Subject: [PATCH 12/15] Refactor some code, and also convert qasm to DAG automatically (still need to convert DAG to dep graph, though) --- feynsum-sml/src/FullSimBFS.sml | 6 +- feynsum-sml/src/common/Circuit.sml | 100 +------- .../{DepGraph.sml => DataFlowGraph.sml} | 91 +++++-- ...uler.sml => DataFlowGraphDynScheduler.sml} | 34 +-- ...heduler.sml => DataFlowGraphScheduler.sml} | 4 +- ...DepGraphUtil.sml => DataFlowGraphUtil.sml} | 42 ++-- feynsum-sml/src/common/GateDefn.sml | 76 ++++++ feynsum-sml/src/common/GateScheduler.sml | 16 -- .../common/GateSchedulerGreedyBranching.sml | 157 ------------ .../common/GateSchedulerGreedyFinishQubit.sml | 178 -------------- .../GateSchedulerGreedyNonBranching.sml | 228 ------------------ feynsum-sml/src/common/GateSchedulerNaive.sml | 34 --- feynsum-sml/src/common/GateSchedulerOrder.sml | 31 --- feynsum-sml/src/common/HYBRID_STATE.sml | 19 ++ feynsum-sml/src/common/HybridState.sml | 20 ++ .../FinishQubit.sml} | 104 ++++---- .../GreedyBranching.sml} | 6 +- .../GreedyNonBranching.sml} | 8 +- feynsum-sml/src/common/sources.mlb | 52 ++-- feynsum-sml/src/main.sml | 79 +++--- 20 files changed, 353 insertions(+), 932 deletions(-) rename feynsum-sml/src/common/{DepGraph.sml => DataFlowGraph.sml} (60%) rename feynsum-sml/src/common/{DepGraphDynScheduler.sml => DataFlowGraphDynScheduler.sml} (85%) rename feynsum-sml/src/common/{DepGraphScheduler.sml => DataFlowGraphScheduler.sml} (78%) rename feynsum-sml/src/common/{DepGraphUtil.sml => DataFlowGraphUtil.sml} (75%) delete mode 100644 feynsum-sml/src/common/GateScheduler.sml delete mode 100644 feynsum-sml/src/common/GateSchedulerGreedyBranching.sml delete mode 100644 feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml delete mode 100644 feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml delete mode 100644 feynsum-sml/src/common/GateSchedulerNaive.sml delete mode 100644 feynsum-sml/src/common/GateSchedulerOrder.sml create mode 100644 feynsum-sml/src/common/HYBRID_STATE.sml create mode 100644 feynsum-sml/src/common/HybridState.sml rename feynsum-sml/src/common/{DepGraphSchedulerGreedyFinishQubit.sml => schedulers/FinishQubit.sml} (75%) rename feynsum-sml/src/common/{DepGraphSchedulerGreedyBranching.sml => schedulers/GreedyBranching.sml} (75%) rename feynsum-sml/src/common/{DepGraphSchedulerGreedyNonBranching.sml => schedulers/GreedyNonBranching.sml} (75%) diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 5114fe7..c0ad597 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -15,7 +15,7 @@ functor FullSimBFS val denseThreshold: real val pullThreshold: real): sig - val run: DepGraph.t + val run: DataFlowGraph.t -> {result: (B.t * C.t) option DelayedSeq.t, counts: int Seq.t} end = struct @@ -113,7 +113,7 @@ struct val numQubits = #numQubits depgraph fun gate i = Seq.nth gates i val depth = Seq.length gates - val dgstate = DepGraphUtil.initState depgraph + val dgstate = DataFlowGraphUtil.initState depgraph val pickNextGate = let val f = gateSched depgraph in @@ -173,7 +173,7 @@ struct (SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) fun runloop () = - DepGraphUtil.scheduleWithOracle' + DataFlowGraphUtil.scheduleWithOracle' (* dependency graph *) depgraph diff --git a/feynsum-sml/src/common/Circuit.sml b/feynsum-sml/src/common/Circuit.sml index 3020933..d45b3c6 100644 --- a/feynsum-sml/src/common/Circuit.sml +++ b/feynsum-sml/src/common/Circuit.sml @@ -3,7 +3,7 @@ sig type circuit = {numQubits: int, gates: GateDefn.t Seq.t} type t = circuit - val toString: circuit -> string + (*val toString: circuit -> string*) val numGates: circuit -> int val numQubits: circuit -> int @@ -214,102 +214,4 @@ struct {numQubits = numQubits, gates = Seq.map convertGate gates} end - - fun toString {numQubits, gates} = - let - val header = "qreg q[" ^ Int.toString numQubits ^ "];\n" - - fun qi i = - "q[" ^ Int.toString i ^ "]" - - fun doOther {name, params, args} = - let - val pstr = - if Seq.length params = 0 then - "()" - else - "(" - ^ - Seq.iterate (fn (acc, e) => acc ^ ", " ^ Real.toString e) - (Real.toString (Seq.nth params 0)) (Seq.drop params 1) ^ ")" - - val front = name ^ pstr - - val args = - if Seq.length args = 0 then - "" - else - Seq.iterate (fn (acc, i) => acc ^ ", " ^ qi i) - (qi (Seq.nth args 0)) (Seq.drop args 1) - in - front ^ " " ^ args - end - - fun gateToString gate = - case gate of - GateDefn.PauliY i => "y " ^ qi i - | GateDefn.PauliZ i => "z " ^ qi i - | GateDefn.Hadamard i => "h " ^ qi i - | GateDefn.T i => "t " ^ qi i - | GateDefn.Tdg i => "tdg " ^ qi i - | GateDefn.SqrtX i => "sx " ^ qi i - | GateDefn.Sxdg i => "sxdg " ^ qi i - | GateDefn.S i => "s " ^ qi i - | GateDefn.Sdg i => "sdg " ^ qi i - | GateDefn.X i => "x " ^ qi i - | GateDefn.CX {control, target} => "cx " ^ qi control ^ ", " ^ qi target - | GateDefn.CZ {control, target} => "cz " ^ qi control ^ ", " ^ qi target - | GateDefn.CCX {control1, control2, target} => - "ccx " ^ qi control1 ^ ", " ^ qi control2 ^ ", " ^ qi target - | GateDefn.Phase {target, rot} => - "phase(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.CPhase {control, target, rot} => - "cphase(" ^ Real.toString rot ^ ") " ^ qi control ^ ", " ^ qi target - | GateDefn.FSim {left, right, theta, phi} => - "fsim(" ^ Real.toString theta ^ ", " ^ Real.toString phi ^ ") " - ^ qi left ^ ", " ^ qi right - | GateDefn.RZ {rot, target} => - "rz(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.RY {rot, target} => - "ry(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.RX {rot, target} => - "rx(" ^ Real.toString rot ^ ") " ^ qi target - | GateDefn.CSwap {control, target1, target2} => - "cswap " ^ qi control ^ ", " ^ qi target1 ^ ", " ^ qi target2 - | GateDefn.Swap {target1, target2} => - "swap " ^ qi target1 ^ ", " ^ qi target2 - - | GateDefn.U {target, theta, phi, lambda} => - doOther - { name = "u" - , params = Seq.fromList [theta, phi, lambda] - , args = Seq.singleton target - } - - | GateDefn.Other {name, params, args} => - let - val pstr = - if Seq.length params = 0 then - "()" - else - "(" - ^ - Seq.iterate (fn (acc, e) => acc ^ ", " ^ Real.toString e) - (Real.toString (Seq.nth params 0)) (Seq.drop params 1) ^ ")" - - val front = name ^ pstr - - val args = - if Seq.length args = 0 then - "" - else - Seq.iterate (fn (acc, i) => acc ^ ", " ^ qi i) - (qi (Seq.nth args 0)) (Seq.drop args 1) - in - front ^ " " ^ args - end - in - Seq.iterate op^ header (Seq.map (fn g => gateToString g ^ ";\n") gates) - end - end diff --git a/feynsum-sml/src/common/DepGraph.sml b/feynsum-sml/src/common/DataFlowGraph.sml similarity index 60% rename from feynsum-sml/src/common/DepGraph.sml rename to feynsum-sml/src/common/DataFlowGraph.sml index 5537cf0..54a95b1 100644 --- a/feynsum-sml/src/common/DepGraph.sml +++ b/feynsum-sml/src/common/DataFlowGraph.sml @@ -1,20 +1,25 @@ -structure DepGraph: +structure DataFlowGraph: sig type gate_idx = int - type dep_graph = { + type data_flow_graph = { gates: GateDefn.t Seq.t, - deps: gate_idx Seq.t Seq.t, - indegree: int Seq.t, + preds: gate_idx Seq.t Seq.t, + succs: gate_idx Seq.t Seq.t, numQubits: int } - type t = dep_graph + type t = data_flow_graph - val fromJSON: JSON.value -> dep_graph - val fromString: string -> dep_graph - val fromFile: string -> dep_graph + val fromJSON: JSON.value -> data_flow_graph + val fromJSONString: string -> data_flow_graph + val fromJSONFile: string -> data_flow_graph + val fromQasm: Circuit.t -> data_flow_graph + + val toString: data_flow_graph -> string + (*val fromQasmString: string -> data_flow_graph*) + (*val fromQasmFile: string -> data_flow_graph*) (*val mkGateDefn: (string * real list option * int list) -> GateDefn.t*) @@ -23,14 +28,14 @@ struct type gate_idx = int - type dep_graph = { + type data_flow_graph = { gates: GateDefn.t Seq.t, - deps: gate_idx Seq.t Seq.t, - indegree: int Seq.t, + preds: gate_idx Seq.t Seq.t, + succs: gate_idx Seq.t Seq.t, numQubits: int } - type t = dep_graph + type t = data_flow_graph fun expect opt err = case opt of NONE => raise Fail err @@ -90,23 +95,24 @@ struct fun arrayToSeq a = Seq.tabulate (fn i => Array.sub (a, i)) (Array.length a) - fun getDepsInDeg (edges, N) = - let val deps = Array.array (N, nil); - val indeg = Array.array (N, 0); - fun incDeg (i) = Array.update (indeg, i, 1 + Array.sub (indeg, i)) + fun getSuccsPredsFromJSON (edges, N) = + let val preds = Array.array (N, nil); + val succs = Array.array (N, nil); + fun consAt (a, i, x) = Array.update (a, i, (x :: Array.sub (a, i))) fun go nil = () | go (JSON.ARRAY [JSON.INT fm, JSON.INT to] :: edges) = let val fm64 = IntInf.toInt fm val to64 = IntInf.toInt to - val () = incDeg to64 - val () = Array.update (deps, fm64, (to64 :: Array.sub (deps, fm64))) + val _ = consAt (preds, to64, fm64) + val _ = consAt (succs, fm64, to64) in go edges end | go (_ :: edges) = raise Fail "Malformed edge in JSON" val () = go edges; + val toSeq = Seq.map (Seq.rev o Seq.fromList) o arrayToSeq in - (Seq.map Seq.fromList (arrayToSeq deps), arrayToSeq indeg) + (toSeq preds, toSeq succs) end fun fromJSON (data) = @@ -126,11 +132,50 @@ struct val edges = case JSONUtil.findField data "edges" of SOME (JSON.ARRAY es) => es | _ => raise Fail "Expected array field 'nodes' in JSON" - val (deps, indegree) = getDepsInDeg (edges, Seq.length gates) + val (preds, succs) = getSuccsPredsFromJSON (edges, Seq.length gates) in - { gates = gates, deps = deps, indegree = indegree, numQubits = numqs } + { gates = gates, + preds = preds, + succs = succs, + numQubits = numqs } (*Seq.zipWith (fn (g, (d, i)) => {gate = g, deps = d, indegree = i}) (gates, Seq.zip (deps, indegree))*) end - fun fromString (str) = fromJSON (JSONParser.parse (JSONParser.openString str)) - fun fromFile (file) = fromJSON (JSONParser.parseFile file) + fun fromJSONString (str) = fromJSON (JSONParser.parse (JSONParser.openString str)) + fun fromJSONFile (file) = fromJSON (JSONParser.parseFile file) + + (* TODO: convert to dependency graph *) + fun fromQasm {numQubits, gates} = + let val numGates = Seq.length gates + val qubitLastGate = Array.array (numQubits, ~1) + val preds = Array.array (numGates, nil) + val succs = Array.array (numGates, nil) + fun fillPreds gidx = + if gidx >= numGates then + () + else + let val gate = Seq.nth gates gidx + val args = GateDefn.getGateArgs gate + val lasts = List.filter (fn i => i >= 0) (List.map (fn qidx => Array.sub (qubitLastGate, qidx)) args) + val _ = Array.update (preds, gidx, lasts) + val _ = List.map (fn gidx' => Array.update (succs, gidx', gidx :: Array.sub (succs, gidx'))) lasts + val _ = List.map (fn qidx => Array.update (qubitLastGate, qidx, gidx)) args + in + fillPreds (gidx + 1) + end + val _ = fillPreds 0 + val predsSeq = Seq.map Seq.fromList (arrayToSeq preds) + val succsSeq = Seq.map (Seq.rev o Seq.fromList) (arrayToSeq succs) + in + { gates = gates, + preds = predsSeq, + succs = succsSeq, + numQubits = numQubits } + end + +fun toString {gates, preds, succs, numQubits} = + let val header = "qreg q[" ^ Int.toString numQubits ^ "];\n" + fun qi i = "q[" ^ Int.toString i ^ "]" + in + Seq.iterate op^ header (Seq.map (fn g => GateDefn.toString g qi ^ ";\n") gates) + end end diff --git a/feynsum-sml/src/common/DepGraphDynScheduler.sml b/feynsum-sml/src/common/DataFlowGraphDynScheduler.sml similarity index 85% rename from feynsum-sml/src/common/DepGraphDynScheduler.sml rename to feynsum-sml/src/common/DataFlowGraphDynScheduler.sml index 7ee0a37..de844c0 100644 --- a/feynsum-sml/src/common/DepGraphDynScheduler.sml +++ b/feynsum-sml/src/common/DataFlowGraphDynScheduler.sml @@ -1,4 +1,4 @@ -signature DEP_GRAPH_DYN_SCHEDULER = +signature DATA_FLOW_GRAPH_DYN_SCHEDULER = sig structure B: BASIS_IDX structure C: COMPLEX @@ -8,7 +8,7 @@ sig type gate_idx = int - type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) val choose: t end @@ -21,7 +21,7 @@ functor DynSchedFinishQubitWrapper sharing C = HS.C val maxBranchingStride: int val disableFusion: bool - ): DEP_GRAPH_DYN_SCHEDULER = + ): DATA_FLOW_GRAPH_DYN_SCHEDULER = struct structure B = B structure C = C @@ -29,9 +29,9 @@ struct type gate_idx = int - type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) - structure DGFQ = DepGraphSchedulerGreedyFinishQubit + structure FQS = FinishQubitScheduler (val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) @@ -39,8 +39,8 @@ struct (structure B = B structure C = C) - fun choose (depgraph: DepGraph.t) = - let val f = DGFQ.scheduler5 depgraph in + fun choose (depgraph: DataFlowGraph.t) = + let val f = FQS.scheduler5 depgraph in fn (_, gates) => f gates end @@ -54,7 +54,7 @@ functor DynSchedNaive sharing C = HS.C val maxBranchingStride: int val disableFusion: bool - ): DEP_GRAPH_DYN_SCHEDULER = + ): DATA_FLOW_GRAPH_DYN_SCHEDULER = struct structure B = B structure C = C @@ -62,9 +62,9 @@ struct type gate_idx = int - type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) - fun choose (depgraph: DepGraph.t) = fn (_, gates) => Seq.nth gates 0 + fun choose (depgraph: DataFlowGraph.t) = fn (_, gates) => Seq.nth gates 0 end @@ -77,7 +77,7 @@ functor DynSchedInterference sharing C = HS.C val maxBranchingStride: int val disableFusion: bool - ): DEP_GRAPH_DYN_SCHEDULER = + ): DATA_FLOW_GRAPH_DYN_SCHEDULER = struct structure B = B structure C = C @@ -85,9 +85,9 @@ struct type gate_idx = int - type t = DepGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) + type t = DataFlowGraph.t -> (HS.t * gate_idx Seq.t -> gate_idx) - structure DGFQ = DepGraphSchedulerGreedyFinishQubit + structure FQS = FinishQubitScheduler (val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) @@ -128,7 +128,7 @@ struct Seq.map isbranched branchedQubits end - fun choose (depgraph: DepGraph.t) = + fun choose (depgraph: DataFlowGraph.t) = let val gates = Seq.map G.fromGateDefn (#gates depgraph) val branchSeq = Seq.map G.expectBranching gates fun branching i = Seq.nth branchSeq i @@ -165,7 +165,7 @@ struct end -(*functor DepGraphDynScheduler +(*functor DataFlowGraphDynScheduler (structure B: BASIS_IDX structure C: COMPLEX structure SST: SPARSE_STATE_TABLE @@ -176,7 +176,7 @@ end val blockSize: int val maxload: real val denseThreshold: real - val pullThreshold: real): DEP_GRAPH_DYN_SCHEDULER = + val pullThreshold: real): DATA_FLOW_GRAPH_DYN_SCHEDULER = struct type gate_idx = int structure Expander = @@ -194,6 +194,6 @@ struct ( * From a frontier, select which gate to apply next * ) ( * args visit gates, update frontier break fusion initial frontier gate batches * ) ( * type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t* ) - type t = DepGraph.t -> (Expander.state * gate_idx Seq.t -> gate_idx) + type t = DataFlowGraph.t -> (Expander.state * gate_idx Seq.t -> gate_idx) end *) diff --git a/feynsum-sml/src/common/DepGraphScheduler.sml b/feynsum-sml/src/common/DataFlowGraphScheduler.sml similarity index 78% rename from feynsum-sml/src/common/DepGraphScheduler.sml rename to feynsum-sml/src/common/DataFlowGraphScheduler.sml index ccf01de..1712ef1 100644 --- a/feynsum-sml/src/common/DepGraphScheduler.sml +++ b/feynsum-sml/src/common/DataFlowGraphScheduler.sml @@ -1,4 +1,4 @@ -structure DepGraphScheduler = +structure Scheduler = struct type gate_idx = int @@ -6,5 +6,5 @@ struct (* From a frontier, select which gate to apply next *) (* args visit gates, update frontier break fusion initial frontier gate batches *) (*type t = args -> (gate_idx -> gate_idx Seq.t) -> (unit -> unit) -> gate_idx Seq.t -> gate_idx Seq.t Seq.t*) - type t = DepGraph.t -> (gate_idx Seq.t -> gate_idx) + type t = DataFlowGraph.t -> (gate_idx Seq.t -> gate_idx) end diff --git a/feynsum-sml/src/common/DepGraphUtil.sml b/feynsum-sml/src/common/DataFlowGraphUtil.sml similarity index 75% rename from feynsum-sml/src/common/DepGraphUtil.sml rename to feynsum-sml/src/common/DataFlowGraphUtil.sml index 115aa0f..ffdb5da 100644 --- a/feynsum-sml/src/common/DepGraphUtil.sml +++ b/feynsum-sml/src/common/DataFlowGraphUtil.sml @@ -1,4 +1,4 @@ -structure DepGraphUtil :> +structure DataFlowGraphUtil :> sig type gate_idx = int @@ -6,31 +6,37 @@ sig (* Traversal Automaton State *) type state = { visited: bool array, indegree: int array } - type dep_graph = DepGraph.t + type data_flow_graph = DataFlowGraph.t - val visit: dep_graph -> gate_idx -> state -> unit + val visit: data_flow_graph -> gate_idx -> state -> unit val frontier: state -> gate_idx Seq.t - val initState: dep_graph -> state + val initState: data_flow_graph -> state (* Switches edge directions *) - val transpose: dep_graph -> dep_graph + val transpose: data_flow_graph -> data_flow_graph - val scheduleWithOracle: dep_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> bool -> int -> gate_idx Seq.t Seq.t + val scheduleWithOracle: data_flow_graph -> (gate_idx -> bool) -> (gate_idx Seq.t -> gate_idx) -> bool -> int -> gate_idx Seq.t Seq.t - val scheduleWithOracle': dep_graph -> (gate_idx -> bool) -> ('state * gate_idx Seq.t -> gate_idx) -> bool -> int -> ('state * gate_idx Seq.t -> 'state) -> 'state -> 'state + val scheduleWithOracle': data_flow_graph -> (gate_idx -> bool) -> ('state * gate_idx Seq.t -> gate_idx) -> bool -> int -> ('state * gate_idx Seq.t -> 'state) -> 'state -> 'state val scheduleCost: gate_idx Seq.t Seq.t -> (gate_idx -> bool) -> real val chooseSchedule: gate_idx Seq.t Seq.t Seq.t -> (gate_idx -> bool) -> gate_idx Seq.t Seq.t - val gateIsBranching: dep_graph -> (gate_idx -> bool) + val gateIsBranching: data_flow_graph -> (gate_idx -> bool) end = struct type gate_idx = int - type dep_graph = DepGraph.t + type data_flow_graph = DataFlowGraph.t - fun transpose ({gates = gs, deps = ds, indegree = is, numQubits = qs}: dep_graph) = + fun transpose ({gates, preds, succs, numQubits}: data_flow_graph) = + { gates = gates, + preds = succs, + succs = preds, + numQubits = numQubits } + + (*fun transpose ({gates = gs, deps = ds, indegree = is, numQubits = qs}: data_flow_graph) = let val N = Seq.length gs val ds2 = Array.array (N, nil) fun apply i = Seq.map (fn j => Array.update (ds2, j, i :: Array.sub (ds2, j))) (Seq.nth ds i) @@ -40,16 +46,16 @@ struct deps = Seq.tabulate (fn i => Seq.rev (Seq.fromList (Array.sub (ds2, i)))) N, indegree = Seq.map Seq.length ds, numQubits = qs} - end + end*) type state = { visited: bool array, indegree: int array } - fun visit {gates = _, deps = ds, indegree = _, numQubits = _} i {visited = vis, indegree = deg} = + fun visit {succs = succs, ...} i {visited = vis, indegree = deg} = ( (* Set visited[i] = true *) Array.update (vis, i, true); (* Decrement indegree of each i dependency *) - Seq.map (fn j => Array.update (deg, j, Array.sub (deg, j) - 1)) (Seq.nth ds i); + Seq.map (fn j => Array.update (deg, j, Array.sub (deg, j) - 1)) (Seq.nth succs i); () ) @@ -66,15 +72,15 @@ struct Seq.fromList (iter (N - 1) nil) end - fun initState (graph: dep_graph) = + fun initState (graph: data_flow_graph) = let val N = Seq.length (#gates graph) val vis = Array.array (N, false) - val deg = Array.tabulate (N, Seq.nth (#indegree graph)) + val deg = Array.tabulate (N, Seq.length o Seq.nth (#preds graph)) in { visited = vis, indegree = deg } end - fun scheduleWithOracle' (graph: dep_graph) (branching: gate_idx -> bool) (choose: 'state * gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) (apply: 'state * gate_idx Seq.t -> 'state) state = + fun scheduleWithOracle' (graph: data_flow_graph) (branching: gate_idx -> bool) (choose: 'state * gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) (apply: 'state * gate_idx Seq.t -> 'state) state = let val dgst = initState graph fun findNonBranching (i: int) (xs: gate_idx Seq.t) = if i = Seq.length xs then @@ -118,7 +124,7 @@ struct loadNext 0 (loadNonBranching nil) state end - fun scheduleWithOracle (graph: dep_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) = Seq.rev (Seq.fromList (scheduleWithOracle' graph branching (fn (_, x) => choose x) disableFusion maxBranchingStride (fn (gs, g) => g :: gs) nil)) + fun scheduleWithOracle (graph: data_flow_graph) (branching: gate_idx -> bool) (choose: gate_idx Seq.t -> gate_idx) (disableFusion: bool) (maxBranchingStride: int) = Seq.rev (Seq.fromList (scheduleWithOracle' graph branching (fn (_, x) => choose x) disableFusion maxBranchingStride (fn (gs, g) => g :: gs) nil)) (*fun scheduleCost2 (order: gate_idx Seq.t Seq.t) (branching: gate_idx -> bool) = let val gates = Seq.flatten order @@ -165,7 +171,7 @@ struct structure C = Complex64) - fun gateIsBranching ({ gates = gates, ...} : dep_graph) = + fun gateIsBranching ({ gates = gates, ...} : data_flow_graph) = let val branchSeq = Seq.map (fn g => Gate_branching.expectBranching (Gate_branching.fromGateDefn g)) gates in fn i => Seq.nth branchSeq i end diff --git a/feynsum-sml/src/common/GateDefn.sml b/feynsum-sml/src/common/GateDefn.sml index a2f7928..4387268 100644 --- a/feynsum-sml/src/common/GateDefn.sml +++ b/feynsum-sml/src/common/GateDefn.sml @@ -41,4 +41,80 @@ struct type gate = t + fun getGateArgs (g: gate) = case g of + PauliY i => [i] + | PauliZ i => [i] + | Hadamard i => [i] + | SqrtX i => [i] + | Sxdg i => [i] + | S i => [i] + | Sdg i => [i] + | X i => [i] + | T i => [i] + | Tdg i => [i] + | CX {control = i, target = j} => [i, j] + | CZ {control = i, target = j} => [i, j] + | CCX {control1 = i, control2 = j, target = k} => [i, j, k] + | Phase {target = i, ...} => [i] + | CPhase {control = i, target = j, ...} => [i, j] + | FSim {left = i, right = j, ...} => [i, j] + | RZ {target = i, ...} => [i] + | RY {target = i, ...} => [i] + | RX {target = i, ...} => [i] + | Swap {target1 = i, target2 = j} => [i, j] + | CSwap {control = i, target1 = j, target2 = k} => [i, j, k] + | U {target = i, ...} => [i] + | Other {args = args, ...} => Seq.toList args + + fun toString (g: gate) (qi: qubit_idx -> string) = case g of + PauliY i => "y " ^ qi i + | PauliZ i => "z " ^ qi i + | Hadamard i => "h " ^ qi i + | T i => "t " ^ qi i + | Tdg i => "tdg " ^ qi i + | SqrtX i => "sx " ^ qi i + | Sxdg i => "sxdg " ^ qi i + | S i => "s " ^ qi i + | Sdg i => "sdg " ^ qi i + | X i => "x " ^ qi i + | CX {control, target} => "cx " ^ qi control ^ ", " ^ qi target + | CZ {control, target} => "cz " ^ qi control ^ ", " ^ qi target + | CCX {control1, control2, target} => + "ccx " ^ qi control1 ^ ", " ^ qi control2 ^ ", " ^ qi target + | Phase {target, rot} => + "phase(" ^ Real.toString rot ^ ") " ^ qi target + | CPhase {control, target, rot} => + "cphase(" ^ Real.toString rot ^ ") " ^ qi control ^ ", " ^ qi target + | FSim {left, right, theta, phi} => + "fsim(" ^ Real.toString theta ^ ", " ^ Real.toString phi ^ ") " + ^ qi left ^ ", " ^ qi right + | RZ {rot, target} => + "rz(" ^ Real.toString rot ^ ") " ^ qi target + | RY {rot, target} => + "ry(" ^ Real.toString rot ^ ") " ^ qi target + | RX {rot, target} => + "rx(" ^ Real.toString rot ^ ") " ^ qi target + | CSwap {control, target1, target2} => + "cswap " ^ qi control ^ ", " ^ qi target1 ^ ", " ^ qi target2 + | Swap {target1, target2} => + "swap " ^ qi target1 ^ ", " ^ qi target2 + | U {target, theta, phi, lambda} => + "u(" ^ Real.toString theta ^ ", " ^ Real.toString phi ^ ", " ^ Real.toString lambda ^ ") " ^ qi target + | Other {name, params, args} => + let val pstr = + if Seq.length params = 0 then + "()" + else + "(" ^ Seq.iterate (fn (acc, e) => acc ^ ", " ^ Real.toString e) + (Real.toString (Seq.nth params 0)) (Seq.drop params 1) ^ ")" + val front = name ^ pstr + val args = + if Seq.length args = 0 then + "" + else + Seq.iterate (fn (acc, i) => acc ^ ", " ^ qi i) + (qi (Seq.nth args 0)) (Seq.drop args 1) + in + front ^ " " ^ args + end end diff --git a/feynsum-sml/src/common/GateScheduler.sml b/feynsum-sml/src/common/GateScheduler.sml deleted file mode 100644 index 9009653..0000000 --- a/feynsum-sml/src/common/GateScheduler.sml +++ /dev/null @@ -1,16 +0,0 @@ -structure GateScheduler = -struct - - type qubit_idx = int - type gate_idx = int - - type args = - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - } - - type t = args -> (unit -> gate_idx Seq.t) - -end diff --git a/feynsum-sml/src/common/GateSchedulerGreedyBranching.sml b/feynsum-sml/src/common/GateSchedulerGreedyBranching.sml deleted file mode 100644 index c23c303..0000000 --- a/feynsum-sml/src/common/GateSchedulerGreedyBranching.sml +++ /dev/null @@ -1,157 +0,0 @@ -structure GateSchedulerGreedyBranching: -sig - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - - datatype sched = - S of - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - (* each qubit keeps track of which gate is next *) - , frontier: gate_idx array - } - - - type t = sched - - - fun contains x xs = - Util.exists (0, Seq.length xs) (fn i => Seq.nth xs i = x) - - - fun nextTouch (xxx as {numGates, gateTouches}) qubit gidx = - if gidx >= numGates then numGates - else if contains qubit (gateTouches gidx) then gidx - else nextTouch xxx qubit (gidx + 1) - - - fun new {numQubits, numGates, gateTouches, gateIsBranching} = - S { numQubits = numQubits - , numGates = numGates - , gateTouches = gateTouches - , gateIsBranching = gateIsBranching - , frontier = SeqBasis.tabulate 100 (0, numQubits) (fn i => - nextTouch {numGates = numGates, gateTouches = gateTouches} i 0) - } - - - (* It's safe to visit a gate G if, for all qubits the gate touches, the - * qubit's next gate is G. - *) - fun okayToVisit - (S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) gidx = - if gidx >= numGates then - false - else - let - val touches = gateTouches gidx - in - Util.all (0, Seq.length touches) (fn i => - Array.sub (frontier, Seq.nth touches i) = gidx) - end - - - fun tryVisit - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) - gidx = - if not (okayToVisit sched gidx) then - false - else - let - (* val _ = print - ("GateScheduler.tryVisit " ^ Int.toString gidx ^ " numGates " - ^ Int.toString numGates ^ " frontier " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val touches = gateTouches gidx - in - ( Util.for (0, Seq.length touches) (fn i => - let - val qi = Seq.nth touches i - val next = - nextTouch {numGates = numGates, gateTouches = gateTouches} qi - (gidx + 1) - in - Array.update (frontier, qi, next) - end) - - ; true - ) - end - - - fun visitBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso gateIsBranching gi - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - if tryVisit sched gidx then () - else raise Fail "GateSchedulerGreedyBranching.visitBranching: error"; - - SOME gidx - end - end - - - fun visitNonBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - (* val _ = print - ("visitNonBranching frontier = " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso not (gateIsBranching gi) - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - (* print - ("visitNonBranching trying to visit " ^ Int.toString gidx ^ "\n"); *) - - if tryVisit sched gidx then - () - else - raise Fail "GateSchedulerGreedyBranching.visitNonBranching: error"; - - SOME gidx - end - end - - - fun pickNext sched = - case visitBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => - case visitNonBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => Seq.empty () - - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml deleted file mode 100644 index 7f1b579..0000000 --- a/feynsum-sml/src/common/GateSchedulerGreedyFinishQubit.sml +++ /dev/null @@ -1,178 +0,0 @@ -functor GateSchedulerGreedyFinishQubit - (val maxBranchingStride: int val disableFusion: bool): -sig - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - - datatype sched = - S of - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - (* each qubit keeps track of which gate is next *) - , frontier: gate_idx array - } - - - type t = sched - - - fun contains x xs = - Util.exists (0, Seq.length xs) (fn i => Seq.nth xs i = x) - - - fun nextTouch (xxx as {numGates, gateTouches}) qubit gidx = - if gidx >= numGates then numGates - else if contains qubit (gateTouches gidx) then gidx - else nextTouch xxx qubit (gidx + 1) - - - fun new {numQubits, numGates, gateTouches, gateIsBranching} = - S { numQubits = numQubits - , numGates = numGates - , gateTouches = gateTouches - , gateIsBranching = gateIsBranching - , frontier = SeqBasis.tabulate 100 (0, numQubits) (fn i => - nextTouch {numGates = numGates, gateTouches = gateTouches} i 0) - } - - - (* It's safe to visit a gate G if, for all qubits the gate touches, the - * qubit's next gate is G. - *) - fun okayToVisit - (S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) gidx = - if gidx >= numGates then - false - else - let - val touches = gateTouches gidx - in - Util.all (0, Seq.length touches) (fn i => - Array.sub (frontier, Seq.nth touches i) = gidx) - end - - - fun tryVisit - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) - gidx = - if not (okayToVisit sched gidx) then - false - else - let - (* val _ = print - ("GateScheduler.tryVisit " ^ Int.toString gidx ^ " numGates " - ^ Int.toString numGates ^ " frontier " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val touches = gateTouches gidx - in - ( Util.for (0, Seq.length touches) (fn i => - let - val qi = Seq.nth touches i - val next = - nextTouch {numGates = numGates, gateTouches = gateTouches} qi - (gidx + 1) - in - Array.update (frontier, qi, next) - end) - - ; true - ) - end - - - fun peekNext - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - fun makeProgressOnQubit qi = - let - val desiredGate = Array.sub (frontier, qi) - in - if okayToVisit sched desiredGate then - desiredGate - else - (* Find which qubit is not ready to do this gate yet, and make - * progress on that qubit (which might recursively need to - * make progress on a different qubit, etc.) - *) - let - val touches = gateTouches desiredGate - val dependency = - FindFirst.findFirstSerial (0, Seq.length touches) (fn i => - let val qj = Seq.nth touches i - in Array.sub (frontier, qj) < desiredGate - end) - in - case dependency of - SOME i => makeProgressOnQubit (Seq.nth touches i) - | NONE => - raise Fail - "GateSchedulerGreedyFinishQubit.peekNext.makeProgressOnQubit: error" - end - end - - val unfinishedQubit = FindFirst.findFirstSerial (0, numQubits) (fn qi => - Array.sub (frontier, qi) < numGates) - in - case unfinishedQubit of - NONE => NONE - | SOME qi => SOME (makeProgressOnQubit qi) - end - - - fun pickNextNoFusion sched = - case peekNext sched of - NONE => Seq.empty () - | SOME gidx => - ( if tryVisit sched gidx then - () - else - raise Fail - "GateSchedulerGreedyFinishQubit.pickNextNoFusion: visit failed (should be impossible)" - ; Seq.singleton gidx - ) - - - fun pickNext (sched as S {gateIsBranching, ...}) = - if disableFusion then - pickNextNoFusion sched - else - let - fun loop acc numBranchingSoFar = - if numBranchingSoFar >= maxBranchingStride then - acc - else - case peekNext sched of - NONE => acc - | SOME gidx => - let - val numBranchingSoFar' = - numBranchingSoFar + (if gateIsBranching gidx then 1 else 0) - in - if tryVisit sched gidx then - () - else - raise Fail - "GateSchedulerGreedyFinishQubit.pickNext.loop: visit failed (should be impossible)"; - - loop (gidx :: acc) numBranchingSoFar' - end - - val acc = loop [] 0 - in - Seq.fromRevList acc - end - - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml b/feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml deleted file mode 100644 index 6b2e80a..0000000 --- a/feynsum-sml/src/common/GateSchedulerGreedyNonBranching.sml +++ /dev/null @@ -1,228 +0,0 @@ -functor GateSchedulerGreedyNonBranching - (val maxBranchingStride: int val disableFusion: bool): -sig - type sched - type t = sched - - type qubit_idx = int - type gate_idx = int - - val new: GateScheduler.args -> sched - - val tryVisit: sched -> gate_idx -> bool - val visitMaximalNonBranchingRun: sched -> gate_idx Seq.t - val visitBranching: sched -> gate_idx option - val visitNonBranching: sched -> gate_idx option - val pickNext: sched -> gate_idx Seq.t - - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - - datatype sched = - S of - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - (* each qubit keeps track of which gate is next *) - , frontier: gate_idx array - } - - - type t = sched - - - fun contains x xs = - Util.exists (0, Seq.length xs) (fn i => Seq.nth xs i = x) - - - fun nextTouch (xxx as {numGates, gateTouches}) qubit gidx = - if gidx >= numGates then numGates - else if contains qubit (gateTouches gidx) then gidx - else nextTouch xxx qubit (gidx + 1) - - - fun new {numQubits, numGates, gateTouches, gateIsBranching} = - S { numQubits = numQubits - , numGates = numGates - , gateTouches = gateTouches - , gateIsBranching = gateIsBranching - , frontier = SeqBasis.tabulate 100 (0, numQubits) (fn i => - nextTouch {numGates = numGates, gateTouches = gateTouches} i 0) - } - - - fun okayToVisit - (S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) gidx = - if gidx >= numGates then - false - else - let - val touches = gateTouches gidx - in - Util.all (0, Seq.length touches) (fn i => - Array.sub (frontier, Seq.nth touches i) = gidx) - end - - - (* It's safe to visit a gate G if, for all qubits the gate touches, the - * qubit's next gate is G. - *) - fun tryVisit - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) - gidx = - if not (okayToVisit sched gidx) then - false - else - let - (* val _ = print - ("GateScheduler.tryVisit " ^ Int.toString gidx ^ " numGates " - ^ Int.toString numGates ^ " frontier " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val touches = gateTouches gidx - in - ( Util.for (0, Seq.length touches) (fn i => - let - val qi = Seq.nth touches i - val next = - nextTouch {numGates = numGates, gateTouches = gateTouches} qi - (gidx + 1) - in - Array.update (frontier, qi, next) - end) - - ; true - ) - end - - - fun visitMaximalNonBranchingRun - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - (* visit as many non-branching gates on qubit qi as possible *) - fun loopQubit acc qi = - let - val nextg = Array.sub (frontier, qi) - in - if - nextg >= numGates orelse gateIsBranching nextg - orelse not (tryVisit sched nextg) - then acc - else loopQubit (nextg :: acc) qi - end - - fun loop acc = - let - val selection = Util.loop (0, numQubits) [] (fn (acc, qi) => - loopQubit acc qi) - in - case selection of - [] => acc - | _ => loop (selection @ acc) - end - in - Seq.fromRevList (loop []) - end - - - fun visitBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso gateIsBranching gi - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - if tryVisit sched gidx then - () - else - raise Fail "GateSchedulerGreedyNonBranching.visitBranching: error"; - - SOME gidx - end - end - - - fun visitNonBranching - (sched as S {numQubits, numGates, gateTouches, gateIsBranching, frontier}) = - let - (* val _ = print - ("visitNonBranching frontier = " - ^ Seq.toString Int.toString (ArraySlice.full frontier) ^ "\n") *) - val possibles = - Seq.filter - (fn gi => - gi < numGates andalso not (gateIsBranching gi) - andalso okayToVisit sched gi) (ArraySlice.full frontier) - in - if Seq.length possibles = 0 then - NONE - else - let - val gidx = Seq.nth possibles 0 - in - (* print - ("visitNonBranching trying to visit " ^ Int.toString gidx ^ "\n"); *) - - if tryVisit sched gidx then - () - else - raise Fail - "GateSchedulerGreedyNonBranching.visitNonBranching: error"; - - SOME gidx - end - end - - - fun pickNextNoFusion sched = - case visitNonBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => - case visitBranching sched of - SOME gidx => Seq.singleton gidx - | NONE => Seq.empty () - - - fun pickNext sched = - if disableFusion then - pickNextNoFusion sched - else - let - fun loop acc numBranchingSoFar = - if numBranchingSoFar >= maxBranchingStride then - acc - else - let - val nb = visitMaximalNonBranchingRun sched - in - case visitBranching sched of - NONE => nb :: acc - | SOME gidx => - loop (Seq.singleton gidx :: nb :: acc) (numBranchingSoFar + 1) - end - - val acc = loop [] 0 - in - Seq.flatten (Seq.fromRevList acc) - end - - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerNaive.sml b/feynsum-sml/src/common/GateSchedulerNaive.sml deleted file mode 100644 index 917c476..0000000 --- a/feynsum-sml/src/common/GateSchedulerNaive.sml +++ /dev/null @@ -1,34 +0,0 @@ -(* A "naive" scheduler: execute in straightline order, i.e., as the gates are - * written in the input qasm. No fusion. - *) -structure GateSchedulerNaive: -sig - val scheduler: GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - datatype sched = S of {numGates: int, next: int ref} - - type t = sched - - fun new - { numQubits: int - , numGates: int - , gateTouches: gate_idx -> qubit_idx Seq.t - , gateIsBranching: gate_idx -> bool - } = - S {numGates = numGates, next = ref 0} - - fun pickNext (S {numGates, next, ...}) = - if !next >= numGates then Seq.empty () - else let val gi = !next in next := gi + 1; Seq.singleton gi end - - fun scheduler args = - let val sched = new args - in fn () => pickNext sched - end - -end diff --git a/feynsum-sml/src/common/GateSchedulerOrder.sml b/feynsum-sml/src/common/GateSchedulerOrder.sml deleted file mode 100644 index 8235706..0000000 --- a/feynsum-sml/src/common/GateSchedulerOrder.sml +++ /dev/null @@ -1,31 +0,0 @@ -structure GateSchedulerOrder: -sig - type gate_idx = int - val mkScheduler: gate_idx Seq.t Seq.t -> GateScheduler.t -end = -struct - - type qubit_idx = int - type gate_idx = int - - fun mkScheduler (order: gate_idx Seq.t Seq.t) (args: GateScheduler.args) = - let val i = ref 0 - val N = Seq.length order - in - print ("Schedule cost: " ^ Real.toString (DepGraphUtil.scheduleCost order (#gateIsBranching args)) ^ "\n"); - fn () => if !i >= N then - Seq.empty () - else - (i := !i + 1; Seq.nth order (!i - 1)) - end - - (*fun mkSchedulerFusion (order: gate_idx Seq.t Seq.t) args = - let val i = ref 0 - val N = Seq.length order - in - fn () => if !i >= N then - Seq.empty () - else - let val gi = !i in i := gi + 1; Seq.singleton gi end - end*) -end diff --git a/feynsum-sml/src/common/HYBRID_STATE.sml b/feynsum-sml/src/common/HYBRID_STATE.sml new file mode 100644 index 0000000..9d1f3f9 --- /dev/null +++ b/feynsum-sml/src/common/HYBRID_STATE.sml @@ -0,0 +1,19 @@ +signature HYBRID_STATE = +sig + structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C + + (*type t + type state = t*) + datatype state = + Sparse of SST.t + | Dense of DS.t + | DenseKnownNonZeroSize of DS.t * int + + type t = state + +end diff --git a/feynsum-sml/src/common/HybridState.sml b/feynsum-sml/src/common/HybridState.sml new file mode 100644 index 0000000..17cdcc2 --- /dev/null +++ b/feynsum-sml/src/common/HybridState.sml @@ -0,0 +1,20 @@ +functor HybridState + (structure B: BASIS_IDX + structure C: COMPLEX + structure SST: SPARSE_STATE_TABLE + structure DS: DENSE_STATE + sharing B = SST.B = DS.B + sharing C = SST.C = DS.C): HYBRID_STATE = +struct + structure B = B + structure C = C + structure SST = SST + structure DS = DS + + datatype state = + Sparse of SST.t + | Dense of DS.t + | DenseKnownNonZeroSize of DS.t * int + + type t = state +end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml b/feynsum-sml/src/common/schedulers/FinishQubit.sml similarity index 75% rename from feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml rename to feynsum-sml/src/common/schedulers/FinishQubit.sml index 2282433..a4ae826 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyFinishQubit.sml +++ b/feynsum-sml/src/common/schedulers/FinishQubit.sml @@ -1,12 +1,12 @@ -functor DepGraphSchedulerGreedyFinishQubit +functor FinishQubitScheduler (val maxBranchingStride: int val disableFusion: bool): sig - val scheduler: DepGraphScheduler.t - val scheduler2: DepGraphScheduler.t - val scheduler3: DepGraphScheduler.t - val scheduler4: DepGraphScheduler.t - val scheduler5: DepGraphScheduler.t - val schedulerRandom: int -> DepGraphScheduler.t + val scheduler: Scheduler.t + val scheduler2: Scheduler.t + val scheduler3: Scheduler.t + val scheduler4: Scheduler.t + val scheduler5: Scheduler.t + val schedulerRandom: int -> Scheduler.t end = struct @@ -15,19 +15,19 @@ struct fun DFS ((new, old) : (int list * int list)) = new @ old fun BFS ((new, old) : (int list * int list)) = old @ new - fun revTopologicalSort (dg: DepGraph.t) (push: (int list * int list) -> int list) = - let val N = Seq.length (#gates dg) - val ind = Array.tabulate (N, Seq.nth (#indegree dg)) + fun revTopologicalSort (dfg: DataFlowGraph.t) (push: (int list * int list) -> int list) = + let val N = Seq.length (#gates dfg) + val ind = Array.tabulate (N, Seq.length o Seq.nth (#preds dfg)) fun decInd i = let val d = Array.sub (ind, i) in Array.update (ind, i, d - 1); d - 1 end val queue = ref nil (*val push = case tr of BFS => (fn xs => queue := (!queue) @ xs) | DFS => (fn xs => queue := xs @ (!queue))*) fun pop () = case !queue of nil => NONE | x :: xs => (queue := xs; SOME x) - val _ = queue := push (List.filter (fn i => Seq.nth (#indegree dg) i = 0) (List.tabulate (N, fn i => i)), !queue) + val _ = queue := push (List.filter (fn i => Seq.length (Seq.nth (#preds dfg) i) = 0) (List.tabulate (N, fn i => i)), !queue) fun loop L = case pop () of NONE => L | SOME n => - (let val ndeps = Seq.nth (#deps dg) n in + (let val ndeps = Seq.nth (#succs dfg) n in queue := push (List.filter (fn m => decInd m = 0) (List.tabulate (Seq.length ndeps, Seq.nth ndeps)), !queue) end; loop (n :: L)) @@ -35,21 +35,21 @@ struct loop nil end - fun topologicalSort (dg: DepGraph.t) (push: (int list * int list) -> int list) = - List.rev (revTopologicalSort dg push) + fun topologicalSort (dfg: DataFlowGraph.t) (push: (int list * int list) -> int list) = + List.rev (revTopologicalSort dfg push) val gateDepths: int array option ref = ref NONE - fun computeGateDepths (dg: DepGraph.t) = - let val N = Seq.length (#gates dg) + fun computeGateDepths (dfg: DataFlowGraph.t) = + let val N = Seq.length (#gates dfg) val depths = Array.array (N, ~1) fun gdep i = - Array.update (depths, i, 1 + Seq.reduce Int.max ~1 (Seq.map (fn j => Array.sub (depths, j)) (Seq.nth (#deps dg) i))) + Array.update (depths, i, 1 + Seq.reduce Int.max ~1 (Seq.map (fn j => Array.sub (depths, j)) (Seq.nth (#succs dfg) i))) (*case Array.sub (depths, i) of - ~1 => 1 + Seq.reduce Int.min ~1 (Seq.map gateDepth (Seq.nth (#deps dg) i)) + ~1 => 1 + Seq.reduce Int.min ~1 (Seq.map gateDepth (Seq.nth (#deps dfg) i)) | d => d*) in - List.foldl (fn (i, ()) => gdep i) () (revTopologicalSort dg DFS); depths + List.foldl (fn (i, ()) => gdep i) () (revTopologicalSort dfg DFS); depths end fun sortList (lt: 'a * 'a -> bool) (xs: 'a list) = @@ -62,14 +62,14 @@ struct end (* Choose in reverse topological order, sorted by easiest qubit to finish *) - fun scheduler3 dg = - let val dgt = DepGraphUtil.transpose dg - val depths = computeGateDepths dg - val gib = DepGraphUtil.gateIsBranching dg + fun scheduler3 dfg = + let val dfgt = DataFlowGraphUtil.transpose dfg + val depths = computeGateDepths dfg + val gib = DataFlowGraphUtil.gateIsBranching dfg fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso not (gib a) andalso gib b) fun push (new, old) = DFS (sortList lt new, old) - val xs = revTopologicalSort dgt push - val N = Seq.length (#gates dg) + val xs = revTopologicalSort dfgt push + val N = Seq.length (#gates dfg) val ord = Array.array (N, ~1) fun writeOrd i xs = case xs of nil => () | x :: xs' => (Array.update (ord, x, i); writeOrd (i + 1) xs') val _ = writeOrd 0 xs @@ -87,72 +87,72 @@ struct end end - fun gateDepth i dg = + fun gateDepth i dfg = case !gateDepths of - NONE => let val gd = computeGateDepths dg in + NONE => let val gd = computeGateDepths dfg in print "recompouting gate depths"; gateDepths := SOME gd; Array.sub (gd, i) end | SOME gd => Array.sub (gd, i) - fun pickLeastDepth best_idx best_depth i gates dg = + fun pickLeastDepth best_idx best_depth i gates dfg = if i = Seq.length gates then best_idx else let val cur_idx = Seq.nth gates i - val cur_depth = gateDepth cur_idx dg in + val cur_depth = gateDepth cur_idx dfg in if cur_depth < best_depth then - pickLeastDepth cur_idx cur_depth (i + 1) gates dg + pickLeastDepth cur_idx cur_depth (i + 1) gates dfg else - pickLeastDepth best_idx best_depth (i + 1) gates dg + pickLeastDepth best_idx best_depth (i + 1) gates dfg end - fun pickGreatestDepth best_idx best_depth i gates dg = + fun pickGreatestDepth best_idx best_depth i gates dfg = if i = Seq.length gates then best_idx else let val cur_idx = Seq.nth gates i - val cur_depth = gateDepth cur_idx dg in + val cur_depth = gateDepth cur_idx dfg in if cur_depth > best_depth then - pickGreatestDepth cur_idx cur_depth (i + 1) gates dg + pickGreatestDepth cur_idx cur_depth (i + 1) gates dfg else - pickGreatestDepth best_idx best_depth (i + 1) gates dg + pickGreatestDepth best_idx best_depth (i + 1) gates dfg end (* From a frontier, select which gate to apply next *) - fun scheduler dg = - (gateDepths := SOME (computeGateDepths dg); + fun scheduler dfg = + (gateDepths := SOME (computeGateDepths dfg); fn gates => let val g0 = Seq.nth gates 0 in - pickLeastDepth g0 (gateDepth g0 dg) 1 gates dg + pickLeastDepth g0 (gateDepth g0 dfg) 1 gates dfg end) (* Select gate with greatest number of descendants *) - fun scheduler4 dg = - (gateDepths := SOME (computeGateDepths dg); + fun scheduler4 dfg = + (gateDepths := SOME (computeGateDepths dfg); fn gates => let val g0 = Seq.nth gates 0 in - pickGreatestDepth g0 (gateDepth g0 dg) 1 gates dg + pickGreatestDepth g0 (gateDepth g0 dfg) 1 gates dfg end) structure G = Gate (structure B = BasisIdxUnlimited structure C = Complex64) (* Hybrid of scheduler2 (avoid branching on unbranched qubits) and also scheduler3 (choose in reverse topological order, sorted by easiest qubit to finish) *) - fun scheduler5 (dg: DepGraph.t) = - let val gates = Seq.map G.fromGateDefn (#gates dg) + fun scheduler5 (dfg: DataFlowGraph.t) = + let val gates = Seq.map G.fromGateDefn (#gates dfg) fun touches i = #touches (Seq.nth gates i) fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 - val dgt = DepGraphUtil.transpose dg - val depths = computeGateDepths dg + val dfgt = DataFlowGraphUtil.transpose dfg + val depths = computeGateDepths dfg fun lt (a, b) = Array.sub (depths, a) < Array.sub (depths, b) orelse (Array.sub (depths, a) = Array.sub (depths, b) andalso branches a < branches b) fun push (new, old) = DFS (sortList lt new, old) - val xs = revTopologicalSort dgt push - val N = Seq.length (#gates dg) + val xs = revTopologicalSort dfgt push + val N = Seq.length (#gates dfg) val ord = Array.array (N, ~1) fun writeOrd i xs = case xs of nil => () | x :: xs' => (Array.update (ord, x, i); writeOrd (i + 1) xs') val _ = writeOrd 0 xs - val touched = Array.array (#numQubits dg, false) + val touched = Array.array (#numQubits dfg, false) fun touch i = Array.update (touched, i, true) fun touchAll gidx = let val ts = touches gidx in List.tabulate (Seq.length ts, fn i => touch (Seq.nth ts i)); () end fun newTouches i = @@ -182,9 +182,9 @@ struct end (* Avoids branching on unbranched qubits *) - fun scheduler2 (dg: DepGraph.t) = - let val touched = Array.array (#numQubits dg, false) - val gates = Seq.map G.fromGateDefn (#gates dg) + fun scheduler2 (dfg: DataFlowGraph.t) = + let val touched = Array.array (#numQubits dfg, false) + val gates = Seq.map G.fromGateDefn (#gates dfg) fun touches i = #touches (Seq.nth gates i) fun branches i = case #action (Seq.nth gates i) of G.NonBranching _ => 0 | G.MaybeBranching _ => 1 | G.Branching _ => 2 fun touch i = Array.update (touched, i, true) @@ -216,7 +216,7 @@ struct val seed = Random.rand (50, 14125) - fun schedulerRandom seedNum dg = + fun schedulerRandom seedNum dfg = (*let val seed = Random.rand (seedNum, seedNum * seedNum) in*) fn gates => let val r = Random.randRange (0, Seq.length gates - 1) seed in ((*print ("Randomly chose " ^ Int.toString r ^ " from range [0, " ^ Int.toString (Seq.length gates) ^ ")\n");*) diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml b/feynsum-sml/src/common/schedulers/GreedyBranching.sml similarity index 75% rename from feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml rename to feynsum-sml/src/common/schedulers/GreedyBranching.sml index 0a6d360..c8386b5 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyBranching.sml +++ b/feynsum-sml/src/common/schedulers/GreedyBranching.sml @@ -1,6 +1,6 @@ -structure DepGraphSchedulerGreedyBranching: +structure GreedyBranchingScheduler: sig - val scheduler: DepGraphScheduler.t + val scheduler: Scheduler.t end = struct @@ -17,7 +17,7 @@ struct (* From a frontier, select which gate to apply next *) fun scheduler dg = - let val branching = DepGraphUtil.gateIsBranching dg in + let val branching = DataFlowGraphUtil.gateIsBranching dg in fn gates => pickBranching 0 branching gates end end diff --git a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml b/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml similarity index 75% rename from feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml rename to feynsum-sml/src/common/schedulers/GreedyNonBranching.sml index 3de8f1f..bc2331b 100644 --- a/feynsum-sml/src/common/DepGraphSchedulerGreedyNonBranching.sml +++ b/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml @@ -1,14 +1,14 @@ -functor DepGraphSchedulerGreedyNonBranching +functor GreedyNonBranchingScheduler (val maxBranchingStride: int val disableFusion: bool): sig - val scheduler: DepGraphScheduler.t + val scheduler: Scheduler.t end = struct type gate_idx = int type args = - { depGraph: DepGraph.t + { depGraph: DataFlowGraph.t , gateIsBranching: gate_idx -> bool } @@ -23,7 +23,7 @@ struct (* From a frontier, select which gate to apply next *) fun scheduler dg = - let val branching = DepGraphUtil.gateIsBranching dg in + let val branching = DataFlowGraphUtil.gateIsBranching dg in fn gates => pickNonBranching 0 branching gates end end diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index cd19c08..d4af809 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -60,22 +60,22 @@ local ExpandState.sml - GateScheduler.sml - GateSchedulerNaive.sml - GateSchedulerGreedyNonBranching.sml - GateSchedulerGreedyBranching.sml - GateSchedulerGreedyFinishQubit.sml + (*GateScheduler.sml*) + (*GateSchedulerNaive.sml*) + (*GateSchedulerGreedyNonBranching.sml*) + (*GateSchedulerGreedyBranching.sml*) + (*GateSchedulerGreedyFinishQubit.sml*) Fingerprint.sml - DepGraph.sml - DepGraphUtil.sml - DepGraphScheduler.sml - DepGraphSchedulerGreedyBranching.sml - DepGraphSchedulerGreedyNonBranching.sml - DepGraphSchedulerGreedyFinishQubit.sml - DepGraphDynScheduler.sml - GateSchedulerOrder.sml + DataFlowGraph.sml + DataFlowGraphUtil.sml + DataFlowGraphScheduler.sml + schedulers/GreedyBranching.sml + schedulers/GreedyNonBranching.sml + schedulers/FinishQubit.sml + DataFlowGraphDynScheduler.sml + (*GateSchedulerOrder.sml*) in structure HashTable structure ApplyUntilFailure @@ -124,25 +124,25 @@ in functor ExpandState - structure GateScheduler - structure GateSchedulerNaive - functor GateSchedulerGreedyNonBranching - structure GateSchedulerGreedyBranching - functor GateSchedulerGreedyFinishQubit + (*structure GateScheduler*) + (*structure GateSchedulerNaive*) + (*functor GateSchedulerGreedyNonBranching*) + (*structure GateSchedulerGreedyBranching*) + (*functor GateSchedulerGreedyFinishQubit*) functor RedBlackMapFn functor RedBlackSetFn functor Fingerprint - structure DepGraph - structure DepGraphUtil - structure DepGraphScheduler - structure DepGraphSchedulerGreedyBranching - functor DepGraphSchedulerGreedyNonBranching - functor DepGraphSchedulerGreedyFinishQubit - structure GateSchedulerOrder - signature DEP_GRAPH_DYN_SCHEDULER + structure DataFlowGraph + structure DataFlowGraphUtil + structure Scheduler + structure GreedyBranchingScheduler + functor GreedyNonBranchingScheduler + functor FinishQubitScheduler + (*structure GateSchedulerOrder*) + signature DATA_FLOW_GRAPH_DYN_SCHEDULER functor DynSchedFinishQubitWrapper functor DynSchedInterference functor DynSchedNaive diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index e9ea4dd..fbf0ef7 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -58,28 +58,24 @@ fun parseQasm () = val simpleCirc = SMLQasmSimpleCircuit.fromAst ast in - Circuit.fromSMLQasmSimpleCircuit simpleCirc + DataFlowGraph.fromQasm (Circuit.fromSMLQasmSimpleCircuit simpleCirc) end -val (circuit, depGraph) = (*(circuit, optDepGraph)*) +val dfg = case inputName of "" => Util.die ("missing: -input FILE.qasm") - | _ => if String.isSuffix ".qasm" inputName then - raise Fail ".qasm no longer supported, use .json dependency graph" - (*parseQasm (), NONE*) + parseQasm () + else if String.isSuffix ".json" inputName then + DataFlowGraph.fromJSONFile inputName else - let val dg = DepGraph.fromFile inputName - val circuit = {numQubits = #numQubits dg, gates = #gates dg} - in - (circuit, dg) - end + raise Fail "Unknown file suffix, use .qasm or .json" val _ = print ("-------------------------------\n") -val _ = print ("gates " ^ Int.toString (Circuit.numGates circuit) ^ "\n") -val _ = print ("qubits " ^ Int.toString (Circuit.numQubits circuit) ^ "\n") +val _ = print ("gates " ^ Int.toString (Seq.length (#gates dfg)) ^ "\n") +val _ = print ("qubits " ^ Int.toString (#numQubits dfg) ^ "\n") val showCircuit = CLA.parseFlag "show-circuit" val _ = print ("show-circuit? " ^ (if showCircuit then "yes" else "no") ^ "\n") @@ -89,7 +85,7 @@ val _ = else print ("=========================================================\n" - ^ Circuit.toString circuit + ^ DataFlowGraph.toString dfg ^ "=========================================================\n") (* ======================================================================== @@ -99,7 +95,7 @@ val _ = val disableFusion = CLA.parseFlag "scheduler-disable-fusion" val maxBranchingStride = CLA.parseInt "scheduler-max-branching-stride" 2 -structure GNB = +(*structure GNB = GateSchedulerGreedyNonBranching (val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) @@ -107,15 +103,15 @@ structure GNB = structure GFQ = GateSchedulerGreedyFinishQubit (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) + val disableFusion = disableFusion)*) structure DGNB = - DepGraphSchedulerGreedyNonBranching + GreedyNonBranchingScheduler (val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) structure DGFQ = - DepGraphSchedulerGreedyFinishQubit + FinishQubitScheduler (val maxBranchingStride = maxBranchingStride val disableFusion = disableFusion) @@ -139,75 +135,76 @@ fun print_sched_info () = type gate_idx = int type Schedule = gate_idx Seq.t -fun dep_graph_to_schedule (dg: DepGraph.t) (sched: DepGraphScheduler.t) = +(*fun dep_graph_to_schedule (dg: DataFlowGraph.t) (sched: Scheduler.t) = let val choose = sched dg - val st = DepGraphUtil.initState dg + val st = DataFlowGraphUtil.initState dg fun loadNext (acc: gate_idx list) = - let val frntr = DepGraphUtil.frontier st in + let val frntr = DataFlowGraphUtil.frontier st in if Seq.length frntr = 0 then Seq.rev (Seq.fromList acc) else (let val next = choose frntr in - DepGraphUtil.visit dg next st; + DataFlowGraphUtil.visit dg next st; loadNext (next :: acc) end) end in GateSchedulerOrder.mkScheduler (Seq.map Seq.singleton (loadNext nil)) end +*) val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride (*fun greedybranching () = - case optDepGraph of + case optDataFlowGraph of NONE => GateSchedulerGreedyBranching.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DepGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + | SOME dg => GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DataFlowGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedynonbranching () = - case optDepGraph of + case optDataFlowGraph of NONE => GNB.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGNB.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + | SOME dg => GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGNB.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit () = - case optDepGraph of + case optDataFlowGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit2 () = - case optDepGraph of + case optDataFlowGraph of NONE => GFQ.scheduler | SOME dg => (*GateSchedulerOrder.mkScheduler (Seq.fromList (DGFQ.ordered { depGraph = dg, gateIsBranching = gate_branching dg }))*) (* dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGFQ.scheduler2 *) - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler2 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler2 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit3 () = - case optDepGraph of + case optDataFlowGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler3 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler3 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit4 () = - case optDepGraph of + case optDataFlowGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler4 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler4 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun greedyfinishqubit5 () = - case optDepGraph of + case optDataFlowGraph of NONE => GFQ.scheduler | SOME dg => - GateSchedulerOrder.mkScheduler (DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler5 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') + GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler5 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') fun randomsched (samples: int) = - case optDepGraph of + case optDataFlowGraph of NONE => raise Fail "Need dep graph for random scheduler" | SOME dg => let val ags = { depGraph = dg, gateIsBranching = gate_branching dg } - (*val scheds = Seq.fromList (List.tabulate (samples, fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride')) *) - val scheds = Seq.tabulate (fn i => DepGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) disableFusion maxBranchingStride') samples - val chosen = DepGraphUtil.chooseSchedule scheds (gate_branching dg) + (*val scheds = Seq.fromList (List.tabulate (samples, fn i => DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride')) *) + val scheds = Seq.tabulate (fn i => DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) disableFusion maxBranchingStride') samples + val chosen = DataFlowGraphUtil.chooseSchedule scheds (gate_branching dg) in GateSchedulerOrder.mkScheduler chosen end*) @@ -301,7 +298,7 @@ structure M_U_64 = val denseThreshold = denseThreshold val pullThreshold = pullThreshold) -val basisIdx64Okay = Circuit.numQubits circuit <= 62 +val basisIdx64Okay = #numQubits dfg <= 62 val main = case precision of @@ -314,4 +311,4 @@ val main = (* ======================================================================== *) -val _ = main (inputName, depGraph) +val _ = main (inputName, dfg) From f750c78dc37d376951ae43b61a27771703f9f4d9 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Tue, 19 Dec 2023 16:49:30 -0500 Subject: [PATCH 13/15] More refactoring --- feynsum-sml/src/FullSimBFS.sml | 14 +- feynsum-sml/src/common/ApplyUntilFailure.sml | 40 ----- ...hDynScheduler.sml => DynGateScheduler.sml} | 10 +- ...owGraphScheduler.sml => GateScheduler.sml} | 2 +- feynsum-sml/src/common/HashSet.sml | 142 ++++++++++++++++++ .../src/common/{ => basis}/BASIS_IDX.sml | 0 .../src/common/{ => basis}/BasisIdx64.sml | 0 .../common/{ => basis}/BasisIdxUnlimited.sml | 0 .../src/common/{ => complex}/Complex.sml | 0 .../src/common/{ => complex}/Complex32.sml | 0 .../src/common/{ => complex}/Complex64.sml | 0 .../src/common/{ => complex}/MkComplex.sml | 0 .../src/common/schedulers/FinishQubit.sml | 12 +- .../src/common/schedulers/GreedyBranching.sml | 2 +- .../common/schedulers/GreedyNonBranching.sml | 2 +- feynsum-sml/src/common/sources.mlb | 55 +++---- .../src/common/{ => state}/DENSE_STATE.sml | 0 .../src/common/{ => state}/DenseState.sml | 0 .../src/common/{ => state}/ExpandState.sml | 0 .../src/common/{ => state}/HYBRID_STATE.sml | 0 .../src/common/{ => state}/HybridState.sml | 0 .../common/{ => state}/SPARSE_STATE_TABLE.sml | 0 .../src/common/{ => state}/SparseState.sml | 0 .../common/{ => state}/SparseStateTable.sml | 0 .../SparseStateTableLockedSlots.sml | 0 feynsum-sml/src/main.sml | 105 ------------- 26 files changed, 184 insertions(+), 200 deletions(-) delete mode 100644 feynsum-sml/src/common/ApplyUntilFailure.sml rename feynsum-sml/src/common/{DataFlowGraphDynScheduler.sml => DynGateScheduler.sml} (96%) rename feynsum-sml/src/common/{DataFlowGraphScheduler.sml => GateScheduler.sml} (93%) create mode 100644 feynsum-sml/src/common/HashSet.sml rename feynsum-sml/src/common/{ => basis}/BASIS_IDX.sml (100%) rename feynsum-sml/src/common/{ => basis}/BasisIdx64.sml (100%) rename feynsum-sml/src/common/{ => basis}/BasisIdxUnlimited.sml (100%) rename feynsum-sml/src/common/{ => complex}/Complex.sml (100%) rename feynsum-sml/src/common/{ => complex}/Complex32.sml (100%) rename feynsum-sml/src/common/{ => complex}/Complex64.sml (100%) rename feynsum-sml/src/common/{ => complex}/MkComplex.sml (100%) rename feynsum-sml/src/common/{ => state}/DENSE_STATE.sml (100%) rename feynsum-sml/src/common/{ => state}/DenseState.sml (100%) rename feynsum-sml/src/common/{ => state}/ExpandState.sml (100%) rename feynsum-sml/src/common/{ => state}/HYBRID_STATE.sml (100%) rename feynsum-sml/src/common/{ => state}/HybridState.sml (100%) rename feynsum-sml/src/common/{ => state}/SPARSE_STATE_TABLE.sml (100%) rename feynsum-sml/src/common/{ => state}/SparseState.sml (100%) rename feynsum-sml/src/common/{ => state}/SparseStateTable.sml (100%) rename feynsum-sml/src/common/{ => state}/SparseStateTableLockedSlots.sml (100%) diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index c0ad597..2c5503d 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -107,16 +107,16 @@ struct | "interference" => DGI.choose | _ => raise Fail ("Unknown scheduler '" ^ gateScheduler ^ "'\n") - fun run depgraph (*{numQubits, gates}*) = + fun run dfg (*{numQubits, gates}*) = let - val gates = Seq.map G.fromGateDefn (#gates depgraph) - val numQubits = #numQubits depgraph + val gates = Seq.map G.fromGateDefn (#gates dfg) + val numQubits = #numQubits dfg fun gate i = Seq.nth gates i val depth = Seq.length gates - val dgstate = DataFlowGraphUtil.initState depgraph + val dgstate = DataFlowGraphUtil.initState dfg val pickNextGate = - let val f = gateSched depgraph in + let val f = gateSched dfg in fn (s, g) => f (s, g) end @@ -175,8 +175,8 @@ struct fun runloop () = DataFlowGraphUtil.scheduleWithOracle' - (* dependency graph *) - depgraph + (* data flow graph *) + dfg (* gate is branching *) (fn i => G.expectBranching (Seq.nth gates i)) diff --git a/feynsum-sml/src/common/ApplyUntilFailure.sml b/feynsum-sml/src/common/ApplyUntilFailure.sml deleted file mode 100644 index d12c98f..0000000 --- a/feynsum-sml/src/common/ApplyUntilFailure.sml +++ /dev/null @@ -1,40 +0,0 @@ -structure ApplyUntilFailure: -sig - datatype 'a element_result = Success | Failure of 'a - - val doPrefix: {grain: int, acceleration: real} - -> int * int - -> (int -> 'a element_result) - -> {numApplied: int, failed: 'a Seq.t} -end = -struct - - datatype 'a element_result = Success | Failure of 'a - - fun doPrefix {grain, acceleration} (start, stop) doElem = - let - fun loop numApplied (lo, hi) = - if lo >= hi then - {numApplied = numApplied, failed = Seq.empty ()} - else - let - val resultHere = SeqBasis.tabFilter grain (lo, hi) (fn i => - case doElem i of - Success => NONE - | Failure x => SOME x) - - val widthHere = hi - lo - val numApplied' = numApplied + widthHere - - val desiredWidth = Real.ceil (acceleration * Real.fromInt widthHere) - val lo' = hi - val hi' = Int.min (hi + desiredWidth, stop) - in - if Array.length resultHere = 0 then loop numApplied' (lo', hi') - else {numApplied = numApplied', failed = ArraySlice.full resultHere} - end - in - loop 0 (start, stop) - end - -end diff --git a/feynsum-sml/src/common/DataFlowGraphDynScheduler.sml b/feynsum-sml/src/common/DynGateScheduler.sml similarity index 96% rename from feynsum-sml/src/common/DataFlowGraphDynScheduler.sml rename to feynsum-sml/src/common/DynGateScheduler.sml index de844c0..d63f4cb 100644 --- a/feynsum-sml/src/common/DataFlowGraphDynScheduler.sml +++ b/feynsum-sml/src/common/DynGateScheduler.sml @@ -1,4 +1,4 @@ -signature DATA_FLOW_GRAPH_DYN_SCHEDULER = +signature DYN_GATE_SCHEDULER = sig structure B: BASIS_IDX structure C: COMPLEX @@ -21,7 +21,7 @@ functor DynSchedFinishQubitWrapper sharing C = HS.C val maxBranchingStride: int val disableFusion: bool - ): DATA_FLOW_GRAPH_DYN_SCHEDULER = + ): DYN_GATE_SCHEDULER = struct structure B = B structure C = C @@ -54,7 +54,7 @@ functor DynSchedNaive sharing C = HS.C val maxBranchingStride: int val disableFusion: bool - ): DATA_FLOW_GRAPH_DYN_SCHEDULER = + ): DYN_GATE_SCHEDULER = struct structure B = B structure C = C @@ -77,7 +77,7 @@ functor DynSchedInterference sharing C = HS.C val maxBranchingStride: int val disableFusion: bool - ): DATA_FLOW_GRAPH_DYN_SCHEDULER = + ): DYN_GATE_SCHEDULER = struct structure B = B structure C = C @@ -176,7 +176,7 @@ end val blockSize: int val maxload: real val denseThreshold: real - val pullThreshold: real): DATA_FLOW_GRAPH_DYN_SCHEDULER = + val pullThreshold: real): DYN_GATE_SCHEDULER = struct type gate_idx = int structure Expander = diff --git a/feynsum-sml/src/common/DataFlowGraphScheduler.sml b/feynsum-sml/src/common/GateScheduler.sml similarity index 93% rename from feynsum-sml/src/common/DataFlowGraphScheduler.sml rename to feynsum-sml/src/common/GateScheduler.sml index 1712ef1..b07bc03 100644 --- a/feynsum-sml/src/common/DataFlowGraphScheduler.sml +++ b/feynsum-sml/src/common/GateScheduler.sml @@ -1,4 +1,4 @@ -structure Scheduler = +structure GateScheduler = struct type gate_idx = int diff --git a/feynsum-sml/src/common/HashSet.sml b/feynsum-sml/src/common/HashSet.sml new file mode 100644 index 0000000..8a30e01 --- /dev/null +++ b/feynsum-sml/src/common/HashSet.sml @@ -0,0 +1,142 @@ +structure HashSet :> +sig + type 'a t + type 'a table = 'a t + + exception Full + + val make: {hash: 'a -> int, eq: 'a * 'a -> bool, capacity: int, maxload: real} -> 'a table + val size: 'a table -> int + val capacity: 'a table -> int + val resize: 'a table -> 'a table + val increaseCapacityTo: int -> 'a table -> 'a table + val insertIfNotPresent: 'a table -> 'a -> bool + val lookup: 'a table -> 'a -> bool + val compact: 'a table -> 'a Seq.t + + (* Unsafe because underlying array is shared. If the table is mutated, + * then the Seq would not appear to be immutable. + * + * Could also imagine a function `freezeViewContents` which marks the + * table as immutable (preventing further inserts). That would be a safer + * version of this function. + *) + val unsafeViewContents: 'a table -> 'a option Seq.t +end = +struct + + datatype 'a t = + T of + { data: 'a option array + , hash: 'a -> int + , eq: 'a * 'a -> bool + , maxload: real + } + + exception Full + + type 'a table = 'a t + + + fun make {hash, eq, capacity, maxload} = + if capacity = 0 then + raise Fail "HashTable.make: capacity 0" + else + let val data = SeqBasis.tabulate 5000 (0, capacity) (fn _ => NONE) + in T {data = data, hash = hash, eq = eq, maxload = maxload} + end + + + fun unsafeViewContents (T {data, ...}) = ArraySlice.full data + + + fun bcas (arr, i) (old, new) = + MLton.eq (old, Concurrency.casArray (arr, i) (old, new)) + + + fun size (T {data, ...}) = + SeqBasis.reduce 10000 op+ 0 (0, Array.length data) (fn i => + if Option.isSome (Array.sub (data, i)) then 1 else 0) + + + fun capacity (T {data, ...}) = Array.length data + + + fun insertIfNotPresent' (input as T {data, hash, eq, maxload}) x force = + let + val n = Array.length data + val tolerance = 20 * Real.ceil (1.0 / (1.0 - maxload)) + + fun loop i probes = + if not force andalso probes >= tolerance then + raise Full + else if i >= n then + loop 0 probes + else + let + val current = Array.sub (data, i) + in + case current of + SOME y => + if eq (x, y) then false else loop (i + 1) (probes + 1) + | NONE => + if bcas (data, i) (current, SOME x) then + (* (Concurrency.faa sz 1; true) *) + true + else + loop i probes + end + + val start = (hash x) mod (Array.length data) + in + loop start 0 + end + + + fun insertIfNotPresent s x = + insertIfNotPresent' s x false + + + fun lookup (T {data, hash, eq, ...}) x = + let + val n = Array.length data + val start = (hash x) mod n + + fun loop i = + case Array.sub (data, i) of + NONE => false + | SOME y => eq (x, y) orelse loopCheck (i + 1) + + and loopCheck i = + if i >= n then loopCheck 0 else (i <> start andalso loop i) + in + n <> 0 andalso loop start + end + + + fun increaseCapacityTo newcap (input as T {data, hash, eq, maxload}) = + if newcap < capacity input then + raise Fail "HashTable.increaseCapacityTo: new cap is too small" + else + let + val new = make + {hash = hash, eq = eq, capacity = newcap, maxload = maxload} + in + ForkJoin.parfor 1000 (0, Array.length data) (fn i => + case Array.sub (data, i) of + NONE => () + | SOME x => (insertIfNotPresent' new x true; ())); + + new + end + + + fun resize x = + increaseCapacityTo (2 * capacity x) x + + + fun compact (T {data, ...}) = + ArraySlice.full (SeqBasis.tabFilter 2000 (0, Array.length data) (fn i => + Array.sub (data, i))) + +end diff --git a/feynsum-sml/src/common/BASIS_IDX.sml b/feynsum-sml/src/common/basis/BASIS_IDX.sml similarity index 100% rename from feynsum-sml/src/common/BASIS_IDX.sml rename to feynsum-sml/src/common/basis/BASIS_IDX.sml diff --git a/feynsum-sml/src/common/BasisIdx64.sml b/feynsum-sml/src/common/basis/BasisIdx64.sml similarity index 100% rename from feynsum-sml/src/common/BasisIdx64.sml rename to feynsum-sml/src/common/basis/BasisIdx64.sml diff --git a/feynsum-sml/src/common/BasisIdxUnlimited.sml b/feynsum-sml/src/common/basis/BasisIdxUnlimited.sml similarity index 100% rename from feynsum-sml/src/common/BasisIdxUnlimited.sml rename to feynsum-sml/src/common/basis/BasisIdxUnlimited.sml diff --git a/feynsum-sml/src/common/Complex.sml b/feynsum-sml/src/common/complex/Complex.sml similarity index 100% rename from feynsum-sml/src/common/Complex.sml rename to feynsum-sml/src/common/complex/Complex.sml diff --git a/feynsum-sml/src/common/Complex32.sml b/feynsum-sml/src/common/complex/Complex32.sml similarity index 100% rename from feynsum-sml/src/common/Complex32.sml rename to feynsum-sml/src/common/complex/Complex32.sml diff --git a/feynsum-sml/src/common/Complex64.sml b/feynsum-sml/src/common/complex/Complex64.sml similarity index 100% rename from feynsum-sml/src/common/Complex64.sml rename to feynsum-sml/src/common/complex/Complex64.sml diff --git a/feynsum-sml/src/common/MkComplex.sml b/feynsum-sml/src/common/complex/MkComplex.sml similarity index 100% rename from feynsum-sml/src/common/MkComplex.sml rename to feynsum-sml/src/common/complex/MkComplex.sml diff --git a/feynsum-sml/src/common/schedulers/FinishQubit.sml b/feynsum-sml/src/common/schedulers/FinishQubit.sml index a4ae826..5b27449 100644 --- a/feynsum-sml/src/common/schedulers/FinishQubit.sml +++ b/feynsum-sml/src/common/schedulers/FinishQubit.sml @@ -1,12 +1,12 @@ functor FinishQubitScheduler (val maxBranchingStride: int val disableFusion: bool): sig - val scheduler: Scheduler.t - val scheduler2: Scheduler.t - val scheduler3: Scheduler.t - val scheduler4: Scheduler.t - val scheduler5: Scheduler.t - val schedulerRandom: int -> Scheduler.t + val scheduler: GateScheduler.t + val scheduler2: GateScheduler.t + val scheduler3: GateScheduler.t + val scheduler4: GateScheduler.t + val scheduler5: GateScheduler.t + val schedulerRandom: int -> GateScheduler.t end = struct diff --git a/feynsum-sml/src/common/schedulers/GreedyBranching.sml b/feynsum-sml/src/common/schedulers/GreedyBranching.sml index c8386b5..3d544a6 100644 --- a/feynsum-sml/src/common/schedulers/GreedyBranching.sml +++ b/feynsum-sml/src/common/schedulers/GreedyBranching.sml @@ -1,6 +1,6 @@ structure GreedyBranchingScheduler: sig - val scheduler: Scheduler.t + val scheduler: GateScheduler.t end = struct diff --git a/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml b/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml index bc2331b..1206fbd 100644 --- a/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml +++ b/feynsum-sml/src/common/schedulers/GreedyNonBranching.sml @@ -1,7 +1,7 @@ functor GreedyNonBranchingScheduler (val maxBranchingStride: int val disableFusion: bool): sig - val scheduler: Scheduler.t + val scheduler: GateScheduler.t end = struct diff --git a/feynsum-sml/src/common/sources.mlb b/feynsum-sml/src/common/sources.mlb index d4af809..99d75a8 100644 --- a/feynsum-sml/src/common/sources.mlb +++ b/feynsum-sml/src/common/sources.mlb @@ -24,22 +24,22 @@ local $(SML_LIB)/smlnj-lib/Util/smlnj-lib.mlb HashTable.sml - ApplyUntilFailure.sml + HashSet.sml Rat.sml QubitIdx.sml Constants.sml - MkComplex.sml - Complex32.sml - Complex64.sml - Complex.sml + complex/MkComplex.sml + complex/Complex32.sml + complex/Complex64.sml + complex/Complex.sml - BASIS_IDX.sml + basis/BASIS_IDX.sml ann "allowExtendedTextConsts true" in - BasisIdx64.sml - BasisIdxUnlimited.sml + basis/BasisIdx64.sml + basis/BasisIdxUnlimited.sml end ann "allowExtendedTextConsts true" in @@ -48,37 +48,30 @@ local Circuit.sml end - SPARSE_STATE_TABLE.sml - SparseStateTable.sml - SparseStateTableLockedSlots.sml + state/SPARSE_STATE_TABLE.sml + state/SparseStateTable.sml + state/SparseStateTableLockedSlots.sml - DENSE_STATE.sml - DenseState.sml + state/DENSE_STATE.sml + state/DenseState.sml - HYBRID_STATE.sml - HybridState.sml + state/HYBRID_STATE.sml + state/HybridState.sml - ExpandState.sml - - (*GateScheduler.sml*) - (*GateSchedulerNaive.sml*) - (*GateSchedulerGreedyNonBranching.sml*) - (*GateSchedulerGreedyBranching.sml*) - (*GateSchedulerGreedyFinishQubit.sml*) + state/ExpandState.sml Fingerprint.sml DataFlowGraph.sml DataFlowGraphUtil.sml - DataFlowGraphScheduler.sml + GateScheduler.sml schedulers/GreedyBranching.sml schedulers/GreedyNonBranching.sml schedulers/FinishQubit.sml - DataFlowGraphDynScheduler.sml - (*GateSchedulerOrder.sml*) + DynGateScheduler.sml in structure HashTable - structure ApplyUntilFailure + structure HashSet structure Rat @@ -124,12 +117,6 @@ in functor ExpandState - (*structure GateScheduler*) - (*structure GateSchedulerNaive*) - (*functor GateSchedulerGreedyNonBranching*) - (*structure GateSchedulerGreedyBranching*) - (*functor GateSchedulerGreedyFinishQubit*) - functor RedBlackMapFn functor RedBlackSetFn @@ -137,12 +124,12 @@ in structure DataFlowGraph structure DataFlowGraphUtil - structure Scheduler + structure GateScheduler structure GreedyBranchingScheduler functor GreedyNonBranchingScheduler functor FinishQubitScheduler (*structure GateSchedulerOrder*) - signature DATA_FLOW_GRAPH_DYN_SCHEDULER + signature DYN_GATE_SCHEDULER functor DynSchedFinishQubitWrapper functor DynSchedInterference functor DynSchedNaive diff --git a/feynsum-sml/src/common/DENSE_STATE.sml b/feynsum-sml/src/common/state/DENSE_STATE.sml similarity index 100% rename from feynsum-sml/src/common/DENSE_STATE.sml rename to feynsum-sml/src/common/state/DENSE_STATE.sml diff --git a/feynsum-sml/src/common/DenseState.sml b/feynsum-sml/src/common/state/DenseState.sml similarity index 100% rename from feynsum-sml/src/common/DenseState.sml rename to feynsum-sml/src/common/state/DenseState.sml diff --git a/feynsum-sml/src/common/ExpandState.sml b/feynsum-sml/src/common/state/ExpandState.sml similarity index 100% rename from feynsum-sml/src/common/ExpandState.sml rename to feynsum-sml/src/common/state/ExpandState.sml diff --git a/feynsum-sml/src/common/HYBRID_STATE.sml b/feynsum-sml/src/common/state/HYBRID_STATE.sml similarity index 100% rename from feynsum-sml/src/common/HYBRID_STATE.sml rename to feynsum-sml/src/common/state/HYBRID_STATE.sml diff --git a/feynsum-sml/src/common/HybridState.sml b/feynsum-sml/src/common/state/HybridState.sml similarity index 100% rename from feynsum-sml/src/common/HybridState.sml rename to feynsum-sml/src/common/state/HybridState.sml diff --git a/feynsum-sml/src/common/SPARSE_STATE_TABLE.sml b/feynsum-sml/src/common/state/SPARSE_STATE_TABLE.sml similarity index 100% rename from feynsum-sml/src/common/SPARSE_STATE_TABLE.sml rename to feynsum-sml/src/common/state/SPARSE_STATE_TABLE.sml diff --git a/feynsum-sml/src/common/SparseState.sml b/feynsum-sml/src/common/state/SparseState.sml similarity index 100% rename from feynsum-sml/src/common/SparseState.sml rename to feynsum-sml/src/common/state/SparseState.sml diff --git a/feynsum-sml/src/common/SparseStateTable.sml b/feynsum-sml/src/common/state/SparseStateTable.sml similarity index 100% rename from feynsum-sml/src/common/SparseStateTable.sml rename to feynsum-sml/src/common/state/SparseStateTable.sml diff --git a/feynsum-sml/src/common/SparseStateTableLockedSlots.sml b/feynsum-sml/src/common/state/SparseStateTableLockedSlots.sml similarity index 100% rename from feynsum-sml/src/common/SparseStateTableLockedSlots.sml rename to feynsum-sml/src/common/state/SparseStateTableLockedSlots.sml diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index fbf0ef7..d1c1c2d 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -135,113 +135,8 @@ fun print_sched_info () = type gate_idx = int type Schedule = gate_idx Seq.t -(*fun dep_graph_to_schedule (dg: DataFlowGraph.t) (sched: Scheduler.t) = - let val choose = sched dg - val st = DataFlowGraphUtil.initState dg - fun loadNext (acc: gate_idx list) = - let val frntr = DataFlowGraphUtil.frontier st in - if Seq.length frntr = 0 then - Seq.rev (Seq.fromList acc) - else - (let val next = choose frntr in - DataFlowGraphUtil.visit dg next st; - loadNext (next :: acc) - end) - end - in - GateSchedulerOrder.mkScheduler (Seq.map Seq.singleton (loadNext nil)) - end -*) - val maxBranchingStride' = if disableFusion then 1 else maxBranchingStride -(*fun greedybranching () = - case optDataFlowGraph of - NONE => GateSchedulerGreedyBranching.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DataFlowGraphSchedulerGreedyBranching.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun greedynonbranching () = - case optDataFlowGraph of - NONE => GNB.scheduler - | SOME dg => GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGNB.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun greedyfinishqubit () = - case optDataFlowGraph of - NONE => GFQ.scheduler - | SOME dg => - GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun greedyfinishqubit2 () = - case optDataFlowGraph of - NONE => GFQ.scheduler - | SOME dg => - (*GateSchedulerOrder.mkScheduler (Seq.fromList (DGFQ.ordered { depGraph = dg, gateIsBranching = gate_branching dg }))*) - (* dep_graph_to_schedule { depGraph = dg, gateIsBranching = gate_branching dg } DGFQ.scheduler2 *) - GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler2 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun greedyfinishqubit3 () = - case optDataFlowGraph of - NONE => GFQ.scheduler - | SOME dg => - GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler3 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun greedyfinishqubit4 () = - case optDataFlowGraph of - NONE => GFQ.scheduler - | SOME dg => - GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler4 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun greedyfinishqubit5 () = - case optDataFlowGraph of - NONE => GFQ.scheduler - | SOME dg => - GateSchedulerOrder.mkScheduler (DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.scheduler5 { depGraph = dg, gateIsBranching = gate_branching dg }) disableFusion maxBranchingStride') - -fun randomsched (samples: int) = - case optDataFlowGraph of - NONE => raise Fail "Need dep graph for random scheduler" - | SOME dg => - let val ags = { depGraph = dg, gateIsBranching = gate_branching dg } - (*val scheds = Seq.fromList (List.tabulate (samples, fn i => DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) maxBranchingStride')) *) - val scheds = Seq.tabulate (fn i => DataFlowGraphUtil.scheduleWithOracle dg (gate_branching dg) (DGFQ.schedulerRandom i ags) disableFusion maxBranchingStride') samples - val chosen = DataFlowGraphUtil.chooseSchedule scheds (gate_branching dg) - in - GateSchedulerOrder.mkScheduler chosen - end*) - -(*val gateScheduler = - case schedulerName of - "naive" => GateSchedulerNaive.scheduler - - | "greedy-branching" => (print_sched_info (); greedybranching ()) - | "gb" => (print_sched_info (); greedybranching ()) - - | "greedy-nonbranching" => (print_sched_info (); greedynonbranching ()) - | "gnb" => (print_sched_info (); greedynonbranching ()) - - | "greedy-finish-qubit" => (print_sched_info (); greedyfinishqubit ()) - | "gfq" => (print_sched_info (); greedyfinishqubit ()) - - | "greedy-finish-qubit2" => (print_sched_info (); greedyfinishqubit2 ()) - | "gfq2" => (print_sched_info (); greedyfinishqubit2 ()) - - | "greedy-finish-qubit3" => (print_sched_info (); greedyfinishqubit3 ()) - | "gfq3" => (print_sched_info (); greedyfinishqubit3 ()) - - | "greedy-finish-qubit4" => (print_sched_info (); greedyfinishqubit4 ()) - | "gfq4" => (print_sched_info (); greedyfinishqubit4 ()) - - | "greedy-finish-qubit5" => (print_sched_info (); greedyfinishqubit5 ()) - | "gfq5" => (print_sched_info (); greedyfinishqubit5 ()) - - | "random" => randomsched 50 - - | _ => - Util.die - ("unknown scheduler: " ^ schedulerName - ^ - "; valid options are: naive, greedy-branching (gb), greedy-nonbranching (gnb), greedy-finish-qubit (gfq)")*) - (* ======================================================================== * mains: 32-bit and 64-bit *) From f0bb51ae665847e1c54cc689854bc8d9862bdfbd Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Tue, 19 Dec 2023 17:23:31 -0500 Subject: [PATCH 14/15] Tweak hash set --- feynsum-sml/src/common/HashSet.sml | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/feynsum-sml/src/common/HashSet.sml b/feynsum-sml/src/common/HashSet.sml index 8a30e01..05c0546 100644 --- a/feynsum-sml/src/common/HashSet.sml +++ b/feynsum-sml/src/common/HashSet.sml @@ -10,7 +10,7 @@ sig val capacity: 'a table -> int val resize: 'a table -> 'a table val increaseCapacityTo: int -> 'a table -> 'a table - val insertIfNotPresent: 'a table -> 'a -> bool + val insert: 'a table -> 'a -> unit val lookup: 'a table -> 'a -> bool val compact: 'a table -> 'a Seq.t @@ -62,7 +62,7 @@ struct fun capacity (T {data, ...}) = Array.length data - fun insertIfNotPresent' (input as T {data, hash, eq, maxload}) x force = + fun insert' (T {data, hash, eq, maxload}) x force = let val n = Array.length data val tolerance = 20 * Real.ceil (1.0 / (1.0 - maxload)) @@ -73,16 +73,13 @@ struct else if i >= n then loop 0 probes else - let - val current = Array.sub (data, i) - in + let val current = Array.sub (data, i) in case current of SOME y => - if eq (x, y) then false else loop (i + 1) (probes + 1) + if eq (x, y) then () else loop (i + 1) (probes + 1) | NONE => - if bcas (data, i) (current, SOME x) then - (* (Concurrency.faa sz 1; true) *) - true + if bcas (data, i) (NONE, SOME x) then + () else loop i probes end @@ -93,8 +90,7 @@ struct end - fun insertIfNotPresent s x = - insertIfNotPresent' s x false + fun insert s x = insert' s x false fun lookup (T {data, hash, eq, ...}) x = @@ -125,7 +121,7 @@ struct ForkJoin.parfor 1000 (0, Array.length data) (fn i => case Array.sub (data, i) of NONE => () - | SOME x => (insertIfNotPresent' new x true; ())); + | SOME x => (insert' new x true; ())); new end From 7c718fee299e0313c9b90e17a2527b7a71e1cb82 Mon Sep 17 00:00:00 2001 From: Colin McDonald Date: Tue, 19 Dec 2023 17:31:04 -0500 Subject: [PATCH 15/15] Refactoring --- feynsum-sml/src/FullSimBFS.sml | 28 +++++++++++----------------- feynsum-sml/src/MkMain.sml | 17 ++++++++++++++--- feynsum-sml/src/main.sml | 9 --------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/feynsum-sml/src/FullSimBFS.sml b/feynsum-sml/src/FullSimBFS.sml index 2c5503d..ecaecbd 100644 --- a/feynsum-sml/src/FullSimBFS.sml +++ b/feynsum-sml/src/FullSimBFS.sml @@ -1,10 +1,10 @@ functor FullSimBFS (structure B: BASIS_IDX structure C: COMPLEX - structure SST: SPARSE_STATE_TABLE + structure HS: HYBRID_STATE structure G: GATE - sharing B = SST.B = G.B - sharing C = SST.C = G.C + sharing B = HS.B = G.B + sharing C = HS.C = G.C val disableFusion: bool val maxBranchingStride: int @@ -20,12 +20,6 @@ sig end = struct - structure DS = DenseState (structure C = C structure B = B) - structure HS = HybridState (structure C = C - structure B = B - structure DS = DS - structure SST = SST) - structure Expander = ExpandState (structure B = B @@ -65,9 +59,9 @@ struct let val ss = case s of - HS.Sparse sst => SST.unsafeViewContents sst - | HS.Dense ds => DS.unsafeViewContents ds - | HS.DenseKnownNonZeroSize (ds, _) => DS.unsafeViewContents ds + HS.Sparse sst => HS.SST.unsafeViewContents sst + | HS.Dense ds => HS.DS.unsafeViewContents ds + | HS.DenseKnownNonZeroSize (ds, _) => HS.DS.unsafeViewContents ds in Util.for (0, DelayedSeq.length ss) (fn i => case DelayedSeq.nth ss i of @@ -163,14 +157,14 @@ struct fun getNumZeros state = case state of - HS.Sparse sst => SST.zeroSize sst + HS.Sparse sst => HS.SST.zeroSize sst | HS.Dense ds => 0 (*raise Fail "Can't do dense stuff!"*) (*DS.unsafeViewContents ds, DS.nonZeroSize ds, TODO exception*) | HS.DenseKnownNonZeroSize (ds, nz) => 0 (*raise Fail "Can't do dense stuff!"*) (*DS.unsafeViewContents ds, nz, TODO exception*) val initialState = HS.Sparse - (SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) + (HS.SST.singleton {numQubits = numQubits} (B.zeros, C.defaultReal 1.0)) fun runloop () = DataFlowGraphUtil.scheduleWithOracle' @@ -224,9 +218,9 @@ struct val (finalState, numGateApps, counts, gatesVisited) = runloop () val nonZeros = case finalState of - HS.Sparse sst => SST.unsafeViewContents sst - | HS.Dense ds => DS.unsafeViewContents ds - | HS.DenseKnownNonZeroSize (ds, nz) => DS.unsafeViewContents ds + HS.Sparse sst => HS.SST.unsafeViewContents sst + | HS.Dense ds => HS.DS.unsafeViewContents ds + | HS.DenseKnownNonZeroSize (ds, nz) => HS.DS.unsafeViewContents ds val _ = print ("gate app count " ^ Int.toString numGateApps ^ "\n") in {result = nonZeros, counts = Seq.fromList counts} diff --git a/feynsum-sml/src/MkMain.sml b/feynsum-sml/src/MkMain.sml index 2b2e68f..f8667c3 100644 --- a/feynsum-sml/src/MkMain.sml +++ b/feynsum-sml/src/MkMain.sml @@ -15,12 +15,23 @@ struct structure G = Gate (structure B = B structure C = C) + structure SSTLocked = SparseStateTableLockedSlots (structure B = B structure C = C) + structure SSTLockFree = SparseStateTable (structure B = B structure C = C) + structure DS = DenseState (structure B = B structure C = C) + structure HSLocked = HybridState (structure B = B + structure C = C + structure SST = SSTLocked + structure DS = DS) + structure HSLockFree = HybridState (structure B = B + structure C = C + structure SST = SSTLockFree + structure DS = DS) + structure BFSLocked = FullSimBFS (structure B = B structure C = C - structure SST = - SparseStateTableLockedSlots (structure B = B structure C = C) + structure HS = HSLocked structure G = G val disableFusion = disableFusion val maxBranchingStride = maxBranchingStride @@ -35,7 +46,7 @@ struct FullSimBFS (structure B = B structure C = C - structure SST = SparseStateTable (structure B = B structure C = C) + structure HS = HSLockFree structure G = G val disableFusion = disableFusion val maxBranchingStride = maxBranchingStride diff --git a/feynsum-sml/src/main.sml b/feynsum-sml/src/main.sml index d1c1c2d..ece05ec 100644 --- a/feynsum-sml/src/main.sml +++ b/feynsum-sml/src/main.sml @@ -95,15 +95,6 @@ val _ = val disableFusion = CLA.parseFlag "scheduler-disable-fusion" val maxBranchingStride = CLA.parseInt "scheduler-max-branching-stride" 2 -(*structure GNB = - GateSchedulerGreedyNonBranching - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion) - -structure GFQ = - GateSchedulerGreedyFinishQubit - (val maxBranchingStride = maxBranchingStride - val disableFusion = disableFusion)*) structure DGNB = GreedyNonBranchingScheduler