diff --git a/transact/build.gradle.kts b/transact/build.gradle.kts index 55d77ea9..e4a2a07f 100644 --- a/transact/build.gradle.kts +++ b/transact/build.gradle.kts @@ -35,6 +35,7 @@ dependencies { implementation("com.fasterxml.jackson.core:jackson-databind:2.20.1") // json implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1") implementation("com.cronutils:cron-utils:9.2.1") // cron for scheduled wf + implementation("io.netty:netty-all:4.1.130.Final") // netty for websocket compileOnly("org.jspecify:jspecify:1.0.0") diff --git a/transact/src/main/java/dev/dbos/transact/DBOS.java b/transact/src/main/java/dev/dbos/transact/DBOS.java index 000d59f0..68b71a9c 100644 --- a/transact/src/main/java/dev/dbos/transact/DBOS.java +++ b/transact/src/main/java/dev/dbos/transact/DBOS.java @@ -696,6 +696,28 @@ public static void cancelWorkflow(@NonNull String workflowId) { return forkWorkflow(workflowId, startStep, new ForkOptions()); } + /** + * Deletes a workflow from the system. Does not delete child workflows. + * + * @param workflowId the unique identifier of the workflow to delete. Must not be null. + * @throws IllegalArgumentException if workflowId is null + */ + public static void deleteWorkflow(@NonNull String workflowId) { + deleteWorkflow(workflowId, false); + } + + /** + * Deletes a workflow and optionally its child workflows from the system. + * + * @param workflowId the unique identifier of the workflow to delete. Must not be null. + * @param deleteChildren if true, also deletes all child workflows associated with the specified + * workflow; if false, only deletes the specified workflow + * @throws IllegalArgumentException if workflowId is null + */ + public static void deleteWorkflow(@NonNull String workflowId, boolean deleteChildren) { + executor("deleteWorkflow").deleteWorkflow(workflowId, deleteChildren); + } + /** * Retrieve a handle to a workflow, given its ID. Note that a handle is always returned, whether * the workflow exists or not; getStatus() can be used to tell the difference diff --git a/transact/src/main/java/dev/dbos/transact/conductor/Conductor.java b/transact/src/main/java/dev/dbos/transact/conductor/Conductor.java index 6d3491d5..2dd5c236 100644 --- a/transact/src/main/java/dev/dbos/transact/conductor/Conductor.java +++ b/transact/src/main/java/dev/dbos/transact/conductor/Conductor.java @@ -4,6 +4,7 @@ import dev.dbos.transact.database.SystemDatabase; import dev.dbos.transact.execution.DBOSExecutor; import dev.dbos.transact.json.JSONUtil; +import dev.dbos.transact.workflow.ExportedWorkflow; import dev.dbos.transact.workflow.ForkOptions; import dev.dbos.transact.workflow.ListWorkflowsInput; import dev.dbos.transact.workflow.StepInfo; @@ -11,19 +12,19 @@ import dev.dbos.transact.workflow.WorkflowStatus; import dev.dbos.transact.workflow.internal.GetPendingWorkflowsOutput; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.net.InetAddress; import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.WebSocket; -import java.net.http.WebSocket.Listener; -import java.nio.ByteBuffer; -import java.time.Duration; +import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -31,22 +32,56 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; import java.util.stream.Collectors; - +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolConfig; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.handler.codec.json.JsonObjectDecoder; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Conductor implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(Conductor.class); - private static final Map> + private static final Map< + MessageType, BiFunction>> dispatchMap; static { - Map> map = + Map>> map = new java.util.EnumMap<>(MessageType.class); map.put(MessageType.EXECUTOR_INFO, Conductor::handleExecutorInfo); map.put(MessageType.RECOVERY, Conductor::handleRecovery); map.put(MessageType.CANCEL, Conductor::handleCancel); + map.put(MessageType.DELETE, Conductor::handleDelete); map.put(MessageType.RESUME, Conductor::handleResume); map.put(MessageType.RESTART, Conductor::handleRestart); map.put(MessageType.FORK_WORKFLOW, Conductor::handleFork); @@ -57,13 +92,15 @@ public class Conductor implements AutoCloseable { map.put(MessageType.GET_WORKFLOW, Conductor::handleGetWorkflow); map.put(MessageType.RETENTION, Conductor::handleRetention); map.put(MessageType.GET_METRICS, Conductor::handleGetMetrics); + map.put(MessageType.IMPORT_WORKFLOW, Conductor::handleImportWorkflow); + map.put(MessageType.EXPORT_WORKFLOW, Conductor::handleExportWorkflow); + dispatchMap = Collections.unmodifiableMap(map); } private final int pingPeriodMs; private final int pingTimeoutMs; private final int reconnectDelayMs; - private final int connectTimeoutMs; private final String url; private final SystemDatabase systemDatabase; @@ -71,7 +108,9 @@ public class Conductor implements AutoCloseable { private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); private final AtomicBoolean isShutdown = new AtomicBoolean(false); - private WebSocket webSocket; + private Channel channel; + private EventLoopGroup group; + private NettyWebSocketHandler handler; private ScheduledFuture pingInterval; private ScheduledFuture pingTimeout; private ScheduledFuture reconnectTimeout; @@ -106,7 +145,224 @@ private Conductor(Builder builder) { this.pingPeriodMs = builder.pingPeriodMs; this.pingTimeoutMs = builder.pingTimeoutMs; this.reconnectDelayMs = builder.reconnectDelayMs; - this.connectTimeoutMs = builder.connectTimeoutMs; + } + + private class NettyWebSocketHandler extends SimpleChannelInboundHandler { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + logger.info("Successfully established websocket connection to DBOS conductor at {}", url); + setPingInterval(ctx.channel()); + } else if (evt + == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) { + logger.error("Websocket handshake timeout with conductor at {}", url); + } + super.userEventTriggered(ctx, evt); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof PingWebSocketFrame ping) { + logger.debug("Received ping from conductor"); + ctx.channel().writeAndFlush(new PongWebSocketFrame(ping.content().retain())); + } else if (msg instanceof PongWebSocketFrame) { + logger.debug("Received pong from conductor"); + if (pingTimeout != null) { + pingTimeout.cancel(false); + pingTimeout = null; + logger.debug("Cancelled ping timeout - connection is healthy"); + } else { + logger.debug("Received pong but no ping timeout was active"); + } + } else if (msg instanceof CloseWebSocketFrame closeFrame) { + logger.warn( + "Received close frame from conductor: status={}, reason='{}'", + closeFrame.statusCode(), + closeFrame.reasonText()); + if (isShutdown.get()) { + logger.debug("Shutdown Conductor connection"); + } else if (reconnectTimeout == null) { + logger.warn("onClose: Connection to conductor lost. Reconnecting"); + resetWebSocket(); + } + } else if (msg instanceof ByteBuf content) { + int messageSize = content.readableBytes(); + logger.debug("Received {} bytes from Conductor {}", messageSize, msg.getClass().getName()); + + BaseMessage request; + try (InputStream is = new ByteBufInputStream(content)) { + request = JSONUtil.fromJson(is, BaseMessage.class); + } catch (Exception e) { + logger.error("Conductor JSON Parsing error for {} byte message", messageSize, e); + return; + } + + try { + long startTime = System.currentTimeMillis(); + logger.info( + "Processing conductor request: type={}, id={}", request.type, request.request_id); + + getResponseAsync(request) + .whenComplete( + (response, throwable) -> { + try { + long processingTime = System.currentTimeMillis() - startTime; + if (throwable != null) { + logger.error( + "Error processing request: type={}, id={}, duration={}ms", + request.type, + request.request_id, + processingTime, + throwable); + + // Create an error response + BaseResponse errorResponse = + new BaseResponse( + request.type, request.request_id, throwable.getMessage()); + writeFragmentedResponse(ctx, errorResponse); + } else { + logger.info( + "Completed processing request: type={}, id={}, duration={}ms", + request.type, + request.request_id, + processingTime); + writeFragmentedResponse(ctx, response); + } + } catch (Exception e) { + logger.error( + "Error writing response for request type={}, id={}", + request.type, + request.request_id, + e); + } + }); + } catch (Exception e) { + logger.error( + "Conductor Response error for request type={}, id={}", + request.type, + request.request_id, + e); + } + } + } + + private static void writeFragmentedResponse(ChannelHandlerContext ctx, BaseResponse response) + throws Exception { + int fragmentSize = 128 * 1024; // 128k + logger.debug( + "Starting to write fragmented response: type={}, id={}", + response.type, + response.request_id); + try (OutputStream out = new FragmentingOutputStream(ctx, fragmentSize)) { + JSONUtil.toJsonStream(response, out); + } + logger.debug( + "Completed writing fragmented response: type={}, id={}", + response.type, + response.request_id); + } + + private static class FragmentingOutputStream extends OutputStream { + private final ChannelHandlerContext ctx; + private final int fragmentSize; + private ByteBuf currentBuffer; + private boolean firstFrame = true; + private boolean closed = false; + + public FragmentingOutputStream(ChannelHandlerContext ctx, int fragmentSize) { + this.ctx = ctx; + this.fragmentSize = fragmentSize; + this.currentBuffer = ctx.alloc().buffer(fragmentSize); + logger.debug("Created FragmentingOutputStream with fragment size: {}", fragmentSize); + } + + @Override + public void write(int b) throws IOException { + currentBuffer.writeByte(b); + if (currentBuffer.readableBytes() == fragmentSize) { + flushBuffer(false); + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + while (len > 0) { + int toCopy = Math.min(len, fragmentSize - currentBuffer.readableBytes()); + currentBuffer.writeBytes(b, off, toCopy); + off += toCopy; + len -= toCopy; + if (currentBuffer.readableBytes() == fragmentSize) { + flushBuffer(false); + } + } + } + + private void flushBuffer(boolean last) { + if (currentBuffer.readableBytes() == 0 && !last) { + return; + } + + int frameSize = currentBuffer.readableBytes(); + WebSocketFrame frame; + if (firstFrame) { + frame = new TextWebSocketFrame(last, 0, currentBuffer); + firstFrame = false; + } else { + frame = new ContinuationWebSocketFrame(last, 0, currentBuffer); + } + + try { + ctx.channel() + .writeAndFlush(frame) + .addListener( + future -> { + if (!future.isSuccess()) { + logger.error( + "Failed to send websocket frame: {} bytes", frameSize, future.cause()); + } + }); + } catch (Exception e) { + logger.error("Exception while sending websocket frame: {} bytes", frameSize, e); + throw e; + } + + if (!last) { + currentBuffer = ctx.alloc().buffer(fragmentSize); + } else { + currentBuffer = null; + } + } + + @Override + public void close() throws IOException { + if (!closed) { + flushBuffer(true); + closed = true; + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + logger.warn( + "Unexpected exception in websocket connection to conductor. Channel active: {}, writable: {}", + ctx.channel().isActive(), + ctx.channel().isWritable(), + cause); + resetWebSocket(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + logger.warn( + "Websocket channel became inactive. Shutdown: {}, reconnect pending: {}", + isShutdown.get(), + reconnectTimeout != null); + if (!isShutdown.get() && reconnectTimeout == null) { + logger.warn("Channel inactive: Connection to conductor lost. Reconnecting"); + resetWebSocket(); + } + } } public static class Builder { @@ -117,7 +373,6 @@ public static class Builder { private int pingPeriodMs = 20000; private int pingTimeoutMs = 15000; private int reconnectDelayMs = 1000; - private int connectTimeoutMs = 5000; public Builder(DBOSExecutor e, SystemDatabase s, String key) { systemDatabase = s; @@ -146,11 +401,6 @@ Builder reconnectDelayMs(int reconnectDelayMs) { return this; } - Builder connectTimeoutMs(int connectTimeoutMs) { - this.connectTimeoutMs = connectTimeoutMs; - return this; - } - public Conductor build() { return new Conductor(this); } @@ -163,7 +413,7 @@ public void close() { public void start() { logger.debug("start"); - dispatchLoop(); + connectWebSocket(); } public void stop() { @@ -181,14 +431,18 @@ public void stop() { scheduler.shutdownNow(); - if (webSocket != null) { - webSocket.sendClose(WebSocket.NORMAL_CLOSURE, ""); - webSocket = null; + if (channel != null) { + channel.close(); + channel = null; + } + if (group != null) { + group.shutdownGracefully(); + group = null; } } } - void setPingInterval(WebSocket webSocket) { + void setPingInterval(Channel channel) { logger.debug("setPingInterval"); if (pingInterval != null) { @@ -201,32 +455,29 @@ void setPingInterval(WebSocket webSocket) { return; } try { - // Check for null in case webSocket connects before webSocket variable is assigned - if (webSocket == null) { - logger.debug("webSocket null, NOT sending ping to conductor"); + if (channel == null || !channel.isActive()) { + logger.debug("channel not active, NOT sending ping to conductor"); return; } - if (webSocket.isOutputClosed()) { - logger.debug("webSocket closed, NOT sending ping to conductor"); - return; - } - - logger.debug("Sending ping to conductor"); - webSocket - .sendPing(ByteBuffer.allocate(0)) - .exceptionally( - ex -> { - logger.error("Failed to send ping to conductor", ex); - resetWebSocket(); - return null; + logger.debug("Sending ping to conductor (timeout in {}ms)", pingTimeoutMs); + channel + .writeAndFlush(new PingWebSocketFrame()) + .addListener( + future -> { + if (!future.isSuccess()) { + logger.error("Failed to send ping to conductor", future.cause()); + resetWebSocket(); + } }); pingTimeout = scheduler.schedule( () -> { if (!isShutdown.get()) { - logger.warn("pingTimeout: Connection to conductor lost. Reconnecting."); + logger.error( + "Ping timeout after {}ms - no pong received from conductor. Connection lost, reconnecting.", + pingTimeoutMs); resetWebSocket(); } }, @@ -242,6 +493,10 @@ void setPingInterval(WebSocket webSocket) { } void resetWebSocket() { + logger.info( + "Resetting websocket connection. Channel active: {}", + channel != null ? channel.isActive() : "null"); + if (pingInterval != null) { pingInterval.cancel(false); pingInterval = null; @@ -252,315 +507,504 @@ void resetWebSocket() { pingTimeout = null; } - if (webSocket != null) { - webSocket.abort(); - webSocket = null; + if (channel != null) { + channel.close(); + channel = null; + } + + if (group != null) { + group.shutdownGracefully(); + group = null; } if (isShutdown.get()) { + logger.debug("Not scheduling reconnection - conductor is shutting down"); return; } if (reconnectTimeout == null) { + logger.info("Scheduling websocket reconnection in {}ms", reconnectDelayMs); reconnectTimeout = scheduler.schedule( () -> { reconnectTimeout = null; - dispatchLoop(); + logger.info("Attempting websocket reconnection"); + connectWebSocket(); }, reconnectDelayMs, TimeUnit.MILLISECONDS); + } else { + logger.debug("Reconnection already scheduled"); } } - void dispatchLoop() { - if (webSocket != null) { - logger.warn("Conductor websocket already exists"); + void connectWebSocket() { + if (channel != null) { + logger.warn("Conductor channel already exists"); return; } if (isShutdown.get()) { - logger.debug("Not starting dispatch loop as conductor is shutting down"); + logger.debug("Not connecting web socket as conductor is shutting down"); return; } try { logger.debug("Connecting to conductor at {}", url); + URI uri = new URI(url); + String scheme = uri.getScheme() == null ? "ws" : uri.getScheme(); + final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); + final int port; + if (uri.getPort() == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else if ("wss".equalsIgnoreCase(scheme)) { + port = 443; + } else { + port = -1; + } + } else { + port = uri.getPort(); + } - HttpClient client = HttpClient.newHttpClient(); - webSocket = - client - .newWebSocketBuilder() - .connectTimeout(Duration.ofMillis(connectTimeoutMs)) - .buildAsync( - URI.create(url), - new WebSocket.Listener() { - @Override - public void onOpen(WebSocket webSocket) { - logger.debug("Opened connection to DBOS conductor"); - webSocket.request(1); - setPingInterval(webSocket); - } - - @Override - public CompletionStage onPong(WebSocket webSocket, ByteBuffer message) { - logger.debug("Received pong from conductor"); - webSocket.request(1); - if (pingTimeout != null) { - pingTimeout.cancel(false); - pingTimeout = null; - } - return null; - } - - @Override - public CompletionStage onClose( - WebSocket webSocket, int statusCode, String reason) { - if (isShutdown.get()) { - logger.debug("Shutdown Conductor connection"); - } else if (reconnectTimeout == null) { - logger.warn("onClose: Connection to conductor lost. Reconnecting"); - resetWebSocket(); - } - return Listener.super.onClose(webSocket, statusCode, reason); - } - - @Override - public void onError(WebSocket webSocket, Throwable error) { - logger.warn( - "Unexpected exception in connection to conductor. Reconnecting", error); - resetWebSocket(); - } + if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { + logger.error("Only WS(S) is supported."); + return; + } - @Override - public CompletionStage onText( - WebSocket webSocket, CharSequence data, boolean last) { - BaseMessage request; - webSocket.request(1); - try { - request = JSONUtil.fromJson(data.toString(), BaseMessage.class); - } catch (Exception e) { - logger.error("Conductor JSON Parsing error", e); - return CompletableFuture.completedStage(null); - } + final boolean ssl = "wss".equalsIgnoreCase(scheme); + final SslContext sslCtx; + if (ssl) { + sslCtx = + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + } else { + sslCtx = null; + } - String responseText; - try { - BaseResponse response = getResponse(request); - responseText = JSONUtil.toJson(response); - } catch (Exception e) { - logger.error("Conductor Response error", e); - return CompletableFuture.completedStage(null); - } + group = new NioEventLoopGroup(); + handler = new NettyWebSocketHandler(); + + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + var p = ch.pipeline(); + if (sslCtx != null) { + p.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + } + p.addLast( + new HttpClientCodec(), + new HttpObjectAggregator(256 * 1024 * 1024), // 256MB max message size + new WebSocketClientProtocolHandler( + WebSocketClientProtocolConfig.newBuilder() + .webSocketUri(uri) + .version(WebSocketVersion.V13) + .subprotocol(null) + .allowExtensions(false) + .customHeaders(EmptyHttpHeaders.INSTANCE) + .dropPongFrames(false) + .handleCloseFrames(false) + .maxFramePayloadLength(256 * 1024 * 1024) + .build()), + new MessageToMessageDecoder() { + @Override + protected void decode( + ChannelHandlerContext ctx, WebSocketFrame frame, List out) { + if (frame instanceof TextWebSocketFrame + || frame instanceof ContinuationWebSocketFrame) { + out.add(frame.content().retain()); + } else { + out.add(frame.retain()); + } + } + }, + new JsonObjectDecoder(256 * 1024 * 1024) { + { + setCumulator(COMPOSITE_CUMULATOR); + } + }, + handler); + } + }); + + ChannelFuture future = b.connect(host, port); + channel = future.channel(); + future.addListener( + f -> { + if (f.isSuccess()) { + logger.info("Successfully connected to conductor at {}:{}", host, port); + } else { + logger.warn( + "Failed to connect to conductor at {}:{}. Reconnecting", host, port, f.cause()); + resetWebSocket(); + } + }); - return webSocket - .sendText(responseText, true) - .exceptionally( - ex -> { - logger.error("Conductor sendText error", ex); - return null; - }); - } - }) - .join(); } catch (Exception e) { logger.warn("Error in conductor loop. Reconnecting", e); resetWebSocket(); } } - BaseResponse getResponse(BaseMessage message) { - logger.debug("getResponse {}", message.type); + CompletableFuture getResponseAsync(BaseMessage message) { + logger.debug("getResponseAsync {}", message.type); MessageType messageType = MessageType.fromValue(message.type); - BiFunction func = dispatchMap.get(messageType); + BiFunction> func = + dispatchMap.get(messageType); if (func != null) { return func.apply(this, message); } else { logger.warn("Conductor unknown message type {}", message.type); - return new BaseResponse(message.type, message.request_id, "Unknown message type"); + return CompletableFuture.completedFuture( + new BaseResponse(message.type, message.request_id, "Unknown message type")); } } - static BaseResponse handleExecutorInfo(Conductor conductor, BaseMessage message) { - try { - String hostname = InetAddress.getLocalHost().getHostName(); - return new ExecutorInfoResponse( - message, - conductor.dbosExecutor.executorId(), - conductor.dbosExecutor.appVersion(), - hostname); - } catch (Exception e) { - return new ExecutorInfoResponse(message, e); - } + static CompletableFuture handleExecutorInfo( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + try { + String hostname = InetAddress.getLocalHost().getHostName(); + return new ExecutorInfoResponse( + message, + conductor.dbosExecutor.executorId(), + conductor.dbosExecutor.appVersion(), + hostname); + } catch (Exception e) { + return new ExecutorInfoResponse(message, e); + } + }); } - static BaseResponse handleRecovery(Conductor conductor, BaseMessage message) { - RecoveryRequest request = (RecoveryRequest) message; - try { - conductor.dbosExecutor.recoverPendingWorkflows(request.executor_ids); - return new SuccessResponse(request, true); - } catch (Exception e) { - logger.error("Exception encountered when recovering pending workflows", e); - return new SuccessResponse(request, e); - } + static CompletableFuture handleRecovery(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + RecoveryRequest request = (RecoveryRequest) message; + try { + conductor.dbosExecutor.recoverPendingWorkflows(request.executor_ids); + return new SuccessResponse(request, true); + } catch (Exception e) { + logger.error("Exception encountered when recovering pending workflows", e); + return new SuccessResponse(request, e); + } + }); } - static BaseResponse handleCancel(Conductor conductor, BaseMessage message) { - CancelRequest request = (CancelRequest) message; - try { - conductor.dbosExecutor.cancelWorkflow(request.workflow_id); - return new SuccessResponse(request, true); - } catch (Exception e) { - logger.error("Exception encountered when cancelling workflow {}", request.workflow_id, e); - return new SuccessResponse(request, e); - } + static CompletableFuture handleCancel(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + CancelRequest request = (CancelRequest) message; + try { + conductor.dbosExecutor.cancelWorkflow(request.workflow_id); + return new SuccessResponse(request, true); + } catch (Exception e) { + logger.error( + "Exception encountered when cancelling workflow {}", request.workflow_id, e); + return new SuccessResponse(request, e); + } + }); } - static BaseResponse handleResume(Conductor conductor, BaseMessage message) { - ResumeRequest request = (ResumeRequest) message; - try { - conductor.dbosExecutor.resumeWorkflow(request.workflow_id); - return new SuccessResponse(request, true); - } catch (Exception e) { - logger.error("Exception encountered when resuming workflow {}", request.workflow_id, e); - return new SuccessResponse(request, e); - } + static CompletableFuture handleDelete(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + DeleteRequest request = (DeleteRequest) message; + try { + conductor.dbosExecutor.deleteWorkflow(request.workflow_id, request.delete_children); + return new SuccessResponse(request, true); + } catch (Exception e) { + logger.error("Exception encountered when deleting workflow {}", request.workflow_id, e); + return new SuccessResponse(request, e); + } + }); } - static BaseResponse handleRestart(Conductor conductor, BaseMessage message) { - RestartRequest request = (RestartRequest) message; - try { - ForkOptions options = new ForkOptions(); - conductor.dbosExecutor.forkWorkflow(request.workflow_id, 0, options); - return new SuccessResponse(request, true); - } catch (Exception e) { - logger.error("Exception encountered when restarting workflow {}", request.workflow_id, e); - return new SuccessResponse(request, e); - } + static CompletableFuture handleResume(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ResumeRequest request = (ResumeRequest) message; + try { + conductor.dbosExecutor.resumeWorkflow(request.workflow_id); + return new SuccessResponse(request, true); + } catch (Exception e) { + logger.error("Exception encountered when resuming workflow {}", request.workflow_id, e); + return new SuccessResponse(request, e); + } + }); } - static BaseResponse handleFork(Conductor conductor, BaseMessage message) { - ForkWorkflowRequest request = (ForkWorkflowRequest) message; - if (request.body.workflow_id == null || request.body.start_step == null) { - return new ForkWorkflowResponse(request, null, "Invalid Fork Workflow Request"); - } - try { - var options = - new ForkOptions(request.body.new_workflow_id, request.body.application_version, null); - WorkflowHandle handle = - conductor.dbosExecutor.forkWorkflow( - request.body.workflow_id, request.body.start_step, options); - return new ForkWorkflowResponse(request, handle.workflowId()); - } catch (Exception e) { - logger.error("Exception encountered when forking workflow {}", request, e); - return new ForkWorkflowResponse(request, e); - } + static CompletableFuture handleRestart(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + RestartRequest request = (RestartRequest) message; + try { + ForkOptions options = new ForkOptions(); + conductor.dbosExecutor.forkWorkflow(request.workflow_id, 0, options); + return new SuccessResponse(request, true); + } catch (Exception e) { + logger.error( + "Exception encountered when restarting workflow {}", request.workflow_id, e); + return new SuccessResponse(request, e); + } + }); } - static BaseResponse handleListWorkflows(Conductor conductor, BaseMessage message) { - ListWorkflowsRequest request = (ListWorkflowsRequest) message; - try { - ListWorkflowsInput input = request.asInput(); - List statuses = conductor.dbosExecutor.listWorkflows(input); - List output = - statuses.stream().map(s -> new WorkflowsOutput(s)).collect(Collectors.toList()); - return new WorkflowOutputsResponse(request, output); - } catch (Exception e) { - logger.error("Exception encountered when listing workflows", e); - return new WorkflowOutputsResponse(request, e); - } + static CompletableFuture handleFork(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ForkWorkflowRequest request = (ForkWorkflowRequest) message; + if (request.body.workflow_id == null || request.body.start_step == null) { + return new ForkWorkflowResponse(request, null, "Invalid Fork Workflow Request"); + } + try { + var options = + new ForkOptions( + request.body.new_workflow_id, request.body.application_version, null); + WorkflowHandle handle = + conductor.dbosExecutor.forkWorkflow( + request.body.workflow_id, request.body.start_step, options); + return new ForkWorkflowResponse(request, handle.workflowId()); + } catch (Exception e) { + logger.error("Exception encountered when forking workflow {}", request, e); + return new ForkWorkflowResponse(request, e); + } + }); } - static BaseResponse handleListQueuedWorkflows(Conductor conductor, BaseMessage message) { - ListQueuedWorkflowsRequest request = (ListQueuedWorkflowsRequest) message; - try { - ListWorkflowsInput input = request.asInput(); - List statuses = conductor.dbosExecutor.listWorkflows(input); - List output = - statuses.stream().map(s -> new WorkflowsOutput(s)).collect(Collectors.toList()); - return new WorkflowOutputsResponse(request, output); - } catch (Exception e) { - logger.error("Exception encountered when listing workflows", e); - return new WorkflowOutputsResponse(request, e); - } + static CompletableFuture handleListWorkflows( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ListWorkflowsRequest request = (ListWorkflowsRequest) message; + try { + ListWorkflowsInput input = request.asInput(); + List statuses = conductor.dbosExecutor.listWorkflows(input); + List output = + statuses.stream().map(s -> new WorkflowsOutput(s)).collect(Collectors.toList()); + return new WorkflowOutputsResponse(request, output); + } catch (Exception e) { + logger.error("Exception encountered when listing workflows", e); + return new WorkflowOutputsResponse(request, e); + } + }); } - static BaseResponse handleListSteps(Conductor conductor, BaseMessage message) { - ListStepsRequest request = (ListStepsRequest) message; - try { - List stepInfoList = conductor.dbosExecutor.listWorkflowSteps(request.workflow_id); - List steps = - stepInfoList.stream() - .map(i -> new ListStepsResponse.Step(i)) - .collect(Collectors.toList()); - return new ListStepsResponse(request, steps); - } catch (Exception e) { - logger.error("Exception encountered when listing steps {}", request.workflow_id, e); - return new ListStepsResponse(request, e); - } + static CompletableFuture handleListQueuedWorkflows( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ListQueuedWorkflowsRequest request = (ListQueuedWorkflowsRequest) message; + try { + ListWorkflowsInput input = request.asInput(); + List statuses = conductor.dbosExecutor.listWorkflows(input); + List output = + statuses.stream().map(s -> new WorkflowsOutput(s)).collect(Collectors.toList()); + return new WorkflowOutputsResponse(request, output); + } catch (Exception e) { + logger.error("Exception encountered when listing workflows", e); + return new WorkflowOutputsResponse(request, e); + } + }); } - static BaseResponse handleExistPendingWorkflows(Conductor conductor, BaseMessage message) { - ExistPendingWorkflowsRequest request = (ExistPendingWorkflowsRequest) message; - try { - List pending = - conductor.systemDatabase.getPendingWorkflows( - request.executor_id, request.application_version); - return new ExistPendingWorkflowsResponse(request, pending.size() > 0); - } catch (Exception e) { - logger.error("Exception encountered when checking for pending workflows", e); - return new ExistPendingWorkflowsResponse(request, e); - } + static CompletableFuture handleListSteps(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ListStepsRequest request = (ListStepsRequest) message; + try { + List stepInfoList = + conductor.dbosExecutor.listWorkflowSteps(request.workflow_id); + List steps = + stepInfoList.stream() + .map(i -> new ListStepsResponse.Step(i)) + .collect(Collectors.toList()); + return new ListStepsResponse(request, steps); + } catch (Exception e) { + logger.error("Exception encountered when listing steps {}", request.workflow_id, e); + return new ListStepsResponse(request, e); + } + }); } - static BaseResponse handleGetWorkflow(Conductor conductor, BaseMessage message) { - GetWorkflowRequest request = (GetWorkflowRequest) message; - try { - var status = conductor.systemDatabase.getWorkflowStatus(request.workflow_id); - WorkflowsOutput output = status == null ? null : new WorkflowsOutput(status); - return new GetWorkflowResponse(request, output); - } catch (Exception e) { - logger.error("Exception encountered when getting workflow {}", request.workflow_id, e); - return new GetWorkflowResponse(request, e); - } + static CompletableFuture handleExistPendingWorkflows( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ExistPendingWorkflowsRequest request = (ExistPendingWorkflowsRequest) message; + try { + List pending = + conductor.systemDatabase.getPendingWorkflows( + request.executor_id, request.application_version); + return new ExistPendingWorkflowsResponse(request, pending.size() > 0); + } catch (Exception e) { + logger.error("Exception encountered when checking for pending workflows", e); + return new ExistPendingWorkflowsResponse(request, e); + } + }); } - static BaseResponse handleRetention(Conductor conductor, BaseMessage message) { - RetentionRequest request = (RetentionRequest) message; + static CompletableFuture handleGetWorkflow( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + GetWorkflowRequest request = (GetWorkflowRequest) message; + try { + var status = conductor.systemDatabase.getWorkflowStatus(request.workflow_id); + WorkflowsOutput output = status == null ? null : new WorkflowsOutput(status); + return new GetWorkflowResponse(request, output); + } catch (Exception e) { + logger.error("Exception encountered when getting workflow {}", request.workflow_id, e); + return new GetWorkflowResponse(request, e); + } + }); + } - try { - conductor.systemDatabase.garbageCollect( - request.body.gc_cutoff_epoch_ms, request.body.gc_rows_threshold); - } catch (Exception e) { - logger.error("Exception encountered garbage collecting system database", e); - return new SuccessResponse(request, e); - } + static CompletableFuture handleRetention(Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + RetentionRequest request = (RetentionRequest) message; + + try { + conductor.systemDatabase.garbageCollect( + request.body.gc_cutoff_epoch_ms, request.body.gc_rows_threshold); + } catch (Exception e) { + logger.error("Exception encountered garbage collecting system database", e); + return new SuccessResponse(request, e); + } + + try { + if (request.body.timeout_cutoff_epoch_ms != null) { + conductor.dbosExecutor.globalTimeout(request.body.timeout_cutoff_epoch_ms); + } + } catch (Exception e) { + logger.error("Exception encountered setting global timeout", e); + return new SuccessResponse(request, e); + } + + return new SuccessResponse(request, true); + }); + } - try { - if (request.body.timeout_cutoff_epoch_ms != null) { - conductor.dbosExecutor.globalTimeout(request.body.timeout_cutoff_epoch_ms); - } - } catch (Exception e) { - logger.error("Exception encountered setting global timeout", e); - return new SuccessResponse(request, e); - } + static CompletableFuture handleGetMetrics( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + GetMetricsRequest request = (GetMetricsRequest) message; + + try { + if (request.metric_class.equals("workflow_step_count")) { + var metrics = + conductor.systemDatabase.getMetrics(request.startTime(), request.endTime()); + return new GetMetricsResponse(request, metrics); + } else { + logger.warn("Unexpected metric class {}", request.metric_class); + throw new RuntimeException( + "Unexpected metric class %s".formatted(request.metric_class)); + } + } catch (Exception e) { + return new GetMetricsResponse(request, e); + } + }); + } - return new SuccessResponse(request, true); + static CompletableFuture handleImportWorkflow( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ImportWorkflowRequest request = (ImportWorkflowRequest) message; + long startTime = System.currentTimeMillis(); + logger.info("Starting import workflow"); + + try { + var exportedWorkflows = deserializeExportedWorkflows(request.serialized_workflow); + logger.info("deserialization completed workflow count={}", exportedWorkflows.size()); + conductor.systemDatabase.importWorkflow(exportedWorkflows); + long duration = System.currentTimeMillis() - startTime; + logger.info( + "Database import completed: {} workflows imported, duration={}ms", + exportedWorkflows.size(), + duration); + return new SuccessResponse(request, true); + } catch (Exception e) { + logger.error("Exception encountered when importing workflow", e); + return new SuccessResponse(request, e); + } + }); } - static BaseResponse handleGetMetrics(Conductor conductor, BaseMessage message) { - GetMetricsRequest request = (GetMetricsRequest) message; + static CompletableFuture handleExportWorkflow( + Conductor conductor, BaseMessage message) { + return CompletableFuture.supplyAsync( + () -> { + ExportWorkflowRequest request = (ExportWorkflowRequest) message; + long startTime = System.currentTimeMillis(); + logger.info( + "Starting export workflow: id={}, export_children={}", + request.workflow_id, + request.export_children); + + try { + var workflows = + conductor.systemDatabase.exportWorkflow( + request.workflow_id, request.export_children); + + logger.info( + "Database export completed: workflow_id={}, {} workflows retrieved", + request.workflow_id, + workflows.size()); + + var serializedWorkflow = serializeExportedWorkflows(workflows); + + long duration = System.currentTimeMillis() - startTime; + logger.info( + "Export workflow completed: id={}, workflows={}, serialized_size={} bytes, duration={}ms", + request.workflow_id, + workflows.size(), + serializedWorkflow.length(), + duration); + + return new ExportWorkflowResponse(message, serializedWorkflow); + } catch (Exception e) { + long duration = System.currentTimeMillis() - startTime; + var children = request.export_children ? "with children" : ""; + logger.error( + "Exception encountered when exporting workflow {} {} after {}ms", + request.workflow_id, + children, + duration, + e); + return new ExportWorkflowResponse(request, e); + } finally { + long totalDuration = System.currentTimeMillis() - startTime; + logger.info( + "handleExportWorkflow completed: id={}, total_duration={}ms", + request.workflow_id, + totalDuration); + } + }); + } - try { - if (request.metric_class.equals("workflow_step_count")) { - var metrics = conductor.systemDatabase.getMetrics(request.startTime(), request.endTime()); - return new GetMetricsResponse(request, metrics); - } else { - logger.warn("Unexpected metric class {}", request.metric_class); - throw new RuntimeException("Unexpected metric class %s".formatted(request.metric_class)); - } - } catch (Exception e) { - return new GetMetricsResponse(request, e); + static List deserializeExportedWorkflows(String serializedWorkflow) + throws IOException { + var compressed = Base64.getDecoder().decode(serializedWorkflow); + try (var gis = new GZIPInputStream(new ByteArrayInputStream(compressed))) { + var typeRef = new TypeReference>() {}; + return JSONUtil.fromJson(gis, typeRef); } } + + static String serializeExportedWorkflows(List workflows) throws IOException { + var out = new ByteArrayOutputStream(); + try (var gOut = new GZIPOutputStream(out)) { + JSONUtil.toJson(gOut, workflows); + } + + return Base64.getEncoder().encodeToString(out.toByteArray()); + } } diff --git a/transact/src/main/java/dev/dbos/transact/conductor/protocol/BaseMessage.java b/transact/src/main/java/dev/dbos/transact/conductor/protocol/BaseMessage.java index 5d1f6f83..e4175781 100644 --- a/transact/src/main/java/dev/dbos/transact/conductor/protocol/BaseMessage.java +++ b/transact/src/main/java/dev/dbos/transact/conductor/protocol/BaseMessage.java @@ -11,6 +11,7 @@ @JsonSubTypes.Type(value = ExecutorInfoRequest.class, name = "executor_info"), @JsonSubTypes.Type(value = RecoveryRequest.class, name = "recovery"), @JsonSubTypes.Type(value = CancelRequest.class, name = "cancel"), + @JsonSubTypes.Type(value = DeleteRequest.class, name = "delete"), @JsonSubTypes.Type(value = ResumeRequest.class, name = "resume"), @JsonSubTypes.Type(value = RestartRequest.class, name = "restart"), @JsonSubTypes.Type(value = ForkWorkflowRequest.class, name = "fork_workflow"), @@ -21,6 +22,8 @@ @JsonSubTypes.Type(value = ListStepsRequest.class, name = "list_steps"), @JsonSubTypes.Type(value = RetentionRequest.class, name = "retention"), @JsonSubTypes.Type(value = GetMetricsRequest.class, name = "get_metrics"), + @JsonSubTypes.Type(value = ExportWorkflowRequest.class, name = "export_workflow"), + @JsonSubTypes.Type(value = ImportWorkflowRequest.class, name = "import_workflow"), }) @JsonIgnoreProperties(ignoreUnknown = true) public abstract class BaseMessage { diff --git a/transact/src/main/java/dev/dbos/transact/conductor/protocol/DeleteRequest.java b/transact/src/main/java/dev/dbos/transact/conductor/protocol/DeleteRequest.java new file mode 100644 index 00000000..c43eb30a --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/conductor/protocol/DeleteRequest.java @@ -0,0 +1,15 @@ +package dev.dbos.transact.conductor.protocol; + +public class DeleteRequest extends BaseMessage { + public String workflow_id; + public boolean delete_children; + + public DeleteRequest() {} + + public DeleteRequest(String requestId, String workflowId, boolean deleteChildren) { + this.type = MessageType.DELETE.getValue(); + this.request_id = requestId; + this.workflow_id = workflowId; + this.delete_children = deleteChildren; + } +} diff --git a/transact/src/main/java/dev/dbos/transact/conductor/protocol/ExportWorkflowRequest.java b/transact/src/main/java/dev/dbos/transact/conductor/protocol/ExportWorkflowRequest.java new file mode 100644 index 00000000..55c40490 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/conductor/protocol/ExportWorkflowRequest.java @@ -0,0 +1,15 @@ +package dev.dbos.transact.conductor.protocol; + +public class ExportWorkflowRequest extends BaseMessage { + public String workflow_id; + public boolean export_children; + + public ExportWorkflowRequest() {} + + public ExportWorkflowRequest(String requestId, String workflowId, boolean exportChildren) { + this.type = MessageType.DELETE.getValue(); + this.request_id = requestId; + this.workflow_id = workflowId; + this.export_children = exportChildren; + } +} diff --git a/transact/src/main/java/dev/dbos/transact/conductor/protocol/ExportWorkflowResponse.java b/transact/src/main/java/dev/dbos/transact/conductor/protocol/ExportWorkflowResponse.java new file mode 100644 index 00000000..eeea692b --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/conductor/protocol/ExportWorkflowResponse.java @@ -0,0 +1,14 @@ +package dev.dbos.transact.conductor.protocol; + +public class ExportWorkflowResponse extends BaseResponse { + public String serialized_workflow; // optional + + public ExportWorkflowResponse(BaseMessage message, String serializedWorkflow) { + super(MessageType.EXPORT_WORKFLOW.getValue(), message.request_id); + this.serialized_workflow = serializedWorkflow; + } + + public ExportWorkflowResponse(BaseMessage message, Exception ex) { + super(message.type, message.request_id, ex.getMessage()); + } +} diff --git a/transact/src/main/java/dev/dbos/transact/conductor/protocol/ImportWorkflowRequest.java b/transact/src/main/java/dev/dbos/transact/conductor/protocol/ImportWorkflowRequest.java new file mode 100644 index 00000000..9e5dca0a --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/conductor/protocol/ImportWorkflowRequest.java @@ -0,0 +1,13 @@ +package dev.dbos.transact.conductor.protocol; + +public class ImportWorkflowRequest extends BaseMessage { + public String serialized_workflow; + + public ImportWorkflowRequest() {} + + public ImportWorkflowRequest(String requestId, String serializedWorkflow) { + this.type = MessageType.IMPORT_WORKFLOW.getValue(); + this.request_id = requestId; + this.serialized_workflow = serializedWorkflow; + } +} diff --git a/transact/src/main/java/dev/dbos/transact/conductor/protocol/MessageType.java b/transact/src/main/java/dev/dbos/transact/conductor/protocol/MessageType.java index cda11911..508a4316 100644 --- a/transact/src/main/java/dev/dbos/transact/conductor/protocol/MessageType.java +++ b/transact/src/main/java/dev/dbos/transact/conductor/protocol/MessageType.java @@ -4,6 +4,7 @@ public enum MessageType { EXECUTOR_INFO("executor_info"), RECOVERY("recovery"), CANCEL("cancel"), + DELETE("delete"), LIST_WORKFLOWS("list_workflows"), LIST_QUEUED_WORKFLOWS("list_queued_workflows"), RESUME("resume"), @@ -13,7 +14,9 @@ public enum MessageType { LIST_STEPS("list_steps"), FORK_WORKFLOW("fork_workflow"), RETENTION("retention"), - GET_METRICS("get_metrics"); + GET_METRICS("get_metrics"), + EXPORT_WORKFLOW("export_workflow"), + IMPORT_WORKFLOW("import_workflow"); private final String value; diff --git a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java index 1b1b1433..b67ebaeb 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java @@ -193,6 +193,12 @@ static StepResult checkStepExecutionTxn( } List listWorkflowSteps(String workflowId) throws SQLException { + try (Connection connection = dataSource.getConnection()) { + return listWorkflowSteps(connection, workflowId); + } + } + + List listWorkflowSteps(Connection connection, String workflowId) throws SQLException { final String sql = """ @@ -205,8 +211,7 @@ List listWorkflowSteps(String workflowId) throws SQLException { List steps = new ArrayList<>(); - try (Connection connection = dataSource.getConnection(); - PreparedStatement stmt = connection.prepareStatement(sql)) { + try (PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, workflowId); @@ -233,19 +238,7 @@ List listWorkflowSteps(String workflowId) throws SQLException { } // Deserialize error if present - ErrorResult stepError = null; - if (errorData != null) { - Exception error = null; - try { - error = (Exception) JSONUtil.deserializeAppException(errorData); - } catch (Exception e) { - throw new RuntimeException( - "Failed to deserialize error for function " + functionId, e); - } - var errorWrapper = JSONUtil.deserializeAppExceptionWrapper(errorData); - stepError = new ErrorResult(errorWrapper.type, errorWrapper.message, errorData, error); - } - + ErrorResult stepError = ErrorResult.deserialize(errorData); Object outputVal = output != null ? output[0] : null; steps.add( new StepInfo( diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index 7de4b78a..c3fb87d4 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -3,11 +3,16 @@ import dev.dbos.transact.Constants; import dev.dbos.transact.config.DBOSConfig; import dev.dbos.transact.exceptions.*; +import dev.dbos.transact.json.JSONUtil; +import dev.dbos.transact.workflow.ExportedWorkflow; import dev.dbos.transact.workflow.ForkOptions; import dev.dbos.transact.workflow.ListWorkflowsInput; import dev.dbos.transact.workflow.Queue; import dev.dbos.transact.workflow.StepInfo; +import dev.dbos.transact.workflow.WorkflowEvent; +import dev.dbos.transact.workflow.WorkflowEventHistory; import dev.dbos.transact.workflow.WorkflowStatus; +import dev.dbos.transact.workflow.WorkflowStream; import dev.dbos.transact.workflow.internal.GetPendingWorkflowsOutput; import dev.dbos.transact.workflow.internal.StepResult; import dev.dbos.transact.workflow.internal.WorkflowStatusInternal; @@ -18,6 +23,7 @@ import java.time.Duration; import java.time.Instant; import java.util.*; +import java.util.stream.Stream; import javax.sql.DataSource; @@ -599,4 +605,332 @@ public boolean deprecatePatch(String workflowId, int functionId, String patchNam } }); } + + public void deleteWorkflows(String... workflowIds) { + if (workflowIds == null || workflowIds.length == 0) { + return; + } + + var sql = + """ + DELETE FROM %s.workflow_status + WHERE workflow_uuid = ANY(?); + """ + .formatted(this.schema); + + dbRetry( + () -> { + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(sql)) { + var array = conn.createArrayOf("text", workflowIds); + stmt.setArray(1, array); + stmt.executeUpdate(); + } + return null; + }); + } + + public List getWorkflowChildren(String workflowId) { + return dbRetry(() -> getWorkflowChildrenInternal(workflowId)); + } + + List getWorkflowChildrenInternal(String workflowId) throws SQLException { + var children = new HashSet(); + var toProcess = new ArrayDeque(); + toProcess.add(workflowId); + + var sql = + """ + SELECT child_workflow_id + FROM %s.operation_outputs + WHERE workflow_uuid = ? AND child_workflow_id IS NOT NULL + """ + .formatted(this.schema); + + try (var conn = dataSource.getConnection()) { + while (!toProcess.isEmpty()) { + var wfid = toProcess.poll(); + + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, wfid); + + try (var rs = stmt.executeQuery()) { + while (rs.next()) { + var childWorkflowId = rs.getString(1); + if (!children.contains(childWorkflowId)) { + children.add(childWorkflowId); + toProcess.add(childWorkflowId); + } + } + } + } + } + } + return new ArrayList(children); + } + + List listWorkflowEvents(Connection conn, String workflowId) throws SQLException { + var sql = + """ + SELECT key, value + FROM %s.workflow_events + WHERE workflow_uuid = ? + """ + .formatted(this.schema); + + var events = new ArrayList(); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + try (var rs = stmt.executeQuery()) { + while (rs.next()) { + var key = rs.getString("key"); + var value = rs.getString("value"); + events.add(new WorkflowEvent(key, value)); + } + } + } + return events; + } + + List listWorkflowEventHistory(Connection conn, String workflowId) + throws SQLException { + var sql = + """ + SELECT key, value, function_id + FROM %s.workflow_events_history + WHERE workflow_uuid = ? + """ + .formatted(this.schema); + + var history = new ArrayList(); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + try (var rs = stmt.executeQuery()) { + while (rs.next()) { + var key = rs.getString("key"); + var value = rs.getString("value"); + var stepId = rs.getInt("function_id"); + history.add(new WorkflowEventHistory(key, value, stepId)); + } + } + } + return history; + } + + List listWorkflowStreams(Connection conn, String workflowId) throws SQLException { + var sql = + """ + SELECT key, value, "offset", function_id + FROM %s.streams + WHERE workflow_uuid = ? + """ + .formatted(this.schema); + + var streams = new ArrayList(); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + try (var rs = stmt.executeQuery()) { + while (rs.next()) { + var key = rs.getString("key"); + var value = rs.getString("value"); + var offset = rs.getInt("offset"); + var stepId = rs.getInt("function_id"); + streams.add(new WorkflowStream(key, value, offset, stepId)); + } + } + } + return streams; + } + + public List exportWorkflow(String workflowId, boolean exportChildren) { + return dbRetry( + () -> { + var workflowIds = + exportChildren + ? Stream.concat( + getWorkflowChildrenInternal(workflowId).stream(), + List.of(workflowId).stream()) + .toList() + : List.of(workflowId); + + var workflows = new ArrayList(); + for (var wfid : workflowIds) { + try (var conn = dataSource.getConnection()) { + var status = workflowDAO.getWorkflowStatus(conn, wfid); + var steps = stepsDAO.listWorkflowSteps(conn, wfid); + var events = listWorkflowEvents(conn, wfid); + var eventHistory = listWorkflowEventHistory(conn, wfid); + var streams = listWorkflowStreams(conn, wfid); + workflows.add(new ExportedWorkflow(status, steps, events, eventHistory, streams)); + } + } + return workflows; + }); + } + + public void importWorkflow(List workflows) { + var wfSQL = + """ + INSERT INTO %s.workflow_status ( + workflow_uuid, status, + name, class_name, config_name, + authenticated_user, assumed_role, authenticated_roles, + output, error, inputs, + executor_id, application_version, application_id, + created_at, updated_at, started_at_epoch_ms, + queue_name, deduplication_id, priority, queue_partition_key, + workflow_timeout_ms, workflow_deadline_epoch_ms, recovery_attempts, forked_from + ) VALUES ( + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ) + """ + .formatted(this.schema); + + var stepSQL = + """ + INSERT INTO %s.operation_outputs ( + workflow_uuid, function_id, function_name, + output, error, child_workflow_id, + started_at_epoch_ms, completed_at_epoch_ms + ) VALUES ( + ?, ?, ?, ?, ?, ?, ?, ? + ) + """ + .formatted(this.schema); + + var eventSQL = + """ + INSERT INTO %s.workflow_events ( + workflow_uuid, key, value + ) VALUES ( + ?, ?, ? + ) + """ + .formatted(this.schema); + + var eventHistorySQL = + """ + INSERT INTO %s.workflow_events_history ( + workflow_uuid, key, value, function_id + ) VALUES ( + ?, ?, ?, ? + ) + """ + .formatted(this.schema); + + var streamsSQL = + """ + INSERT INTO %s.streams ( + workflow_uuid, key, value, function_id, offset + ) VALUES ( + ?, ?, ?, ?, ? + ) + """ + .formatted(this.schema); + + dbRetry( + () -> { + try (var conn = dataSource.getConnection()) { + conn.setAutoCommit(false); + + try { + for (var workflow : workflows) { + + var status = workflow.status(); + try (var stmt = conn.prepareStatement(wfSQL)) { + stmt.setString(1, status.workflowId()); + stmt.setString(2, status.status().toString()); + stmt.setString(3, status.name()); + stmt.setString(4, status.className()); + stmt.setString(5, status.instanceName()); + stmt.setString(6, status.authenticatedUser()); + stmt.setString(7, status.assumedRole()); + stmt.setString( + 8, + status.authenticatedRoles() == null + ? null + : JSONUtil.serializeArray(status.authenticatedRoles())); + stmt.setString( + 9, status.output() == null ? null : JSONUtil.serialize(status.output())); + stmt.setString( + 10, status.error() == null ? null : status.error().serializedError()); + stmt.setString( + 11, status.input() == null ? null : JSONUtil.serializeArray(status.input())); + stmt.setString(12, status.executorId()); + stmt.setString(13, status.appVersion()); + stmt.setString(14, status.appId()); + stmt.setObject(15, status.createdAt()); + stmt.setObject(16, status.updatedAt()); + stmt.setObject(17, status.startedAtEpochMs()); + stmt.setString(18, status.queueName()); + stmt.setString(19, status.deduplicationId()); + stmt.setObject(20, status.priority()); + stmt.setString(21, status.queuePartitionKey()); + stmt.setObject(22, status.timeoutMs()); + stmt.setObject(23, status.deadlineEpochMs()); + stmt.setObject(24, status.recoveryAttempts()); + stmt.setString(25, status.forkedFrom()); + + stmt.executeUpdate(); + } + + for (var step : workflow.steps()) { + try (var stmt = conn.prepareStatement(stepSQL)) { + stmt.setString(1, status.workflowId()); + stmt.setInt(2, step.functionId()); + stmt.setString(3, step.functionName()); + stmt.setString( + 4, step.output() == null ? null : JSONUtil.serialize(step.output())); + stmt.setString(5, step.error() == null ? null : step.error().serializedError()); + stmt.setString(6, step.childWorkflowId()); + stmt.setObject(7, step.startedAtEpochMs()); + stmt.setObject(8, step.completedAtEpochMs()); + + stmt.executeUpdate(); + } + } + + for (var event : workflow.events()) { + try (var stmt = conn.prepareStatement(eventSQL)) { + stmt.setString(1, status.workflowId()); + stmt.setString(2, event.key()); + stmt.setString(3, event.value()); + + stmt.executeUpdate(); + } + } + + for (var history : workflow.eventHistory()) { + try (var stmt = conn.prepareStatement(eventHistorySQL)) { + stmt.setString(1, status.workflowId()); + stmt.setString(2, history.key()); + stmt.setString(3, history.value()); + stmt.setInt(4, history.stepId()); + + stmt.executeUpdate(); + } + } + + for (var stream : workflow.streams()) { + try (var stmt = conn.prepareStatement(streamsSQL)) { + stmt.setString(1, status.workflowId()); + stmt.setString(2, stream.key()); + stmt.setString(3, stream.value()); + stmt.setInt(4, stream.stepId()); + stmt.setInt(5, stream.offset()); + + stmt.executeUpdate(); + } + } + } + conn.commit(); + } catch (SQLException e) { + conn.rollback(); + throw e; + } + } + + return null; + }); + } } diff --git a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java b/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java index 3298646d..cd78c2c2 100644 --- a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java @@ -317,10 +317,35 @@ void recordWorkflowError(String workflowId, String error) throws SQLException { WorkflowStatus getWorkflowStatus(String workflowId) throws SQLException { - var input = new ListWorkflowsInput().withWorkflowId(workflowId); - List output = listWorkflows(input); - if (output.size() > 0) { - return output.get(0); + try (var conn = dataSource.getConnection()) { + return getWorkflowStatus(conn, workflowId); + } + } + + WorkflowStatus getWorkflowStatus(Connection conn, String workflowId) throws SQLException { + var sql = + """ + SELECT + workflow_uuid, status, forked_from, + name, class_name, config_name, + inputs, output, error, + queue_name, deduplication_id, priority, queue_partition_key, + executor_id, application_version, application_id, + authenticated_user, assumed_role, authenticated_roles, + created_at, updated_at, recovery_attempts, started_at_epoch_ms, + workflow_timeout_ms, workflow_deadline_epoch_ms + FROM %s.workflow_status + WHERE workflow_uuid = ? + """ + .formatted(this.schema); + + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + try (var rs = stmt.executeQuery()) { + if (rs.next()) { + return resultsToWorkflowStatus(rs, true, true); + } + } } return null; @@ -474,55 +499,7 @@ List listWorkflows(ListWorkflowsInput input) throws SQLException try (ResultSet rs = pstmt.executeQuery()) { while (rs.next()) { - var workflow_uuid = rs.getString("workflow_uuid"); - String authenticatedRolesJson = rs.getString("authenticated_roles"); - String serializedInput = loadInput ? rs.getString("inputs") : null; - String serializedOutput = loadOutput ? rs.getString("output") : null; - String serializedError = loadOutput ? rs.getString("error") : null; - ErrorResult err = null; - if (serializedError != null) { - var wrapper = JSONUtil.deserializeAppExceptionWrapper(serializedError); - Throwable throwable = null; - try { - throwable = JSONUtil.deserializeAppException(serializedError); - } catch (Exception e) { - throw new RuntimeException( - "Failed to deserialize error for workflow " + workflow_uuid, e); - } - err = new ErrorResult(wrapper.type, wrapper.message, serializedError, throwable); - } - WorkflowStatus info = - new WorkflowStatus( - workflow_uuid, - rs.getString("status"), - rs.getString("name"), - rs.getString("class_name"), - rs.getString("config_name"), - rs.getString("authenticated_user"), - rs.getString("assumed_role"), - (authenticatedRolesJson != null) - ? (String[]) JSONUtil.deserializeToArray(authenticatedRolesJson) - : null, - (serializedInput != null) ? JSONUtil.deserializeToArray(serializedInput) : null, - (serializedOutput != null) - ? JSONUtil.deserializeToArray(serializedOutput)[0] - : null, - err, - rs.getString("executor_id"), - rs.getObject("created_at", Long.class), - rs.getObject("updated_at", Long.class), - rs.getString("application_version"), - rs.getString("application_id"), - rs.getInt("recovery_attempts"), - rs.getString("queue_name"), - rs.getObject("workflow_timeout_ms", Long.class), - rs.getObject("workflow_deadline_epoch_ms", Long.class), - rs.getObject("started_at_epoch_ms", Long.class), - rs.getString("deduplication_id"), - rs.getObject("priority", Integer.class), - rs.getString("queue_partition_key"), - rs.getString("forked_from")); - + WorkflowStatus info = resultsToWorkflowStatus(rs, loadInput, loadOutput); workflows.add(info); } } @@ -531,6 +508,46 @@ List listWorkflows(ListWorkflowsInput input) throws SQLException return workflows; } + private static WorkflowStatus resultsToWorkflowStatus( + ResultSet rs, boolean loadInput, boolean loadOutput) throws SQLException { + var workflow_uuid = rs.getString("workflow_uuid"); + String authenticatedRolesJson = rs.getString("authenticated_roles"); + String serializedInput = loadInput ? rs.getString("inputs") : null; + String serializedOutput = loadOutput ? rs.getString("output") : null; + String serializedError = loadOutput ? rs.getString("error") : null; + ErrorResult err = ErrorResult.deserialize(serializedError); + WorkflowStatus info = + new WorkflowStatus( + workflow_uuid, + rs.getString("status"), + rs.getString("name"), + rs.getString("class_name"), + rs.getString("config_name"), + rs.getString("authenticated_user"), + rs.getString("assumed_role"), + (authenticatedRolesJson != null) + ? (String[]) JSONUtil.deserializeToArray(authenticatedRolesJson) + : null, + (serializedInput != null) ? JSONUtil.deserializeToArray(serializedInput) : null, + (serializedOutput != null) ? JSONUtil.deserializeToArray(serializedOutput)[0] : null, + err, + rs.getString("executor_id"), + rs.getObject("created_at", Long.class), + rs.getObject("updated_at", Long.class), + rs.getString("application_version"), + rs.getString("application_id"), + rs.getInt("recovery_attempts"), + rs.getString("queue_name"), + rs.getObject("workflow_timeout_ms", Long.class), + rs.getObject("workflow_deadline_epoch_ms", Long.class), + rs.getObject("started_at_epoch_ms", Long.class), + rs.getString("deduplication_id"), + rs.getObject("priority", Integer.class), + rs.getString("queue_partition_key"), + rs.getString("forked_from")); + return info; + } + List getPendingWorkflows(String executorId, String appVersion) throws SQLException { diff --git a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java index a5b5e470..20b290cf 100644 --- a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java +++ b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java @@ -56,6 +56,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -629,6 +630,27 @@ public void cancelWorkflow(String workflowId) { null); } + public void deleteWorkflow(String workflowId, boolean deleteChildren) { + Objects.requireNonNull(workflowId); + this.callFunctionAsStep( + () -> { + logger.info( + "Deleting workflow {}{}", workflowId, deleteChildren ? "" : " and its children"); + if (deleteChildren) { + var children = systemDatabase.getWorkflowChildren(workflowId); + var array = + Stream.concat(Stream.of(workflowId), children.stream()).toArray(String[]::new); + systemDatabase.deleteWorkflows(array); + } else { + systemDatabase.deleteWorkflows(workflowId); + } + + return null; + }, + "DBOS.deleteWorkflow", + null); + } + public WorkflowHandle forkWorkflow( String workflowId, int startStep, ForkOptions options) { diff --git a/transact/src/main/java/dev/dbos/transact/json/JSONUtil.java b/transact/src/main/java/dev/dbos/transact/json/JSONUtil.java index ec1f0927..18b6eb79 100644 --- a/transact/src/main/java/dev/dbos/transact/json/JSONUtil.java +++ b/transact/src/main/java/dev/dbos/transact/json/JSONUtil.java @@ -5,8 +5,10 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.OutputStream; import java.io.StreamCorruptedException; import java.io.UncheckedIOException; import java.util.Arrays; @@ -15,6 +17,8 @@ import java.util.Map; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.StreamReadConstraints; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; @@ -22,12 +26,19 @@ import org.slf4j.LoggerFactory; public class JSONUtil { + + static { + // extend max JSON string length to handle large workflow import/export JSON + StreamReadConstraints.overrideDefaultStreamReadConstraints( + StreamReadConstraints.builder().maxStringLength(1_000_000_000).build()); + } + private static final Logger logger = LoggerFactory.getLogger(Conductor.class); private static final ObjectMapper mapper = new ObjectMapper(); public static class JsonRuntimeException extends RuntimeException { - public JsonRuntimeException(JsonProcessingException cause) { + public JsonRuntimeException(Exception cause) { super(cause.getMessage(), cause); setStackTrace(cause.getStackTrace()); for (Throwable suppressed : cause.getSuppressed()) { @@ -90,6 +101,14 @@ public static String toJson(Object obj) { } } + public static void toJson(OutputStream out, Object obj) { + try { + mapper.writeValue(out, obj); + } catch (IOException e) { + throw new JsonRuntimeException(e); + } + } + public static T fromJson(String content, Class valueType) { try { return mapper.readValue(content, valueType); @@ -98,6 +117,58 @@ public static T fromJson(String content, Class valueType) { } } + public static T fromJson(InputStream stream, TypeReference valueType) { + try { + return mapper.readValue(stream, valueType); + } catch (IOException e) { + if (e instanceof JsonProcessingException) { + throw new JsonRuntimeException((JsonProcessingException) e); + } + throw new RuntimeException(e); + } + } + + public static T fromJson(byte[] content, Class valueType) { + try { + return mapper.readValue(content, valueType); + } catch (IOException e) { + if (e instanceof JsonProcessingException) { + throw new JsonRuntimeException((JsonProcessingException) e); + } + throw new RuntimeException(e); + } + } + + public static T fromJson(InputStream content, Class valueType) { + try { + return mapper.readValue(content, valueType); + } catch (IOException e) { + if (e instanceof JsonProcessingException) { + throw new JsonRuntimeException((JsonProcessingException) e); + } + throw new RuntimeException(e); + } + } + + public static byte[] toJsonBytes(Object obj) { + try { + return mapper.writeValueAsBytes(obj); + } catch (JsonProcessingException e) { + throw new JsonRuntimeException(e); + } + } + + public static void toJsonStream(Object obj, OutputStream out) { + try { + mapper.writeValue(out, obj); + } catch (IOException e) { + if (e instanceof JsonProcessingException) { + throw new JsonRuntimeException((JsonProcessingException) e); + } + throw new RuntimeException(e); + } + } + public static final class WireThrowable { public int v = 1; public String type; diff --git a/transact/src/main/java/dev/dbos/transact/workflow/ErrorResult.java b/transact/src/main/java/dev/dbos/transact/workflow/ErrorResult.java index 6d8f9678..de57c32b 100644 --- a/transact/src/main/java/dev/dbos/transact/workflow/ErrorResult.java +++ b/transact/src/main/java/dev/dbos/transact/workflow/ErrorResult.java @@ -5,10 +5,22 @@ public record ErrorResult( String className, String message, String serializedError, Throwable throwable) { - public static ErrorResult of(Throwable error) { - String errorString = JSONUtil.serializeAppException(error); - var wrapper = JSONUtil.deserializeAppExceptionWrapper(errorString); - Throwable throwable = JSONUtil.deserializeAppException(errorString); - return new ErrorResult(wrapper.type, wrapper.message, errorString, throwable); + public static ErrorResult fromThrowable(Throwable error) { + if (error != null) { + var serializedError = JSONUtil.serializeAppException(error); + return deserialize(serializedError); + } else { + return null; + } + } + + public static ErrorResult deserialize(String serializedError) { + if (serializedError != null) { + var wrapper = JSONUtil.deserializeAppExceptionWrapper(serializedError); + Throwable throwable = JSONUtil.deserializeAppException(serializedError); + return new ErrorResult(wrapper.type, wrapper.message, serializedError, throwable); + } else { + return null; + } } } diff --git a/transact/src/main/java/dev/dbos/transact/workflow/ExportedWorkflow.java b/transact/src/main/java/dev/dbos/transact/workflow/ExportedWorkflow.java new file mode 100644 index 00000000..1b9f8a75 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/workflow/ExportedWorkflow.java @@ -0,0 +1,10 @@ +package dev.dbos.transact.workflow; + +import java.util.List; + +public record ExportedWorkflow( + WorkflowStatus status, + List steps, + List events, + List eventHistory, + List streams) {} diff --git a/transact/src/main/java/dev/dbos/transact/workflow/WorkflowEvent.java b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowEvent.java new file mode 100644 index 00000000..ef0ddd40 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowEvent.java @@ -0,0 +1,3 @@ +package dev.dbos.transact.workflow; + +public record WorkflowEvent(String key, String value) {} diff --git a/transact/src/main/java/dev/dbos/transact/workflow/WorkflowEventHistory.java b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowEventHistory.java new file mode 100644 index 00000000..e5d3895c --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowEventHistory.java @@ -0,0 +1,3 @@ +package dev.dbos.transact.workflow; + +public record WorkflowEventHistory(String key, String value, int stepId) {} diff --git a/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStatus.java b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStatus.java index a8a47bac..52c5cae7 100644 --- a/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStatus.java +++ b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStatus.java @@ -48,4 +48,68 @@ public Duration timeout() { return null; } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + + WorkflowStatus that = (WorkflowStatus) obj; + + return java.util.Objects.equals(workflowId, that.workflowId) + && java.util.Objects.equals(status, that.status) + && java.util.Objects.equals(name, that.name) + && java.util.Objects.equals(className, that.className) + && java.util.Objects.equals(instanceName, that.instanceName) + && java.util.Objects.equals(authenticatedUser, that.authenticatedUser) + && java.util.Objects.equals(assumedRole, that.assumedRole) + && java.util.Arrays.equals(authenticatedRoles, that.authenticatedRoles) + && java.util.Arrays.deepEquals(input, that.input) + && java.util.Objects.equals(output, that.output) + && java.util.Objects.equals(error, that.error) + && java.util.Objects.equals(executorId, that.executorId) + && java.util.Objects.equals(createdAt, that.createdAt) + && java.util.Objects.equals(updatedAt, that.updatedAt) + && java.util.Objects.equals(appVersion, that.appVersion) + && java.util.Objects.equals(appId, that.appId) + && java.util.Objects.equals(recoveryAttempts, that.recoveryAttempts) + && java.util.Objects.equals(queueName, that.queueName) + && java.util.Objects.equals(timeoutMs, that.timeoutMs) + && java.util.Objects.equals(deadlineEpochMs, that.deadlineEpochMs) + && java.util.Objects.equals(startedAtEpochMs, that.startedAtEpochMs) + && java.util.Objects.equals(deduplicationId, that.deduplicationId) + && java.util.Objects.equals(priority, that.priority) + && java.util.Objects.equals(queuePartitionKey, that.queuePartitionKey) + && java.util.Objects.equals(forkedFrom, that.forkedFrom); + } + + @Override + public int hashCode() { + return java.util.Objects.hash( + workflowId, + status, + name, + className, + instanceName, + authenticatedUser, + assumedRole, + java.util.Arrays.hashCode(authenticatedRoles), + java.util.Arrays.deepHashCode(input), + output, + error, + executorId, + createdAt, + updatedAt, + appVersion, + appId, + recoveryAttempts, + queueName, + timeoutMs, + deadlineEpochMs, + startedAtEpochMs, + deduplicationId, + priority, + queuePartitionKey, + forkedFrom); + } } diff --git a/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStream.java b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStream.java new file mode 100644 index 00000000..aed8d105 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/workflow/WorkflowStream.java @@ -0,0 +1,3 @@ +package dev.dbos.transact.workflow; + +public record WorkflowStream(String key, String value, int offset, int stepId) {} diff --git a/transact/src/test/java/dev/dbos/transact/admin/AdminServerTest.java b/transact/src/test/java/dev/dbos/transact/admin/AdminServerTest.java index ed1929cd..e5f204f4 100644 --- a/transact/src/test/java/dev/dbos/transact/admin/AdminServerTest.java +++ b/transact/src/test/java/dev/dbos/transact/admin/AdminServerTest.java @@ -486,7 +486,7 @@ public void listSteps() throws IOException { } steps.add(new StepInfo(3, "step-3", null, null, "child-wfid-3", null, null)); var error = new RuntimeException("error-4"); - steps.add(new StepInfo(4, "step-4", null, ErrorResult.of(error), null, null, null)); + steps.add(new StepInfo(4, "step-4", null, ErrorResult.fromThrowable(error), null, null, null)); when(mockDB.listWorkflowSteps(any())).thenReturn(steps); diff --git a/transact/src/test/java/dev/dbos/transact/conductor/ConductorTest.java b/transact/src/test/java/dev/dbos/transact/conductor/ConductorTest.java index 85520668..9fd167ac 100644 --- a/transact/src/test/java/dev/dbos/transact/conductor/ConductorTest.java +++ b/transact/src/test/java/dev/dbos/transact/conductor/ConductorTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -16,22 +17,30 @@ import dev.dbos.transact.database.SystemDatabase; import dev.dbos.transact.execution.DBOSExecutor; import dev.dbos.transact.utils.WorkflowStatusBuilder; +import dev.dbos.transact.workflow.ExportedWorkflow; import dev.dbos.transact.workflow.ForkOptions; import dev.dbos.transact.workflow.ListWorkflowsInput; import dev.dbos.transact.workflow.StepInfo; +import dev.dbos.transact.workflow.WorkflowEvent; +import dev.dbos.transact.workflow.WorkflowEventHistory; import dev.dbos.transact.workflow.WorkflowHandle; import dev.dbos.transact.workflow.WorkflowState; import dev.dbos.transact.workflow.WorkflowStatus; +import dev.dbos.transact.workflow.WorkflowStream; import dev.dbos.transact.workflow.internal.GetPendingWorkflowsOutput; import java.net.InetAddress; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.time.OffsetDateTime; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -41,12 +50,15 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import org.java_websocket.WebSocket; +import org.java_websocket.enums.Opcode; import org.java_websocket.framing.Framedata; import org.java_websocket.handshake.ClientHandshake; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junitpioneer.jupiter.RetryingTest; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.MockitoAnnotations; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -77,6 +89,8 @@ void beforeEach() throws Exception { mockExec = mock(DBOSExecutor.class); when(mockExec.appName()).thenReturn("test-app-name"); builder = new Conductor.Builder(mockExec, mockDB, "conductor-key").domain(domain); + + MockitoAnnotations.openMocks(this); } @AfterEach @@ -236,16 +250,135 @@ public void onMessage(WebSocket conn, String message) { messageLatch.countDown(); } - public void send(MessageType type, String requestId, Map fields) + private void sendFragmented(String message, int chunkSize) { + byte[] data = message.getBytes(StandardCharsets.UTF_8); + + if (data.length <= chunkSize) { + // Message is small enough, send normally + this.webSocket.send(message); + return; + } + + // Send first fragment + ByteBuffer firstChunk = ByteBuffer.wrap(data, 0, chunkSize); + this.webSocket.sendFragmentedFrame(Opcode.TEXT, firstChunk, false); + + // Send intermediate fragments + int offset = chunkSize; + while (offset < data.length - chunkSize) { + ByteBuffer chunk = ByteBuffer.wrap(data, offset, chunkSize); + this.webSocket.sendFragmentedFrame(Opcode.TEXT, chunk, false); + offset += chunkSize; + } + + // Send final fragment + ByteBuffer lastChunk = ByteBuffer.wrap(data, offset, data.length - offset); + this.webSocket.sendFragmentedFrame(Opcode.TEXT, lastChunk, true); + } + + public void send(MessageType type, String requestId, Map fields, int chunkSize) throws Exception { logger.debug("sending {}", type.getValue()); - Map message = new HashMap<>(fields); + Map message = new LinkedHashMap<>(); message.put("type", Objects.requireNonNull(type).getValue()); message.put("request_id", Objects.requireNonNull(requestId)); + message.putAll(fields); String json = ConductorTest.mapper.writeValueAsString(message); - this.webSocket.send(json); + if (chunkSize > 0) { + sendFragmented(json, chunkSize); + } else { + this.webSocket.send(json); + } + } + + public void send(MessageType type, String requestId, Map fields) + throws Exception { + this.send(type, requestId, fields, 1024); + } + } + + @RetryingTest(3) + public void canHandleChunks() throws Exception { + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + + String hostname = InetAddress.getLocalHost().getHostName(); + + when(mockExec.appVersion()).thenReturn("test-app-version"); + when(mockExec.executorId()).thenReturn("test-executor-id"); + + try (Conductor conductor = builder.build()) { + conductor.start(); + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = Map.of("unknown-field", "unknown-field-value"); + listener.send(MessageType.EXECUTOR_INFO, "12345", message, 10); + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("executor_info", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertEquals(hostname, jsonNode.get("hostname").asText()); + assertEquals("test-app-version", jsonNode.get("application_version").asText()); + assertEquals("test-executor-id", jsonNode.get("executor_id").asText()); + assertEquals("java", jsonNode.get("language").asText()); + assertEquals(DBOS.version(), jsonNode.get("dbos_version").asText()); + assertNull(jsonNode.get("error_message")); + } + } + + @RetryingTest(3) + public void testSendsFragmentedResponse() throws Exception { + class FragmentCountingListener extends MessageListener { + int frameCount = 0; + + @Override + public void onWebsocketMessage(WebSocket conn, Framedata frame) { + if (frame.getOpcode() == Opcode.TEXT || frame.getOpcode() == Opcode.CONTINUOUS) { + frameCount++; + } + } + } + + FragmentCountingListener listener = new FragmentCountingListener(); + testServer.setListener(listener); + + Random random = new Random(); + String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + + // Create a large list of steps to exceed 32KB + List steps = new ArrayList<>(); + for (int i = 0; i < 200; i++) { + var builder = new StringBuilder(1024); + builder.append("output_%d_".formatted(i)); + for (int j = 0; j < 1024; j++) { + builder.append(characters.charAt(random.nextInt(characters.length()))); + } + steps.add(new StepInfo(i, "function" + i, builder.toString(), null, null, null, null)); + } + when(mockExec.listWorkflowSteps("large-wf")).thenReturn(steps); + + try (Conductor conductor = builder.build()) { + conductor.start(); + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = Map.of("workflow_id", "large-wf"); + listener.send(MessageType.LIST_STEPS, "12345", message); + + assertTrue(listener.messageLatch.await(5, TimeUnit.SECONDS), "message latch timed out"); + + // Each StepInfo is roughly 100-200 bytes. 500 steps should be > 50KB. + // 32KB fragment size should result in at least 2 frames. + assertTrue( + listener.frameCount > 1, + "Should have received more than one frame, but got " + listener.frameCount); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertEquals("list_steps", jsonNode.get("type").asText()); + assertEquals(200, jsonNode.get("output").size()); } } @@ -401,6 +534,80 @@ public void canCancelThrows() throws Exception { } } + @RetryingTest(3) + public void canDelete() throws Exception { + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + String workflowId = "sample-wf-id"; + + try (Conductor conductor = builder.build()) { + conductor.start(); + + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = + Map.of( + "workflow_id", + workflowId, + "delete_children", + Boolean.TRUE, + "unknown-field", + "unknown-field-value"); + listener.send(MessageType.DELETE, "12345", message); + + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + + // Verify that deleteWorkflow was called with the correct argument + verify(mockExec).deleteWorkflow(workflowId, true); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("delete", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertNull(jsonNode.get("error_message")); + assertTrue(jsonNode.get("success").asBoolean()); + } + } + + @RetryingTest(3) + public void canDeleteThrows() throws Exception { + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + + String errorMessage = "canDeleteThrows error"; + String workflowId = "sample-wf-id"; + + doThrow(new RuntimeException(errorMessage)) + .when(mockExec) + .deleteWorkflow(anyString(), anyBoolean()); + + try (Conductor conductor = builder.build()) { + conductor.start(); + + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = + Map.of( + "workflow_id", + workflowId, + "delete_children", + Boolean.TRUE, + "unknown-field", + "unknown-field-value"); + listener.send(MessageType.DELETE, "12345", message); + + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + verify(mockExec).deleteWorkflow(workflowId, true); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("delete", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertEquals(errorMessage, jsonNode.get("error_message").asText()); + assertFalse(jsonNode.get("success").asBoolean()); + } + } + @RetryingTest(3) public void canResume() throws Exception { MessageListener listener = new MessageListener(); @@ -1177,4 +1384,233 @@ public void canGetMetricsInvalidMetricThrows() throws Exception { assertEquals(errorMessage, jsonNode.get("error_message").asText()); } } + + @Captor ArgumentCaptor> workflowListCaptor; + + @RetryingTest(3) + public void canImport() throws Exception { + + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + + var workflows = createTestExportedWorkflows(); + var serialized = Conductor.serializeExportedWorkflows(workflows); + + try (Conductor conductor = builder.build()) { + conductor.start(); + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = + Map.of("serialized_workflow", serialized, "unknown-field", "unknown-field-value"); + listener.send(MessageType.IMPORT_WORKFLOW, "12345", message); + + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + + verify(mockDB).importWorkflow(workflowListCaptor.capture()); + assertTrue(workflows.equals(workflowListCaptor.getValue())); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("import_workflow", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertNull(jsonNode.get("error_message")); + assertTrue(jsonNode.get("success").asBoolean()); + } + } + + @RetryingTest(3) + public void canImportThrows() throws Exception { + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + + String errorMessage = "canImportThrows error"; + doThrow(new RuntimeException(errorMessage)).when(mockDB).importWorkflow(any()); + + var workflows = createTestExportedWorkflows(); + var serialized = Conductor.serializeExportedWorkflows(workflows); + + try (Conductor conductor = builder.build()) { + conductor.start(); + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = + Map.of("serialized_workflow", serialized, "unknown-field", "unknown-field-value"); + listener.send(MessageType.IMPORT_WORKFLOW, "12345", message); + + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + + verify(mockDB).importWorkflow(any()); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("import_workflow", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertEquals(errorMessage, jsonNode.get("error_message").asText()); + assertFalse(jsonNode.get("success").asBoolean()); + } + } + + @RetryingTest(3) + public void canExportThrows() throws Exception { + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + + String errorMessage = "canExportThrows error"; + doThrow(new RuntimeException(errorMessage)) + .when(mockDB) + .exportWorkflow(anyString(), anyBoolean()); + + try (Conductor conductor = builder.build()) { + conductor.start(); + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = + Map.of( + "workflow_id", + "abc-123", + "export_children", + true, + "unknown-field", + "unknown-field-value"); + listener.send(MessageType.EXPORT_WORKFLOW, "12345", message); + + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + verify(mockDB).exportWorkflow("abc-123", true); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("export_workflow", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertEquals(errorMessage, jsonNode.get("error_message").asText()); + assertNull(jsonNode.get("serialized_workflow")); + } + } + + @RetryingTest(3) + public void canExport() throws Exception { + + MessageListener listener = new MessageListener(); + testServer.setListener(listener); + + var workflows = createTestExportedWorkflows(); + var serialized = Conductor.serializeExportedWorkflows(workflows); + + when(mockDB.exportWorkflow(anyString(), anyBoolean())).thenReturn(workflows); + + try (Conductor conductor = builder.build()) { + conductor.start(); + assertTrue(listener.openLatch.await(5, TimeUnit.SECONDS), "open latch timed out"); + + Map message = + Map.of( + "workflow_id", + "abc-123", + "export_children", + true, + "unknown-field", + "unknown-field-value"); + listener.send(MessageType.EXPORT_WORKFLOW, "12345", message); + + assertTrue(listener.messageLatch.await(1, TimeUnit.SECONDS), "message latch timed out"); + verify(mockDB).exportWorkflow("abc-123", true); + + JsonNode jsonNode = mapper.readTree(listener.message); + assertNotNull(jsonNode); + assertEquals("export_workflow", jsonNode.get("type").asText()); + assertEquals("12345", jsonNode.get("request_id").asText()); + assertNull(jsonNode.get("error_message")); + assertEquals(serialized, jsonNode.get("serialized_workflow").asText()); + } + } + + private static ExportedWorkflow createTestExportedWorkflow(int index) { + String suffix = index > 0 ? "-" + index : ""; + WorkflowStatus status = + new WorkflowStatusBuilder( + "test-workflow-id-%d%s".formatted(System.currentTimeMillis(), suffix)) + .status( + index > 0 + ? WorkflowState.values()[index % WorkflowState.values().length] + : WorkflowState.SUCCESS) + .name("TestWorkflow" + (index > 0 ? (index + 1) : "")) + .className("dev.dbos.transact.test.TestClass" + (index > 0 ? (index + 1) : "")) + .instanceName("test-instance" + (index > 0 ? "-" + (index + 1) : "")) + .authenticatedUser("test-user" + (index > 0 ? "-" + (index + 1) : "")) + .assumedRole("test-role" + (index > 0 ? "-" + (index + 1) : "")) + .authenticatedRoles(new String[] {"role1", "role2"}) + .input(new Object[] {"input1", "input2"}) + .output("test-output" + (index > 0 ? "-" + (index + 1) : "")) + .error(null) + .executorId("test-executor" + (index > 0 ? "-" + (index + 1) : "")) + .createdAt(System.currentTimeMillis() - (5000L * (index + 1))) + .updatedAt(System.currentTimeMillis() - (1000L * (index + 1))) + .appVersion(index > 0 ? "1." + index + ".0" : "1.0.0") + .appId("test-app" + (index > 0 ? "-" + (index + 1) : "")) + .recoveryAttempts(index) + .queueName("test-queue" + (index > 0 ? "-" + (index + 1) : "")) + .timeoutMs(30000L + (index * 5000L)) + .deadlineEpochMs(System.currentTimeMillis() + (60000L * (index + 1))) + .startedAtEpochMs(System.currentTimeMillis() - (index * 1000L)) + .deduplicationId("test-dedup-id" + (index > 0 ? "-" + (index + 1) : "")) + .priority(index + 1) + .partitionKey("test-partition" + (index > 0 ? "-" + (index + 1) : "")) + .forkedFrom(index > 0 ? "parent-workflow-" + index : null) + .build(); + + int stepCount = (int) (Math.random() * 8) + 2; + List steps = new ArrayList<>(); + long currentTime = System.currentTimeMillis() + (index * 10000L); + String prefix = index > 0 ? "wf" + (index + 1) + "_" : ""; + for (int i = 0; i < stepCount; i++) { + steps.add( + new StepInfo( + i, + prefix + "function" + (i + 1), + prefix + "result" + (i + 1), + null, + null, + currentTime + (i * 1000), + currentTime + ((i + 1) * 1000))); + } + + int eventCount = (int) (Math.random() * 8) + 2; + List events = new ArrayList<>(); + for (int i = 0; i < eventCount; i++) { + events.add(new WorkflowEvent(prefix + "event" + (i + 1), prefix + "value" + (i + 1))); + } + + int historyCount = (int) (Math.random() * 8) + 2; + List eventHistory = new ArrayList<>(); + for (int i = 0; i < historyCount; i++) { + int stepId = i % Math.max(1, stepCount); // Distribute across available steps + String eventKey = + eventCount > 0 ? prefix + "event" + ((i % eventCount) + 1) : prefix + "event" + (i + 1); + eventHistory.add( + new WorkflowEventHistory(eventKey, prefix + "historyvalue" + (i + 1), stepId)); + } + + int streamCount = (int) (Math.random() * 8) + 2; + List streams = new ArrayList<>(); + for (int i = 0; i < streamCount; i++) { + int stepId = i % Math.max(1, stepCount); // Distribute across available steps + int offset = i % 3; // Vary offset between 0-2 + String streamKey = prefix + "stream" + ((i % 3) + 1); // Use 3 different stream keys + streams.add(new WorkflowStream(streamKey, prefix + "streamvalue" + (i + 1), offset, stepId)); + } + + return new ExportedWorkflow(status, steps, events, eventHistory, streams); + } + + // Helper method to create multiple test ExportedWorkflow instances + private static List createTestExportedWorkflows() { + // Create a random number of workflows (1-5) + int workflowCount = (int) (Math.random() * 5) + 2; + List workflows = new ArrayList<>(); + + for (int i = 0; i < workflowCount; i++) { + workflows.add(createTestExportedWorkflow(i)); + } + + return workflows; + } } diff --git a/transact/src/test/java/dev/dbos/transact/conductor/TestWebSocketServer.java b/transact/src/test/java/dev/dbos/transact/conductor/TestWebSocketServer.java index f06a3d19..564d17c3 100644 --- a/transact/src/test/java/dev/dbos/transact/conductor/TestWebSocketServer.java +++ b/transact/src/test/java/dev/dbos/transact/conductor/TestWebSocketServer.java @@ -1,10 +1,14 @@ package dev.dbos.transact.conductor; import java.net.InetSocketAddress; +import java.util.Collections; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import org.java_websocket.WebSocket; +import org.java_websocket.WebSocketImpl; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.exceptions.InvalidDataException; import org.java_websocket.framing.Framedata; import org.java_websocket.framing.PingFrame; import org.java_websocket.framing.PongFrame; @@ -28,6 +32,8 @@ default void onOpen(WebSocket conn, ClientHandshake handshake) {} default void onMessage(WebSocket conn, String message) {} + default void onWebsocketMessage(WebSocket conn, Framedata frame) {} + default void onClose(WebSocket conn, int code, String reason, boolean remote) {} } @@ -35,8 +41,33 @@ default void onClose(WebSocket conn, int code, String reason, boolean remote) {} private WebSocketTestListener listener; private Semaphore startEvent = new Semaphore(0); + private static class InterceptingDraft extends Draft_6455 { + TestWebSocketServer server; + + @Override + public void processFrame(WebSocketImpl webSocketImpl, Framedata frame) + throws InvalidDataException { + if (server != null && server.listener != null) { + server.listener.onWebsocketMessage(webSocketImpl, frame); + } + super.processFrame(webSocketImpl, frame); + } + + @Override + public Draft_6455 copyInstance() { + InterceptingDraft copy = new InterceptingDraft(); + copy.server = this.server; + return copy; + } + } + public TestWebSocketServer(int port) { - super(new InetSocketAddress(port)); + this(port, new InterceptingDraft()); + } + + private TestWebSocketServer(int port, InterceptingDraft draft) { + super(new InetSocketAddress(port), Collections.singletonList(draft)); + draft.server = this; } public void setListener(WebSocketTestListener listener) { diff --git a/transact/src/test/java/dev/dbos/transact/database/SystemDatabaseTest.java b/transact/src/test/java/dev/dbos/transact/database/SystemDatabaseTest.java index abd1223b..a77bdef9 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SystemDatabaseTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SystemDatabaseTest.java @@ -52,6 +52,67 @@ void afterEachTest() throws Exception { sysdb.close(); } + @Test + public void testDeleteWorkflows() throws Exception { + for (var i = 0; i < 5; i++) { + var wfid = "wfid-%d".formatted(i); + var status = WorkflowStatusInternal.builder(wfid, WorkflowState.PENDING).build(); + sysdb.initWorkflowStatus(status, 5, false, false); + } + + var rows = DBUtils.getWorkflowRows(dataSource); + assertEquals(5, rows.size()); + + sysdb.deleteWorkflows("wfid-1", "wfid-3"); + + rows = DBUtils.getWorkflowRows(dataSource); + assertEquals(3, rows.size()); + + assertTrue(rows.stream().noneMatch(r -> r.workflowId().equals("wfid-1"))); + assertTrue(rows.stream().noneMatch(r -> r.workflowId().equals("wfid-3"))); + + assertTrue(rows.stream().anyMatch(r -> r.workflowId().equals("wfid-0"))); + assertTrue(rows.stream().anyMatch(r -> r.workflowId().equals("wfid-2"))); + assertTrue(rows.stream().anyMatch(r -> r.workflowId().equals("wfid-4"))); + } + + @Test + public void testGetChildWorkflows() throws Exception { + for (var i = 0; i < 5; i++) { + var wfid = "wfid-%d".formatted(i); + var status = WorkflowStatusInternal.builder(wfid, WorkflowState.PENDING).build(); + sysdb.initWorkflowStatus(status, 5, false, false); + } + + for (var i = 0; i < 5; i++) { + var parentWfId = "wfid-2"; + var wfid = "childwfid-%d".formatted(i); + var status = WorkflowStatusInternal.builder(wfid, WorkflowState.PENDING).build(); + sysdb.initWorkflowStatus(status, 5, false, false); + sysdb.recordChildWorkflow( + parentWfId, wfid, i, "step-%d".formatted(i), System.currentTimeMillis()); + } + + for (var i = 0; i < 5; i++) { + var parentWfId = "childwfid-%d".formatted(i); + var wfid = "grandchildwfid-%d".formatted(i); + var status = WorkflowStatusInternal.builder(wfid, WorkflowState.PENDING).build(); + sysdb.initWorkflowStatus(status, 5, false, false); + sysdb.recordChildWorkflow( + parentWfId, wfid, i, "step-%d".formatted(i), System.currentTimeMillis()); + } + + var children = sysdb.getWorkflowChildrenInternal("wfid-2"); + assertEquals(10, children.size()); + + for (var i = 0; i < 5; i++) { + var child = "childwfid-%d".formatted(i); + var grandchild = "grandchildwfid-%d".formatted(i); + assertTrue(children.stream().anyMatch(r -> r.equals(child))); + assertTrue(children.stream().anyMatch(r -> r.equals(grandchild))); + } + } + @Test public void testRetries() throws Exception { var wfid = "wfid-1"; diff --git a/transact/src/test/java/dev/dbos/transact/invocation/PatchTest.java b/transact/src/test/java/dev/dbos/transact/invocation/PatchTest.java index 1d7f090d..ec8c6c50 100644 --- a/transact/src/test/java/dev/dbos/transact/invocation/PatchTest.java +++ b/transact/src/test/java/dev/dbos/transact/invocation/PatchTest.java @@ -16,8 +16,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; interface PatchService { int workflow(); @@ -103,8 +101,6 @@ public int workflowB() { @org.junit.jupiter.api.Timeout(value = 2, unit = java.util.concurrent.TimeUnit.MINUTES) public class PatchTest { - private static final Logger logger = LoggerFactory.getLogger(PatchTest.class); - @AfterEach void afterEachTest() throws Exception { DBOS.shutdown(); @@ -113,12 +109,9 @@ void afterEachTest() throws Exception { @Test public void testPatch() throws Exception { - // Note, for this test we have to manually update the workflow name when forking across - // versions. This requires pausing and unpausing the queue service to ensure the forked - // workflow isn't executed until the workflow name is updated. - - // This hack is required because we can't have multiple service implementations with the same - // name the way you can in a dynamic programming language like python. + // Note, we are simulating the patch service changing over time. + // We have multiple implementations, each aliased to "PatchService" via @WorkflowClassName. + // This allows us to reinitialize and re-register workflows during the test. // In production, developers would be expected to be updating services in place, so they would // have the same workflow name across deployed versions. @@ -138,7 +131,6 @@ public void testPatch() throws Exception { DBOS.launch(); assertEquals("test-version", DBOSTestAccess.getDbosExecutor().appVersion()); - var queueService = DBOSTestAccess.getQueueService(); // Register and run the first version of a workflow var h1 = DBOS.startWorkflow(() -> proxy1.workflow(), new StartWorkflowOptions("impl1")); @@ -306,6 +298,7 @@ public void mulipleDefinitions() throws Exception { DBUtils.recreateDB(dbosConfig); DBOS.reinitialize(dbosConfig); + @SuppressWarnings("unused") var proxy5 = DBOS.registerWorkflows(PatchService.class, new PatchServiceImplFive()); assertThrows( IllegalStateException.class, diff --git a/transact/src/test/java/dev/dbos/transact/utils/WorkflowStatusBuilder.java b/transact/src/test/java/dev/dbos/transact/utils/WorkflowStatusBuilder.java index e6d76689..34be02b7 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/WorkflowStatusBuilder.java +++ b/transact/src/test/java/dev/dbos/transact/utils/WorkflowStatusBuilder.java @@ -9,7 +9,6 @@ public class WorkflowStatusBuilder { private String workflowId; private String status; - private String forkedFrom; private String name; private String className; @@ -39,6 +38,7 @@ public class WorkflowStatusBuilder { private Long timeoutMs; private Long deadlineEpochMs; + private String forkedFrom; public WorkflowStatus build() { return new WorkflowStatus( @@ -83,11 +83,6 @@ public WorkflowStatusBuilder status(WorkflowState state) { return this; } - public WorkflowStatusBuilder forkedFrom(String forkedFrom) { - this.forkedFrom = forkedFrom; - return this; - } - public WorkflowStatusBuilder name(String name) { this.name = name; return this; @@ -114,7 +109,7 @@ public WorkflowStatusBuilder output(Object output) { } public WorkflowStatusBuilder error(Throwable error) { - this.error = ErrorResult.of(error); + this.error = ErrorResult.fromThrowable(error); return this; } @@ -197,4 +192,9 @@ public WorkflowStatusBuilder deadlineEpochMs(Long deadlineEpochMs) { this.deadlineEpochMs = deadlineEpochMs; return this; } + + public WorkflowStatusBuilder forkedFrom(String forkedFrom) { + this.forkedFrom = forkedFrom; + return this; + } }