From fcf5c6f82eec64edbc03f1cc691ddc2df5d53448 Mon Sep 17 00:00:00 2001 From: Parth Date: Tue, 16 Dec 2025 16:01:03 +0530 Subject: [PATCH] [SYSTEMDS-3933] Generalize OOC matrix-vector binary to support streamed vector input --- .../ooc/MatrixVectorBinaryOOCInstruction.java | 53 ++++++++++--------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index 38586428e1e..731ad8fa810 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -22,7 +22,6 @@ import java.util.HashMap; import org.apache.sysds.common.Opcodes; -import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; @@ -51,7 +50,7 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) { InstructionUtils.checkNumFields(parts, 4); String opcode = parts[0]; CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed) - CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory) + CPOperand in2 = new CPOperand(parts[2]); // vector operand (may be OOC) CPOperand out = new CPOperand(parts[3]); AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); @@ -62,40 +61,46 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) { @Override public void processInstruction( ExecutionContext ec ) { - // 1. Identify the inputs - MatrixObject min = ec.getMatrixObject(input1); // big matrix - MatrixBlock vin = ec.getMatrixObject(input2) - .acquireReadAndRelease(); // in-memory vector - - // 2. Pre-partition the in-memory vector into a hashmap - HashMap partitionedVector = new HashMap<>(); - int blksize = vin.getDataCharacteristics().getBlocksize(); - if (blksize < 0) - blksize = ConfigurationManager.getBlocksize(); - for (int i=0; i qIn = min.getStreamHandle(); + OOCStream qIn1 = min.getStreamHandle(); + OOCStream qIn2 = vin.getStreamHandle(); // Stream handles for matrix and vector (both may be OOC) OOCStream qOut = createWritableStream(); BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { - IndexedMatrixValue tmp = null; + try { - while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + // Cache vector blocks indexed by their block row id + // This removes the assumption that the vector is fully in-memory + HashMap vectorCache = new HashMap<>(); + + // Consume the entire vector stream and cache it block-wise + IndexedMatrixValue vecVal; + while ((vecVal = qIn2.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + vectorCache.put( + vecVal.getIndexes().getRowIndex(), + (MatrixBlock) vecVal.getValue()); + } + + // Stream through matrix blocks and match them with vector blocks + IndexedMatrixValue tmp = null; + while((tmp = qIn1.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); long rowIndex = tmp.getIndexes().getRowIndex(); long colIndex = tmp.getIndexes().getColumnIndex(); - MatrixBlock vectorSlice = partitionedVector.get(colIndex); + MatrixBlock vectorSlice = vectorCache.get(colIndex); + + // Fail fast if the corresponding vector block is missing + if (vectorSlice == null) + throw new DMLRuntimeException("Missing vector block for column block " + colIndex); // Now, call the operation with the correct, specific operator. MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( @@ -103,7 +108,7 @@ public void processInstruction( ExecutionContext ec ) { // for single column block, no aggregation neeeded if(emitThreshold == 1) { - qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(rowIndex, 1), partialResult)); } else { // aggregation @@ -129,6 +134,6 @@ public void processInstruction( ExecutionContext ec ) { finally { qOut.closeInput(); } - }, qIn, qOut); + }, qIn1, qIn2, qOut); } }