Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.HashMap;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
Expand Down Expand Up @@ -51,7 +50,7 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {
InstructionUtils.checkNumFields(parts, 4);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed)
CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory)
CPOperand in2 = new CPOperand(parts[2]); // vector operand (may be OOC)
CPOperand out = new CPOperand(parts[3]);

AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
Expand All @@ -62,48 +61,54 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {

@Override
public void processInstruction( ExecutionContext ec ) {
// 1. Identify the inputs
MatrixObject min = ec.getMatrixObject(input1); // big matrix
MatrixBlock vin = ec.getMatrixObject(input2)
.acquireReadAndRelease(); // in-memory vector

// 2. Pre-partition the in-memory vector into a hashmap
HashMap<Long, MatrixBlock> partitionedVector = new HashMap<>();
int blksize = vin.getDataCharacteristics().getBlocksize();
if (blksize < 0)
blksize = ConfigurationManager.getBlocksize();
for (int i=0; i<vin.getNumRows(); i+=blksize) {
long key = (long) (i/blksize) + 1; // the key starts at 1
int end_row = Math.min(i + blksize, vin.getNumRows());
MatrixBlock vectorSlice = vin.slice(i, end_row - 1);
partitionedVector.put(key, vectorSlice);
}
// Fetch both inputs without assuming which one fits in memory
MatrixObject min = ec.getMatrixObject(input1);
MatrixObject vin = ec.getMatrixObject(input2);

// number of colBlocks for early block output
long emitThreshold = min.getDataCharacteristics().getNumColBlocks();
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold);

OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
OOCStream<IndexedMatrixValue> qIn1 = min.getStreamHandle();
OOCStream<IndexedMatrixValue> qIn2 = vin.getStreamHandle(); // Stream handles for matrix and vector (both may be OOC)
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
ec.getMatrixObject(output).setStreamHandle(qOut);

submitOOCTask(() -> {
IndexedMatrixValue tmp = null;

try {
while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
// Cache vector blocks indexed by their block row id
// This removes the assumption that the vector is fully in-memory
HashMap<Long, MatrixBlock> vectorCache = new HashMap<>();

// Consume the entire vector stream and cache it block-wise
IndexedMatrixValue vecVal;
while ((vecVal = qIn2.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
vectorCache.put(
vecVal.getIndexes().getRowIndex(),
(MatrixBlock) vecVal.getValue());
}

// Stream through matrix blocks and match them with vector blocks
IndexedMatrixValue tmp = null;
while((tmp = qIn1.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
long rowIndex = tmp.getIndexes().getRowIndex();
long colIndex = tmp.getIndexes().getColumnIndex();
MatrixBlock vectorSlice = partitionedVector.get(colIndex);
MatrixBlock vectorSlice = vectorCache.get(colIndex);

// Fail fast if the corresponding vector block is missing
if (vectorSlice == null)
throw new DMLRuntimeException("Missing vector block for column block " + colIndex);

// Now, call the operation with the correct, specific operator.
MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations(
matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);

// for single column block, no aggregation neeeded
if(emitThreshold == 1) {
qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(rowIndex, 1), partialResult));
}
else {
// aggregation
Expand All @@ -129,6 +134,6 @@ public void processInstruction( ExecutionContext ec ) {
finally {
qOut.closeInput();
}
}, qIn, qOut);
}, qIn1, qIn2, qOut);
}
}