From 8f03bb4d29b32f75e23021e221d14a5db8d5a6c3 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 17 Nov 2025 11:52:55 +0100 Subject: [PATCH 1/5] Allow Cache Deletes and More Robust Implementation of SubscribableQueues to avoid Race Conditions and Unclosed Streams --- .../java/org/apache/sysds/api/DMLScript.java | 3 + .../sysds/hops/rewrite/ProgramRewriter.java | 3 +- .../hops/rewrite/RewriteInjectOOCTee.java | 212 +++++++++++++----- .../controlprogram/caching/CacheableData.java | 15 +- .../controlprogram/parfor/LocalTaskQueue.java | 2 +- .../cp/VariableCPInstruction.java | 7 + .../instructions/ooc/CachingStream.java | 161 ++++++++++--- .../ooc/IndexingOOCInstruction.java | 28 +++ .../ooc/MatrixIndexingOOCInstruction.java | 86 +++---- .../instructions/ooc/OOCEvictionManager.java | 36 ++- .../instructions/ooc/OOCInstruction.java | 203 +++++++++++------ .../runtime/instructions/ooc/OOCStream.java | 26 +++ .../instructions/ooc/OOCStreamable.java | 2 - .../runtime/instructions/ooc/OOCWatchdog.java | 76 +++++++ .../ParameterizedBuiltinOOCInstruction.java | 27 +-- .../instructions/ooc/PlaybackStream.java | 66 +++++- .../ooc/SubscribableTaskQueue.java | 168 ++++++++++---- .../instructions/ooc/TeeOOCInstruction.java | 42 +++- .../sysds/test/functions/ooc/PCATest.java | 123 ++++++++++ src/test/scripts/functions/ooc/PCA.dml | 28 +++ 20 files changed, 1043 insertions(+), 271 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java create mode 100644 src/test/scripts/functions/ooc/PCA.dml diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 65805b5c2ed..748a0c43ac0 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -71,6 +71,7 @@ import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; +import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.LineageCacheConfig; import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy; @@ -497,6 +498,8 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, MapApply Rewrites (Modification): Iterate over the collected candidate and put * {@code TeeOp}, and safely rewire the graph. */ -public class RewriteInjectOOCTee extends HopRewriteRule { +public class RewriteInjectOOCTee extends StatementBlockRewriteRule { public static boolean APPLY_ONLY_XtX_PATTERN = false; + + private static final Map _transientVars = new HashMap<>(); + private static final Map> _transientHops = new HashMap<>(); + private static final Set teeTransientVars = new HashSet<>(); private static final Set rewrittenHops = new HashSet<>(); private static final Map handledHop = new HashMap<>(); // Maintain a list of candidates to rewrite in the second pass private final List rewriteCandidates = new ArrayList<>(); - - /** - * Handle a generic (last-level) hop DAG with multiple roots. - * - * @param roots high-level operator roots - * @param state program rewrite status - * @return list of high-level operators - */ - @Override - public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { - if (roots == null) { - return null; - } - - // Clear candidates for this pass - rewriteCandidates.clear(); - - // PASS 1: Identify candidates without modifying the graph - for (Hop root : roots) { - root.resetVisitStatus(); - findRewriteCandidates(root); - } - - // PASS 2: Apply rewrites to identified candidates - for (Hop candidate : rewriteCandidates) { - applyTopDownTeeRewrite(candidate); - } - - return roots; - } - - /** - * Handle a predicate hop DAG with exactly one root. - * - * @param root high-level operator root - * @param state program rewrite status - * @return high-level operator - */ - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { - if (root == null) { - return null; - } - - // Clear candidates for this pass - rewriteCandidates.clear(); - - // PASS 1: Identify candidates without modifying the graph - root.resetVisitStatus(); - findRewriteCandidates(root); - - // PASS 2: Apply rewrites to identified candidates - for (Hop candidate : rewriteCandidates) { - applyTopDownTeeRewrite(candidate); - } - - return root; - } + private boolean forceTee = false; /** * First pass: Find candidates for rewrite without modifying the graph. @@ -137,6 +85,35 @@ private void findRewriteCandidates(Hop hop) { findRewriteCandidates(input); } + boolean isRewriteCandidate = DMLScript.USE_OOC + && hop.getDataType().isMatrix() + && !HopRewriteUtils.isData(hop, OpOpData.TEE) + && hop.getParent().size() > 1 + && (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern(hop)); + + if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) && hop.getDataType().isMatrix()) { + _transientVars.compute(hop.getName(), (key, ctr) -> { + int incr = (isRewriteCandidate || forceTee) ? 2 : 1; + + int ret = ctr == null ? 0 : ctr; + ret += incr; + + if (ret > 1) + teeTransientVars.add(hop.getName()); + + return ret; + }); + + _transientHops.compute(hop.getName(), (key, hops) -> { + if (hops == null) + return new ArrayList<>(List.of(hop)); + hops.add(hop); + return hops; + }); + + return; // We do not tee transient reads but rather inject before TWrite or PRead as caching stream + } + // Check if this hop is a candidate for OOC Tee injection if (DMLScript.USE_OOC && hop.getDataType().isMatrix() @@ -160,11 +137,17 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { return; } + int consumerCount = sharedInput.getParent().size(); + if (LOG.isDebugEnabled()) { + LOG.debug("Inject tee for hop " + sharedInput.getHopID() + " (" + + sharedInput.getName() + "), consumers=" + consumerCount); + } + // Take a defensive copy of consumers before modifying the graph ArrayList consumers = new ArrayList<>(sharedInput.getParent()); // Create the new TeeOp with the original hop as input - DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), + DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), sharedInput.getDataType(), sharedInput.getValueType(), Types.OpOpData.TEE, null, sharedInput.getDim1(), sharedInput.getDim2(), sharedInput.getNnz(), sharedInput.getBlocksize()); HopRewriteUtils.addChildReference(teeOp, sharedInput); @@ -177,6 +160,11 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { // Record that we've handled this hop handledHop.put(sharedInput.getHopID(), teeOp); rewrittenHops.add(sharedInput.getHopID()); + + if (LOG.isDebugEnabled()) { + LOG.debug("Created tee hop " + teeOp.getHopID() + " -> " + + teeOp.getName()); + } } @SuppressWarnings("unused") @@ -196,4 +184,108 @@ else if (HopRewriteUtils.isMatrixMultiply(parent)) { } return hasTransposeConsumer && hasMatrixMultiplyConsumer; } + + @Override + public boolean createsSplitDag() { + return false; + } + + @Override + public List rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) { + if (!DMLScript.USE_OOC) + return List.of(sb); + + rewriteSB(sb, state); + + for (String tVar : teeTransientVars) { + List tHops = _transientHops.get(tVar); + + if (tHops == null) + continue; + + for (Hop affectedHops : tHops) { + applyTopDownTeeRewrite(affectedHops); + } + + tHops.clear(); + } + + removeRedundantTeeChains(sb); + + return List.of(sb); + } + + @Override + public List rewriteStatementBlocks(List sbs, ProgramRewriteStatus state) { + if (!DMLScript.USE_OOC) + return sbs; + + for (StatementBlock sb : sbs) + rewriteSB(sb, state); + + for (String tVar : teeTransientVars) { + List tHops = _transientHops.get(tVar); + + if (tHops == null) + continue; + + for (Hop affectedHops : tHops) { + applyTopDownTeeRewrite(affectedHops); + } + } + + for (StatementBlock sb : sbs) + removeRedundantTeeChains(sb); + + return sbs; + } + + private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) { + rewriteCandidates.clear(); + + if (sb.getHops() != null) { + for(Hop hop : sb.getHops()) { + hop.resetVisitStatus(); + findRewriteCandidates(hop); + } + } + + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + } + + private void removeRedundantTeeChains(StatementBlock sb) { + if (sb == null || sb.getHops() == null) + return; + + Hop.resetVisitStatus(sb.getHops()); + for (Hop hop : sb.getHops()) + removeRedundantTeeChains(hop); + Hop.resetVisitStatus(sb.getHops()); + } + + private void removeRedundantTeeChains(Hop hop) { + if (hop.isVisited()) + return; + + ArrayList inputs = new ArrayList<>(hop.getInput()); + for (Hop in : inputs) + removeRedundantTeeChains(in); + + if (HopRewriteUtils.isData(hop, OpOpData.TEE) && hop.getInput().size() == 1) { + Hop teeInput = hop.getInput().get(0); + if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) { + if (LOG.isDebugEnabled()) { + LOG.debug("Remove redundant tee hop " + hop.getHopID() + + " (" + hop.getName() + ") -> " + teeInput.getHopID() + + " (" + teeInput.getName() + ")"); + } + HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput); + HopRewriteUtils.removeAllChildReferences(hop); + } + } + + hop.setVisited(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 34a8aa18631..d826af89c0e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -471,12 +471,12 @@ public boolean hasBroadcastHandle() { return _bcHandle != null && _bcHandle.hasBackReference(); } - public OOCStream getStreamHandle() { + public synchronized OOCStream getStreamHandle() { if( !hasStreamHandle() ) { final SubscribableTaskQueue _mStream = new SubscribableTaskQueue<>(); - _streamHandle = _mStream; DataCharacteristics dc = getDataCharacteristics(); MatrixBlock src = (MatrixBlock)acquireReadAndRelease(); + _streamHandle = _mStream; LongStream.range(0, dc.getNumBlocks()) .mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i)) .forEach( blk -> { @@ -489,7 +489,14 @@ public OOCStream getStreamHandle() { _mStream.closeInput(); } - return _streamHandle.getReadStream(); + OOCStream stream = _streamHandle.getReadStream(); + if (!stream.hasStreamCache()) + _streamHandle = null; // To ensure read once + return stream; + } + + public OOCStreamable getStreamable() { + return _streamHandle; } /** @@ -499,7 +506,7 @@ public OOCStream getStreamHandle() { * @return true if existing, false otherwise */ public boolean hasStreamHandle() { - return _streamHandle != null && !_streamHandle.isProcessed(); + return _streamHandle != null; } @SuppressWarnings({ "rawtypes", "unchecked" }) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java index 783981e0f12..50143cd0ad7 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java @@ -45,7 +45,7 @@ public class LocalTaskQueue protected LinkedList _data = null; protected boolean _closedInput = false; - private DMLRuntimeException _failure = null; + protected DMLRuntimeException _failure = null; private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName()); public LocalTaskQueue() diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 5dd8e55e821..83421bf5d82 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -46,6 +46,8 @@ import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5; @@ -1026,6 +1028,9 @@ private void processCopyInstruction(ExecutionContext ec) { if ( dd == null ) throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString()); + if (DMLScript.USE_OOC && dd instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject)dd).getStreamable(), 1); + // remove existing variable bound to target name Data input2_data = ec.removeVariable(getInput2().getName()); @@ -1117,6 +1122,8 @@ private void processSetFileNameInstruction(ExecutionContext ec){ public static void processRmvarInstruction( ExecutionContext ec, String varname ) { // remove variable from symbol table Data dat = ec.removeVariable(varname); + if (DMLScript.USE_OOC && dat instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject) dat).getStreamable(), -1); //cleanup matrix data on fs/hdfs (if necessary) if( dat != null ) ec.cleanupDataObject(dat); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index d7c80e4de3c..cdc23911516 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -24,6 +24,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.HashMap; import java.util.Map; @@ -39,6 +40,7 @@ public class CachingStream implements OOCStreamable { // original live stream private final OOCStream _source; + private final IntArrayList _consumptionCounts = new IntArrayList(); // stream identifier private final long _streamId; @@ -54,6 +56,10 @@ public class CachingStream implements OOCStreamable { private DMLRuntimeException _failure; + private boolean deletable = false; + private int maxConsumptionCount = 0; + private int cachePins = 0; + public CachingStream(OOCStream source) { this(source, _streamSeq.getNextID()); } @@ -61,23 +67,43 @@ public CachingStream(OOCStream source) { public CachingStream(OOCStream source, long streamId) { _source = source; _streamId = streamId; - source.setSubscriber(() -> { + source.setSubscriber(tmp -> { try { - boolean closed = fetchFromStream(); - Runnable[] mSubscribers = _subscribers; + final IndexedMatrixValue task = tmp.get(); + int blk; + Runnable[] mSubscribers; + + synchronized (this) { + if(task != LocalTaskQueue.NO_MORE_TASKS) { + if (!_cacheInProgress) + throw new DMLRuntimeException("Stream is closed"); + OOCEvictionManager.put(_streamId, _numBlocks, task); + if (_index != null) + _index.put(task.getIndexes(), _numBlocks); + blk = _numBlocks; + _numBlocks++; + _consumptionCounts.add(0); + notifyAll(); + } + else { + _cacheInProgress = false; // caching is complete + notifyAll(); + blk = -1; + } + + mSubscribers = _subscribers; + } if(mSubscribers != null) { for(Runnable mSubscriber : mSubscribers) mSubscriber.run(); - if (closed) { + if (blk == -1) { synchronized (this) { _subscribers = null; } } } - } catch (InterruptedException e) { - throw new DMLRuntimeException(e); } catch (DMLRuntimeException e) { // Propagate failure to subscribers _failure = e; @@ -98,25 +124,28 @@ public CachingStream(OOCStream source, long streamId) { }); } - private synchronized boolean fetchFromStream() throws InterruptedException { - if(!_cacheInProgress) - throw new DMLRuntimeException("Stream is closed"); + public synchronized void scheduleDeletion() { + deletable = true; + if (_cacheInProgress && maxConsumptionCount == 0) + throw new DMLRuntimeException("Cannot have a caching stream with no listeners"); + for (int i = 0; i < _consumptionCounts.size(); i++) { + tryDeleteBlock(i); + } + } - IndexedMatrixValue task = _source.dequeue(); + public String toString() { + return "CachingStream@" + _streamId; + } - if(task != LocalTaskQueue.NO_MORE_TASKS) { - OOCEvictionManager.put(_streamId, _numBlocks, task); - if (_index != null) - _index.put(task.getIndexes(), _numBlocks); - _numBlocks++; - notifyAll(); - return false; - } - else { - _cacheInProgress = false; // caching is complete - notifyAll(); - return true; - } + private synchronized void tryDeleteBlock(int i) { + if (cachePins > 0) + return; // Block deletion is prevented + + int count = _consumptionCounts.getInt(i); + if (count > maxConsumptionCount) + throw new DMLRuntimeException("Cannot have more than " + maxConsumptionCount + " consumptions."); + if (count == maxConsumptionCount) + OOCEvictionManager.forget(_streamId, i); } public synchronized IndexedMatrixValue get(int idx) throws InterruptedException { @@ -129,6 +158,16 @@ else if (idx < _numBlocks) { if (_index != null) // Ensure index is up to date _index.putIfAbsent(out.getIndexes(), idx); + int newCount = _consumptionCounts.getInt(idx)+1; + + if (newCount > maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow! Expected: " + maxConsumptionCount); + + _consumptionCounts.set(idx, newCount); + + if (deletable) + tryDeleteBlock(idx); + return out; } else if (!_cacheInProgress) return (IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS; @@ -137,8 +176,31 @@ else if (idx < _numBlocks) { } } + public synchronized int findCachedIndex(MatrixIndexes idx) { + return _index.get(idx); + } + public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) { - return OOCEvictionManager.get(_streamId, _index.get(idx)); + int mIdx = _index.get(idx); + int newCount = _consumptionCounts.getInt(mIdx)+1; + if (newCount > maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + _consumptionCounts.set(mIdx, newCount); + + IndexedMatrixValue imv = OOCEvictionManager.get(_streamId, mIdx); + + if (deletable) + tryDeleteBlock(mIdx); + + return imv; + } + + /** + * Finds a cached item without counting it as a consumption. + */ + public synchronized IndexedMatrixValue peekCached(MatrixIndexes idx) { + int mIdx = _index.get(idx); + return OOCEvictionManager.get(_streamId, mIdx); } public synchronized void activateIndexing() { @@ -161,12 +223,18 @@ public boolean isProcessed() { return false; } - @Override - public void setSubscriber(Runnable subscriber) { + public void setSubscriber(Runnable subscriber, boolean incrConsumers) { + if (deletable) + throw new DMLRuntimeException("Cannot register a new subscriber on " + this + " because has been flagged for deletion"); + int mNumBlocks; + boolean cacheInProgress; synchronized (this) { mNumBlocks = _numBlocks; - if (_cacheInProgress) { + cacheInProgress = _cacheInProgress; + if (incrConsumers) + maxConsumptionCount++; + if (cacheInProgress) { int newLen = _subscribers == null ? 1 : _subscribers.length + 1; Runnable[] newSubscribers = new Runnable[newLen]; @@ -181,7 +249,44 @@ public void setSubscriber(Runnable subscriber) { for (int i = 0; i < mNumBlocks; i++) subscriber.run(); - if (!_cacheInProgress) + if (!cacheInProgress) subscriber.run(); // To fetch the NO_MORE_TASK element } + + /** + * Artificially increase subscriber count. + * Only use if certain blocks are accessed more than once. + */ + public synchronized void incrSubscriberCount(int count) { + maxConsumptionCount += count; + } + + /** + * Artificially increase the processing count of a block. + */ + public synchronized void incrProcessingCount(int i, int count) { + _consumptionCounts.set(i, _consumptionCounts.getInt(i)+count); + + if (deletable) + tryDeleteBlock(i); + } + + /** + * Force pins blocks in the cache to not be subject to block deletion. + */ + public synchronized void pinStream() { + cachePins++; + } + + /** + * Unpins the stream, allowing blocks to be deleted from cache. + */ + public synchronized void unpinStream() { + cachePins--; + + if (cachePins == 0) { + for (int i = 0; i < _consumptionCounts.size(); i++) + tryDeleteBlock(i); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java index 1d555da8d6c..175d81d6e06 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java @@ -115,6 +115,34 @@ public boolean isAligned() { return (_indexRange.rowStart % _blocksize) == 0 && (_indexRange.colStart % _blocksize) == 0; } + public int getNumConsumptions(MatrixIndexes index) { + long blockRow = index.getRowIndex() - 1; + long blockCol = index.getColumnIndex() - 1; + + if(!_blockRange.isWithin(blockRow, blockCol)) + return 0; + + long blockRowStart = blockRow * _blocksize; + long blockRowEnd = blockRowStart + _blocksize - 1; + long blockColStart = blockCol * _blocksize; + long blockColEnd = blockColStart + _blocksize - 1; + + long overlapRowStart = Math.max(_indexRange.rowStart, blockRowStart); + long overlapRowEnd = Math.min(_indexRange.rowEnd, blockRowEnd); + long overlapColStart = Math.max(_indexRange.colStart, blockColStart); + long overlapColEnd = Math.min(_indexRange.colEnd, blockColEnd); + + if(overlapRowStart > overlapRowEnd || overlapColStart > overlapColEnd) + return 0; + + int outRowStart = (int) ((overlapRowStart - _indexRange.rowStart) / _blocksize); + int outRowEnd = (int) ((overlapRowEnd - _indexRange.rowStart) / _blocksize); + int outColStart = (int) ((overlapColStart - _indexRange.colStart) / _blocksize); + int outColEnd = (int) ((overlapColEnd - _indexRange.colStart) / _blocksize); + + return (outRowEnd - outRowStart + 1) * (outColEnd - outColStart + 1); + } + public boolean putNext(MatrixIndexes index, T data, BiConsumer> emitter) { long blockRow = index.getRowIndex() - 1; long blockCol = index.getColumnIndex() - 1; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java index a04a77677cd..33c6675051e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java @@ -33,6 +33,7 @@ import org.apache.sysds.runtime.util.IndexRange; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -43,10 +44,10 @@ public MatrixIndexingOOCInstruction(CPOperand in, CPOperand rl, CPOperand ru, CP super(in, rl, ru, cl, cu, out, opcode, istr); } - protected MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, - CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { - super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr); - } +// protected MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, +// CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { +// super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr); +// } @Override public void processInstruction(ExecutionContext ec) { @@ -96,8 +97,9 @@ public void processInstruction(ExecutionContext ec) { final int outBlockCols = (int) Math.ceil((double) (ix.colSpan() + 1) / blocksize); final int totalBlocks = outBlockRows * outBlockCols; final AtomicInteger producedBlocks = new AtomicInteger(0); + CompletableFuture future = new CompletableFuture<>(); - CompletableFuture future = filterOOC(qIn, tmp -> { + filterOOC(qIn, tmp -> { MatrixIndexes inIdx = tmp.getIndexes(); long blockRow = inIdx.getRowIndex() - 1; long blockCol = inIdx.getColumnIndex() - 1; @@ -124,12 +126,12 @@ public void processInstruction(ExecutionContext ec) { long outBlockCol = blockCol - firstBlockCol + 1; qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(outBlockRow, outBlockCol), outBlock)); - if(producedBlocks.incrementAndGet() >= totalBlocks) { - CompletableFuture f = futureRef.get(); - if(f != null) - f.cancel(true); - } + if(producedBlocks.incrementAndGet() >= totalBlocks) + future.complete(null); }, tmp -> { + if (future.isDone()) // Then we may skip blocks and avoid submitting tasks + return false; + long blockRow = tmp.getIndexes().getRowIndex() - 1; long blockCol = tmp.getIndexes().getColumnIndex() - 1; return blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && @@ -139,20 +141,23 @@ public void processInstruction(ExecutionContext ec) { return; } - final BlockAligner aligner = new BlockAligner<>(ix, blocksize); + final BlockAligner aligner = new BlockAligner<>(ix, blocksize); + final ConcurrentHashMap consumptionCounts = new ConcurrentHashMap<>(); // We may need to construct our own intermediate stream to properly manage the cached items boolean hasIntermediateStream = !qIn.hasStreamCache(); final CachingStream cachedStream = hasIntermediateStream ? new CachingStream(new SubscribableTaskQueue<>()) : qOut.getStreamCache(); cachedStream.activateIndexing(); + cachedStream.incrSubscriberCount(1); // We may require re-consumption of blocks (up to 4 times) + final CompletableFuture future = new CompletableFuture<>(); - CompletableFuture future = filterOOC(qIn.getReadStream(), tmp -> { + filterOOC(qIn.getReadStream(), tmp -> { if (hasIntermediateStream) { // We write to an intermediate stream to ensure that these matrix blocks are properly cached cachedStream.getWriteStream().enqueue(tmp); } - boolean completed = aligner.putNext(tmp.getIndexes(), new IndexedBlockMeta(tmp), (idx, sector) -> { + boolean completed = aligner.putNext(tmp.getIndexes(), tmp.getIndexes(), (idx, sector) -> { int targetBlockRow = (int) (idx.getRowIndex() - 1); int targetBlockCol = (int) (idx.getColumnIndex() - 1); @@ -176,18 +181,18 @@ public void processInstruction(ExecutionContext ec) { for(int r = 0; r < rowSegments; r++) { for(int c = 0; c < colSegments; c++) { - IndexedBlockMeta ibm = sector.get(r, c); - if(ibm == null) + MatrixIndexes mIdx = sector.get(r, c); + if(mIdx == null) continue; - IndexedMatrixValue mv = cachedStream.findCached(ibm.idx); + IndexedMatrixValue mv = cachedStream.peekCached(mIdx); MatrixBlock srcBlock = (MatrixBlock) mv.getValue(); if(target == null) target = new MatrixBlock(nRows, nCols, srcBlock.isInSparseFormat()); - long srcBlockRowStart = (ibm.idx.getRowIndex() - 1) * blocksize; - long srcBlockColStart = (ibm.idx.getColumnIndex() - 1) * blocksize; + long srcBlockRowStart = (mIdx.getRowIndex() - 1) * blocksize; + long srcBlockColStart = (mIdx.getColumnIndex() - 1) * blocksize; long sliceRowStartGlobal = Math.max(targetRowStartGlobal, srcBlockRowStart); long sliceRowEndGlobal = Math.min(targetRowEndGlobal, srcBlockRowStart + srcBlock.getNumRows() - 1); @@ -205,21 +210,31 @@ public void processInstruction(ExecutionContext ec) { MatrixBlock sliced = srcBlock.slice(sliceRowStart, sliceRowEnd, sliceColStart, sliceColEnd); sliced.putInto(target, targetRowOffset, targetColOffset, true); + final int maxConsumptions = aligner.getNumConsumptions(mIdx); + + Integer con = consumptionCounts.compute(mIdx, (k, v) -> { + if (v == null) + v = 0; + v = v+1; + if (v == maxConsumptions) + return null; + return v; + }); + + if (con == null) + cachedStream.incrProcessingCount(cachedStream.findCachedIndex(mIdx), 1); } } qOut.enqueue(new IndexedMatrixValue(idx, target)); }); - if(completed) { - // All blocks have been processed; we can cancel the future - // Currently, this does not affect processing (predicates prevent task submission anyway). - // However, a cancelled future may allow early file read aborts once implemented. - CompletableFuture f = futureRef.get(); - if(f != null) - f.cancel(true); - } + if(completed) + future.complete(null); }, tmp -> { + if (future.isDone()) // Then we may skip blocks and avoid submitting tasks + return false; + // Pre-filter incoming blocks to avoid unnecessary task submission long blockRow = tmp.getIndexes().getRowIndex() - 1; long blockCol = tmp.getIndexes().getColumnIndex() - 1; @@ -228,8 +243,15 @@ public void processInstruction(ExecutionContext ec) { }, () -> { aligner.close(); qOut.closeInput(); + }, tmp -> { + // If elements are not processed in an existing caching stream, we increment the process counter to allow block deletion + if (!hasIntermediateStream) + cachedStream.incrProcessingCount(cachedStream.findCachedIndex(tmp.getIndexes()), 1); }); futureRef.set(future); + + if (hasIntermediateStream) + cachedStream.scheduleDeletion(); // We can immediately delete blocks after consumption } //left indexing else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) { @@ -239,16 +261,4 @@ else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) { throw new DMLRuntimeException( "Invalid opcode (" + opcode + ") encountered in MatrixIndexingOOCInstruction."); } - - private static class IndexedBlockMeta { - public final MatrixIndexes idx; - ////public final long nrows; - //public final long ncols; - - public IndexedBlockMeta(IndexedMatrixValue mv) { - this.idx = mv.getIndexes(); - //this.nrows = mv.getValue().getNumRows(); - //this.ncols = mv.getValue().getNumColumns(); - } - } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java index f5ae7573b0a..dace1ab9e53 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java @@ -84,7 +84,7 @@ public class OOCEvictionManager { // Configuration: OOC buffer limit as percentage of heap - private static final double OOC_BUFFER_PERCENTAGE = 0.15 * 0.01 * 2; // 15% of heap + private static final double OOC_BUFFER_PERCENTAGE = 0.15; // 15% of heap private static final double PARTITION_EVICTION_SIZE = 64 * 1024 * 1024; // 64 MB @@ -170,6 +170,40 @@ private static class BlockEntry { LocalFileUtils.createLocalFileIfNotExist(_spillDir); } + public static void reset() { + TeeOOCInstruction.reset(); + if (!_cache.isEmpty()) { + System.err.println("There are dangling elements in the OOC Eviction cache: " + _cache.size()); + } + _size.set(0); + _cache.clear(); + _spillLocations.clear(); + _partitions.clear(); + _partitionCounter.set(0); + _streamPartitions.clear(); + } + + /** + * Removes a block from the cache without setting its data to null. + */ + public static void forget(long streamId, int blockId) { + BlockEntry e; + synchronized (_cacheLock) { + e = _cache.remove(streamId + "_" + blockId); + } + + if (e != null) { + e.lock.lock(); + try { + if (e.state == BlockState.HOT) + _size.addAndGet(-e.size); + } finally { + e.lock.unlock(); + } + System.out.println("Removed block " + streamId + "_" + blockId + " from cache (idx: " + (e.value != null ? e.value.getIndexes() : "?") + ")"); + } + } + /** * Store a block in the OOC cache (serialize once) */ diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index ca13cfdb2c3..1b6862361ed 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.util.OOCJoin; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -131,11 +132,15 @@ protected OOCStream createWritableStream() { return new SubscribableTaskQueue<>(); } - protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer) { + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer) { + return filterOOC(qIn, processor, predicate, finalizer, null); + } + + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer, Consumer onNotProcessed) { if (_inQueues == null || _outQueues == null) throw new NotImplementedException("filterOOC requires manual specification of all input and output streams for error propagation"); - return submitOOCTasks(qIn, processor, finalizer, predicate); + return submitOOCTasks(qIn, processor, finalizer, predicate, onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp) : null); } protected CompletableFuture mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { @@ -163,10 +168,16 @@ protected CompletableFuture broadcastJoinOOC(OOCStream> availableLeftInput = new ConcurrentHashMap<>(); Map availableBroadcastInput = new ConcurrentHashMap<>(); - return submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { + CompletableFuture future = submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { P key = on.apply(tmp); if (i == 0) { // qIn stream @@ -184,11 +195,22 @@ protected CompletableFuture broadcastJoinOOC(OOCStream CompletableFuture broadcastJoinOOC(OOCStream { + availableBroadcastInput.forEach((k, v) -> { + rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1); + }); + availableBroadcastInput.clear(); + qOut.closeInput(); + }); + + if (explicitLeftCaching) + leftCache.scheduleDeletion(); + if (explicitRightCaching) + rightCache.scheduleDeletion(); + + return future; } protected static class BroadcastedElement { @@ -244,7 +283,7 @@ public MatrixIndexes getIndex() { public IndexedMatrixValue getValue() { return value; } - }; + } protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function on) { return joinOOC(qIn1, qIn2, qOut, mapper, on, on); @@ -257,12 +296,18 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream final CompletableFuture future = new CompletableFuture<>(); + boolean explicitLeftCaching = !qIn1.hasStreamCache(); + boolean explicitRightCaching = !qIn2.hasStreamCache(); + // We need to construct our own stream to properly manage the cached items in the hash join - CachingStream leftCache = qIn1.hasStreamCache() ? qIn1.getStreamCache() : new CachingStream((SubscribableTaskQueue)qIn1); // We have to assume this generic type for now - CachingStream rightCache = qIn2.hasStreamCache() ? qIn2.getStreamCache() : new CachingStream((SubscribableTaskQueue)qIn2); // We have to assume this generic type for now + CachingStream leftCache = explicitLeftCaching ? new CachingStream((OOCStream) qIn1) : qIn1.getStreamCache(); + CachingStream rightCache = explicitRightCaching ? new CachingStream((OOCStream) qIn2) : qIn2.getStreamCache(); leftCache.activateIndexing(); rightCache.activateIndexing(); + leftCache.incrSubscriberCount(1); + rightCache.incrSubscriberCount(1); + final OOCJoin join = new OOCJoin<>((idx, left, right) -> { T leftObj = (T) leftCache.findCached(left); T rightObj = (T) rightCache.findCached(right); @@ -280,36 +325,40 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream future.complete(null); }); + if (explicitLeftCaching) + leftCache.scheduleDeletion(); + if (explicitRightCaching) + rightCache.scheduleDeletion(); + return future; } protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer) { + return submitOOCTasks(queues, consumer, finalizer, null); + } + + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, BiConsumer onNotProcessed) { List> futures = new ArrayList<>(queues.size()); for (int i = 0; i < queues.size(); i++) futures.add(new CompletableFuture<>()); - return submitOOCTasks(queues, consumer, finalizer, futures, null); + return submitOOCTasks(queues, consumer, finalizer, futures, null, onNotProcessed); } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, List> futures, BiFunction predicate) { + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, List> futures, BiFunction predicate, BiConsumer onNotProcessed) { addInStream(queues.toArray(OOCStream[]::new)); ExecutorService pool = CommonThreadPool.get(); final List activeTaskCtrs = new ArrayList<>(queues.size()); - final List streamsClosed = new ArrayList<>(queues.size()); - for (int i = 0; i < queues.size(); i++) { - activeTaskCtrs.add(new AtomicInteger(0)); - streamsClosed.add(new AtomicBoolean(false)); - } + for (int i = 0; i < queues.size(); i++) + activeTaskCtrs.add(new AtomicInteger(1)); - final AtomicInteger globalTaskCtr = new AtomicInteger(0); final CompletableFuture globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); if (_outQueues == null) _outQueues = Collections.emptySet(); final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new)); - final Object globalLock = new Object(); int i = 0; @SuppressWarnings("unused") @@ -319,84 +368,67 @@ protected CompletableFuture submitOOCTasks(final List> qu for (OOCStream queue : queues) { final int k = i; final AtomicInteger localTaskCtr = activeTaskCtrs.get(k); - final AtomicBoolean localStreamClosed = streamsClosed.get(k); final CompletableFuture localFuture = futures.get(k); + final AtomicBoolean closeRaceWatchdog = new AtomicBoolean(false); //System.out.println("Substream (k " + k + ", id " + streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + queue.hashCode() + ")"); - queue.setSubscriber(oocTask(() -> { - final T item = queue.dequeue(); + queue.setSubscriber(oocTask(callback -> { + final T item = callback.get(); - if (predicate != null && item != null && !predicate.apply(k, item)) // Can get closed due to cancellation - return; + if(item == null) { + if(!closeRaceWatchdog.compareAndSet(false, true)) + throw new DMLRuntimeException("Race condition observed: NO_MORE_TASKS callback has been triggered more than once"); - synchronized (globalLock) { - if (localFuture.isDone()) - return; + if(localTaskCtr.decrementAndGet() == 0) { + // Then we can run the finalization procedure already + localFuture.complete(null); + } + return; + } - globalTaskCtr.incrementAndGet(); + if(predicate != null && !predicate.apply(k, item)) { // Can get closed due to cancellation + if(onNotProcessed != null) + onNotProcessed.accept(k, item); + return; } - localTaskCtr.incrementAndGet(); + if(localFuture.isDone()) { + if(onNotProcessed != null) + onNotProcessed.accept(k, item); + return; + } + else { + localTaskCtr.incrementAndGet(); + } pool.submit(oocTask(() -> { - if(item != null) { - //System.out.println("Accept" + ((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId + ")"); - consumer.accept(k, item); - } - else { - //System.out.println("Close substream (k " + k + ", id " + streamId + ")"); - localStreamClosed.set(true); - } - - boolean runFinalizer = false; - - synchronized (globalLock) { - int localTasks = localTaskCtr.decrementAndGet(); - boolean finalizeStream = localTasks == 0 && localStreamClosed.get(); - - int globalTasks = globalTaskCtr.get() - 1; - - if (finalizeStream || (globalFuture.isDone() && localTasks == 0)) { - localFuture.complete(null); + // TODO For caching streams, we have no guarantee that item is still in memory -> NullPointer possible + consumer.accept(k, item); - if (globalFuture.isDone() && globalTasks == 0) - runFinalizer = true; - } - - globalTaskCtr.decrementAndGet(); - } - - if (runFinalizer) - oocFinalizer.run(); + if(localTaskCtr.decrementAndGet() == 0) + localFuture.complete(null); }, localFuture, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); + + if(closeRaceWatchdog.get()) // Sanity check + throw new DMLRuntimeException("Race condition observed"); }, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); i++; } - pool.shutdown(); - globalFuture.whenComplete((res, e) -> { - if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) + if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) { futures.forEach(f -> { - if (!f.isDone()) { - if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) + if(!f.isDone()) { + if(globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) f.cancel(true); else f.complete(null); } }); - - boolean runFinalizer; - - synchronized (globalLock) { - runFinalizer = globalTaskCtr.get() == 0; } - if (runFinalizer) - oocFinalizer.run(); - - //System.out.println("Shutdown (id " + streamId + ")"); + oocFinalizer.run(); }); return globalFuture; } @@ -405,8 +437,8 @@ protected CompletableFuture submitOOCTasks(OOCStream queue, Consume return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); } - protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer, Function predicate) { - return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp)); + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer, Function predicate, BiConsumer onNotProcessed) { + return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp), onNotProcessed); } protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queues) { @@ -450,6 +482,31 @@ private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream< }; } + private Consumer> oocTask(Consumer> c, CompletableFuture future, OOCStream... queues) { + return callback -> { + try { + c.accept(callback); + } + catch (Exception ex) { + DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); + + if (_failed) // Do avoid infinite cycles + throw re; + + _failed = true; + + for (OOCStream q : queues) + q.propagateFailure(re); + + if (future != null) + future.completeExceptionally(re); + + // Rethrow to ensure proper future handling + throw re; + } + }; + } + /** * Tracks blocks and their counts to enable early emission * once all blocks for a given index are processed. diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index 1a12cb138b7..f02c847e055 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -22,6 +22,8 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import java.util.function.Consumer; + public interface OOCStream extends OOCStreamable { void enqueue(T t); @@ -36,4 +38,28 @@ public interface OOCStream extends OOCStreamable { boolean hasStreamCache(); CachingStream getStreamCache(); + + /** + * Registers a new subscriber that consumes the stream. + * While there is no guarantee for any specific order, the closing item LocalTaskQueue.NO_MORE_TASKS + * is guaranteed to be invoked after every other item has finished processing. Thus, the NO_MORE_TASKS + * callback can be used to free dependent resources and close output streams. + */ + void setSubscriber(Consumer> subscriber); + + class QueueCallback { + private final T _result; + private final DMLRuntimeException _failure; + + public QueueCallback(T result, DMLRuntimeException failure) { + _result = result; + _failure = failure; + } + + public T get() { + if (_failure != null) + throw _failure; + return _result; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java index bdc4086bdcd..af2c0afa660 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java @@ -25,6 +25,4 @@ public interface OOCStreamable { OOCStream getWriteStream(); boolean isProcessed(); - - void setSubscriber(Runnable subscriber); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java new file mode 100644 index 00000000000..302f57b63ba --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java @@ -0,0 +1,76 @@ +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Watchdog to help debug OOC streams/tasks that never close. + */ +public final class OOCWatchdog { + public static final boolean WATCH = false; + private static final ConcurrentHashMap OPEN = new ConcurrentHashMap<>(); + private static final ScheduledExecutorService EXEC = + Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "TemporaryWatchdog"); + t.setDaemon(true); + return t; + }); + + private static final long STALE_MS = TimeUnit.SECONDS.toMillis(10); + private static final long SCAN_INTERVAL_MS = TimeUnit.SECONDS.toMillis(10); + + static { + EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS, SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS); + } + + private OOCWatchdog() { + // no-op + } + + public static void registerOpen(String id, String desc, String context, OOCStream stream) { + OPEN.put(id, new Entry(desc, context, System.currentTimeMillis(), stream)); + } + + public static void addEvent(String id, String eventMsg) { + Entry e = OPEN.get(id); + if (e != null) + e.events.add(eventMsg); + } + + public static void registerClose(String id) { + OPEN.remove(id); + } + + private static void scan() { + long now = System.currentTimeMillis(); + for (Map.Entry e : OPEN.entrySet()) { + if (now - e.getValue().openedAt >= STALE_MS) { + if (e.getValue().events.isEmpty()) + continue; // Probably just a stream that has no consumer (remains to be checked why this can happen) + System.err.println("[TemporaryWatchdog] Still open after " + (now - e.getValue().openedAt) + "ms: " + + e.getKey() + " (" + e.getValue().desc + ")" + + (e.getValue().context != null ? " ctx=" + e.getValue().context : "")); + } + } + } + + private static class Entry { + final String desc; + final String context; + final long openedAt; + final OOCStream stream; + ConcurrentLinkedQueue events; + + Entry(String desc, String context, long openedAt, OOCStream stream) { + this.desc = desc; + this.context = context; + this.openedAt = openedAt; + this.stream = stream; + this.events = new ConcurrentLinkedQueue<>(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java index e56d32e4401..d70fc3ccb94 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java @@ -20,7 +20,6 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.commons.lang3.NotImplementedException; -import org.apache.commons.lang3.mutable.MutableObject; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.DMLRuntimeException; @@ -43,7 +42,6 @@ import java.util.LinkedHashMap; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicBoolean; public class ParameterizedBuiltinOOCInstruction extends ComputationOOCInstruction { @@ -110,29 +108,26 @@ else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { Data finalPattern = pattern; - AtomicBoolean found = new AtomicBoolean(false); + addInStream(qIn); + addOutStream(); // This instruction has no output stream - MutableObject> futureRef = new MutableObject<>(); - CompletableFuture future = submitOOCTasks(qIn, tmp -> { - boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); + CompletableFuture future = new CompletableFuture<>(); - if (contains) { - found.set(true); + filterOOC(qIn, tmp -> { + boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); - // Now we may complete the future - if (futureRef.getValue() != null) - futureRef.getValue().complete(null); - } - }, () -> {}); - futureRef.setValue(future); + if (contains) + future.complete(true); + }, tmp -> !future.isDone(), // Don't start a separate worker if result already known + () -> future.complete(false)); // Then the pattern was not found + boolean ret; try { - futureRef.getValue().get(); + ret = future.get(); } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException(e); } - boolean ret = found.get(); ec.setScalarOutput(output.getName(), new BooleanObject(ret)); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java index 6edc4ecf270..5b996da0dbe 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java @@ -23,13 +23,22 @@ import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + public class PlaybackStream implements OOCStream, OOCStreamable { private final CachingStream _streamCache; - private int _streamIdx; + private final AtomicInteger _streamIdx; + private final AtomicInteger _taskCtr; + private final AtomicBoolean _subscriberSet; public PlaybackStream(CachingStream streamCache) { this._streamCache = streamCache; - this._streamIdx = 0; + this._streamIdx = new AtomicInteger(0); + this._taskCtr = new AtomicInteger(1); + this._subscriberSet = new AtomicBoolean(false); + streamCache.incrSubscriberCount(1); } @Override @@ -44,15 +53,29 @@ public void closeInput() { @Override public LocalTaskQueue toLocalTaskQueue() { - final SubscribableTaskQueue q = new SubscribableTaskQueue<>(); - setSubscriber(() -> q.enqueue(dequeue())); + final LocalTaskQueue q = new LocalTaskQueue<>(); + setSubscriber(val -> { + if (val.get() == null) { + q.closeInput(); + return; + } + try { + q.enqueueTask(val.get()); + } + catch(InterruptedException e) { + throw new RuntimeException(e); + } + }); return q; } @Override - public synchronized IndexedMatrixValue dequeue() { + public IndexedMatrixValue dequeue() { + if (_subscriberSet.get()) + throw new IllegalStateException("Cannot dequeue from a playback stream if a subscriber has been set"); + try { - return _streamCache.get(_streamIdx++); + return _streamCache.get(_streamIdx.getAndIncrement()); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } @@ -74,8 +97,35 @@ public boolean isProcessed() { } @Override - public void setSubscriber(Runnable subscriber) { - _streamCache.setSubscriber(subscriber); + public void setSubscriber(Consumer> subscriber) { + if (!_subscriberSet.compareAndSet(false, true)) + throw new IllegalArgumentException("Subscriber cannot be set multiple times"); + + /** + * To guarantee that NO_MORE_TASKS is invoked after all subscriber calls + * finished, we keep track of running tasks using a task counter. + */ + _streamCache.setSubscriber(() -> { + try { + _taskCtr.incrementAndGet(); + + IndexedMatrixValue val; + + try { + val = _streamCache.get(_streamIdx.getAndIncrement()); + } catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + + if (val != null) + subscriber.accept(new QueueCallback<>(val, null)); + + if (_taskCtr.addAndGet(val == null ? -2 : -1) == 0) + subscriber.accept(new QueueCallback<>(null, null)); + } catch (DMLRuntimeException e) { + subscriber.accept(new QueueCallback<>(null, e)); + } + }, false); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index f136ffc2bb6..7563d8471b6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -22,80 +22,172 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import java.util.LinkedList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + public class SubscribableTaskQueue extends LocalTaskQueue implements OOCStream { - private Runnable _subscriber; - @Override - public synchronized void enqueue(T t) { - try { - super.enqueueTask(t); - } - catch (InterruptedException e) { - throw new DMLRuntimeException(e); + private final AtomicInteger _availableCtr = new AtomicInteger(1); + private final AtomicBoolean _closed = new AtomicBoolean(false); + private volatile Consumer> _subscriber = null; + private String _watchdogId; + + public SubscribableTaskQueue() { + if (OOCWatchdog.WATCH) { + _watchdogId = "STQ-" + hashCode(); + // Capture a short context to help identify origin + OOCWatchdog.registerOpen(_watchdogId, "SubscribableTaskQueue@" + hashCode(), getCtxMsg(), this); } + } - if(_subscriber != null) - _subscriber.run(); + private String getCtxMsg() { + StackTraceElement[] st = new Exception().getStackTrace(); + // Skip the first few frames (constructor, createWritableStream, etc.) + StringBuilder sb = new StringBuilder(); + int limit = Math.min(st.length, 7); + for(int i = 2; i < limit; i++) { + sb.append(st[i].getClassName()).append(".").append(st[i].getMethodName()).append(":") + .append(st[i].getLineNumber()); + if(i < limit - 1) + sb.append(" <- "); + } + return sb.toString(); } @Override - public T dequeue() { - try { - return super.dequeueTask(); + public void enqueue(T t) { + if (t == NO_MORE_TASKS) + throw new DMLRuntimeException("Cannot enqueue NO_MORE_TASKS item"); + + int cnt = _availableCtr.incrementAndGet(); + + if (cnt <= 1) { // Then the queue was already closed and we disallow further enqueues + _availableCtr.decrementAndGet(); // Undo increment + throw new DMLRuntimeException("Cannot enqueue into closed SubscribableTaskQueue"); } - catch (InterruptedException e) { - throw new DMLRuntimeException(e); + + Consumer> s = _subscriber; + + if (s != null) { + s.accept(new QueueCallback<>(t, _failure)); + onDeliveryFinished(); + return; + } + + synchronized (this) { + // Re-check that subscriber is really null to avoid race conditions + if (_subscriber == null) { + try { + super.enqueueTask(t); + } + catch(InterruptedException e) { + throw new DMLRuntimeException(e); + } + return; + } + // Otherwise do not insert and re-schedule subscriber invocation + s = _subscriber; } + + // Last case if due to race a subscriber has been set + s.accept(new QueueCallback<>(t, _failure)); + onDeliveryFinished(); } @Override - public synchronized void closeInput() { - super.closeInput(); - - if(_subscriber != null) { - _subscriber.run(); - _subscriber = null; - } + public void enqueueTask(T t) { + enqueue(t); } @Override - public LocalTaskQueue toLocalTaskQueue() { - return this; + public T dequeue() { + try { + if (OOCWatchdog.WATCH) + OOCWatchdog.addEvent(_watchdogId, "dequeue -- " + getCtxMsg()); + T deq = super.dequeueTask(); + if (deq != NO_MORE_TASKS) + onDeliveryFinished(); + return deq; + } + catch(InterruptedException e) { + throw new DMLRuntimeException(e); + } } @Override - public OOCStream getReadStream() { - return this; + public T dequeueTask() { + return dequeue(); } @Override - public OOCStream getWriteStream() { - return this; + public void closeInput() { + if (_closed.compareAndSet(false, true)) { + super.closeInput(); + onDeliveryFinished(); + } else { + throw new IllegalStateException("Multiple close input calls"); + } } @Override - public void setSubscriber(Runnable subscriber) { - int queueSize; + public void setSubscriber(Consumer> subscriber) { + if(subscriber == null) + throw new IllegalArgumentException("Cannot set subscriber to null"); - synchronized (this) { + LinkedList data; + + synchronized(this) { if(_subscriber != null) throw new DMLRuntimeException("Cannot set multiple subscribers"); - _subscriber = subscriber; - queueSize = _data.size(); - queueSize += _closedInput ? 1 : 0; // To trigger the NO_MORE_TASK element + if(_failure != null) + throw _failure; + data = _data; + _data = new LinkedList<>(); + } + + for (T t : data) { + subscriber.accept(new QueueCallback<>(t, _failure)); + onDeliveryFinished(); } + } - for (int i = 0; i < queueSize; i++) - subscriber.run(); + private void onDeliveryFinished() { + int ctr = _availableCtr.decrementAndGet(); + + if (ctr == 0) { + Consumer> s = _subscriber; + if (s != null) + s.accept(new QueueCallback<>((T) LocalTaskQueue.NO_MORE_TASKS, _failure)); + + if (OOCWatchdog.WATCH) + OOCWatchdog.registerClose(_watchdogId); + } } @Override public synchronized void propagateFailure(DMLRuntimeException re) { super.propagateFailure(re); + Consumer> s = _subscriber; + if(s != null) + s.accept(new QueueCallback<>(null, re)); + } + + @Override + public LocalTaskQueue toLocalTaskQueue() { + return this; + } + + @Override + public OOCStream getReadStream() { + return this; + } - if(_subscriber != null) - _subscriber.run(); + @Override + public OOCStream getWriteStream() { + return this; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index fd80b4e6e90..aba36297e7f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -25,8 +25,37 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import java.util.concurrent.ConcurrentHashMap; + public class TeeOOCInstruction extends ComputationOOCInstruction { + private static final ConcurrentHashMap refCtr = new ConcurrentHashMap<>(); + + public static void reset() { + if (!refCtr.isEmpty()) { + System.err.println("There are some dangling streams still in the cache: " + refCtr); + refCtr.clear(); + } + } + + /** + * Increments the reference counter of a stream by the set amount. + */ + public static void incrRef(OOCStreamable stream, int incr) { + if (!(stream instanceof CachingStream)) + return; + + Integer ref = refCtr.compute((CachingStream)stream, (k, v) -> { + if (v == null) + v = 0; + v += incr; + return v <= 0 ? null : v; + }); + + if (ref == null) + ((CachingStream)stream).scheduleDeletion(); + } + protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, String opcode, String istr) { super(type, null, in1, out, opcode, istr); } @@ -45,9 +74,20 @@ public void processInstruction( ExecutionContext ec ) { MatrixObject min = ec.getMatrixObject(input1); OOCStream qIn = min.getStreamHandle(); + CachingStream handle = qIn.hasStreamCache() ? qIn.getStreamCache() : new CachingStream(qIn); + + if (!qIn.hasStreamCache()) { + // We also set the input stream handle + min.setStreamHandle(handle); + incrRef(handle, 2); + } + else { + incrRef(handle, 1); + } + //get output and create new resettable stream MatrixObject mo = ec.getMatrixObject(output); - mo.setStreamHandle(new CachingStream(qIn)); + mo.setStreamHandle(handle); mo.setMetaData(min.getMetaData()); } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java new file mode 100644 index 00000000000..e20b7ec4269 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class PCATest extends AutomatedTestBase { + private final static String TEST_NAME1 = "PCA"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + PCATest.class.getSimpleName() + "/"; + //private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME_1 = "PC"; + private static final String OUTPUT_NAME_2 = "V"; + + private final static int rows = 50000; + private final static int cols = 1000; + private final static int maxVal = 2; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testPCA() { + boolean allow_opfusion = OptimizerUtils.ALLOW_OPERATOR_FUSION; + OptimizerUtils.ALLOW_OPERATOR_FUSION = false; // some fused ops are not implemented yet + runPCATest(16); + OptimizerUtils.ALLOW_OPERATOR_FUSION = allow_opfusion; + } + + private void runPCATest(int k) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "hops", "-stats", "-ooc", "-args", input(INPUT_NAME_1), Integer.toString(k), output(OUTPUT_NAME_1), output(OUTPUT_NAME_2)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, 1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + X_data = null; + X_mb = null; + + runTest(true, false, null, -1); + + //check replace OOC op + //Assert.assertTrue("OOC wasn't used for replacement", + // heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.REPLACE)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "hops", "-stats", "-args", input(INPUT_NAME_1), Integer.toString(k), output(OUTPUT_NAME_1 + "_target"), output(OUTPUT_NAME_2 + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + /*MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1 + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + + MatrixBlock ret2_1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2_2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2 + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret2_1, ret2_2, eps);*/ + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/PCA.dml b/src/test/scripts/functions/ooc/PCA.dml new file mode 100644 index 00000000000..567d701ec06 --- /dev/null +++ b/src/test/scripts/functions/ooc/PCA.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +k = $2; + +[PC, V] = pca(X=X, K=k) + +write(PC, $3, format="binary"); +write(V, $4, format="binary"); From f84236f82cb7f60e1dcce171f637beb580558822 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 11 Dec 2025 14:08:05 +0100 Subject: [PATCH 2/5] Add Missing License --- .../runtime/instructions/ooc/OOCWatchdog.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java index 302f57b63ba..b7a16778ab7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.instructions.ooc; import java.util.Map; From 30ff515fa7bc67fb750d28651fa49a4ebc5f9d10 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 1 Dec 2025 10:41:51 +0100 Subject: [PATCH 3/5] [Major] Refactor Cache Manager / Performance Improvements / Better OOC Statistics --- .../java/org/apache/sysds/api/DMLOptions.java | 51 +- .../java/org/apache/sysds/api/DMLScript.java | 17 +- .../apache/sysds/api/ScriptExecutorUtils.java | 4 + .../sysds/hops/rewrite/ProgramRewriter.java | 4 +- .../runtime/compress/io/WriterCompressed.java | 4 +- .../controlprogram/caching/CacheableData.java | 4 +- .../controlprogram/caching/FrameObject.java | 3 +- .../controlprogram/caching/MatrixObject.java | 8 +- .../controlprogram/caching/TensorObject.java | 4 +- .../controlprogram/parfor/LocalTaskQueue.java | 4 + .../instructions/OOCInstructionParser.java | 6 + .../instructions/cp/CPInstruction.java | 8 + .../ooc/AggregateTernaryOOCInstruction.java | 231 +++++++ .../ooc/AggregateUnaryOOCInstruction.java | 128 ++-- .../instructions/ooc/CachingStream.java | 244 +++++-- .../ooc/MatrixIndexingOOCInstruction.java | 49 +- .../ooc/MatrixVectorBinaryOOCInstruction.java | 44 +- .../instructions/ooc/OOCEvictionManager.java | 491 -------------- .../instructions/ooc/OOCInstruction.java | 366 ++++++++--- .../runtime/instructions/ooc/OOCStream.java | 50 +- .../runtime/instructions/ooc/OOCWatchdog.java | 11 +- .../instructions/ooc/PlaybackStream.java | 58 +- .../ooc/SubscribableTaskQueue.java | 15 +- .../ooc/TernaryOOCInstruction.java | 204 ++++++ .../apache/sysds/runtime/io/MatrixWriter.java | 4 +- .../sysds/runtime/io/WriterBinaryBlock.java | 5 +- .../apache/sysds/runtime/io/WriterHDF5.java | 4 +- .../sysds/runtime/io/WriterMatrixMarket.java | 4 +- .../sysds/runtime/io/WriterTextCSV.java | 4 +- .../sysds/runtime/io/WriterTextCell.java | 4 +- .../sysds/runtime/io/WriterTextLIBSVM.java | 4 +- .../sysds/runtime/ooc/cache/BlockEntry.java | 107 ++++ .../sysds/runtime/ooc/cache/BlockKey.java | 67 ++ .../sysds/runtime/ooc/cache/BlockState.java | 49 ++ .../runtime/ooc/cache/CloseableQueue.java | 98 +++ .../runtime/ooc/cache/OOCCacheManager.java | 194 ++++++ .../runtime/ooc/cache/OOCCacheScheduler.java | 81 +++ .../sysds/runtime/ooc/cache/OOCIOHandler.java | 32 + .../ooc/cache/OOCLRUCacheScheduler.java | 605 ++++++++++++++++++ .../runtime/ooc/cache/OOCMatrixIOHandler.java | 359 +++++++++++ .../sysds/runtime/ooc/stats/OOCEventLog.java | 179 ++++++ .../sysds/runtime/util/LocalFileUtils.java | 1 - .../org/apache/sysds/utils/Statistics.java | 157 +++++ .../sysds/test/functions/ooc/LmCGTest.java | 126 ++++ .../sysds/test/functions/ooc/PCATest.java | 7 +- src/test/scripts/functions/ooc/lmCG.dml | 26 + 46 files changed, 3255 insertions(+), 870 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java delete mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/BlockKey.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/BlockState.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java create mode 100644 src/test/scripts/functions/ooc/lmCG.dml diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index 917aecc4ab3..10c41e3d0a8 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -19,6 +19,10 @@ package org.apache.sysds.api; +import java.nio.file.Files; +import java.nio.file.InvalidPathException; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; @@ -66,6 +70,10 @@ public class DMLOptions { public boolean gpu = false; // Whether to use the GPU public boolean forceGPU = false; // Whether to ignore memory & estimates and always use the GPU public boolean ooc = false; // Whether to use the OOC backend + public boolean oocLogEvents = false; // Whether to record I/O and task compute events (fine grained, may impact performance on many small tasks) + public String oocLogPath = "./"; // The directory where to save the recorded event logs (csv) + public boolean oocStats = false; // Wether to record and print coarse grained ooc statistics + public int oocStatsCount = 10; // Default ooc statistics count public boolean debug = false; // to go into debug mode to be able to step through a program public String filePath = null; // path to script public String script = null; // the script itself @@ -105,7 +113,11 @@ public String toString() { ", fedStats=" + fedStats + ", fedStatsCount=" + fedStatsCount + ", fedMonitoring=" + fedMonitoring + - ", fedMonitoringAddress" + fedMonitoringAddress + + ", fedMonitoringAddress=" + fedMonitoringAddress + + ", oocStats=" + oocStats + + ", oocStatsCount=" + oocStatsCount + + ", oocLogEvents=" + oocLogEvents + + ", oocLogPath=" + oocLogPath + ", memStats=" + memStats + ", explainType=" + explainType + ", execMode=" + execMode + @@ -193,7 +205,7 @@ else if (lineageType.equalsIgnoreCase("debugger")) else if (execMode.equalsIgnoreCase("hybrid")) dmlOptions.execMode = ExecMode.HYBRID; else if (execMode.equalsIgnoreCase("spark")) dmlOptions.execMode = ExecMode.SPARK; else throw new org.apache.commons.cli.ParseException("Invalid argument specified for -exec option, must be one of [hadoop, singlenode, hybrid, HYBRID, spark]"); - } + } if (line.hasOption("explain")) { dmlOptions.explainType = ExplainType.RUNTIME; String explainType = line.getOptionValue("explain"); @@ -259,6 +271,33 @@ else if (lineageType.equalsIgnoreCase("debugger")) } } + dmlOptions.oocStats = line.hasOption("oocStats"); + if (dmlOptions.oocStats) { + String oocStatsCount = line.getOptionValue("oocStats"); + if (oocStatsCount != null) { + try { + dmlOptions.oocStatsCount = Integer.parseInt(oocStatsCount); + } catch (NumberFormatException e) { + throw new org.apache.commons.cli.ParseException("Invalid argument specified for -oocStats option, must be a valid integer"); + } + } + } + + dmlOptions.oocLogEvents = line.hasOption("oocLogEvents"); + if (dmlOptions.oocLogEvents) { + String eventLogPath = line.getOptionValue("oocLogEvents"); + if (eventLogPath != null) { + try { + Path p = Paths.get(eventLogPath); + if (!Files.isDirectory(p)) + throw new org.apache.commons.cli.ParseException("Invalid argument specified for -oocLogEvents option, must be valid directory"); + } catch (InvalidPathException e) { + throw new org.apache.commons.cli.ParseException("Invalid argument specified for -oocLogEvents option, must be a valid path"); + } + dmlOptions.oocLogPath = eventLogPath; + } + } + dmlOptions.memStats = line.hasOption("mem"); dmlOptions.clean = line.hasOption("clean"); @@ -387,6 +426,12 @@ private static Options createCLIOptions() { Option fedStatsOpt = OptionBuilder.withArgName("count") .withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter is 10 unless overridden; default off") .hasOptionalArg().create("fedStats"); + Option oocStatsOpt = OptionBuilder + .withDescription("monitors and reports summary execution statistics of ooc operators and tasks; heavy hitter is 10 unless overriden; default off") + .hasOptionalArg().create("oocStats"); + Option oocLogEventsOpt = OptionBuilder + .withDescription("records fine grained events of compute tasks, I/O, and cache; -oocLogEvents [dir='./']") + .hasOptionalArg().create("oocLogEvents"); Option memOpt = OptionBuilder.withDescription("monitors and reports max memory consumption in CP; default off") .create("mem"); Option explainOpt = OptionBuilder.withArgName("level") @@ -452,6 +497,8 @@ private static Options createCLIOptions() { options.addOption(statsOpt); options.addOption(ngramsOpt); options.addOption(fedStatsOpt); + options.addOption(oocStatsOpt); + options.addOption(oocLogEventsOpt); options.addOption(memOpt); options.addOption(explainOpt); options.addOption(execOpt); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 748a0c43ac0..ce9bab53e7f 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -71,11 +71,11 @@ import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; -import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.LineageCacheConfig; import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; +import org.apache.sysds.runtime.ooc.stats.OOCEventLog; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.runtime.util.LocalFileUtils; @@ -150,6 +150,12 @@ public class DMLScript public static boolean SYNCHRONIZE_GPU = true; // Set OOC backend public static boolean USE_OOC = DMLOptions.defaultOptions.ooc; + // Record and print OOC statistics + public static boolean OOC_STATISTICS = DMLOptions.defaultOptions.oocStats; + public static int OOC_STATISTICS_COUNT = DMLOptions.defaultOptions.oocStatsCount; + // Record and save fine grained OOC event logs as csv to the specified dir + public static boolean OOC_LOG_EVENTS = DMLOptions.defaultOptions.oocLogEvents; + public static String OOC_LOG_PATH = DMLOptions.defaultOptions.oocLogPath; // Enable eager CUDA free on rmvar public static boolean EAGER_CUDA_FREE = false; @@ -273,6 +279,10 @@ public static boolean executeScript( String[] args ) USE_ACCELERATOR = dmlOptions.gpu; FORCE_ACCELERATOR = dmlOptions.forceGPU; USE_OOC = dmlOptions.ooc; + OOC_STATISTICS = dmlOptions.oocStats; + OOC_STATISTICS_COUNT = dmlOptions.oocStatsCount; + OOC_LOG_EVENTS = dmlOptions.oocLogEvents; + OOC_LOG_PATH = dmlOptions.oocLogPath; EXPLAIN = dmlOptions.explainType; EXEC_MODE = dmlOptions.execMode; LINEAGE = dmlOptions.lineage; @@ -324,6 +334,9 @@ public static boolean executeScript( String[] args ) LineageCacheConfig.setCachePolicy(LINEAGE_POLICY); LineageCacheConfig.setEstimator(LINEAGE_ESTIMATE); + if (dmlOptions.oocLogEvents) + OOCEventLog.setup(100000); + String dmlScriptStr = readDMLScript(isFile, fileOrScript); Map argVals = dmlOptions.argVals; @@ -498,8 +511,6 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map(); - - + + //STATIC REWRITES (which do not rely on size information) if( staticRewrites ) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java index c4d9db367bb..1ba0ba61d10 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java @@ -46,7 +46,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibSeparator; import org.apache.sysds.runtime.compress.lib.CLALibSeparator.SeparatedGroups; import org.apache.sysds.runtime.compress.lib.CLALibSlice; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.CompressionSPInstruction.CompressionFunction; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils; @@ -410,7 +410,7 @@ public Object call() throws Exception { } @Override - public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) { throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the HDF5 format."); }; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index d826af89c0e..2759ed33925 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -633,7 +633,7 @@ && getRDDHandle() == null) ) { _requiresLocalWrite = false; } else if( hasStreamHandle() ) { - _data = readBlobFromStream( getStreamHandle().toLocalTaskQueue() ); + _data = readBlobFromStream( getStreamHandle() ); } else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) { if( DMLScript.STATISTICS ) @@ -1168,7 +1168,7 @@ protected abstract T readBlobFromHDFS(String fname, long[] dims) protected abstract T readBlobFromRDD(RDDObject rdd, MutableBoolean status) throws IOException; - protected abstract T readBlobFromStream(LocalTaskQueue stream) + protected abstract T readBlobFromStream(OOCStream stream) throws IOException; // Federated read diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java index f4d20bb55a0..8c5a0793467 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java @@ -35,6 +35,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -316,7 +317,7 @@ protected void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String ofmt) } @Override - protected FrameBlock readBlobFromStream(LocalTaskQueue stream) throws IOException { + protected FrameBlock readBlobFromStream(OOCStream stream) throws IOException { // TODO Auto-generated method stub return null; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java index 8191040eb18..b633c4007c7 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java @@ -45,6 +45,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -528,17 +529,16 @@ protected MatrixBlock readBlobFromRDD(RDDObject rdd, MutableBoolean writeStatus) @Override - protected MatrixBlock readBlobFromStream(LocalTaskQueue stream) throws IOException { + protected MatrixBlock readBlobFromStream(OOCStream stream) throws IOException { boolean dimsUnknown = getNumRows() < 0 || getNumColumns() < 0; int nrows = (int)getNumRows(); int ncols = (int)getNumColumns(); MatrixBlock ret = dimsUnknown ? null : new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false); - // TODO if stream is CachingStream, block parts might be evicted resulting in null pointer exceptions List blockCache = dimsUnknown ? new ArrayList<>() : null; IndexedMatrixValue tmp = null; try { int blen = getBlocksize(), lnnz = 0; - while( (tmp = stream.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS ) { + while( (tmp = stream.dequeue()) != LocalTaskQueue.NO_MORE_TASKS ) { // compute row/column block offsets final int row_offset = (int) (tmp.getIndexes().getRowIndex() - 1) * blen; final int col_offset = (int) (tmp.getIndexes().getColumnIndex() - 1) * blen; @@ -636,7 +636,7 @@ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatP MetaDataFormat iimd = (MetaDataFormat) _metaData; FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat()); MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop); - return writer.writeMatrixFromStream(fname, getStreamHandle().toLocalTaskQueue(), + return writer.writeMatrixFromStream(fname, getStreamHandle(), getNumRows(), getNumColumns(), ConfigurationManager.getBlocksize()); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java index d0111a34300..474db9a65fe 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java @@ -30,9 +30,9 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.data.TensorIndexes; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -210,7 +210,7 @@ protected void writeBlobFromRDDtoHDFS(RDDObject rdd, String fname, String ofmt) @Override - protected TensorBlock readBlobFromStream(LocalTaskQueue stream) throws IOException { + protected TensorBlock readBlobFromStream(OOCStream stream) throws IOException { // TODO Auto-generated method stub return null; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java index 50143cd0ad7..1849ad066b3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java @@ -103,6 +103,10 @@ public synchronized T dequeueTask() return t; } + public synchronized boolean hasNext() { + return !_data.isEmpty() || _closedInput; + } + /** * Synchronized (logical) insert of a NO_MORE_TASKS symbol at the end of the FIFO queue in order to * mark that no more tasks will be inserted into the queue. diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index feefe5f63d6..a2e64dd0bac 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.InstructionType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.AggregateTernaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CSVReblockOOCInstruction; @@ -33,6 +34,7 @@ import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.TernaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; @@ -64,10 +66,14 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return CSVReblockOOCInstruction.parseInstruction(str); case AggregateUnary: return AggregateUnaryOOCInstruction.parseInstruction(str); + case AggregateTernary: + return AggregateTernaryOOCInstruction.parseInstruction(str); case Unary: return UnaryOOCInstruction.parseInstruction(str); case Binary: return BinaryOOCInstruction.parseInstruction(str); + case Ternary: + return TernaryOOCInstruction.parseInstruction(str); case AggregateBinary: case MAPMM: return MatrixVectorBinaryOOCInstruction.parseInstruction(str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index b35ca55dab6..b8d84ca3898 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.fed.FEDInstructionUtils; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.ooc.stats.OOCEventLog; public abstract class CPInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(CPInstruction.class.getName()); @@ -52,6 +53,7 @@ public enum CPType { protected final CPType _cptype; protected final boolean _requiresLabelUpdate; + private long nanoTime; protected CPInstruction(CPType type, String opcode, String istr) { this(type, null, opcode, istr); @@ -88,6 +90,8 @@ public String getGraphString() { @Override public Instruction preprocessInstruction(ExecutionContext ec) { + if (DMLScript.OOC_LOG_EVENTS) + nanoTime = System.nanoTime(); //default preprocess behavior (e.g., debug state, lineage) Instruction tmp = super.preprocessInstruction(ec); @@ -118,6 +122,10 @@ public Instruction preprocessInstruction(ExecutionContext ec) { public void postprocessInstruction(ExecutionContext ec) { if (DMLScript.LINEAGE_DEBUGGER) ec.maintainLineageDebuggerInfo(this); + if (DMLScript.OOC_LOG_EVENTS) { + int callerId = OOCEventLog.registerCaller(getExtendedOpcode() + "_" + hashCode()); + OOCEventLog.onComputeEvent(callerId, nanoTime, System.nanoTime()); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java new file mode 100644 index 00000000000..40616ef8880 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java @@ -0,0 +1,231 @@ +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.functionobjects.KahanPlus; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.ReduceAll; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues; +import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +public class AggregateTernaryOOCInstruction extends ComputationOOCInstruction { + + private static final Log LOG = LogFactory.getLog(AggregateTernaryOOCInstruction.class.getName()); + + private AggregateTernaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String opcode, String istr) { + super(OOCInstruction.OOCType.AggregateTernary, op, in1, in2, in3, out, opcode, istr); + } + + public static AggregateTernaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + + if(opcode.equalsIgnoreCase(Opcodes.TAKPM.toString()) || opcode.equalsIgnoreCase(Opcodes.TACKPM.toString())) { + InstructionUtils.checkNumFields(parts , 4, 5); + + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[4]); + //int numThreads = parts.length == 6 ? Integer.parseInt(parts[5]) : 1; + + AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, 1); + return new AggregateTernaryOOCInstruction(op, in1, in2, in3, out, opcode, str); + } + throw new DMLRuntimeException("AggregateTernaryOOCInstruction.parseInstruction():: Unknown opcode " + opcode); + } + + @Override + public void processInstruction(ExecutionContext ec) { + MatrixObject m1 = ec.getMatrixObject(input1); + MatrixObject m2 = ec.getMatrixObject(input2); + MatrixObject m3 = input3.isLiteral() ? null : ec.getMatrixObject(input3); + + AggregateTernaryOperator abOp = (AggregateTernaryOperator) _optr; + validateInput(m1, m2, m3, abOp, input1.getName(), input2.getName(), input3.getName()); + + boolean isReduceAll = abOp.indexFn instanceof ReduceAll; + + OOCStream qIn1 = m1.getStreamHandle(); + OOCStream qIn2 = m2.getStreamHandle(); + OOCStream qIn3 = m3 == null ? null : m3.getStreamHandle(); + + if(isReduceAll) + processReduceAll(ec, abOp, qIn1, qIn2, qIn3); + else + processReduceRow(ec, abOp, qIn1, qIn2, qIn3, m1.getDataCharacteristics()); + } + + private void processReduceAll(ExecutionContext ec, AggregateTernaryOperator abOp, + OOCStream qIn1, OOCStream qIn2, OOCStream qIn3) { + + final int extra = abOp.aggOp.correction.getNumRemovedRowsColumns(); + final MatrixBlock agg = new MatrixBlock(1, 1 + extra, false); + final MatrixBlock corr = new MatrixBlock(1, 1 + extra, false); + + OOCStream qMid = createWritableStream(); + + List> streams = new ArrayList<>(); + streams.add(qIn1); + streams.add(qIn2); + if(qIn3 != null) + streams.add(qIn3); + + List> keyFns = new ArrayList<>(); + for(int i = 0; i < streams.size(); i++) + keyFns.add(IndexedMatrixValue::getIndexes); + + CompletableFuture fut = joinOOC(streams, qMid, blocks -> { + MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue(); + MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue(); + MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null; + MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false); + return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial); + }, keyFns); + + try { + IndexedMatrixValue imv; + while((imv = qMid.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock partial = (MatrixBlock) imv.getValue(); + OperationsOnMatrixValues.incrementalAggregation(agg, + abOp.aggOp.existsCorrection() ? corr : null, partial, abOp.aggOp, true); + } + fut.join(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + + agg.dropLastRowsOrColumns(abOp.aggOp.correction); + ec.setScalarOutput(output.getName(), new DoubleObject(agg.get(0, 0))); + } + + private void processReduceRow(ExecutionContext ec, AggregateTernaryOperator abOp, + OOCStream qIn1, OOCStream qIn2, OOCStream qIn3, + DataCharacteristics dc) { + + long emitThreshold = dc.getNumRowBlocks(); + if(emitThreshold <= 0) + throw new DMLRuntimeException("Unknown number of row blocks for out-of-core aggregate ternary."); + + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + OOCStream qMid = createWritableStream(); + + List> streams = new ArrayList<>(); + streams.add(qIn1); + streams.add(qIn2); + if(qIn3 != null) + streams.add(qIn3); + + List> keyFns = new ArrayList<>(); + for(int i = 0; i < streams.size(); i++) + keyFns.add(IndexedMatrixValue::getIndexes); + + CompletableFuture fut = joinOOC(streams, qMid, blocks -> { + MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue(); + MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue(); + MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null; + MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false); + return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial); + }, keyFns); + + final Map aggMap = new HashMap<>(); + final Map corrMap = new HashMap<>(); + final Map cntMap = new HashMap<>(); + + try { + IndexedMatrixValue imv; + while((imv = qMid.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixIndexes idx = imv.getIndexes(); + long colIx = idx.getColumnIndex(); + MatrixBlock partial = (MatrixBlock) imv.getValue(); + + MatrixBlock curAgg = aggMap.get(colIx); + MatrixBlock curCorr = corrMap.get(colIx); + if(curAgg == null) { + aggMap.put(colIx, partial); + curCorr = new MatrixBlock(partial.getNumRows(), partial.getNumColumns(), false); + corrMap.put(colIx, curCorr); + cntMap.put(colIx, 1); + } + else { + OperationsOnMatrixValues.incrementalAggregation(curAgg, abOp.aggOp.existsCorrection() ? curCorr : null, + partial, abOp.aggOp, true); + cntMap.put(colIx, cntMap.get(colIx) + 1); + } + + if(cntMap.get(colIx) >= emitThreshold) { + MatrixBlock finalAgg = aggMap.remove(colIx); + corrMap.remove(colIx); + cntMap.remove(colIx); + + finalAgg.dropLastRowsOrColumns(abOp.aggOp.correction); + MatrixIndexes outIdx = new MatrixIndexes(1, colIx); + qOut.enqueue(new IndexedMatrixValue(outIdx, finalAgg)); + } + } + fut.join(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + qOut.closeInput(); + } + } + + private static void validateInput(MatrixObject m1, MatrixObject m2, MatrixObject m3, AggregateTernaryOperator op, + String name1, String name2, String name3) { + + DataCharacteristics c1 = m1.getDataCharacteristics(); + DataCharacteristics c2 = m2.getDataCharacteristics(); + DataCharacteristics c3 = m3 == null ? c2 : m3.getDataCharacteristics(); + + long m1r = c1.getRows(); + long m2r = c2.getRows(); + long m3r = c3.getRows(); + long m1c = c1.getCols(); + long m2c = c2.getCols(); + long m3c = c3.getCols(); + + if(m1r <= 0 || m2r <= 0 || m3r <= 0 || m1c <= 0 || m2c <= 0 || m3c <= 0) + throw new DMLRuntimeException("Unknown dimensions for aggregate ternary inputs."); + + if(m1r != m2r || m1c != m2c || m2r != m3r || m2c != m3c){ + if(LOG.isTraceEnabled()){ + LOG.trace("matBlock1:" + name1 + " (" + m1r + "x" + m1c + ")"); + LOG.trace("matBlock2:" + name2 + " (" + m2r + "x" + m2c + ")"); + LOG.trace("matBlock3:" + name3 + " (" + m3r + "x" + m3c + ")"); + } + throw new DMLRuntimeException("Invalid dimensions for aggregate ternary (" + m1r + "x" + m1c + ", " + + m2r + "x" + m2c + ", " + m3r + "x" + m3c + ")."); + } + + if(!(op.aggOp.increOp.fn instanceof KahanPlus && op.binaryFn instanceof Multiply)) + throw new DMLRuntimeException("Unsupported operator for aggregate ternary operations."); + + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index 2a53c5400ae..54d87dd3f2d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -21,7 +21,6 @@ import org.apache.sysds.common.Types.CorrectionLocationType; import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; @@ -76,7 +75,7 @@ public void processInstruction( ExecutionContext ec ) { //setup operators and input queue AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator(); MatrixObject min = ec.getMatrixObject(input1); - OOCStream q = min.getStreamHandle(); + OOCStream qIn = min.getStreamHandle(); int blen = ConfigurationManager.getBlocksize(); if (aggun.isRowAggregate() || aggun.isColAggregate()) { @@ -87,89 +86,70 @@ public void processInstruction( ExecutionContext ec ) { HashMap corrs = new HashMap<>(); // correction blocks OOCStream qOut = createWritableStream(); + OOCStream qLocal = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + // per-block aggregation (parallel map) + mapOOC(qIn, qLocal, tmp -> { + MatrixIndexes midx = aggun.isRowAggregate() ? + new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : + new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); + + MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) + .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); + return new IndexedMatrixValue(midx, ltmp); + }); + + // global reduce submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { - long idx = aggun.isRowAggregate() ? - tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); - MatrixBlock ret = aggTracker.get(idx); - if(ret != null) { - MatrixBlock corr = corrs.get(idx); - - // aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) - .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); - OperationsOnMatrixValues.incrementalAggregation(ret, - _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - - if (!aggTracker.putAndIncrementCount(idx, ret)){ - corrs.replace(idx, corr); - continue; - } - } - else { - // first block for this idx - init aggregate and correction - // TODO avoid corr block for inplace incremental aggregation - int rows = tmp.getValue().getNumRows(); - int cols = tmp.getValue().getNumColumns(); - int extra = _aop.correction.getNumRemovedRowsColumns(); - ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); - MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); - - // aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations( - aggun, new MatrixBlock(), blen, tmp.getIndexes()); - OperationsOnMatrixValues.incrementalAggregation(ret, - _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - - if(emitThreshold > 1){ - aggTracker.putAndIncrementCount(idx, ret); - corrs.put(idx, corr); - continue; - } - } - - // all input blocks for this idx processed - emit aggregated block - ret.dropLastRowsOrColumns(_aop.correction); - MatrixIndexes midx = aggun.isRowAggregate() ? - new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : - new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); - IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); - - qOut.enqueue(tmpOut); - // drop intermediate states - aggTracker.remove(idx); - corrs.remove(idx); - } - qOut.closeInput(); + IndexedMatrixValue partial; + while ((partial = qLocal.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + long idx = aggun.isRowAggregate() ? partial.getIndexes().getRowIndex() : partial.getIndexes() + .getColumnIndex(); + + MatrixBlock ret = aggTracker.get(idx); + boolean ready; + if(ret != null) { + MatrixBlock corr = corrs.get(idx); + OperationsOnMatrixValues.incrementalAggregation(ret, + _aop.existsCorrection() ? corr : null, (MatrixBlock) partial.getValue(), _aop, + true); + ready = aggTracker.incrementCount(idx); } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + else { + ret = (MatrixBlock) partial.getValue(); + MatrixBlock corr = _aop.existsCorrection() ? new MatrixBlock(ret.getNumRows(), + ret.getNumColumns(), false) : null; + ready = aggTracker.putAndIncrementCount(idx, ret); + if(!ready && _aop.existsCorrection()) + corrs.put(idx, corr); + } + + if(ready) { + ret.dropLastRowsOrColumns(_aop.correction); + qOut.enqueue(new IndexedMatrixValue(partial.getIndexes(), ret)); + aggTracker.remove(idx); + corrs.remove(idx); } - }, q, qOut); + } + qOut.closeInput(); + }); } // full aggregation else { - IndexedMatrixValue tmp = null; - //read blocks and aggregate immediately into result + OOCStream qLocal = createWritableStream(); + + mapOOC(qIn, qLocal, tmp -> (MatrixBlock) ((MatrixBlock) tmp.getValue()) + .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes())); + + MatrixBlock ltmp; int extra = _aop.correction.getNumRemovedRowsColumns(); MatrixBlock ret = new MatrixBlock(1,1+extra,false); MatrixBlock corr = new MatrixBlock(1,1+extra,false); - try { - while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { - //block aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) - .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); - //accumulation into final result - OperationsOnMatrixValues.incrementalAggregation( - ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + while((ltmp = qLocal.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + OperationsOnMatrixValues.incrementalAggregation( + ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true); } //create scalar output diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index cdc23911516..9cda04e0c77 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -24,10 +24,14 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; /** * A wrapper around LocalTaskQueue to consume the source stream and reset to @@ -41,6 +45,7 @@ public class CachingStream implements OOCStreamable { // original live stream private final OOCStream _source; private final IntArrayList _consumptionCounts = new IntArrayList(); + private final IntArrayList _consumerConsumptionCounts = new IntArrayList(); // stream identifier private final long _streamId; @@ -48,7 +53,7 @@ public class CachingStream implements OOCStreamable { // block counter private int _numBlocks = 0; - private Runnable[] _subscribers; + private Consumer>[] _subscribers; // state flags private boolean _cacheInProgress = true; // caching in progress, in the first pass. @@ -58,7 +63,7 @@ public class CachingStream implements OOCStreamable { private boolean deletable = false; private int maxConsumptionCount = 0; - private int cachePins = 0; + private String _watchdogId = null; public CachingStream(OOCStream source) { this(source, _streamSeq.getNextID()); @@ -67,17 +72,27 @@ public CachingStream(OOCStream source) { public CachingStream(OOCStream source, long streamId) { _source = source; _streamId = streamId; + if (OOCWatchdog.WATCH) { + _watchdogId = "CS-" + hashCode(); + // Capture a short context to help identify origin + OOCWatchdog.registerOpen(_watchdogId, "CachingStream@" + hashCode(), getCtxMsg(), this); + } source.setSubscriber(tmp -> { - try { + try (tmp) { final IndexedMatrixValue task = tmp.get(); int blk; - Runnable[] mSubscribers; + Consumer>[] mSubscribers; + OOCStream.QueueCallback mCallback = null; synchronized (this) { + mSubscribers = _subscribers; if(task != LocalTaskQueue.NO_MORE_TASKS) { if (!_cacheInProgress) throw new DMLRuntimeException("Stream is closed"); - OOCEvictionManager.put(_streamId, _numBlocks, task); + if (mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.put(_streamId, _numBlocks, task); + else + mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); if (_index != null) _index.put(task.getIndexes(), _numBlocks); blk = _numBlocks; @@ -87,20 +102,32 @@ public CachingStream(OOCStream source, long streamId) { } else { _cacheInProgress = false; // caching is complete + if (OOCWatchdog.WATCH) + OOCWatchdog.registerClose(_watchdogId); notifyAll(); blk = -1; } - - mSubscribers = _subscribers; } - if(mSubscribers != null) { - for(Runnable mSubscriber : mSubscribers) - mSubscriber.run(); - - if (blk == -1) { - synchronized (this) { - _subscribers = null; + if(mSubscribers != null && mSubscribers.length > 0) { + final OOCStream.QueueCallback finalCallback = mCallback; + try(finalCallback) { + if(blk != -1) { + for(int i = 0; i < mSubscribers.length; i++) { + OOCStream.QueueCallback localCallback = finalCallback.keepOpen(); + try(localCallback) { + mSubscribers[i].accept(localCallback); + } + if(onConsumed(blk, i)) + mSubscribers[i].accept(OOCStream.eos(_failure)); + } + } + else { + OOCStream.QueueCallback cb = OOCStream.eos(_failure); + for(int i = 0; i < mSubscribers.length; i++) { + if(onNoMoreTasks(i)) + mSubscribers[i].accept(cb); + } } } } @@ -111,11 +138,12 @@ public CachingStream(OOCStream source, long streamId) { notifyAll(); } - Runnable[] mSubscribers = _subscribers; + Consumer>[] mSubscribers = _subscribers; + OOCStream.QueueCallback err = OOCStream.eos( _failure); if(mSubscribers != null) { - for(Runnable mSubscriber : mSubscribers) { + for(Consumer> mSubscriber : mSubscribers) { try { - mSubscriber.run(); + mSubscriber.accept(err); } catch (Exception ignored) { } } @@ -124,10 +152,28 @@ public CachingStream(OOCStream source, long streamId) { }); } + private String getCtxMsg() { + StackTraceElement[] st = new Exception().getStackTrace(); + // Skip the first few frames (constructor, createWritableStream, etc.) + StringBuilder sb = new StringBuilder(); + int limit = Math.min(st.length, 7); + for(int i = 2; i < limit; i++) { + sb.append(st[i].getClassName()).append(".").append(st[i].getMethodName()).append(":") + .append(st[i].getLineNumber()); + if(i < limit - 1) + sb.append(" <- "); + } + return sb.toString(); + } + public synchronized void scheduleDeletion() { - deletable = true; + if (deletable) + return; // Deletion already scheduled + if (_cacheInProgress && maxConsumptionCount == 0) throw new DMLRuntimeException("Cannot have a caching stream with no listeners"); + + deletable = true; for (int i = 0; i < _consumptionCounts.size(); i++) { tryDeleteBlock(i); } @@ -138,31 +184,47 @@ public String toString() { } private synchronized void tryDeleteBlock(int i) { - if (cachePins > 0) - return; // Block deletion is prevented + int cnt = _consumptionCounts.getInt(i); + if (cnt > maxConsumptionCount) + throw new DMLRuntimeException("Cannot have more than " + maxConsumptionCount + " consumptions."); + if (cnt == maxConsumptionCount) + OOCCacheManager.forget(_streamId, i); + } - int count = _consumptionCounts.getInt(i); - if (count > maxConsumptionCount) + private synchronized boolean onConsumed(int blockIdx, int consumerIdx) { + int newCount = _consumptionCounts.getInt(blockIdx)+1; + if (newCount > maxConsumptionCount) throw new DMLRuntimeException("Cannot have more than " + maxConsumptionCount + " consumptions."); - if (count == maxConsumptionCount) - OOCEvictionManager.forget(_streamId, i); + _consumptionCounts.set(blockIdx, newCount); + int newConsumerCount = _consumerConsumptionCounts.getInt(consumerIdx)+1; + _consumerConsumptionCounts.set(consumerIdx, newConsumerCount); + + if (deletable) + tryDeleteBlock(blockIdx); + + return !_cacheInProgress && newConsumerCount == _numBlocks + 1; + } + + private synchronized boolean onNoMoreTasks(int consumerIdx) { + int newConsumerCount = _consumerConsumptionCounts.getInt(consumerIdx)+1; + _consumerConsumptionCounts.set(consumerIdx, newConsumerCount); + return !_cacheInProgress && newConsumerCount == _numBlocks + 1; } - public synchronized IndexedMatrixValue get(int idx) throws InterruptedException { + public synchronized OOCStream.QueueCallback get(int idx) throws InterruptedException, + ExecutionException { while (true) { if (_failure != null) throw _failure; else if (idx < _numBlocks) { - IndexedMatrixValue out = OOCEvictionManager.get(_streamId, idx); + OOCStream.QueueCallback out = OOCCacheManager.requestBlock(_streamId, idx).get(); if (_index != null) // Ensure index is up to date - _index.putIfAbsent(out.getIndexes(), idx); + _index.putIfAbsent(out.get().getIndexes(), idx); int newCount = _consumptionCounts.getInt(idx)+1; - if (newCount > maxConsumptionCount) throw new DMLRuntimeException("Consumer overflow! Expected: " + maxConsumptionCount); - _consumptionCounts.set(idx, newCount); if (deletable) @@ -170,7 +232,7 @@ else if (idx < _numBlocks) { return out; } else if (!_cacheInProgress) - return (IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS; + return new OOCStream.SimpleQueueCallback<>(null, null); wait(); } @@ -180,27 +242,78 @@ public synchronized int findCachedIndex(MatrixIndexes idx) { return _index.get(idx); } - public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) { + public synchronized BlockKey peekCachedBlockKey(MatrixIndexes idx) { + return new BlockKey(_streamId, _index.get(idx)); + } + + public synchronized OOCStream.QueueCallback findCached(MatrixIndexes idx) { int mIdx = _index.get(idx); int newCount = _consumptionCounts.getInt(mIdx)+1; if (newCount > maxConsumptionCount) throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + _consumptionCounts.set(mIdx, newCount); - IndexedMatrixValue imv = OOCEvictionManager.get(_streamId, mIdx); + try { + return OOCCacheManager.requestBlock(_streamId, mIdx).get(); + } catch (InterruptedException | ExecutionException e) { + return new OOCStream.SimpleQueueCallback<>(null, new DMLRuntimeException(e)); + } finally { + if (deletable) + tryDeleteBlock(mIdx); + } + } + + public void findCachedAsync(MatrixIndexes idx, Consumer> callback) { + int mIdx; + synchronized(this) { + mIdx = _index.get(idx); + int newCount = _consumptionCounts.getInt(mIdx)+1; + if (newCount > maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + } + OOCCacheManager.requestBlock(_streamId, mIdx).whenComplete((cb, r) -> { + try (cb) { + synchronized(CachingStream.this) { + int newCount = _consumptionCounts.getInt(mIdx) + 1; + if(newCount > maxConsumptionCount) { + _failure = new DMLRuntimeException( + "Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + cb.fail(_failure); + } + else + _consumptionCounts.set(mIdx, newCount); + } - if (deletable) - tryDeleteBlock(mIdx); + callback.accept(cb); + } + }); + } - return imv; + /** + * Finds a cached item asynchronously without counting it as a consumption. + */ + public void peekCachedAsync(MatrixIndexes idx, Consumer> callback) { + int mIdx; + synchronized(this) { + mIdx = _index.get(idx); + } + OOCCacheManager.requestBlock(_streamId, mIdx).whenComplete((cb, r) -> callback.accept(cb)); } /** * Finds a cached item without counting it as a consumption. */ - public synchronized IndexedMatrixValue peekCached(MatrixIndexes idx) { - int mIdx = _index.get(idx); - return OOCEvictionManager.get(_streamId, mIdx); + public OOCStream.QueueCallback peekCached(MatrixIndexes idx) { + int mIdx; + synchronized(this) { + mIdx = _index.get(idx); + } + try { + return OOCCacheManager.requestBlock(_streamId, mIdx).get(); + } catch (InterruptedException | ExecutionException e) { + return new OOCStream.SimpleQueueCallback<>(null, new DMLRuntimeException(e)); + } } public synchronized void activateIndexing() { @@ -223,20 +336,25 @@ public boolean isProcessed() { return false; } - public void setSubscriber(Runnable subscriber, boolean incrConsumers) { + public void setSubscriber(Consumer> subscriber, boolean incrConsumers) { if (deletable) throw new DMLRuntimeException("Cannot register a new subscriber on " + this + " because has been flagged for deletion"); + if (_failure != null) + throw _failure; int mNumBlocks; boolean cacheInProgress; + int consumerIdx; synchronized (this) { mNumBlocks = _numBlocks; cacheInProgress = _cacheInProgress; + consumerIdx = _consumerConsumptionCounts.size(); + _consumerConsumptionCounts.add(0); if (incrConsumers) maxConsumptionCount++; if (cacheInProgress) { int newLen = _subscribers == null ? 1 : _subscribers.length + 1; - Runnable[] newSubscribers = new Runnable[newLen]; + Consumer>[] newSubscribers = new Consumer[newLen]; if(newLen > 1) System.arraycopy(_subscribers, 0, newSubscribers, 0, newLen - 1); @@ -246,11 +364,24 @@ public void setSubscriber(Runnable subscriber, boolean incrConsumers) { } } - for (int i = 0; i < mNumBlocks; i++) - subscriber.run(); + for (int i = 0; i < mNumBlocks; i++) { + final int idx = i; + OOCCacheManager.requestBlock(_streamId, i).whenComplete((cb, r) -> { + try (cb) { + synchronized(CachingStream.this) { + if(_index != null) + _index.put(cb.get().getIndexes(), idx); + } + subscriber.accept(cb); + + if (onConsumed(idx, consumerIdx)) + subscriber.accept(OOCStream.eos(_failure)); // NO_MORE_TASKS + } + }); + } - if (!cacheInProgress) - subscriber.run(); // To fetch the NO_MORE_TASK element + if (!cacheInProgress && onNoMoreTasks(consumerIdx)) + subscriber.accept(OOCStream.eos(_failure)); // NO_MORE_TASKS } /** @@ -258,6 +389,9 @@ public void setSubscriber(Runnable subscriber, boolean incrConsumers) { * Only use if certain blocks are accessed more than once. */ public synchronized void incrSubscriberCount(int count) { + if (deletable) + throw new IllegalStateException("Cannot increment the subscriber count if flagged for deletion"); + maxConsumptionCount += count; } @@ -265,28 +399,10 @@ public synchronized void incrSubscriberCount(int count) { * Artificially increase the processing count of a block. */ public synchronized void incrProcessingCount(int i, int count) { - _consumptionCounts.set(i, _consumptionCounts.getInt(i)+count); + int cnt = _consumptionCounts.getInt(i)+count; + _consumptionCounts.set(i, cnt); if (deletable) tryDeleteBlock(i); } - - /** - * Force pins blocks in the cache to not be subject to block deletion. - */ - public synchronized void pinStream() { - cachePins++; - } - - /** - * Unpins the stream, allowing blocks to be deleted from cache. - */ - public synchronized void unpinStream() { - cachePins--; - - if (cachePins == 0) { - for (int i = 0; i < _consumptionCounts.size(); i++) - tryDeleteBlock(i); - } - } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java index 33c6675051e..d506825e140 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java @@ -35,7 +35,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; public class MatrixIndexingOOCInstruction extends IndexingOOCInstruction { @@ -89,8 +88,6 @@ public void processInstruction(ExecutionContext ec) { throw new DMLRuntimeException("Desired block not found"); } - final AtomicReference> futureRef = new AtomicReference<>(); - if(ix.rowStart % blocksize == 0 && ix.colStart % blocksize == 0) { // Aligned case: interior blocks can be forwarded directly, borders may require slicing final int outBlockRows = (int) Math.ceil((double) (ix.rowSpan() + 1) / blocksize); @@ -137,7 +134,6 @@ public void processInstruction(ExecutionContext ec) { return blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && blockCol <= lastBlockCol; }, qOut::closeInput); - futureRef.set(future); return; } @@ -185,31 +181,35 @@ public void processInstruction(ExecutionContext ec) { if(mIdx == null) continue; - IndexedMatrixValue mv = cachedStream.peekCached(mIdx); - MatrixBlock srcBlock = (MatrixBlock) mv.getValue(); + try (OOCStream.QueueCallback cb = cachedStream.peekCached(mIdx)) { + IndexedMatrixValue mv = cb.get(); + MatrixBlock srcBlock = (MatrixBlock) mv.getValue(); + + if(target == null) + target = new MatrixBlock(nRows, nCols, srcBlock.isInSparseFormat()); - if(target == null) - target = new MatrixBlock(nRows, nCols, srcBlock.isInSparseFormat()); + long srcBlockRowStart = (mIdx.getRowIndex() - 1) * blocksize; + long srcBlockColStart = (mIdx.getColumnIndex() - 1) * blocksize; + long sliceRowStartGlobal = Math.max(targetRowStartGlobal, srcBlockRowStart); + long sliceRowEndGlobal = Math.min(targetRowEndGlobal, + srcBlockRowStart + srcBlock.getNumRows() - 1); + long sliceColStartGlobal = Math.max(targetColStartGlobal, srcBlockColStart); + long sliceColEndGlobal = Math.min(targetColEndGlobal, + srcBlockColStart + srcBlock.getNumColumns() - 1); - long srcBlockRowStart = (mIdx.getRowIndex() - 1) * blocksize; - long srcBlockColStart = (mIdx.getColumnIndex() - 1) * blocksize; - long sliceRowStartGlobal = Math.max(targetRowStartGlobal, srcBlockRowStart); - long sliceRowEndGlobal = Math.min(targetRowEndGlobal, - srcBlockRowStart + srcBlock.getNumRows() - 1); - long sliceColStartGlobal = Math.max(targetColStartGlobal, srcBlockColStart); - long sliceColEndGlobal = Math.min(targetColEndGlobal, - srcBlockColStart + srcBlock.getNumColumns() - 1); + int sliceRowStart = (int) (sliceRowStartGlobal - srcBlockRowStart); + int sliceRowEnd = (int) (sliceRowEndGlobal - srcBlockRowStart); + int sliceColStart = (int) (sliceColStartGlobal - srcBlockColStart); + int sliceColEnd = (int) (sliceColEndGlobal - srcBlockColStart); - int sliceRowStart = (int) (sliceRowStartGlobal - srcBlockRowStart); - int sliceRowEnd = (int) (sliceRowEndGlobal - srcBlockRowStart); - int sliceColStart = (int) (sliceColStartGlobal - srcBlockColStart); - int sliceColEnd = (int) (sliceColEndGlobal - srcBlockColStart); + int targetRowOffset = (int) (sliceRowStartGlobal - targetRowStartGlobal); + int targetColOffset = (int) (sliceColStartGlobal - targetColStartGlobal); - int targetRowOffset = (int) (sliceRowStartGlobal - targetRowStartGlobal); - int targetColOffset = (int) (sliceColStartGlobal - targetColStartGlobal); + MatrixBlock sliced = srcBlock.slice(sliceRowStart, sliceRowEnd, sliceColStart, + sliceColEnd); + sliced.putInto(target, targetRowOffset, targetColOffset, true); + } - MatrixBlock sliced = srcBlock.slice(sliceRowStart, sliceRowEnd, sliceColStart, sliceColEnd); - sliced.putInto(target, targetRowOffset, targetColOffset, true); final int maxConsumptions = aligner.getNumConsumptions(mIdx); Integer con = consumptionCounts.compute(mIdx, (k, v) -> { @@ -248,7 +248,6 @@ public void processInstruction(ExecutionContext ec) { if (!hasIntermediateStream) cachedStream.incrProcessingCount(cachedStream.findCachedIndex(tmp.getIndexes()), 1); }); - futureRef.set(future); if (hasIntermediateStream) cachedStream.scheduleDeletion(); // We can immediately delete blocks after consumption diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index 38586428e1e..b3eac0b6478 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -23,10 +23,8 @@ import org.apache.sysds.common.Opcodes; import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -87,8 +85,46 @@ public void processInstruction( ExecutionContext ec ) { OOCStream qOut = createWritableStream(); BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); ec.getMatrixObject(output).setStreamHandle(qOut); + final Object lock = new Object(); + + submitOOCTasks(qIn, cb -> { + try(cb) { + IndexedMatrixValue tmp = cb.get(); + MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); + long rowIndex = tmp.getIndexes().getRowIndex(); + long colIndex = tmp.getIndexes().getColumnIndex(); + MatrixBlock vectorSlice = partitionedVector.get(colIndex); + + // Now, call the operation with the correct, specific operator. + MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations(matrixBlock, vectorSlice, + new MatrixBlock(), (AggregateBinaryOperator) _optr); + + // for single column block, no aggregation neeeded + if(emitThreshold == 1) { + qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + } + else { + // aggregation + synchronized(lock) { + MatrixBlock currAgg = aggTracker.get(rowIndex); + if(currAgg == null) { + aggTracker.putAndIncrementCount(rowIndex, partialResult); + } + else { + currAgg = currAgg.binaryOperations(plus, partialResult); + if(aggTracker.putAndIncrementCount(rowIndex, currAgg)) { + // early block output: emit aggregated block + MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); + qOut.enqueue(new IndexedMatrixValue(idx, currAgg)); + aggTracker.remove(rowIndex); + } + } + } + } + } + }, qOut::closeInput); - submitOOCTask(() -> { + /*submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { @@ -129,6 +165,6 @@ public void processInstruction( ExecutionContext ec ) { finally { qOut.closeInput(); } - }, qIn, qOut); + }, qIn, qOut);*/ } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java deleted file mode 100644 index dace1ab9e53..00000000000 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java +++ /dev/null @@ -1,491 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.instructions.ooc; - -import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; -import org.apache.sysds.runtime.io.IOUtilFunctions; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; -import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; -import org.apache.sysds.runtime.util.LocalFileUtils; - -import java.io.DataInputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.RandomAccessFile; -import java.nio.channels.Channels; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.ReentrantLock; - -/** - * Eviction Manager for the Out-Of-Core stream cache - * This is the base implementation for LRU, FIFO - * - * Design choice 1: Pure JVM-memory cache - * What: Store MatrixBlock objects in a synchronized in-memory cache - * (Map + Deque for LRU/FIFO). Spill to disk by serializing MatrixBlock - * only when evicting. - * Pros: Simple to implement; no off-heap management; easy to debug; - * no serialization race since you serialize only when evicting; - * fast cache hits (direct object access). - * Cons: Heap usage counted roughly via serialized-size estimate — actual - * JVM object overhead not accounted; risk of GC pressure and OOM if - * estimates are off or if many small objects cause fragmentation; - * eviction may be more expensive (serialize on eviction). - *

- * Design choice 2: - *

- * This manager runtime memory management by caching serialized - * ByteBuffers and spilling them to disk when needed. - *

- * * core function: Caches ByteBuffers (off-heap/direct) and - * spills them to disk - * * Eviction: Evicts a ByteBuffer by writing its contents to a file - * * Granularity: Evicts one IndexedMatrixValue block at a time - * * Data replay: get() will always return the data either from memory or - * by falling back to the disk - * * Memory: Since the datablocks are off-heap (in ByteBuffer) or disk, - * there won't be OOM. - * - * Pros: Avoids heap OOM by keeping large data off-heap; predictable - * memory usage; good for very large blocks. - * Cons: More complex synchronization; need robust off-heap allocator/free; - * must ensure serialization finishes before adding to queue or make evict - * wait on serialization; careful with native memory leaks. - */ -public class OOCEvictionManager { - - // Configuration: OOC buffer limit as percentage of heap - private static final double OOC_BUFFER_PERCENTAGE = 0.15; // 15% of heap - - private static final double PARTITION_EVICTION_SIZE = 64 * 1024 * 1024; // 64 MB - - // Memory limit for ByteBuffers - private static long _limit; - private static final AtomicLong _size = new AtomicLong(0); - - // Cache structures: map key -> MatrixBlock and eviction deque (head=oldest block) - private static LinkedHashMap _cache = new LinkedHashMap<>(); - - // Spill related structures - private static ConcurrentHashMap _spillLocations = new ConcurrentHashMap<>(); - private static ConcurrentHashMap _partitions = new ConcurrentHashMap<>(); - private static final AtomicInteger _partitionCounter = new AtomicInteger(0); - - // Track which partitions belong to which stream (for cleanup) - private static final ConcurrentHashMap> _streamPartitions = new ConcurrentHashMap<>(); - - - // Cache level lock - private static final Object _cacheLock = new Object(); - - // Spill directory for evicted blocks - private static String _spillDir; - - public enum RPolicy { - FIFO, LRU - } - private static RPolicy _policy = RPolicy.FIFO; - - private enum BlockState { - HOT, // In-memory - EVICTING, // Being written to disk (transition state) - COLD // On disk - } - - private static class spillLocation { - // structure of spillLocation: file, offset - final int partitionId; - final long offset; - - spillLocation(int partitionId, long offset) { - - this.partitionId = partitionId; - this.offset = offset; - } - } - - private static class partitionFile { - final String filePath; - //final long streamId; - - - private partitionFile(String filePath, long streamId) { - this.filePath = filePath; - //this.streamId = streamId; - } - } - - // Per-block state container with own lock. - private static class BlockEntry { - private final ReentrantLock lock = new ReentrantLock(); - private final Condition stateUpdate = lock.newCondition(); - - private BlockState state = BlockState.HOT; - private IndexedMatrixValue value; - private final long streamId; - //private final int blockId; - private final long size; - - BlockEntry(IndexedMatrixValue value, long streamId, int blockId, long size) { - this.value = value; - this.streamId = streamId; - //this.blockId = blockId; - this.size = size; - } - } - - static { - _limit = (long)(Runtime.getRuntime().maxMemory() * OOC_BUFFER_PERCENTAGE); // e.g., 20% of heap - _size.set(0); - _spillDir = LocalFileUtils.getUniqueWorkingDir("ooc_stream"); - LocalFileUtils.createLocalFileIfNotExist(_spillDir); - } - - public static void reset() { - TeeOOCInstruction.reset(); - if (!_cache.isEmpty()) { - System.err.println("There are dangling elements in the OOC Eviction cache: " + _cache.size()); - } - _size.set(0); - _cache.clear(); - _spillLocations.clear(); - _partitions.clear(); - _partitionCounter.set(0); - _streamPartitions.clear(); - } - - /** - * Removes a block from the cache without setting its data to null. - */ - public static void forget(long streamId, int blockId) { - BlockEntry e; - synchronized (_cacheLock) { - e = _cache.remove(streamId + "_" + blockId); - } - - if (e != null) { - e.lock.lock(); - try { - if (e.state == BlockState.HOT) - _size.addAndGet(-e.size); - } finally { - e.lock.unlock(); - } - System.out.println("Removed block " + streamId + "_" + blockId + " from cache (idx: " + (e.value != null ? e.value.getIndexes() : "?") + ")"); - } - } - - /** - * Store a block in the OOC cache (serialize once) - */ - public static void put(long streamId, int blockId, IndexedMatrixValue value) { - MatrixBlock mb = (MatrixBlock) value.getValue(); - long size = estimateSerializedSize(mb); - String key = streamId + "_" + blockId; - - BlockEntry newEntry = new BlockEntry(value, streamId, blockId, size); - BlockEntry old; - synchronized (_cacheLock) { - old = _cache.put(key, newEntry); // remove old value, put new value - } - - // Handle replacement with a new lock - if (old != null) { - old.lock.lock(); - try { - if (old.state == BlockState.HOT) { - _size.addAndGet(-old.size); // read and update size in atomic operation - } - } finally { - old.lock.unlock(); - } - } - - _size.addAndGet(size); - //make room if needed - evict(); - } - - /** - * Get a block from the OOC cache (deserialize on read) - */ - public static IndexedMatrixValue get(long streamId, int blockId) { - String key = streamId + "_" + blockId; - BlockEntry imv; - - synchronized (_cacheLock) { - imv = _cache.get(key); - System.err.println( "value of imv: " + imv); - if (imv != null && _policy == RPolicy.LRU) { - _cache.remove(key); - _cache.put(key, imv); //add last semantic - } - } - - if (imv == null) { - throw new DMLRuntimeException("Block not found in cache: " + key); - } - // use lock and check state - imv.lock.lock(); - try { - // 1. wait for eviction to complete - while (imv.state == BlockState.EVICTING) { - try { - imv.stateUpdate.await(); - } catch (InterruptedException e) { - - throw new DMLRuntimeException(e); - } - } - - // 2. check if the block is in HOT - if (imv.state == BlockState.HOT) { - return imv.value; - } - - } finally { - imv.lock.unlock(); - } - - // restore, since the block is COLD - return loadFromDisk(streamId, blockId); - } - - /** - * Evict ByteBuffers to disk - */ - private static void evict() { - long currentSize = _size.get(); - if (_size.get() <= _limit) { // only trigger eviction, if filled. - System.err.println("Evicting condition: " + _size.get() + "/" + _limit); - return; - } - - // --- 1. COLLECTION PHASE --- - long totalFreedSize = 0; - // list of eviction candidates - List> candidates = new ArrayList<>(); - long targetFreedSize = Math.max(currentSize - _limit, (long) PARTITION_EVICTION_SIZE); - - synchronized (_cacheLock) { - - //move iterator to first entry - Iterator> iter = _cache.entrySet().iterator(); - - while (iter.hasNext() && totalFreedSize < targetFreedSize) { - Map.Entry e = iter.next(); - BlockEntry entry = e.getValue(); - - if (entry.lock.tryLock()) { - try { - if (entry.state == BlockState.HOT) { - entry.state = BlockState.EVICTING; - candidates.add(e); - totalFreedSize += entry.size; - - //remove current iterator entry -// iter.remove(); - } - } finally { - entry.lock.unlock(); - } - } // if tryLock() fails, it means a thread is loading/reading this block. we shall skip it. - } - - } - - if (candidates.isEmpty()) { return; } // no eviction candidates found - - // --- 2. WRITE PHASE --- - // write to partition file - // 1. generate a new ID for the present "partition" (file) - int partitionId = _partitionCounter.getAndIncrement(); - - // Spill to disk - String filename = _spillDir + "/stream_batch_part_" + partitionId; - File spillDirFile = new File(_spillDir); - if (!spillDirFile.exists()) { - spillDirFile.mkdirs(); - } - - // 2. create the partition file metadata - partitionFile partFile = new partitionFile(filename, 0); - _partitions.put(partitionId, partFile); - - FileOutputStream fos = null; - FastBufferedDataOutputStream dos = null; - try { - fos = new FileOutputStream(filename); - dos = new FastBufferedDataOutputStream(fos); - - - // loop over the list of blocks we collected - for (Map.Entry tmp : candidates) { - BlockEntry entry = tmp.getValue(); - - // 1. get the current file position. this is the offset. - // flush any buffered data to the file - dos.flush(); - long offset = fos.getChannel().position(); - - // 2. write indexes and block - entry.value.getIndexes().write(dos); // write Indexes - entry.value.getValue().write(dos); - System.out.println("written, partition id: " + _partitions.get(partitionId) + ", offset: " + offset); - - // 3. create the spillLocation - spillLocation sloc = new spillLocation(partitionId, offset); - _spillLocations.put(tmp.getKey(), sloc); - - // 4. track file for cleanup - _streamPartitions - .computeIfAbsent(entry.streamId, k -> ConcurrentHashMap.newKeySet()) - .add(filename); - - // 5. change state to COLD - entry.lock.lock(); - try { - entry.value = null; // only release ref, don't mutate object - entry.state = BlockState.COLD; // set state to cold, since writing to disk - entry.stateUpdate.signalAll(); // wake up any "get()" threads - } finally { - entry.lock.unlock(); - } - - synchronized (_cacheLock) { - _cache.put(tmp.getKey(), entry); // add last semantic - } - } - } - catch(IOException ex) { - throw new DMLRuntimeException(ex); - } finally { - IOUtilFunctions.closeSilently(dos); - IOUtilFunctions.closeSilently(fos); - } - - // --- 3. ACCOUNTING PHASE --- - if (totalFreedSize > 0) { // note the size, without evicted blocks - _size.addAndGet(-totalFreedSize); - } - } - - /** - * Load block from spill file - */ - private static IndexedMatrixValue loadFromDisk(long streamId, int blockId) { - String key = streamId + "_" + blockId; - - // 1. find the blocks address (spill location) - spillLocation sloc = _spillLocations.get(key); - if (sloc == null) { - throw new DMLRuntimeException("Failed to load spill location for: " + key); - } - - partitionFile partFile = _partitions.get(sloc.partitionId); - if (partFile == null) { - throw new DMLRuntimeException("Failed to load partition for: " + sloc.partitionId); - } - - String filename = partFile.filePath; - - // Create an empty object to read data into. - MatrixIndexes ix = new MatrixIndexes(); - MatrixBlock mb = new MatrixBlock(); - - try (RandomAccessFile raf = new RandomAccessFile(filename, "r")) { - raf.seek(sloc.offset); - - try { - DataInputStream dis = new DataInputStream(Channels.newInputStream(raf.getChannel())); - ix.readFields(dis); // 1. Read Indexes - mb.readFields(dis); // 2. Read Block - } catch (IOException ex) { - throw new DMLRuntimeException("Failed to load block " + key + " from " + filename, ex); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - // Read from disk and put into original indexed matrix value - BlockEntry imvCacheEntry; - synchronized (_cacheLock) { - imvCacheEntry = _cache.get(key); - } - - // 2. Check if it's null (the bug you helped fix before) - if(imvCacheEntry == null) { - throw new DMLRuntimeException("Block entry " + key + " was not in cache during load."); - } - - imvCacheEntry.lock.lock(); - try { - if (imvCacheEntry.state == BlockState.COLD) { - imvCacheEntry.value = new IndexedMatrixValue(ix, mb); - imvCacheEntry.state = BlockState.HOT; - _size.addAndGet(imvCacheEntry.size); - - synchronized (_cacheLock) { - _cache.remove(key); - _cache.put(key, imvCacheEntry); - } - } - -// evict(); // when we add the block, we shall check for limit. - } finally { - imvCacheEntry.lock.unlock(); - } - - return imvCacheEntry.value; - } - - private static long estimateSerializedSize(MatrixBlock mb) { - return mb.getExactSerializedSize(); - } - - @SuppressWarnings("unused") - private static Map.Entry removeFirstFromCache() { - synchronized (_cacheLock) { - - if (_cache.isEmpty()) { - return null; - } - //move iterator to first entry - Iterator> iter = _cache.entrySet().iterator(); - Map.Entry entry = iter.next(); - - //remove current iterator entry - iter.remove(); - - return entry; - } - } -} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 1b6862361ed..a9d50cbd489 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,8 +30,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.ooc.stats.OOCEventLog; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.OOCJoin; +import org.apache.sysds.utils.Statistics; +import scala.Tuple4; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +50,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.LongAdder; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -55,9 +60,10 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); private static final AtomicInteger nextStreamId = new AtomicInteger(0); + private long nanoTime; public enum OOCType { - Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, + Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand } @@ -66,6 +72,8 @@ public enum OOCType { protected Set> _inQueues; protected Set> _outQueues; private boolean _failed; + private LongAdder _localStatisticsAdder; + public final int _callerId; protected OOCInstruction(OOCInstruction.OOCType type, String opcode, String istr) { this(type, null, opcode, istr); @@ -79,6 +87,10 @@ protected OOCInstruction(OOCInstruction.OOCType type, Operator op, String opcode _requiresLabelUpdate = super.requiresLabelUpdate(); _failed = false; + + if (DMLScript.STATISTICS) + _localStatisticsAdder = new LongAdder(); + _callerId = DMLScript.OOC_LOG_EVENTS ? OOCEventLog.registerCaller(getExtendedOpcode() + "_" + hashCode()) : 0; } @Override @@ -102,6 +114,8 @@ public String getGraphString() { @Override public Instruction preprocessInstruction(ExecutionContext ec) { + if (DMLScript.OOC_LOG_EVENTS) + nanoTime = System.nanoTime(); // TODO return super.preprocessInstruction(ec); } @@ -113,6 +127,8 @@ public Instruction preprocessInstruction(ExecutionContext ec) { public void postprocessInstruction(ExecutionContext ec) { if(DMLScript.LINEAGE_DEBUGGER) ec.maintainLineageDebuggerInfo(this); + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, nanoTime, System.nanoTime()); } protected void addInStream(OOCStream... queue) { @@ -140,7 +156,7 @@ protected CompletableFuture filterOOC(OOCStream qIn, Consumer pr if (_inQueues == null || _outQueues == null) throw new NotImplementedException("filterOOC requires manual specification of all input and output streams for error propagation"); - return submitOOCTasks(qIn, processor, finalizer, predicate, onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp) : null); + return submitOOCTasks(qIn, c -> processor.accept(c.get()), finalizer, p -> predicate.apply(p.get()), onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp.get()) : null); } protected CompletableFuture mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { @@ -148,8 +164,8 @@ protected CompletableFuture mapOOC(OOCStream qIn, OOCStream q addOutStream(qOut); return submitOOCTasks(qIn, tmp -> { - try { - R r = mapper.apply(tmp); + try (tmp) { + R r = mapper.apply(tmp.get()); qOut.enqueue(r); } catch (Exception e) { throw e instanceof DMLRuntimeException ? (DMLRuntimeException) e : new DMLRuntimeException(e); @@ -177,66 +193,108 @@ protected CompletableFuture broadcastJoinOOC(OOCStream> availableLeftInput = new ConcurrentHashMap<>(); Map availableBroadcastInput = new ConcurrentHashMap<>(); - CompletableFuture future = submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { - P key = on.apply(tmp); - - if (i == 0) { // qIn stream - BroadcastedElement b = availableBroadcastInput.get(key); - - if (b == null) { - // Matching broadcast element is not available -> cache element - if (explicitLeftCaching) - leftCache.getWriteStream().enqueue(tmp); - - availableLeftInput.compute(key, (k, v) -> { - if (v == null) - v = new ArrayList<>(); - v.add(tmp.getIndexes()); - return v; - }); - } else { - if (!explicitLeftCaching) - leftCache.incrProcessingCount(leftCache.findCachedIndex(tmp.getIndexes()), 1); // Correct for incremented subscriber count to allow block deletion + OOCStream, OOCStream.QueueCallback, BroadcastedElement>> broadcastingQueue = createWritableStream(); + AtomicInteger waitCtr = new AtomicInteger(1); + CompletableFuture fut1 = new CompletableFuture<>(); - b.value = rightCache.peekCached(b.idx); + submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { + try (tmp) { + P key = on.apply(tmp.get()); - // Directly emit - qOut.enqueue(mapper.apply(tmp, b)); + if(i == 0) { // qIn stream + BroadcastedElement b = availableBroadcastInput.get(key); - b.value = null; + if(b == null) { + // Matching broadcast element is not available -> cache element + availableLeftInput.compute(key, (k, v) -> { + if(v == null) + v = new ArrayList<>(); + v.add(tmp.get().getIndexes()); + return v; + }); - if (b.canRelease()) { - availableBroadcastInput.remove(key); - - if (!explicitRightCaching) - rightCache.incrProcessingCount(rightCache.findCachedIndex(b.idx), 1); // Correct for incremented subscriber count to allow block deletion + if(explicitLeftCaching) + leftCache.getWriteStream().enqueue(tmp.get()); + } + else { + waitCtr.incrementAndGet(); + + OOCCacheManager.requestManyBlocks( + List.of(leftCache.peekCachedBlockKey(tmp.get().getIndexes()), rightCache.peekCachedBlockKey(b.idx))) + .whenComplete((items, err) -> { + try { + broadcastingQueue.enqueue(new Tuple4<>(key, items.get(0).keepOpen(), items.get(1).keepOpen(), b)); + } finally { + items.forEach(OOCStream.QueueCallback::close); + } + }); + } + } + else { // broadcast stream + if(explicitRightCaching) + rightCache.getWriteStream().enqueue(tmp.get()); + + BroadcastedElement b = new BroadcastedElement(tmp.get().getIndexes()); + availableBroadcastInput.put(key, b); + + List queued = availableLeftInput.remove(key); + + if(queued != null) { + for(MatrixIndexes idx : queued) { + waitCtr.incrementAndGet(); + + OOCCacheManager.requestManyBlocks( + List.of(leftCache.peekCachedBlockKey(idx), rightCache.peekCachedBlockKey(tmp.get().getIndexes()))) + .whenComplete((items, err) -> { + try{ + broadcastingQueue.enqueue(new Tuple4<>(key, items.get(0).keepOpen(), items.get(1).keepOpen(), b)); + } finally { + items.forEach(OOCStream.QueueCallback::close); + } + }); + } } } - } else { // broadcast stream - if (explicitRightCaching) - rightCache.getWriteStream().enqueue(tmp); + } + }, () -> { + fut1.complete(null); + if (waitCtr.decrementAndGet() == 0) + broadcastingQueue.closeInput(); + }); + + CompletableFuture fut2 = new CompletableFuture<>(); + submitOOCTasks(List.of(broadcastingQueue), (i, tpl) -> { + try (tpl) { + final BroadcastedElement b = tpl.get()._4(); + final OOCStream.QueueCallback lValue = tpl.get()._2(); + final OOCStream.QueueCallback bValue = tpl.get()._3(); - BroadcastedElement b = new BroadcastedElement(tmp.getIndexes()); - availableBroadcastInput.put(key, b); + try (lValue; bValue) { + b.value = bValue.get(); + qOut.enqueue(mapper.apply(lValue.get(), b)); + leftCache.incrProcessingCount(leftCache.findCachedIndex(lValue.get().getIndexes()), 1); - List queued = availableLeftInput.remove(key); + if(b.canRelease()) { + availableBroadcastInput.remove(tpl.get()._1()); - if (queued != null) { - for(MatrixIndexes idx : queued) { - b.value = rightCache.peekCached(b.idx); // Only peek to prevent block deletion - qOut.enqueue(mapper.apply(leftCache.findCached(idx), b)); - b.value = null; + if(!explicitRightCaching) + rightCache.incrProcessingCount(rightCache.findCachedIndex(b.idx), + 1); // Correct for incremented subscriber count to allow block deletion } } - if (b.canRelease()) { - availableBroadcastInput.remove(key); - - if (!explicitRightCaching) - rightCache.incrProcessingCount(rightCache.findCachedIndex(tmp.getIndexes()), 1); // Correct for incremented subscriber count to allow block deletion - } + if(waitCtr.decrementAndGet() == 0) + broadcastingQueue.closeInput(); } - }, () -> { + }, () -> fut2.complete(null)); + + if (explicitLeftCaching) + leftCache.scheduleDeletion(); + if (explicitRightCaching) + rightCache.scheduleDeletion(); + + CompletableFuture fut = CompletableFuture.allOf(fut1, fut2); + fut.whenComplete((res, t) -> { availableBroadcastInput.forEach((k, v) -> { rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1); }); @@ -244,12 +302,7 @@ protected CompletableFuture broadcastJoinOOC(OOCStream CompletableFuture joinOOC(OOCStream qIn1, OOCStream return joinOOC(qIn1, qIn2, qOut, mapper, on, on); } + @SuppressWarnings("unchecked") + protected CompletableFuture joinOOC(List> qIn, OOCStream qOut, Function, R> mapper, List> on) { + if (qIn == null || on == null || qIn.size() != on.size()) + throw new DMLRuntimeException("joinOOC(list) requires the same number of streams and key functions."); + + addInStream(qIn.toArray(OOCStream[]::new)); + addOutStream(qOut); + + final int n = qIn.size(); + + CachingStream[] caches = new CachingStream[n]; + boolean[] explicitCaching = new boolean[n]; + + for (int i = 0; i < n; i++) { + OOCStream s = qIn.get(i); + explicitCaching[i] = !s.hasStreamCache(); + caches[i] = explicitCaching[i] ? new CachingStream((OOCStream) s) : s.getStreamCache(); + caches[i].activateIndexing(); + // One additional consumption for the materialization when emitting + caches[i].incrSubscriberCount(1); + } + + Map seen = new ConcurrentHashMap<>(); + + CompletableFuture future = submitOOCTasks( + Arrays.stream(caches).map(CachingStream::getReadStream).collect(java.util.stream.Collectors.toList()), + (i, tmp) -> { + Function keyFn = on.get(i); + P key = keyFn.apply((T)tmp.get()); + MatrixIndexes idx = tmp.get().getIndexes(); + + MatrixIndexes[] arr = seen.computeIfAbsent(key, k -> new MatrixIndexes[n]); + boolean ready; + synchronized (arr) { + arr[i] = idx; + ready = true; + for (MatrixIndexes ix : arr) { + if (ix == null) { + ready = false; + break; + } + } + } + + if (!ready || !seen.remove(key, arr)) + return; + + List> values = new java.util.ArrayList<>(n); + try { + for(int j = 0; j < n; j++) + values.add((OOCStream.QueueCallback) caches[j].findCached(arr[j])); + + qOut.enqueue(mapper.apply(values.stream().map(OOCStream.QueueCallback::get).toList())); + } finally { + values.forEach(OOCStream.QueueCallback::close); + } + }, qOut::closeInput); + + for (int i = 0; i < n; i++) { + if (explicitCaching[i]) + caches[i].scheduleDeletion(); + } + + return future; + } + @SuppressWarnings("unchecked") protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function onLeft, Function onRight) { addInStream(qIn1, qIn2); @@ -309,16 +428,20 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream rightCache.incrSubscriberCount(1); final OOCJoin join = new OOCJoin<>((idx, left, right) -> { - T leftObj = (T) leftCache.findCached(left); - T rightObj = (T) rightCache.findCached(right); - qOut.enqueue(mapper.apply(leftObj, rightObj)); + OOCStream.QueueCallback leftObj = (OOCStream.QueueCallback) leftCache.findCached(left); + OOCStream.QueueCallback rightObj = (OOCStream.QueueCallback) rightCache.findCached(right); + try (leftObj; rightObj) { + qOut.enqueue(mapper.apply(leftObj.get(), rightObj.get())); + } }); submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), (i, tmp) -> { - if (i == 0) - join.addLeft(onLeft.apply((T)tmp), ((IndexedMatrixValue) tmp).getIndexes()); - else - join.addRight(onRight.apply((T)tmp), ((IndexedMatrixValue) tmp).getIndexes()); + try (tmp) { + if(i == 0) + join.addLeft(onLeft.apply((T) tmp.get()), tmp.get().getIndexes()); + else + join.addRight(onRight.apply((T) tmp.get()), tmp.get().getIndexes()); + } }, () -> { join.close(); qOut.closeInput(); @@ -333,11 +456,11 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream return future; } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer) { + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer) { return submitOOCTasks(queues, consumer, finalizer, null); } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, BiConsumer onNotProcessed) { + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer, BiConsumer> onNotProcessed) { List> futures = new ArrayList<>(queues.size()); for (int i = 0; i < queues.size(); i++) @@ -346,7 +469,7 @@ protected CompletableFuture submitOOCTasks(final List> qu return submitOOCTasks(queues, consumer, finalizer, futures, null, onNotProcessed); } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, List> futures, BiFunction predicate, BiConsumer onNotProcessed) { + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer, List> futures, BiFunction, Boolean> predicate, BiConsumer> onNotProcessed) { addInStream(queues.toArray(OOCStream[]::new)); ExecutorService pool = CommonThreadPool.get(); @@ -373,44 +496,69 @@ protected CompletableFuture submitOOCTasks(final List> qu //System.out.println("Substream (k " + k + ", id " + streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + queue.hashCode() + ")"); queue.setSubscriber(oocTask(callback -> { - final T item = callback.get(); - - if(item == null) { - if(!closeRaceWatchdog.compareAndSet(false, true)) - throw new DMLRuntimeException("Race condition observed: NO_MORE_TASKS callback has been triggered more than once"); + long startTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + try (callback) { + if(callback.isEos()) { + if(!closeRaceWatchdog.compareAndSet(false, true)) + throw new DMLRuntimeException( + "Race condition observed: NO_MORE_TASKS callback has been triggered more than once"); + + if(localTaskCtr.decrementAndGet() == 0) { + // Then we can run the finalization procedure already + localFuture.complete(null); + } + return; + } - if(localTaskCtr.decrementAndGet() == 0) { - // Then we can run the finalization procedure already - localFuture.complete(null); + if(predicate != null && !predicate.apply(k, callback)) { // Can get closed due to cancellation + if(onNotProcessed != null) + onNotProcessed.accept(k, callback); + return; } - return; - } - if(predicate != null && !predicate.apply(k, item)) { // Can get closed due to cancellation - if(onNotProcessed != null) - onNotProcessed.accept(k, item); - return; - } + if(localFuture.isDone()) { + if(onNotProcessed != null) + onNotProcessed.accept(k, callback); + return; + } + else { + localTaskCtr.incrementAndGet(); + } - if(localFuture.isDone()) { - if(onNotProcessed != null) - onNotProcessed.accept(k, item); - return; - } - else { - localTaskCtr.incrementAndGet(); + // The item needs to be pinned in memory to be accessible in the executor thread + final OOCStream.QueueCallback pinned = callback.keepOpen(); + + pool.submit(oocTask(() -> { + long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + try (pinned) { + consumer.accept(k, pinned); + + if(localTaskCtr.decrementAndGet() == 0) + localFuture.complete(null); + } finally { + if (DMLScript.STATISTICS) { + _localStatisticsAdder.add(System.nanoTime() - taskStartTime); + if (globalFuture.isDone()) { + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); + _localStatisticsAdder.reset(); + } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); + } + } + }, localFuture, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); + + if(closeRaceWatchdog.get()) // Sanity check + throw new DMLRuntimeException("Race condition observed"); + } finally { + if (DMLScript.STATISTICS) { + _localStatisticsAdder.add(System.nanoTime() - startTime); + if (globalFuture.isDone()) { + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); + _localStatisticsAdder.reset(); + } + } } - - pool.submit(oocTask(() -> { - // TODO For caching streams, we have no guarantee that item is still in memory -> NullPointer possible - consumer.accept(k, item); - - if(localTaskCtr.decrementAndGet() == 0) - localFuture.complete(null); - }, localFuture, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); - - if(closeRaceWatchdog.get()) // Sanity check - throw new DMLRuntimeException("Race condition observed"); }, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); i++; @@ -433,11 +581,11 @@ protected CompletableFuture submitOOCTasks(final List> qu return globalFuture; } - protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer) { + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer> consumer, Runnable finalizer) { return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); } - protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer, Function predicate, BiConsumer onNotProcessed) { + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer> consumer, Runnable finalizer, Function, Boolean> predicate, BiConsumer> onNotProcessed) { return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp), onNotProcessed); } @@ -445,7 +593,15 @@ protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queu ExecutorService pool = CommonThreadPool.get(); final CompletableFuture future = new CompletableFuture<>(); try { - pool.submit(oocTask(() -> {r.run();future.complete(null);}, future, queues)); + pool.submit(oocTask(() -> { + long startTime = DMLScript.STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; + r.run(); + future.complete(null); + if (DMLScript.STATISTICS) + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), System.nanoTime() - startTime); + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, startTime, System.nanoTime()); + }, future, queues)); } catch (Exception ex) { throw new DMLRuntimeException(ex); @@ -459,6 +615,7 @@ protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queu private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream... queues) { return () -> { + long startTime = DMLScript.STATISTICS ? System.nanoTime() : 0; try { r.run(); } @@ -478,6 +635,9 @@ private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream< // Rethrow to ensure proper future handling throw re; + } finally { + if (DMLScript.STATISTICS) + _localStatisticsAdder.add(System.nanoTime() - startTime); } }; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index f02c847e055..c70f9cbb8db 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -25,14 +25,16 @@ import java.util.function.Consumer; public interface OOCStream extends OOCStreamable { + static QueueCallback eos(DMLRuntimeException e) { + return new SimpleQueueCallback<>(null, e); + } + void enqueue(T t); T dequeue(); void closeInput(); - LocalTaskQueue toLocalTaskQueue(); - void propagateFailure(DMLRuntimeException re); boolean hasStreamCache(); @@ -47,19 +49,53 @@ public interface OOCStream extends OOCStreamable { */ void setSubscriber(Consumer> subscriber); - class QueueCallback { + interface QueueCallback extends AutoCloseable { + T get(); + + /** + * Keeps the callback item pinned in memory until the returned callback is also closed. + */ + QueueCallback keepOpen(); + + void close(); + + void fail(DMLRuntimeException failure); + + boolean isEos(); + } + + class SimpleQueueCallback implements QueueCallback { private final T _result; - private final DMLRuntimeException _failure; + private DMLRuntimeException _failure; - public QueueCallback(T result, DMLRuntimeException failure) { - _result = result; - _failure = failure; + public SimpleQueueCallback(T result, DMLRuntimeException failure) { + this._result = result; + this._failure = failure; } + @Override public T get() { if (_failure != null) throw _failure; return _result; } + + @Override + public QueueCallback keepOpen() { + return this; + } + + @Override + public void fail(DMLRuntimeException failure) { + this._failure = failure; + } + + @Override + public void close() {} + + @Override + public boolean isEos() { + return get() == null; + } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java index b7a16778ab7..69e669a40b7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java @@ -43,14 +43,15 @@ public final class OOCWatchdog { private static final long SCAN_INTERVAL_MS = TimeUnit.SECONDS.toMillis(10); static { - EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS, SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS); + if (WATCH) + EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS, SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS); } private OOCWatchdog() { // no-op } - public static void registerOpen(String id, String desc, String context, OOCStream stream) { + public static void registerOpen(String id, String desc, String context, OOCStreamable stream) { OPEN.put(id, new Entry(desc, context, System.currentTimeMillis(), stream)); } @@ -68,7 +69,7 @@ private static void scan() { long now = System.currentTimeMillis(); for (Map.Entry e : OPEN.entrySet()) { if (now - e.getValue().openedAt >= STALE_MS) { - if (e.getValue().events.isEmpty()) + if (e.getValue().events.isEmpty() && !(e.getValue().stream instanceof CachingStream)) continue; // Probably just a stream that has no consumer (remains to be checked why this can happen) System.err.println("[TemporaryWatchdog] Still open after " + (now - e.getValue().openedAt) + "ms: " + e.getKey() + " (" + e.getValue().desc + ")" @@ -81,10 +82,10 @@ private static class Entry { final String desc; final String context; final long openedAt; - final OOCStream stream; + final OOCStreamable stream; ConcurrentLinkedQueue events; - Entry(String desc, String context, long openedAt, OOCStream stream) { + Entry(String desc, String context, long openedAt, OOCStreamable stream) { this.desc = desc; this.context = context; this.openedAt = openedAt; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java index 5b996da0dbe..bd725e5dd44 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java @@ -20,9 +20,9 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -30,13 +30,12 @@ public class PlaybackStream implements OOCStream, OOCStreamable { private final CachingStream _streamCache; private final AtomicInteger _streamIdx; - private final AtomicInteger _taskCtr; private final AtomicBoolean _subscriberSet; + private QueueCallback _lastDequeue; public PlaybackStream(CachingStream streamCache) { this._streamCache = streamCache; this._streamIdx = new AtomicInteger(0); - this._taskCtr = new AtomicInteger(1); this._subscriberSet = new AtomicBoolean(false); streamCache.incrSubscriberCount(1); } @@ -52,31 +51,16 @@ public void closeInput() { } @Override - public LocalTaskQueue toLocalTaskQueue() { - final LocalTaskQueue q = new LocalTaskQueue<>(); - setSubscriber(val -> { - if (val.get() == null) { - q.closeInput(); - return; - } - try { - q.enqueueTask(val.get()); - } - catch(InterruptedException e) { - throw new RuntimeException(e); - } - }); - return q; - } - - @Override - public IndexedMatrixValue dequeue() { + public synchronized IndexedMatrixValue dequeue() { if (_subscriberSet.get()) throw new IllegalStateException("Cannot dequeue from a playback stream if a subscriber has been set"); try { - return _streamCache.get(_streamIdx.getAndIncrement()); - } catch (InterruptedException e) { + if (_lastDequeue != null) + _lastDequeue.close(); + _lastDequeue = _streamCache.get(_streamIdx.getAndIncrement()); + return _lastDequeue.get(); + } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException(e); } } @@ -101,31 +85,7 @@ public void setSubscriber(Consumer> subscriber if (!_subscriberSet.compareAndSet(false, true)) throw new IllegalArgumentException("Subscriber cannot be set multiple times"); - /** - * To guarantee that NO_MORE_TASKS is invoked after all subscriber calls - * finished, we keep track of running tasks using a task counter. - */ - _streamCache.setSubscriber(() -> { - try { - _taskCtr.incrementAndGet(); - - IndexedMatrixValue val; - - try { - val = _streamCache.get(_streamIdx.getAndIncrement()); - } catch (InterruptedException e) { - throw new DMLRuntimeException(e); - } - - if (val != null) - subscriber.accept(new QueueCallback<>(val, null)); - - if (_taskCtr.addAndGet(val == null ? -2 : -1) == 0) - subscriber.accept(new QueueCallback<>(null, null)); - } catch (DMLRuntimeException e) { - subscriber.accept(new QueueCallback<>(null, e)); - } - }, false); + _streamCache.setSubscriber(subscriber, false); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index 7563d8471b6..51d0d70c865 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -71,7 +71,7 @@ public void enqueue(T t) { Consumer> s = _subscriber; if (s != null) { - s.accept(new QueueCallback<>(t, _failure)); + s.accept(new SimpleQueueCallback<>(t, _failure)); onDeliveryFinished(); return; } @@ -92,7 +92,7 @@ public void enqueue(T t) { } // Last case if due to race a subscriber has been set - s.accept(new QueueCallback<>(t, _failure)); + s.accept(new SimpleQueueCallback<>(t, _failure)); onDeliveryFinished(); } @@ -149,7 +149,7 @@ public void setSubscriber(Consumer> subscriber) { } for (T t : data) { - subscriber.accept(new QueueCallback<>(t, _failure)); + subscriber.accept(new SimpleQueueCallback<>(t, _failure)); onDeliveryFinished(); } } @@ -160,7 +160,7 @@ private void onDeliveryFinished() { if (ctr == 0) { Consumer> s = _subscriber; if (s != null) - s.accept(new QueueCallback<>((T) LocalTaskQueue.NO_MORE_TASKS, _failure)); + s.accept(new SimpleQueueCallback<>((T) LocalTaskQueue.NO_MORE_TASKS, _failure)); if (OOCWatchdog.WATCH) OOCWatchdog.registerClose(_watchdogId); @@ -172,12 +172,7 @@ public synchronized void propagateFailure(DMLRuntimeException re) { super.propagateFailure(re); Consumer> s = _subscriber; if(s != null) - s.accept(new QueueCallback<>(null, re)); - } - - @Override - public LocalTaskQueue toLocalTaskQueue() { - return this; + s.accept(new SimpleQueueCallback<>(null, re)); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java new file mode 100644 index 00000000000..da5c37c50ef --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.List; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.IfElse; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory; +import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.TernaryOperator; + +public class TernaryOOCInstruction extends ComputationOOCInstruction { + + protected TernaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String opcode, String istr) { + super(OOCType.Ternary, op, in1, in2, in3, out, opcode, istr); + } + + public static TernaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 4, 5); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[4]); + int numThreads = parts.length > 5 ? Integer.parseInt(parts[5]) : 1; + TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode, numThreads); + return new TernaryOOCInstruction(op, in1, in2, in3, out, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) { + boolean m1 = input1.isMatrix(); + boolean m2 = input2.isMatrix(); + boolean m3 = input3.isMatrix(); + + if(!m1 && !m2 && !m3) { + processScalarInstruction(ec); + return; + } + + if(m1 && m2 && m3) + processThreeMatrixInstruction(ec); + else if(m1 && m2) + processTwoMatrixInstruction(ec, 1, 2); + else if(m1 && m3) + processTwoMatrixInstruction(ec, 1, 3); + else if(m2 && m3) + processTwoMatrixInstruction(ec, 2, 3); + else if(m1) + processSingleMatrixInstruction(ec, 1); + else if(m2) + processSingleMatrixInstruction(ec, 2); + else + processSingleMatrixInstruction(ec, 3); + } + + private void processScalarInstruction(ExecutionContext ec) { + TernaryOperator op = (TernaryOperator) _optr; + if(op.fn instanceof IfElse && output.getValueType() == ValueType.STRING) { + String value = (ec.getScalarInput(input1).getDoubleValue() != 0 ? + ec.getScalarInput(input2) : ec.getScalarInput(input3)).getStringValue(); + ec.setScalarOutput(output.getName(), new StringObject(value)); + } + else { + double value = op.fn.execute( + ec.getScalarInput(input1).getDoubleValue(), + ec.getScalarInput(input2).getDoubleValue(), + ec.getScalarInput(input3).getDoubleValue()); + ec.setScalarOutput(output.getName(), ScalarObjectFactory + .createScalarObject(output.getValueType(), value)); + } + } + + private void processSingleMatrixInstruction(ExecutionContext ec, int matrixPos) { + MatrixObject mo = getMatrixObject(ec, matrixPos); + MatrixBlock s1 = input1.isMatrix() ? null : getScalarInputBlock(ec, input1); + MatrixBlock s2 = input2.isMatrix() ? null : getScalarInputBlock(ec, input2); + MatrixBlock s3 = input3.isMatrix() ? null : getScalarInputBlock(ec, input3); + + OOCStream qIn = mo.getStreamHandle(); + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + mapOOC(qIn, qOut, tmp -> { + IndexedMatrixValue outVal = new IndexedMatrixValue(); + MatrixBlock op1 = resolveOperandBlock(1, tmp, null, matrixPos, -1, s1, s2, s3); + MatrixBlock op2 = resolveOperandBlock(2, tmp, null, matrixPos, -1, s1, s2, s3); + MatrixBlock op3 = resolveOperandBlock(3, tmp, null, matrixPos, -1, s1, s2, s3); + outVal.set(tmp.getIndexes(), + op1.ternaryOperations((TernaryOperator)_optr, op2, op3, new MatrixBlock())); + return outVal; + }); + } + + private void processTwoMatrixInstruction(ExecutionContext ec, int leftPos, int rightPos) { + MatrixObject left = getMatrixObject(ec, leftPos); + MatrixObject right = getMatrixObject(ec, rightPos); + + MatrixBlock s1 = input1.isMatrix() ? null : getScalarInputBlock(ec, input1); + MatrixBlock s2 = input2.isMatrix() ? null : getScalarInputBlock(ec, input2); + MatrixBlock s3 = input3.isMatrix() ? null : getScalarInputBlock(ec, input3); + + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + joinOOC(left.getStreamHandle(), right.getStreamHandle(), qOut, (l, r) -> { + IndexedMatrixValue outVal = new IndexedMatrixValue(); + MatrixBlock op1 = resolveOperandBlock(1, l, r, leftPos, rightPos, s1, s2, s3); + MatrixBlock op2 = resolveOperandBlock(2, l, r, leftPos, rightPos, s1, s2, s3); + MatrixBlock op3 = resolveOperandBlock(3, l, r, leftPos, rightPos, s1, s2, s3); + outVal.set(l.getIndexes(), + op1.ternaryOperations((TernaryOperator)_optr, op2, op3, new MatrixBlock())); + return outVal; + }, IndexedMatrixValue::getIndexes); + } + + private void processThreeMatrixInstruction(ExecutionContext ec) { + MatrixObject m1 = ec.getMatrixObject(input1); + MatrixObject m2 = ec.getMatrixObject(input2); + MatrixObject m3 = ec.getMatrixObject(input3); + + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + List> streams = List.of( + m1.getStreamHandle(), m2.getStreamHandle(), m3.getStreamHandle()); + + List> keyFns = + List.of(IndexedMatrixValue::getIndexes, IndexedMatrixValue::getIndexes, IndexedMatrixValue::getIndexes); + + joinOOC(streams, qOut, blocks -> { + IndexedMatrixValue b1 = blocks.get(0); + IndexedMatrixValue b2 = blocks.get(1); + IndexedMatrixValue b3 = blocks.get(2); + IndexedMatrixValue outVal = new IndexedMatrixValue(); + outVal.set(b1.getIndexes(), + ((MatrixBlock)b1.getValue()).ternaryOperations((TernaryOperator)_optr, (MatrixBlock)b2.getValue(), (MatrixBlock)b3.getValue(), new MatrixBlock())); + return outVal; + }, keyFns); + } + + private MatrixObject getMatrixObject(ExecutionContext ec, int pos) { + if(pos == 1) + return ec.getMatrixObject(input1); + else if(pos == 2) + return ec.getMatrixObject(input2); + else if(pos == 3) + return ec.getMatrixObject(input3); + else + throw new DMLRuntimeException("Invalid matrix position: " + pos); + } + + private MatrixBlock getScalarInputBlock(ExecutionContext ec, CPOperand operand) { + ScalarObject scalar = ec.getScalarInput(operand); + return new MatrixBlock(scalar.getDoubleValue()); + } + + private MatrixBlock resolveOperandBlock(int operandPos, IndexedMatrixValue left, IndexedMatrixValue right, + int leftPos, int rightPos, MatrixBlock s1, MatrixBlock s2, MatrixBlock s3) { + if(operandPos == leftPos && left != null) + return (MatrixBlock) left.getValue(); + if(operandPos == rightPos && right != null) + return (MatrixBlock) right.getValue(); + + if(operandPos == 1) + return s1; + else if(operandPos == 2) + return s2; + else if(operandPos == 3) + return s3; + else + throw new DMLRuntimeException("Invalid operand position: " + operandPos); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java b/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java index 1844cc1af79..8681a91c7e0 100644 --- a/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java +++ b/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java @@ -23,7 +23,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -56,7 +56,7 @@ public abstract void writeMatrixToHDFS( MatrixBlock src, String fname, long rlen * @param blen The block size * @throws IOException if an I/O error occurs */ - public abstract long writeMatrixFromStream(String fname, LocalTaskQueue stream, + public abstract long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) throws IOException; public void setForcedParallel(boolean par) { diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java index 82c994eb7a8..69fd386c5ef 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; @@ -234,7 +235,7 @@ protected final void writeDiagBinaryBlockMatrixToHDFS(Path path, JobConf job, M } @Override - public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) throws IOException { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) throws IOException { Path path = new Path(fname); SequenceFile.Writer writer = null; @@ -245,7 +246,7 @@ public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) { throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the HDF5 format."); }; } diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterMatrixMarket.java b/src/main/java/org/apache/sysds/runtime/io/WriterMatrixMarket.java index 39855968202..5483dc28ab9 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterMatrixMarket.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterMatrixMarket.java @@ -35,8 +35,8 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.IJV; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -224,7 +224,7 @@ public static void mergeTextcellToMatrixMarket( String srcFileName, String fileN } @Override - public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) { throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the MatrixMarket format."); }; } diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/WriterTextCSV.java index 9bc1edace9d..e96278b7801 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterTextCSV.java @@ -35,9 +35,9 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.HDFSTool; @@ -345,7 +345,7 @@ public final void addHeaderToCSV(String srcFileName, String destFileName, long r } @Override - public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) { throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the TextCSV format."); }; } diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterTextCell.java b/src/main/java/org/apache/sysds/runtime/io/WriterTextCell.java index b876f21752b..ad216bf9406 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterTextCell.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterTextCell.java @@ -30,8 +30,8 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.IJV; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -141,7 +141,7 @@ protected static void writeTextCellMatrixToFile( Path path, JobConf job, FileSys } @Override - public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) { throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the TextCell format."); }; } diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterTextLIBSVM.java b/src/main/java/org/apache/sysds/runtime/io/WriterTextLIBSVM.java index 4a97abefc55..450a20979c4 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterTextLIBSVM.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterTextLIBSVM.java @@ -28,9 +28,9 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.HDFSTool; @@ -160,7 +160,7 @@ protected static void appendIndexValLibsvm(StringBuilder sb, int index, double v } @Override - public long writeMatrixFromStream(String fname, LocalTaskQueue stream, long rlen, long clen, int blen) { + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) { throw new UnsupportedOperationException("Writing from an OOC stream is not supported for the LIBSVM format."); }; } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java new file mode 100644 index 00000000000..eea76c808a2 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +public final class BlockEntry { + private final BlockKey _key; + private final long _size; + private volatile int _pinCount; + private volatile BlockState _state; + private Object _data; + + BlockEntry(BlockKey key, long size, Object data) { + this._key = key; + this._size = size; + this._pinCount = 0; + this._state = BlockState.HOT; + this._data = data; + } + + public BlockKey getKey() { + return _key; + } + + public long getSize() { + return _size; + } + + public Object getData() { + if (_pinCount > 0) + return _data; + throw new IllegalStateException("Cannot get the data of an unpinned entry"); + } + + Object getDataUnsafe() { + return _data; + } + + void setDataUnsafe(Object data) { + _data = data; + } + + public BlockState getState() { + return _state; + } + + public boolean isPinned() { + return _pinCount > 0; + } + + synchronized void setState(BlockState state) { + _state = state; + } + + /** + * Tries to clear the underlying data if it is not pinned + * @return the number of cleared bytes (or 0 if could not clear or data was already cleared) + */ + synchronized long clear() { + if (_pinCount != 0 || _data == null) + return 0; + if (_data instanceof IndexedMatrixValue) + ((IndexedMatrixValue)_data).setValue(null); // Explicitly clear + _data = null; + return _size; + } + + /** + * Pins the underlying data in memory + * @return the new number of pins (0 if pin was unsuccessful) + */ + synchronized int pin() { + if (_data == null) + return 0; + _pinCount++; + return _pinCount; + } + + /** + * Unpins the underlying data + * @return true if the data is now unpinned + */ + synchronized boolean unpin() { + if (_pinCount <= 0) + throw new IllegalStateException("Cannot unpin data if it was not pinned"); + _pinCount--; + return _pinCount == 0; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockKey.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockKey.java new file mode 100644 index 00000000000..c6435672462 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockKey.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.jetbrains.annotations.NotNull; + +public class BlockKey implements Comparable { + private final long _streamId; + private final long _sequenceNumber; + + public BlockKey(long streamId, long sequenceNumber) { + this._streamId = streamId; + this._sequenceNumber = sequenceNumber; + } + + public long getStreamId() { + return _streamId; + } + + public long getSequenceNumber() { + return _sequenceNumber; + } + + @Override + public int compareTo(@NotNull BlockKey blockKey) { + int cmp = Long.compare(_streamId, blockKey._streamId); + if (cmp != 0) + return cmp; + return Long.compare(_sequenceNumber, blockKey._sequenceNumber); + } + + @Override + public int hashCode() { + return 31 * Long.hashCode(_streamId) + Long.hashCode(_sequenceNumber); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof BlockKey && ((BlockKey)obj)._streamId == _streamId && ((BlockKey)obj)._sequenceNumber == _sequenceNumber; + } + + @Override + public String toString() { + return "BlockKey(" + _streamId + ", " + _sequenceNumber + ")"; + } + + public String toFileKey() { + return _streamId + "_" + _sequenceNumber; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockState.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockState.java new file mode 100644 index 00000000000..30013f736e7 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockState.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +public enum BlockState { + HOT, + WARM, + EVICTING, + READING, + //DEFERRED_READ, // Deferred read + COLD, + REMOVED; // Removed state means that it is not owned by the cache anymore. It doesn't mean the object is dereferenced + + public boolean isAvailable() { + return this == HOT || this == WARM || this == EVICTING || this == REMOVED; + } + + public boolean isUnavailable() { + return this == COLD || this == READING; + } + + public boolean readScheduled() { + return this == READING; + } + + public boolean isBackedByDisk() { + return switch(this) { + case WARM, COLD, READING -> true; + default -> false; + }; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java new file mode 100644 index 00000000000..b8c312d2a3d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +public class CloseableQueue { + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final Object POISON = new Object(); // sentinel + private volatile boolean closed = false; + + public CloseableQueue() { } + + /** + * Enqueue if the queue is not closed. + * @return false if already closed + */ + public boolean enqueueIfOpen(T task) throws InterruptedException { + if (task == null) + throw new IllegalArgumentException("null tasks not allowed"); + synchronized (this) { + if (closed) + return false; + queue.put(task); + } + return true; + } + + public T take() throws InterruptedException { + if (closed && queue.isEmpty()) + return null; + + Object x = queue.take(); + + if (x == POISON) + return null; + + return (T) x; + } + + /** + * Poll with max timeout. + * @return item, or null if: + * - timeout, or + * - queue has been closed and this consumer reached its poison pill + */ + @SuppressWarnings("unchecked") + public T poll(long timeout, TimeUnit unit) throws InterruptedException { + if (closed && queue.isEmpty()) + return null; + + Object x = queue.poll(timeout, unit); + if (x == null) + return null; // timeout + + if (x == POISON) + return null; + + return (T) x; + } + + /** + * Close queue for N consumers. + * Each consumer will receive exactly one poison pill and then should stop. + */ + public boolean close() throws InterruptedException { + synchronized (this) { + if (closed) + return false; // idempotent + closed = true; + } + queue.put(POISON); + return true; + } + + public synchronized boolean isFinished() { + return closed && queue.isEmpty(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java new file mode 100644 index 00000000000..50b5cf78218 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.ooc.stats.OOCEventLog; +import org.apache.sysds.utils.Statistics; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +public class OOCCacheManager { + private static final double OOC_BUFFER_PERCENTAGE = 0.2; + private static final double OOC_BUFFER_PERCENTAGE_HARD = 0.3; + private static final long _evictionLimit; + private static final long _hardLimit; + + private static final AtomicReference _ioHandler; + private static final AtomicReference _scheduler; + + static { + _evictionLimit = (long)(Runtime.getRuntime().maxMemory() * OOC_BUFFER_PERCENTAGE); + _hardLimit = (long)(Runtime.getRuntime().maxMemory() * OOC_BUFFER_PERCENTAGE_HARD); + _ioHandler = new AtomicReference<>(); + _scheduler = new AtomicReference<>(); + } + + public static void reset() { + OOCIOHandler ioHandler = _ioHandler.getAndSet(null); + OOCCacheScheduler cacheScheduler = _scheduler.getAndSet(null); + if (ioHandler != null) + ioHandler.shutdown(); + if (cacheScheduler != null) + cacheScheduler.shutdown(); + + if (DMLScript.OOC_STATISTICS) + Statistics.resetOOCEvictionStats(); + + if (DMLScript.OOC_LOG_EVENTS) { + try { + String csv = OOCEventLog.getComputeEventsCSV(); + Files.writeString(Path.of(DMLScript.OOC_LOG_PATH, "ComputeEventLog.csv"), csv); + csv = OOCEventLog.getDiskReadEventsCSV(); + Files.writeString(Path.of(DMLScript.OOC_LOG_PATH, "DiskReadEventLog.csv"), csv); + csv = OOCEventLog.getDiskWriteEventsCSV(); + Files.writeString(Path.of(DMLScript.OOC_LOG_PATH, "DiskWriteEventLog.csv"), csv); + csv = OOCEventLog.getCacheSizeEventsCSV(); + Files.writeString(Path.of(DMLScript.OOC_LOG_PATH, "CacheSizeEventLog.csv"), csv); + csv = OOCEventLog.getRunSettingsCSV(); + Files.writeString(Path.of(DMLScript.OOC_LOG_PATH, "RunSettings.csv"), csv); + System.out.println("Event logs written to: " + DMLScript.OOC_LOG_PATH); + } + catch(IOException e) { + System.err.println("Could not write event logs: " + e.getMessage()); + } + OOCEventLog.clear(); + } + } + + public static OOCCacheScheduler getCache() { + while (true) { + OOCCacheScheduler scheduler = _scheduler.get(); + + if(scheduler != null) + return scheduler; + + OOCIOHandler ioHandler = new OOCMatrixIOHandler(); + scheduler = new OOCLRUCacheScheduler(ioHandler, _evictionLimit, _hardLimit); + + if(_scheduler.compareAndSet(null, scheduler)) { + _ioHandler.set(ioHandler); + return scheduler; + } + } + } + + /** + * Removes a block from the cache without setting its data to null. + */ + public static void forget(long streamId, int blockId) { + BlockKey key = new BlockKey(streamId, blockId); + getCache().forget(key); + } + + /** + * Store a block in the OOC cache (serialize once) + */ + public static void put(long streamId, int blockId, IndexedMatrixValue value) { + BlockKey key = new BlockKey(streamId, blockId); + getCache().put(key, value, ((MatrixBlock)value.getValue()).getExactSerializedSize()); + } + + public static OOCStream.QueueCallback putAndPin(long streamId, int blockId, IndexedMatrixValue value) { + BlockKey key = new BlockKey(streamId, blockId); + return new CachedQueueCallback<>(getCache().putAndPin(key, value, ((MatrixBlock)value.getValue()).getExactSerializedSize()), null); + } + + public static CompletableFuture> requestBlock(long streamId, long blockId) { + BlockKey key = new BlockKey(streamId, blockId); + return getCache().request(key).thenApply(e -> new CachedQueueCallback<>(e, null)); + } + + public static CompletableFuture>> requestManyBlocks(List keys) { + return getCache().request(keys).thenApply( + l -> l.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList()); + } + + private static void pin(BlockEntry entry) { + getCache().pin(entry); + } + + private static void unpin(BlockEntry entry) { + getCache().unpin(entry); + } + + + + + static class CachedQueueCallback implements OOCStream.QueueCallback { + private final BlockEntry _result; + private DMLRuntimeException _failure; + private final AtomicBoolean _pinned; + + CachedQueueCallback(BlockEntry result, DMLRuntimeException failure) { + this._result = result; + this._failure = failure; + this._pinned = new AtomicBoolean(true); + } + + @SuppressWarnings("unchecked") + @Override + public T get() { + if (_failure != null) + throw _failure; + if (!_pinned.get()) + throw new IllegalStateException("Cannot get cached item of a closed callback"); + T ret = (T)_result.getData(); + if (ret == null) + throw new IllegalStateException("Cannot get a cached item if it is not pinned in memory: " + _result.getState()); + return ret; + } + + @Override + public OOCStream.QueueCallback keepOpen() { + if (!_pinned.get()) + throw new IllegalStateException("Cannot keep open an already closed callback"); + pin(_result); + return new CachedQueueCallback<>(_result, _failure); + } + + @Override + public void fail(DMLRuntimeException failure) { + this._failure = failure; + } + + @Override + public boolean isEos() { + return get() == null; + } + + @Override + public void close() { + if (_pinned.compareAndSet(true, false)) { + unpin(_result); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java new file mode 100644 index 00000000000..5346b819cfe --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +public interface OOCCacheScheduler { + + /** + * Requests a single block from the cache. + * @param key the requested key associated to the block + * @return the available BlockEntry + */ + CompletableFuture request(BlockKey key); + + /** + * Requests a list of blocks from the cache that must be available at the same time. + * @param keys the requested keys associated to the block + * @return the list of available BlockEntries + */ + CompletableFuture> request(List keys); + + /** + * Places a new block in the cache. Note that objects are immutable and cannot be overwritten. + * The object data should now only be accessed via cache, as ownership has been transferred. + * @param key the associated key of the block + * @param data the block data + * @param size the size of the data + */ + void put(BlockKey key, Object data, long size); + + /** + * Places a new block in the cache and returns a pinned handle. + * Note that objects are immutable and cannot be overwritten. + * @param key the associated key of the block + * @param data the block data + * @param size the size of the data + */ + BlockEntry putAndPin(BlockKey key, Object data, long size); + + /** + * Forgets a block from the cache. + * @param key the associated key of the block + */ + void forget(BlockKey key); + + /** + * Pins a BlockEntry in cache to prevent eviction. + * @param entry the entry to be pinned + */ + void pin(BlockEntry entry); + + /** + * Unpins a pinned block. + * @param entry the entry to be unpinned + */ + void unpin(BlockEntry entry); + + /** + * Shuts down the cache scheduler. + */ + void shutdown(); +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java new file mode 100644 index 00000000000..dbfda4e56d7 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import java.util.concurrent.CompletableFuture; + +public interface OOCIOHandler { + void shutdown(); + + CompletableFuture scheduleEviction(BlockEntry block); + + CompletableFuture scheduleRead(BlockEntry block); + + CompletableFuture scheduleDeletion(BlockEntry block); +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java new file mode 100644 index 00000000000..1dbba2e3d8f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java @@ -0,0 +1,605 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.ooc.stats.OOCEventLog; +import org.apache.sysds.utils.Statistics; +import scala.Tuple2; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +public class OOCLRUCacheScheduler implements OOCCacheScheduler { + private static final boolean SANITY_CHECKS = false; + + private final OOCIOHandler _ioHandler; + private final LinkedHashMap _cache; + private final HashMap _evictionCache; + private final Deque _deferredReadRequests; + private final Deque _processingReadRequests; + private final long _hardLimit; + private final long _evictionLimit; + private final int _callerId; + private long _cacheSize; + private long _bytesUpForEviction; + private volatile boolean _running; + private boolean _warnThrottling; + + public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long hardLimit) { + this._ioHandler = ioHandler; + this._cache = new LinkedHashMap<>(1024, 0.75f, true); + this._evictionCache = new HashMap<>(); + this._deferredReadRequests = new ArrayDeque<>(); + this._processingReadRequests = new ArrayDeque<>(); + this._hardLimit = hardLimit; + this._evictionLimit = evictionLimit; + this._cacheSize = 0; + this._bytesUpForEviction = 0; + this._running = true; + this._warnThrottling = false; + this._callerId = DMLScript.OOC_LOG_EVENTS ? OOCEventLog.registerCaller("LRUCacheScheduler") : 0; + + if (DMLScript.OOC_LOG_EVENTS) { + OOCEventLog.putRunSetting("CacheEvictionLimit", _evictionLimit); + OOCEventLog.putRunSetting("CacheHardLimit", _hardLimit); + } + } + + @Override + public CompletableFuture request(BlockKey key) { + if (!this._running) + throw new IllegalStateException("Cache scheduler has been shut down."); + + Statistics.incrementOOCEvictionGet(); + + BlockEntry entry; + boolean couldPin = false; + synchronized(this) { + entry = _cache.get(key); + if (entry == null) + entry = _evictionCache.get(key); + if (entry == null) + throw new IllegalArgumentException("Could not find requested block with key " + key); + + synchronized(entry) { + if (entry.getState().isAvailable()) { + if (entry.pin() == 0) + throw new IllegalStateException(); + couldPin = true; + } + } + } + + if (couldPin) { + // Then we could pin the required entry and can terminate + return CompletableFuture.completedFuture(entry); + } + + //System.out.println("Requesting deferred: " + key); + // Schedule deferred read otherwise + final CompletableFuture future = new CompletableFuture<>(); + final CompletableFuture> requestFuture = new CompletableFuture<>(); + requestFuture.whenComplete((r, t) -> future.complete(r.get(0))); + scheduleDeferredRead(new DeferredReadRequest(requestFuture, Collections.singletonList(entry))); + return future; + } + + @Override + public CompletableFuture> request(List keys) { + if (!this._running) + throw new IllegalStateException("Cache scheduler has been shut down."); + + Statistics.incrementOOCEvictionGet(keys.size()); + + List entries = new ArrayList<>(keys.size()); + boolean couldPinAll = true; + + synchronized(this) { + for (BlockKey key : keys) { + BlockEntry entry = _cache.get(key); + if (entry == null) + entry = _evictionCache.get(key); + if (entry == null) + throw new IllegalArgumentException("Could not find requested block with key " + key); + + if (couldPinAll) { + synchronized(entry) { + if(entry.getState().isAvailable()) { + if(entry.pin() == 0) + throw new IllegalStateException(); + } + else { + couldPinAll = false; + } + } + + if (!couldPinAll) { + // Undo pin for all previous entries + for (BlockEntry e : entries) + e.unpin(); // Do not unpin using unpin(...) method to avoid explicit eviction on memory pressure + } + } + entries.add(entry); + } + } + + if (couldPinAll) { + // Then we could pin all entries + return CompletableFuture.completedFuture(entries); + } + + // Schedule deferred read otherwise + final CompletableFuture> future = new CompletableFuture<>(); + scheduleDeferredRead(new DeferredReadRequest(future, entries)); + return future; + } + + private void scheduleDeferredRead(DeferredReadRequest deferredReadRequest) { + synchronized(this) { + _deferredReadRequests.add(deferredReadRequest); + } + onCacheSizeChanged(false); // To schedule deferred reads if possible + } + + @Override + public void put(BlockKey key, Object data, long size) { + put(key, data, size, false); + } + + @Override + public BlockEntry putAndPin(BlockKey key, Object data, long size) { + return put(key, data, size, true); + } + + private BlockEntry put(BlockKey key, Object data, long size, boolean pin) { + if (!this._running) + throw new IllegalStateException(); + if (data == null) + throw new IllegalArgumentException(); + + Statistics.incrementOOCEvictionPut(); + BlockEntry entry = new BlockEntry(key, size, data); + if (pin) + entry.pin(); + synchronized(this) { + BlockEntry avail = _cache.putIfAbsent(key, entry); + if (avail != null || _evictionCache.containsKey(key)) + throw new IllegalStateException("Cannot overwrite existing entries: " + key); + _cacheSize += size; + } + onCacheSizeChanged(true); + return entry; + } + + @Override + public void forget(BlockKey key) { + if (!this._running) + return; + BlockEntry entry; + boolean shouldScheduleDeletion = false; + long cacheSizeDelta = 0; + synchronized(this) { + entry = _cache.remove(key); + + if (entry == null) + entry = _evictionCache.remove(key); + + if (entry != null) { + synchronized(entry) { + shouldScheduleDeletion = entry.getState().isBackedByDisk() + || entry.getState() == BlockState.EVICTING; + cacheSizeDelta = transitionMemState(entry, BlockState.REMOVED); + } + + } + } + if (cacheSizeDelta != 0) + onCacheSizeChanged(cacheSizeDelta > 0); + if (shouldScheduleDeletion) + _ioHandler.scheduleDeletion(entry); + } + + @Override + public void pin(BlockEntry entry) { + if (!this._running) + throw new IllegalStateException("Cache scheduler has been shut down."); + + int pinCount = entry.pin(); + if (pinCount == 0) + throw new IllegalStateException("Could not pin the requested entry: " + entry.getKey()); + synchronized(this) { + // Access element in cache for Lru + _cache.get(entry.getKey()); + } + } + + @Override + public void unpin(BlockEntry entry) { + boolean couldFree = entry.unpin(); + + if (couldFree) { + long cacheSizeDelta = 0; + synchronized(this) { + if (_cacheSize <= _evictionLimit) + return; // Nothing to do + + synchronized(entry) { + if (entry.isPinned()) + return; // Pin state changed so we cannot evict + + if (entry.getState().isAvailable() && entry.getState().isBackedByDisk()) { + cacheSizeDelta = transitionMemState(entry, BlockState.COLD); + long cleared = entry.clear(); + if (cleared != entry.getSize()) + throw new IllegalStateException(); + _cache.remove(entry.getKey()); + _evictionCache.put(entry.getKey(), entry); + } else if (entry.getState() == BlockState.HOT) { + cacheSizeDelta = onUnpinnedHotBlockUnderMemoryPressure(entry); + } + } + } + if (cacheSizeDelta != 0) + onCacheSizeChanged(cacheSizeDelta > 0); + } + } + + @Override + public synchronized void shutdown() { + this._running = false; + _cache.clear(); + _evictionCache.clear(); + _processingReadRequests.clear(); + _deferredReadRequests.clear(); + _cacheSize = 0; + _bytesUpForEviction = 0; + } + + /** + * Must be called while this cache and the corresponding entry are locked + */ + private long onUnpinnedHotBlockUnderMemoryPressure(BlockEntry entry) { + long cacheSizeDelta = transitionMemState(entry, BlockState.EVICTING); + evict(entry); + return cacheSizeDelta; + } + + private void onCacheSizeChanged(boolean incr) { + if (incr) + onCacheSizeIncremented(); + else { + while(onCacheSizeDecremented()) {} + } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onCacheSizeChangedEvent(_callerId, System.nanoTime(), _cacheSize, _bytesUpForEviction); + } + + private synchronized void sanityCheck() { + if (_cacheSize > _hardLimit) { + if (!_warnThrottling) { + _warnThrottling = true; + System.out.println("[INFO] Throttling: " + _cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB > " + _hardLimit/1000 + "KB"); + } + } + else if (_warnThrottling) { + _warnThrottling = false; + System.out.println("[INFO] No more throttling: " + _cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB <= " + _hardLimit/1000 + "KB"); + } + + if (!SANITY_CHECKS) + return; + + int pinned = 0; + int backedByDisk = 0; + int evicting = 0; + int total = 0; + long actualCacheSize = 0; + long upForEviction = 0; + for (BlockEntry entry : _cache.values()) { + if (entry.isPinned()) + pinned++; + if (entry.getState().isBackedByDisk()) + backedByDisk++; + if (entry.getState() == BlockState.EVICTING) { + evicting++; + upForEviction += entry.getSize(); + } + if (!entry.getState().isAvailable()) + throw new IllegalStateException(); + total++; + actualCacheSize += entry.getSize(); + } + for (BlockEntry entry : _evictionCache.values()) { + if (entry.getState().isAvailable()) + throw new IllegalStateException("Invalid eviction state: " + entry.getState()); + if (entry.getState() == BlockState.READING) + actualCacheSize += entry.getSize(); + } + if (actualCacheSize != _cacheSize) + throw new IllegalStateException(actualCacheSize + " != " + _cacheSize); + if (upForEviction != _bytesUpForEviction) + throw new IllegalStateException(upForEviction + " != " + _bytesUpForEviction); + System.out.println("=========="); + System.out.println("Limit: " + _evictionLimit/1000 + "KB"); + System.out.println("Memory: (" + _cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB) / " + _hardLimit/1000 + "KB"); + System.out.println("Pinned: " + pinned + " / " + total); + System.out.println("Disk backed: " + backedByDisk + " / " + total); + System.out.println("Evicting: " + evicting + " / " + total); + } + + private void onCacheSizeIncremented() { + long cacheSizeDelta = 0; + List upForEviction; + synchronized(this) { + if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + return; // Nothing to do + + // Scan for values that can be evicted + Collection entries = _cache.values(); + List toRemove = new ArrayList<>(); + upForEviction = new ArrayList<>(); + + for(BlockEntry entry : entries) { + if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + break; + + synchronized(entry) { + if(!entry.isPinned() && entry.getState().isBackedByDisk()) { + cacheSizeDelta += transitionMemState(entry, BlockState.COLD); + entry.clear(); + toRemove.add(entry); + } + else if(entry.getState() != BlockState.EVICTING && !entry.getState().isBackedByDisk()) { + cacheSizeDelta += transitionMemState(entry, BlockState.EVICTING); + upForEviction.add(entry); + } + } + } + + for(BlockEntry entry : toRemove) { + _cache.remove(entry.getKey()); + _evictionCache.put(entry.getKey(), entry); + } + + sanityCheck(); + } + + for (BlockEntry entry : upForEviction) { + evict(entry); + } + + if (cacheSizeDelta != 0) + onCacheSizeChanged(cacheSizeDelta > 0); + } + + private boolean onCacheSizeDecremented() { + boolean allReserved = true; + List> toRead; + DeferredReadRequest req; + synchronized(this) { + if(_cacheSize >= _hardLimit || _deferredReadRequests.isEmpty()) + return false; // Nothing to do + + // Try to schedule the next disk read + req = _deferredReadRequests.peek(); + toRead = new ArrayList<>(req.getEntries().size()); + + for(int idx = 0; idx < req.getEntries().size(); idx++) { + if(!req.actionRequired(idx)) + continue; + + BlockEntry entry = req.getEntries().get(idx); + synchronized(entry) { + if(entry.getState().isAvailable()) { + if(entry.pin() == 0) + throw new IllegalStateException(); + req.setPinned(idx); + } + else { + if(_cacheSize + entry.getSize() <= _hardLimit) { + transitionMemState(entry, BlockState.READING); + toRead.add(new Tuple2<>(idx, entry)); + req.schedule(idx); + } + else { + allReserved = false; + } + } + } + } + + if (allReserved) { + _deferredReadRequests.poll(); + if (!toRead.isEmpty()) + _processingReadRequests.add(req); + } + + sanityCheck(); + } + + if (allReserved && toRead.isEmpty()) { + req.getFuture().complete(req.getEntries()); + return true; + } + + for (Tuple2 tpl : toRead) { + final int idx = tpl._1; + final BlockEntry entry = tpl._2; + CompletableFuture future = _ioHandler.scheduleRead(entry); + future.whenComplete((r, t) -> { + boolean allAvailable; + synchronized(this) { + synchronized(r) { + transitionMemState(r, BlockState.WARM); + if (r.pin() == 0) + throw new IllegalStateException(); + _evictionCache.remove(r.getKey()); + _cache.put(r.getKey(), r); + allAvailable = req.setPinned(idx); + } + + if (allAvailable) { + _processingReadRequests.remove(req); + } + + sanityCheck(); + } + if (allAvailable) { + req.getFuture().complete(req.getEntries()); + } + }); + } + + return false; + } + + private void evict(final BlockEntry entry) { + CompletableFuture future = _ioHandler.scheduleEviction(entry); + future.whenComplete((r, e) -> onEvicted(entry)); + } + + private void onEvicted(final BlockEntry entry) { + long cacheSizeDelta; + synchronized(this) { + synchronized(entry) { + if(entry.isPinned()) { + transitionMemState(entry, BlockState.WARM); + return; // Then we cannot clear the data + } + cacheSizeDelta = transitionMemState(entry, BlockState.COLD); + entry.clear(); + } + BlockEntry tmp = _cache.remove(entry.getKey()); + if(tmp != null && tmp != entry) + throw new IllegalStateException(); + tmp = _evictionCache.put(entry.getKey(), entry); + if (tmp != null) + throw new IllegalStateException(); + sanityCheck(); + } + if (cacheSizeDelta != 0) + onCacheSizeChanged(cacheSizeDelta > 0); + } + + /** + * Cleanly transitions state of a BlockEntry and handles accounting. + * Requires both the scheduler object and the entry to be locked: + */ + private long transitionMemState(BlockEntry entry, BlockState newState) { + BlockState oldState = entry.getState(); + if (oldState == newState) + return 0; + + long sz = entry.getSize(); + long oldCacheSize = _cacheSize; + + // Remove old contribution + switch (oldState) { + case REMOVED: + throw new IllegalStateException(); + case HOT: + case WARM: + _cacheSize -= sz; + break; + case EVICTING: + _cacheSize -= sz; + _bytesUpForEviction -= sz; + break; + case READING: + _cacheSize -= sz; + break; + case COLD: + break; + } + + // Add new contribution + switch (newState) { + case REMOVED: + case COLD: + break; + case HOT: + case WARM: + _cacheSize += sz; + break; + case EVICTING: + _cacheSize += sz; + _bytesUpForEviction += sz; + break; + case READING: + _cacheSize += sz; + break; + } + + entry.setState(newState); + return _cacheSize - oldCacheSize; + } + + + + private static class DeferredReadRequest { + private static final short NOT_SCHEDULED = 0; + private static final short SCHEDULED = 1; + private static final short PINNED = 2; + + private final CompletableFuture> _future; + private final List _entries; + private final short[] _pinned; + private final AtomicInteger _availableCount; + + DeferredReadRequest(CompletableFuture> future, List entries) { + this._future = future; + this._entries = entries; + this._pinned = new short[entries.size()]; + this._availableCount = new AtomicInteger(0); + } + + CompletableFuture> getFuture() { + return _future; + } + + List getEntries() { + return _entries; + } + + public synchronized boolean actionRequired(int idx) { + return _pinned[idx] == NOT_SCHEDULED; + } + + public synchronized boolean setPinned(int idx) { + if (_pinned[idx] == PINNED) + return false; // already pinned + _pinned[idx] = PINNED; + return _availableCount.incrementAndGet() == _entries.size(); + } + + public synchronized void schedule(int idx) { + _pinned[idx] = SCHEDULED; + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java new file mode 100644 index 00000000000..3cd16272d2b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.ooc.stats.OOCEventLog; +import org.apache.sysds.runtime.util.FastBufferedDataInputStream; +import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; +import org.apache.sysds.runtime.util.LocalFileUtils; +import org.apache.sysds.utils.Statistics; +import scala.Tuple2; +import scala.Tuple3; + +import java.io.DataInput; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; +import java.nio.channels.ClosedByInterruptException; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +public class OOCMatrixIOHandler implements OOCIOHandler { + private static final int WRITER_SIZE = 2; + private static final long OVERFLOW = 8192 * 1024; + private static final long MAX_PARTITION_SIZE = 8192 * 8192; + + private final String _spillDir; + private final ThreadPoolExecutor _writeExec; + private final ThreadPoolExecutor _readExec; + + // Spill related structures + private final ConcurrentHashMap _spillLocations = new ConcurrentHashMap<>(); + private final ConcurrentHashMap _partitions = new ConcurrentHashMap<>(); + private final AtomicInteger _partitionCounter = new AtomicInteger(0); + private final CloseableQueue>>[] _q; + private final AtomicLong _wCtr; + private final AtomicBoolean _started; + + private final int _evictCallerId = OOCEventLog.registerCaller("write"); + private final int _readCallerId = OOCEventLog.registerCaller("read"); + + @SuppressWarnings("unchecked") + public OOCMatrixIOHandler() { + this._spillDir = LocalFileUtils.getUniqueWorkingDir("ooc_stream"); + _writeExec = new ThreadPoolExecutor( + WRITER_SIZE, + WRITER_SIZE, + 0L, + TimeUnit.MILLISECONDS, + new ArrayBlockingQueue<>(100000)); + _readExec = new ThreadPoolExecutor( + 5, + 5, + 0L, + TimeUnit.MILLISECONDS, + new ArrayBlockingQueue<>(100000)); + _q = new CloseableQueue[WRITER_SIZE]; + _wCtr = new AtomicLong(0); + _started = new AtomicBoolean(false); + } + + private synchronized void start() { + if (_started.compareAndSet(false, true)) { + for (int i = 0; i < WRITER_SIZE; i++) { + final int finalIdx = i; + _q[i] = new CloseableQueue<>(); + _writeExec.submit(() -> evictTask(_q[finalIdx])); + } + } + } + + @Override + public void shutdown() { + boolean started = _started.get(); + if (started) { + try { + for(int i = 0; i < WRITER_SIZE; i++) { + _q[i].close(); + } + } + catch(InterruptedException ignored) { + } + } + _writeExec.getQueue().clear(); + _writeExec.shutdownNow(); + _readExec.getQueue().clear(); + _readExec.shutdownNow(); + _spillLocations.clear(); + _partitions.clear(); + if (started) + LocalFileUtils.deleteFileIfExists(_spillDir); + } + + @Override + public CompletableFuture scheduleEviction(BlockEntry block) { + start(); + CompletableFuture future = new CompletableFuture<>(); + try { + long q = _wCtr.getAndAdd(block.getSize()) / OVERFLOW; + int i = (int)(q % WRITER_SIZE); + _q[i].enqueueIfOpen(new Tuple2<>(block, future)); + } + catch(InterruptedException ignored) { + } + + return future; + } + + @Override + public CompletableFuture scheduleRead(final BlockEntry block) { + final CompletableFuture future = new CompletableFuture<>(); + try { + _readExec.submit(() -> { + try { + long ioStart = DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; + loadFromDisk(block); + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onDiskReadEvent(_readCallerId, ioStart, System.nanoTime(), block.getSize()); + future.complete(block); + } catch (Throwable e) { + future.completeExceptionally(e); + } + }); + } catch (RejectedExecutionException e) { + future.completeExceptionally(e); + } + return future; + } + + @Override + public CompletableFuture scheduleDeletion(BlockEntry block) { + // TODO + return CompletableFuture.completedFuture(true); + } + + + private void loadFromDisk(BlockEntry block) { + String key = block.getKey().toFileKey(); + + long ioDuration = 0; + // 1. find the blocks address (spill location) + SpillLocation sloc = _spillLocations.get(key); + if (sloc == null) + throw new DMLRuntimeException("Failed to load spill location for: " + key); + + PartitionFile partFile = _partitions.get(sloc.partitionId); + if (partFile == null) + throw new DMLRuntimeException("Failed to load partition for: " + sloc.partitionId); + + String filename = partFile.filePath; + + // Create an empty object to read data into. + MatrixIndexes ix = new MatrixIndexes(); + MatrixBlock mb = new MatrixBlock(); + + try (RandomAccessFile raf = new RandomAccessFile(filename, "r")) { + raf.seek(sloc.offset); + + DataInput dis = new FastBufferedDataInputStream(Channels.newInputStream(raf.getChannel())); + long ioStart = DMLScript.STATISTICS ? System.nanoTime() : 0; + ix.readFields(dis); // 1. Read Indexes + mb.readFields(dis); // 2. Read Block + if (DMLScript.STATISTICS) + ioDuration = System.nanoTime() - ioStart; + } catch (ClosedByInterruptException ignored) { + } catch (IOException e) { + throw new RuntimeException(e); + } + + block.setDataUnsafe(new IndexedMatrixValue(ix, mb)); + + if (DMLScript.STATISTICS) { + Statistics.incrementOOCLoadFromDisk(); + Statistics.accumulateOOCLoadFromDiskTime(ioDuration); + } + } + + private void evictTask(CloseableQueue>> q) { + long byteCtr = 0; + + while (!q.isFinished()) { + // --- 1. WRITE PHASE --- + int partitionId = _partitionCounter.getAndIncrement(); + + LocalFileUtils.createLocalFileIfNotExist(_spillDir); + + String filename = _spillDir + "/stream_batch_part_" + partitionId; + + PartitionFile partFile = new PartitionFile(filename); + _partitions.put(partitionId, partFile); + + FileOutputStream fos = null; + CountableFastBufferedDataOutputStream dos = null; + ConcurrentLinkedDeque>> waitingForFlush = null; + + try { + fos = new FileOutputStream(filename); + dos = new CountableFastBufferedDataOutputStream(fos); + + Tuple2> tpl; + waitingForFlush = new ConcurrentLinkedDeque<>(); + boolean closePartition = false; + + while((tpl = q.take()) != null) { + long ioStart = DMLScript.STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; + BlockEntry entry = tpl._1; + CompletableFuture future = tpl._2; + long wrote = writeOut(partitionId, entry, future, fos, dos, waitingForFlush); + + if(DMLScript.STATISTICS && wrote > 0) { + Statistics.incrementOOCEvictionWrite(); + Statistics.accumulateOOCEvictionWriteTime(System.nanoTime() - ioStart); + } + + byteCtr += wrote; + if (byteCtr >= MAX_PARTITION_SIZE) { + closePartition = true; + byteCtr = 0; + break; + } + + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onDiskWriteEvent(_evictCallerId, ioStart, System.nanoTime(), wrote); + } + + if (!closePartition && q.close()) { + while((tpl = q.take()) != null) { + long ioStart = DMLScript.STATISTICS ? System.nanoTime() : 0; + BlockEntry entry = tpl._1; + CompletableFuture future = tpl._2; + long wrote = writeOut(partitionId, entry, future, fos, dos, waitingForFlush); + byteCtr += wrote; + + if(DMLScript.STATISTICS && wrote > 0) { + Statistics.incrementOOCEvictionWrite(); + Statistics.accumulateOOCEvictionWriteTime(System.nanoTime() - ioStart); + } + + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onDiskWriteEvent(_evictCallerId, ioStart, System.nanoTime(), wrote); + } + } + } + catch(IOException | InterruptedException ex) { + throw new DMLRuntimeException(ex); + } + catch(Exception e) { + // TODO + } + finally { + IOUtilFunctions.closeSilently(dos); + IOUtilFunctions.closeSilently(fos); + if(waitingForFlush != null) + flushQueue(Long.MAX_VALUE, waitingForFlush); + } + } + } + + private long writeOut(int partitionId, BlockEntry entry, CompletableFuture future, FileOutputStream fos, + CountableFastBufferedDataOutputStream dos, ConcurrentLinkedDeque>> flushQueue) throws IOException { + String key = entry.getKey().toFileKey(); + boolean alreadySpilled = _spillLocations.containsKey(key); + + if (!alreadySpilled) { + // 1. get the current file position. this is the offset. + // flush any buffered data to the file + //dos.flush(); + long offsetBefore = fos.getChannel().position() + dos.getCount(); + + // 2. write indexes and block + IndexedMatrixValue imv = (IndexedMatrixValue) entry.getDataUnsafe(); // Get data without requiring pin + imv.getIndexes().write(dos); // write Indexes + imv.getValue().write(dos); + + long offsetAfter = fos.getChannel().position() + dos.getCount(); + flushQueue.offer(new Tuple3<>(offsetBefore, offsetAfter, future)); + + // 3. create the spillLocation + SpillLocation sloc = new SpillLocation(partitionId, offsetBefore); + _spillLocations.put(key, sloc); + flushQueue(fos.getChannel().position(), flushQueue); + + return offsetAfter - offsetBefore; + } + return 0; + } + + private void flushQueue(long offset, ConcurrentLinkedDeque>> flushQueue) { + Tuple3> tmp; + while ((tmp = flushQueue.peek()) != null && tmp._2() < offset) { + flushQueue.poll(); + tmp._3().complete(null); + } + } + + + + + private static class SpillLocation { + // structure of spillLocation: file, offset + final int partitionId; + final long offset; + + SpillLocation(int partitionId, long offset) { + this.partitionId = partitionId; + this.offset = offset; + } + } + + private static class PartitionFile { + final String filePath; + + PartitionFile(String filePath) { + this.filePath = filePath; + } + } + + private static class CountableFastBufferedDataOutputStream extends FastBufferedDataOutputStream { + public CountableFastBufferedDataOutputStream(OutputStream out) { + super(out); + } + + public int getCount() { + return _count; + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java b/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java new file mode 100644 index 00000000000..0df22c9a851 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stats; + +import org.apache.sysds.api.DMLScript; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +public class OOCEventLog { + private static final AtomicInteger _callerCtr = new AtomicInteger(0); + private static final ConcurrentHashMap _callerNames = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap _runSettings = new ConcurrentHashMap<>(); + + private static final AtomicInteger _logCtr = new AtomicInteger(0); + private static EventType[] _eventTypes; + private static long[] _startTimestamps; + private static long[] _endTimestamps; + private static int[] _callerIds; + private static long[] _threadIds; + private static long[] _data; + + public static void setup(int maxNumEvents) { + _eventTypes = DMLScript.OOC_LOG_EVENTS ? new EventType[maxNumEvents] : null; + _startTimestamps = DMLScript.OOC_LOG_EVENTS ? new long[maxNumEvents] : null; + _endTimestamps = DMLScript.OOC_LOG_EVENTS ? new long[maxNumEvents] : null; + _callerIds = DMLScript.OOC_LOG_EVENTS ? new int[maxNumEvents] : null; + _threadIds = DMLScript.OOC_LOG_EVENTS ? new long[maxNumEvents] : null; + _data = DMLScript.OOC_LOG_EVENTS ? new long[maxNumEvents] : null; + } + + public static int registerCaller(String callerName) { + int callerId = _callerCtr.incrementAndGet(); + _callerNames.put(callerId, callerName); + return callerId; + } + + public static void onComputeEvent(int callerId, long startTimestamp, long endTimestamp) { + int idx = _logCtr.getAndIncrement(); + _eventTypes[idx] = EventType.COMPUTE; + _startTimestamps[idx] = startTimestamp; + _endTimestamps[idx] = endTimestamp; + _callerIds[idx] = callerId; + _threadIds[idx] = Thread.currentThread().getId(); + } + + public static void onDiskWriteEvent(int callerId, long startTimestamp, long endTimestamp, long size) { + int idx = _logCtr.getAndIncrement(); + _eventTypes[idx] = EventType.DISK_WRITE; + _startTimestamps[idx] = startTimestamp; + _endTimestamps[idx] = endTimestamp; + _callerIds[idx] = callerId; + _threadIds[idx] = Thread.currentThread().getId(); + _data[idx] = size; + } + + public static void onDiskReadEvent(int callerId, long startTimestamp, long endTimestamp, long size) { + int idx = _logCtr.getAndIncrement(); + _eventTypes[idx] = EventType.DISK_READ; + _startTimestamps[idx] = startTimestamp; + _endTimestamps[idx] = endTimestamp; + _callerIds[idx] = callerId; + _threadIds[idx] = Thread.currentThread().getId(); + _data[idx] = size; + } + + public static void onCacheSizeChangedEvent(int callerId, long timestamp, long cacheSize, long bytesToEvict) { + int idx = _logCtr.getAndIncrement(); + _eventTypes[idx] = EventType.CACHESIZE_CHANGE; + _startTimestamps[idx] = timestamp; + _endTimestamps[idx] = bytesToEvict; + _callerIds[idx] = callerId; + _threadIds[idx] = Thread.currentThread().getId(); + _data[idx] = cacheSize; + } + + public static void putRunSetting(String setting, Object data) { + _runSettings.put(setting, data); + } + + public static String getComputeEventsCSV() { + return getFilteredCSV("ThreadID,CallerID,StartNanos,EndNanos\n", EventType.COMPUTE, false); + } + + public static String getDiskReadEventsCSV() { + return getFilteredCSV("ThreadID,CallerID,StartNanos,EndNanos,NumBytes\n", EventType.DISK_READ, true); + } + + public static String getDiskWriteEventsCSV() { + return getFilteredCSV("ThreadID,CallerID,StartNanos,EndNanos,NumBytes\n", EventType.DISK_WRITE, true); + } + + public static String getCacheSizeEventsCSV() { + return getFilteredCSV("ThreadID,CallerID,Timestamp,ScheduledEvictionSize,CacheSize\n", EventType.CACHESIZE_CHANGE, true); + } + + private static String getFilteredCSV(String header, EventType filter, boolean data) { + StringBuilder sb = new StringBuilder(); + sb.append(header); + + int maxIdx = _logCtr.get(); + for (int i = 0; i < maxIdx; i++) { + if (_eventTypes[i] != filter) + continue; + sb.append(_threadIds[i]); + sb.append(','); + sb.append(_callerNames.get(_callerIds[i])); + sb.append(','); + sb.append(_startTimestamps[i]); + sb.append(','); + sb.append(_endTimestamps[i]); + if (data) { + sb.append(','); + sb.append(_data[i]); + } + sb.append('\n'); + } + + return sb.toString(); + } + + public static String getRunSettingsCSV() { + StringBuilder sb = new StringBuilder(); + Set> entrySet = _runSettings.entrySet(); + + int ctr = 0; + for (Map.Entry entry : entrySet) { + sb.append(entry.getKey()); + ctr++; + if (ctr >= entrySet.size()) + sb.append('\n'); + else + sb.append(','); + } + + ctr = 0; + for (Map.Entry entry : _runSettings.entrySet()) { + sb.append(entry.getValue()); + ctr++; + if (ctr < entrySet.size()) + sb.append(','); + } + + return sb.toString(); + } + + public static void clear() { + _callerCtr.set(0); + _logCtr.set(0); + _callerNames.clear(); + _runSettings.clear(); + } + + public enum EventType { + COMPUTE, + DISK_WRITE, + DISK_READ, + CACHESIZE_CHANGE + } +} diff --git a/src/main/java/org/apache/sysds/runtime/util/LocalFileUtils.java b/src/main/java/org/apache/sysds/runtime/util/LocalFileUtils.java index c7e2f4b0404..7d5be41c261 100644 --- a/src/main/java/org/apache/sysds/runtime/util/LocalFileUtils.java +++ b/src/main/java/org/apache/sysds/runtime/util/LocalFileUtils.java @@ -496,7 +496,6 @@ public static String getUniqueWorkingDir(String category) { createWorkingDirectory(); StringBuilder sb = new StringBuilder(); sb.append( _workingDir ); - sb.append( Lop.FILE_SEPARATOR ); sb.append( category ); sb.append( Lop.FILE_SEPARATOR ); sb.append( "tmp" ); diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index e6fdf5db3cd..9ec94b1025c 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -65,6 +65,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.DoubleAdder; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.LongAdder; import java.util.function.Consumer; @@ -222,6 +223,16 @@ public Object getMeta(String key) { public static boolean allowWorkerStatistics = true; + // Out-of-core eviction metrics + private static final ConcurrentHashMap oocHeavyHitters = new ConcurrentHashMap<>(); + private static final LongAdder oocGetCalls = new LongAdder(); + private static final LongAdder oocPutCalls = new LongAdder(); + private static final LongAdder oocLoadFromDiskCalls = new LongAdder(); + private static final LongAdder oocLoadFromDiskTimeNanos = new LongAdder(); + private static final LongAdder oocEvictionWriteCalls = new LongAdder(); + private static final LongAdder oocEvictionWriteTimeNanos = new LongAdder(); + private static final AtomicLong oocStatsStartTime = new AtomicLong(System.nanoTime()); + public static long getNoOfExecutedSPInst() { return numExecutedSPInst.longValue(); } @@ -338,6 +349,146 @@ public static void stopRunTimer() { public static long getRunTime() { return execEndTime - execStartTime; } + + public static void resetOOCEvictionStats() { + oocHeavyHitters.clear(); + oocGetCalls.reset(); + oocPutCalls.reset(); + oocLoadFromDiskCalls.reset(); + oocLoadFromDiskTimeNanos.reset(); + oocEvictionWriteCalls.reset(); + oocEvictionWriteTimeNanos.reset(); + oocStatsStartTime.set(System.nanoTime()); + } + + public static String getOOCHeavyHitters(int num) { + if (num <= 0 || oocHeavyHitters == null || oocHeavyHitters.isEmpty()) + return "-"; + + @SuppressWarnings("unchecked") + Map.Entry[] tmp = + oocHeavyHitters.entrySet().toArray(new Map.Entry[0]); + + Arrays.sort(tmp, (e1, e2) -> + Long.compare(e1.getValue().longValue(), e2.getValue().longValue()) + ); + + final String numCol = "#"; + final String instCol = "Instruction"; + final String timeCol = "Time(s)"; + + DecimalFormat sFormat = new DecimalFormat("#,##0.000"); + + StringBuilder sb = new StringBuilder(); + int len = tmp.length; + int numHittersToDisplay = Math.min(num, len); + + int maxNumLen = String.valueOf(numHittersToDisplay).length(); + int maxInstLen = instCol.length(); + int maxTimeLen = timeCol.length(); + + // first pass: compute column widths + for (int i = 0; i < numHittersToDisplay; i++) { + Map.Entry hh = tmp[len - 1 - i]; + + String instruction = hh.getKey(); + double timeS = hh.getValue().longValue() / 1_000_000_000d; + String timeStr = sFormat.format(timeS); + + maxInstLen = Math.max(maxInstLen, instruction.length()); + maxTimeLen = Math.max(maxTimeLen, timeStr.length()); + } + + maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN); + + // header + sb.append(String.format( + " %" + maxNumLen + "s %-" + maxInstLen + "s %" + maxTimeLen + "s", + numCol, instCol, timeCol)); + sb.append("\n"); + + // rows + for (int i = 0; i < numHittersToDisplay; i++) { + Map.Entry hh = tmp[len - 1 - i]; + + String instruction = hh.getKey(); + double timeS = hh.getValue().longValue() / 1_000_000_000d; + String timeStr = sFormat.format(timeS); + + String[] wrappedInstruction = wrap(instruction, maxInstLen); + + for (int w = 0; w < wrappedInstruction.length; w++) { + if (w == 0) { + sb.append(String.format( + " %" + maxNumLen + "d %-" + maxInstLen + "s %" + + maxTimeLen + "s", + (i + 1), wrappedInstruction[w], timeStr)); + } else { + sb.append(String.format( + " %" + maxNumLen + "s %-" + maxInstLen + "s %" + + maxTimeLen + "s", + "", wrappedInstruction[w], "")); + } + sb.append("\n"); + } + } + + return sb.toString(); + } + + public static void maintainOOCHeavyHitter(String op, long timeNanos) { + LongAdder adder = oocHeavyHitters.computeIfAbsent(op, k -> new LongAdder()); + adder.add(timeNanos); + } + + public static void incrementOOCEvictionGet() { + oocGetCalls.increment(); + } + + public static void incrementOOCEvictionGet(int incr) { + oocGetCalls.add(incr); + } + + public static void incrementOOCEvictionPut() { + oocPutCalls.increment(); + } + + public static void incrementOOCLoadFromDisk() { + oocLoadFromDiskCalls.increment(); + } + + public static void incrementOOCEvictionWrite() { + oocEvictionWriteCalls.increment(); + } + + public static void accumulateOOCLoadFromDiskTime(long nanos) { + oocLoadFromDiskTimeNanos.add(nanos); + } + + public static void accumulateOOCEvictionWriteTime(long nanos) { + oocEvictionWriteTimeNanos.add(nanos); + } + + public static String displayOOCEvictionStats() { + long elapsedNanos = Math.max(1, System.nanoTime() - oocStatsStartTime.get()); + double elapsedSeconds = elapsedNanos / 1e9; + double getThroughput = oocGetCalls.longValue() / elapsedSeconds; + double putThroughput = oocPutCalls.longValue() / elapsedSeconds; + + StringBuilder sb = new StringBuilder(); + sb.append("OOC heavy hitters:\n"); + sb.append(getOOCHeavyHitters(DMLScript.OOC_STATISTICS_COUNT)); + sb.append('\n'); + sb.append(String.format(Locale.US, " get calls:\t\t%d (%.2f/sec)\n", + oocGetCalls.longValue(), getThroughput)); + sb.append(String.format(Locale.US, " put calls:\t\t%d (%.2f/sec)\n", + oocPutCalls.longValue(), putThroughput)); + sb.append(String.format(Locale.US, " loadFromDisk:\t\t%d (time %.3f sec)\n", + oocLoadFromDiskCalls.longValue(), oocLoadFromDiskTimeNanos.longValue() / 1e9)); + sb.append(String.format(Locale.US, " evict writes:\t\t%d (time %.3f sec)\n", + oocEvictionWriteCalls.longValue(), oocEvictionWriteTimeNanos.longValue() / 1e9)); + return sb.toString(); + } public static void reset() { @@ -358,6 +509,7 @@ public static void reset() CacheStatistics.reset(); LineageCacheStatistics.reset(); + resetOOCEvictionStats(); resetJITCompileTime(); resetJVMgcTime(); @@ -1126,6 +1278,11 @@ public static String display(int maxHeavyHitters) sb.append(ParamServStatistics.displayFloStatistics()); } + if (DMLScript.OOC_STATISTICS) { + sb.append('\n'); + sb.append(displayOOCEvictionStats()); + } + return sb.toString(); } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java new file mode 100644 index 00000000000..ae5ec5ba2aa --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class LmCGTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "lmCG"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + LmCGTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String INPUT_NAME_2 = "y"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 10000; + private final static int cols = 500; + private final static int maxVal = 2; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testlmCGDense() { + runLmCGTest(false); + } + + @Test + public void testLmCGSparse() { + runLmCGTest(true); + } + + private void runLmCGTest(boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", /*"hops",*/ "-stats", "-ooc", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, sparse ? sparsity2 : sparsity1, 7); + double[][] y_data = getRandomMatrix(rows, 1, 0, 1, 1.0, 3); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + MatrixBlock y_mb = DataConverter.convertToMatrixBlock(y_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + + // 5. Write vector x to a binary SequenceFile + writer.writeMatrixToHDFS(y_mb, input(INPUT_NAME_2), rows, 1, 1000, y_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, 1, 1000, y_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check replace OOC op + /*Assert.assertTrue("OOC wasn't used for contains", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CONTAINS));*/ + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, cols, 1, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, cols, 1, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java index e20b7ec4269..8147490cdab 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java @@ -43,7 +43,7 @@ public class PCATest extends AutomatedTestBase { private static final String OUTPUT_NAME_1 = "PC"; private static final String OUTPUT_NAME_2 = "V"; - private final static int rows = 50000; + private final static int rows = 25000; private final static int cols = 1000; private final static int maxVal = 2; @@ -56,10 +56,7 @@ public void setUp() { @Test public void testPCA() { - boolean allow_opfusion = OptimizerUtils.ALLOW_OPERATOR_FUSION; - OptimizerUtils.ALLOW_OPERATOR_FUSION = false; // some fused ops are not implemented yet runPCATest(16); - OptimizerUtils.ALLOW_OPERATOR_FUSION = allow_opfusion; } private void runPCATest(int k) { @@ -70,7 +67,7 @@ private void runPCATest(int k) { String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; - programArgs = new String[] {"-explain", "hops", "-stats", "-ooc", "-args", input(INPUT_NAME_1), Integer.toString(k), output(OUTPUT_NAME_1), output(OUTPUT_NAME_2)}; + programArgs = new String[] {"-explain", "hops", "-stats", "-ooc", "-oocStats", "5", "-oocLogEvents", output(""), "-args", input(INPUT_NAME_1), Integer.toString(k), output(OUTPUT_NAME_1), output(OUTPUT_NAME_2)}; // 1. Generate the data in-memory as MatrixBlock objects double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, 1, 7); diff --git a/src/test/scripts/functions/ooc/lmCG.dml b/src/test/scripts/functions/ooc/lmCG.dml new file mode 100644 index 00000000000..3c5cee73594 --- /dev/null +++ b/src/test/scripts/functions/ooc/lmCG.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1) +y = read($2) +C = lmCG(X = X, y = y, reg = 1e-12) +write(C, $3, format="binary") + From e34f1bd981eaf734401d990531c22028abb48075 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Wed, 17 Dec 2025 13:38:07 +0100 Subject: [PATCH 4/5] Remove Unnecessary Imports --- src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java | 1 - src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java | 1 - 2 files changed, 2 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java index ae5ec5ba2aa..72650daf9c2 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java @@ -20,7 +20,6 @@ package org.apache.sysds.test.functions.ooc; import org.apache.sysds.common.Types; -import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.io.MatrixWriter; import org.apache.sysds.runtime.io.MatrixWriterFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java index 8147490cdab..202c6b988ef 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java @@ -20,7 +20,6 @@ package org.apache.sysds.test.functions.ooc; import org.apache.sysds.common.Types; -import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.io.MatrixWriter; import org.apache.sysds.runtime.io.MatrixWriterFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; From 5f7bd1d3d23d49da54241272b49225f06c0b98cc Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 18 Dec 2025 15:49:49 +0100 Subject: [PATCH 5/5] Bugfix: Require explicit addOutStream() on OOC task submission to avoid error propagation failures --- .../ooc/MatrixVectorBinaryOOCInstruction.java | 1 + .../instructions/ooc/OOCInstruction.java | 44 ++++++++++++++----- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index b3eac0b6478..b0c08db2dca 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -84,6 +84,7 @@ public void processInstruction( ExecutionContext ec ) { OOCStream qIn = min.getStreamHandle(); OOCStream qOut = createWritableStream(); BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + addOutStream(qOut); ec.getMatrixObject(output).setStreamHandle(qOut); final Object lock = new Object(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index a9d50cbd489..5a4ae19b613 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -138,8 +138,12 @@ protected void addInStream(OOCStream... queue) { } protected void addOutStream(OOCStream... queue) { - // Currently same behavior as addInQueue - if (_outQueues == null) + if (queue.length == 0 && _outQueues == null) { + _outQueues = Collections.emptySet(); + return; + } + + if (_outQueues == null || _outQueues.isEmpty()) _outQueues = new HashSet<>(); _outQueues.addAll(List.of(queue)); } @@ -471,6 +475,8 @@ protected CompletableFuture submitOOCTasks(final List> qu protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer, List> futures, BiFunction, Boolean> predicate, BiConsumer> onNotProcessed) { addInStream(queues.toArray(OOCStream[]::new)); + if (_outQueues == null) + throw new IllegalArgumentException("Explicit specification of all output streams is required before submitting tasks. If no output streams are present use addOutStream()."); ExecutorService pool = CommonThreadPool.get(); final List activeTaskCtrs = new ArrayList<>(queues.size()); @@ -622,13 +628,20 @@ private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream< catch (Exception ex) { DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); - if (_failed) // Do avoid infinite cycles - throw re; + synchronized(this) { + if(_failed) // Do avoid infinite cycles + throw re; - _failed = true; + _failed = true; + } - for (OOCStream q : queues) - q.propagateFailure(re); + for(OOCStream q : queues) { + try { + q.propagateFailure(re); + } catch(Throwable ignore) { + // Should not happen, but catch just in case + } + } if (future != null) future.completeExceptionally(re); @@ -650,13 +663,20 @@ private Consumer> oocTask(Consumer q : queues) - q.propagateFailure(re); + for(OOCStream q : queues) { + try { + q.propagateFailure(re); + } catch(Throwable ignored) { + // Should not happen, but catch just in case + } + } if (future != null) future.completeExceptionally(re);