diff --git a/kit-java-grpc/pom.xml b/kit-java-grpc/pom.xml index cde1e91..df7ef42 100644 --- a/kit-java-grpc/pom.xml +++ b/kit-java-grpc/pom.xml @@ -23,6 +23,12 @@ provided + + + org.slf4j + slf4j-api + + io.grpc @@ -62,6 +68,17 @@ mockito-core test + + org.mockito + mockito-subclass + test + + + ch.qos.logback + logback-classic + 1.5.12 + test + io.grpc grpc-testing diff --git a/kit-java-grpc/src/main/java/dev/suprim/kit/grpc/ContextForwardingInterceptor.java b/kit-java-grpc/src/main/java/dev/suprim/kit/grpc/ContextForwardingInterceptor.java new file mode 100644 index 0000000..f515971 --- /dev/null +++ b/kit-java-grpc/src/main/java/dev/suprim/kit/grpc/ContextForwardingInterceptor.java @@ -0,0 +1,56 @@ +package dev.suprim.kit.grpc; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; + +import java.util.Optional; + +/** + * gRPC client interceptor that forwards trace/request IDs from the current gRPC {@link io.grpc.Context} + * to outgoing call metadata. + *

+ * This ensures trace context is propagated across service-to-service gRPC calls. + *

+ * Registration example: + *

+ * ManagedChannel channel = ManagedChannelBuilder.forTarget("localhost:9090")
+ *     .intercept(new ContextForwardingInterceptor())
+ *     .build();
+ * 
+ */ +public class ContextForwardingInterceptor implements ClientInterceptor { + + @Override + public ClientCall interceptCall( + MethodDescriptor method, + CallOptions callOptions, + Channel next) { + + return new ForwardingClientCall.SimpleForwardingClientCall<>(next.newCall(method, callOptions)) { + @Override + public void start(Listener responseListener, Metadata headers) { + propagateContextToHeaders(headers); + super.start(responseListener, headers); + } + }; + } + + private void propagateContextToHeaders(Metadata headers) { + Optional.ofNullable(GrpcContext.getTraceId()) + .ifPresent(value -> headers.put(MetadataUtils.TRACE_ID, value)); + + Optional.ofNullable(GrpcContext.getRequestId()) + .ifPresent(value -> headers.put(MetadataUtils.REQUEST_ID, value)); + + Optional.ofNullable(GrpcContext.getUserId()) + .ifPresent(value -> headers.put(MetadataUtils.USER_ID, value)); + + Optional.ofNullable(GrpcContext.getTenantId()) + .ifPresent(value -> headers.put(MetadataUtils.TENANT_ID, value)); + } +} diff --git a/kit-java-grpc/src/main/java/dev/suprim/kit/grpc/ContextPropagationInterceptor.java b/kit-java-grpc/src/main/java/dev/suprim/kit/grpc/ContextPropagationInterceptor.java new file mode 100644 index 0000000..6ee5b29 --- /dev/null +++ b/kit-java-grpc/src/main/java/dev/suprim/kit/grpc/ContextPropagationInterceptor.java @@ -0,0 +1,107 @@ +package dev.suprim.kit.grpc; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import org.slf4j.MDC; + +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; + +/** + * gRPC server interceptor that extracts trace/request IDs from incoming metadata, + * attaches them to the gRPC {@link Context}, and sets SLF4J MDC for log correlation. + *

+ * If trace ID or request ID is absent in metadata, a new UUID v4 is generated. + *

+ * Registration example: + *

+ * Server server = ServerBuilder.forPort(9090)
+ *     .addService(new MyServiceImpl())
+ *     .intercept(new ContextPropagationInterceptor())
+ *     .intercept(new ExceptionInterceptor())
+ *     .build();
+ * 
+ */ +public class ContextPropagationInterceptor implements ServerInterceptor { + + public static final String MDC_TRACE_ID = "traceId"; + public static final String MDC_REQUEST_ID = "requestId"; + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + + Objects.requireNonNull(headers, "headers"); + + String traceId = resolveFromMetadata(headers, MetadataUtils.TRACE_ID); + String requestId = resolveFromMetadata(headers, MetadataUtils.REQUEST_ID); + + Context context = Context.current() + .withValue(GrpcContext.TRACE_ID, traceId) + .withValue(GrpcContext.REQUEST_ID, requestId); + + // Also extract user/tenant if present + context = attachIfPresent(context, headers, MetadataUtils.USER_ID, GrpcContext.USER_ID); + context = attachIfPresent(context, headers, MetadataUtils.TENANT_ID, GrpcContext.TENANT_ID); + + // Set MDC for log correlation within this call + MDC.put(MDC_TRACE_ID, traceId); + MDC.put(MDC_REQUEST_ID, requestId); + + return new MdcCleanupListener<>(Contexts.interceptCall(context, call, headers, next)); + } + + private String resolveFromMetadata(Metadata headers, Metadata.Key key) { + return Optional.ofNullable(MetadataUtils.getString(headers, key)) + .filter(value -> !value.isBlank()) + .orElseGet(() -> UUID.randomUUID().toString()); + } + + private Context attachIfPresent(Context context, Metadata headers, + Metadata.Key metadataKey, Context.Key contextKey) { + return Optional.ofNullable(MetadataUtils.getString(headers, metadataKey)) + .map(value -> context.withValue(contextKey, value)) + .orElse(context); + } + + /** + * Listener wrapper that cleans up MDC when the call completes or is cancelled. + */ + private static class MdcCleanupListener + extends io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener { + + MdcCleanupListener(ServerCall.Listener delegate) { + super(delegate); + } + + @Override + public void onComplete() { + try { + super.onComplete(); + } finally { + clearMdc(); + } + } + + @Override + public void onCancel() { + try { + super.onCancel(); + } finally { + clearMdc(); + } + } + + private static void clearMdc() { + MDC.remove(MDC_TRACE_ID); + MDC.remove(MDC_REQUEST_ID); + } + } +} diff --git a/kit-java-grpc/src/test/java/dev/suprim/kit/grpc/ContextForwardingInterceptorTest.java b/kit-java-grpc/src/test/java/dev/suprim/kit/grpc/ContextForwardingInterceptorTest.java new file mode 100644 index 0000000..8e4ccaf --- /dev/null +++ b/kit-java-grpc/src/test/java/dev/suprim/kit/grpc/ContextForwardingInterceptorTest.java @@ -0,0 +1,154 @@ +package dev.suprim.kit.grpc; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.Context; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; + +class ContextForwardingInterceptorTest { + + private ContextForwardingInterceptor interceptor; + private MethodDescriptor methodDescriptor; + private CapturingChannel channel; + + @BeforeEach + void setUp() { + interceptor = new ContextForwardingInterceptor(); + methodDescriptor = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("test/method") + .setRequestMarshaller(new StringMarshaller()) + .setResponseMarshaller(new StringMarshaller()) + .build(); + channel = new CapturingChannel(); + } + + @Test + void shouldForwardTraceIdFromContext() { + Context context = Context.current() + .withValue(GrpcContext.TRACE_ID, "forwarded-trace") + .withValue(GrpcContext.REQUEST_ID, "forwarded-req"); + + Context previous = context.attach(); + try { + ClientCall call = interceptor.interceptCall( + methodDescriptor, CallOptions.DEFAULT, channel); + + Metadata headers = new Metadata(); + call.start(new NoopListener<>(), headers); + + assertEquals("forwarded-trace", headers.get(MetadataUtils.TRACE_ID)); + assertEquals("forwarded-req", headers.get(MetadataUtils.REQUEST_ID)); + } finally { + context.detach(previous); + } + } + + @Test + void shouldForwardUserAndTenantFromContext() { + Context context = Context.current() + .withValue(GrpcContext.TRACE_ID, "trace") + .withValue(GrpcContext.REQUEST_ID, "req") + .withValue(GrpcContext.USER_ID, "user-123") + .withValue(GrpcContext.TENANT_ID, "tenant-456"); + + Context previous = context.attach(); + try { + ClientCall call = interceptor.interceptCall( + methodDescriptor, CallOptions.DEFAULT, channel); + + Metadata headers = new Metadata(); + call.start(new NoopListener<>(), headers); + + assertEquals("user-123", headers.get(MetadataUtils.USER_ID)); + assertEquals("tenant-456", headers.get(MetadataUtils.TENANT_ID)); + } finally { + context.detach(previous); + } + } + + @Test + void shouldNotSetHeadersWhenContextEmpty() { + ClientCall call = interceptor.interceptCall( + methodDescriptor, CallOptions.DEFAULT, channel); + + Metadata headers = new Metadata(); + call.start(new NoopListener<>(), headers); + + assertNull(headers.get(MetadataUtils.TRACE_ID)); + assertNull(headers.get(MetadataUtils.REQUEST_ID)); + assertNull(headers.get(MetadataUtils.USER_ID)); + assertNull(headers.get(MetadataUtils.TENANT_ID)); + } + + private static class CapturingChannel extends Channel { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + return new NoopClientCall<>(); + } + + @Override + public String authority() { + return "test-authority"; + } + } + + private static class NoopClientCall extends ClientCall { + @Override + public void start(Listener responseListener, Metadata headers) { + // no-op + } + + @Override + public void request(int numMessages) { + // no-op + } + + @Override + public void cancel(String message, Throwable cause) { + // no-op + } + + @Override + public void halfClose() { + // no-op + } + + @Override + public void sendMessage(ReqT message) { + // no-op + } + } + + private static class NoopListener extends ClientCall.Listener { + // no-op + } + + private static class StringMarshaller implements MethodDescriptor.Marshaller { + @Override + public InputStream stream(String value) { + return new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public String parse(InputStream stream) { + try { + return new String(stream.readAllBytes(), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/kit-java-grpc/src/test/java/dev/suprim/kit/grpc/ContextPropagationInterceptorTest.java b/kit-java-grpc/src/test/java/dev/suprim/kit/grpc/ContextPropagationInterceptorTest.java new file mode 100644 index 0000000..99b94fa --- /dev/null +++ b/kit-java-grpc/src/test/java/dev/suprim/kit/grpc/ContextPropagationInterceptorTest.java @@ -0,0 +1,151 @@ +package dev.suprim.kit.grpc; + +import io.grpc.Context; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.Status; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.slf4j.MDC; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +class ContextPropagationInterceptorTest { + + private ContextPropagationInterceptor interceptor; + + @Mock + private ServerCall call; + + @Mock + private ServerCallHandler next; + + @Mock + private ServerCall.Listener listener; + + private AutoCloseable mocks; + + @BeforeEach + void setUp() { + mocks = MockitoAnnotations.openMocks(this); + interceptor = new ContextPropagationInterceptor(); + when(next.startCall(any(), any())).thenReturn(listener); + } + + @AfterEach + void tearDown() throws Exception { + MDC.clear(); + mocks.close(); + } + + @Test + void shouldExtractTraceIdFromMetadata() { + Metadata headers = new Metadata(); + headers.put(MetadataUtils.TRACE_ID, "grpc-trace-123"); + headers.put(MetadataUtils.REQUEST_ID, "grpc-req-456"); + + interceptor.interceptCall(call, headers, next); + + assertEquals("grpc-trace-123", MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID)); + assertEquals("grpc-req-456", MDC.get(ContextPropagationInterceptor.MDC_REQUEST_ID)); + } + + @Test + void shouldGenerateIdsWhenMetadataAbsent() { + Metadata headers = new Metadata(); + + interceptor.interceptCall(call, headers, next); + + String traceId = MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID); + String requestId = MDC.get(ContextPropagationInterceptor.MDC_REQUEST_ID); + + assertNotNull(traceId); + assertNotNull(requestId); + assertFalse(traceId.isBlank()); + assertFalse(requestId.isBlank()); + assertNotEquals(traceId, requestId); + } + + @Test + void shouldGenerateIdsWhenMetadataBlank() { + Metadata headers = new Metadata(); + headers.put(MetadataUtils.TRACE_ID, " "); + headers.put(MetadataUtils.REQUEST_ID, ""); + + interceptor.interceptCall(call, headers, next); + + String traceId = MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID); + assertNotNull(traceId); + assertNotEquals(" ", traceId); + } + + @Test + void shouldAttachUserAndTenantToContext() { + Metadata headers = new Metadata(); + headers.put(MetadataUtils.TRACE_ID, "trace"); + headers.put(MetadataUtils.REQUEST_ID, "req"); + headers.put(MetadataUtils.USER_ID, "user-789"); + headers.put(MetadataUtils.TENANT_ID, "tenant-abc"); + + // Capture the context that next.startCall receives + when(next.startCall(any(), any())).thenAnswer(invocation -> { + assertEquals("user-789", GrpcContext.USER_ID.get()); + assertEquals("tenant-abc", GrpcContext.TENANT_ID.get()); + assertEquals("trace", GrpcContext.TRACE_ID.get()); + assertEquals("req", GrpcContext.REQUEST_ID.get()); + return listener; + }); + + interceptor.interceptCall(call, headers, next); + } + + @Test + void shouldRejectNullHeaders() { + assertThrows(NullPointerException.class, () -> interceptor.interceptCall(call, null, next)); + } + + @Test + void shouldCleanMdcOnComplete() { + Metadata headers = new Metadata(); + headers.put(MetadataUtils.TRACE_ID, "trace-complete"); + headers.put(MetadataUtils.REQUEST_ID, "req-complete"); + + ServerCall.Listener wrappedListener = interceptor.interceptCall(call, headers, next); + + // MDC should be set after interceptCall + assertEquals("trace-complete", MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID)); + assertEquals("req-complete", MDC.get(ContextPropagationInterceptor.MDC_REQUEST_ID)); + + // Simulate call completion + wrappedListener.onComplete(); + + // MDC should be cleaned + assertNull(MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID)); + assertNull(MDC.get(ContextPropagationInterceptor.MDC_REQUEST_ID)); + } + + @Test + void shouldCleanMdcOnCancel() { + Metadata headers = new Metadata(); + headers.put(MetadataUtils.TRACE_ID, "trace-cancel"); + headers.put(MetadataUtils.REQUEST_ID, "req-cancel"); + + ServerCall.Listener wrappedListener = interceptor.interceptCall(call, headers, next); + + // MDC should be set after interceptCall + assertEquals("trace-cancel", MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID)); + assertEquals("req-cancel", MDC.get(ContextPropagationInterceptor.MDC_REQUEST_ID)); + + // Simulate call cancellation + wrappedListener.onCancel(); + + // MDC should be cleaned + assertNull(MDC.get(ContextPropagationInterceptor.MDC_TRACE_ID)); + assertNull(MDC.get(ContextPropagationInterceptor.MDC_REQUEST_ID)); + } +} diff --git a/kit-java-web/pom.xml b/kit-java-web/pom.xml index aeeda50..7ab48a6 100644 --- a/kit-java-web/pom.xml +++ b/kit-java-web/pom.xml @@ -28,6 +28,12 @@ kit-java-exception + + + org.slf4j + slf4j-api + + jakarta.servlet @@ -65,6 +71,12 @@ + + ch.qos.logback + logback-classic + 1.5.12 + test + org.junit.jupiter junit-jupiter diff --git a/kit-java-web/src/main/java/dev/suprim/kit/web/context/ContextPropagation.java b/kit-java-web/src/main/java/dev/suprim/kit/web/context/ContextPropagation.java new file mode 100644 index 0000000..1ba2dcb --- /dev/null +++ b/kit-java-web/src/main/java/dev/suprim/kit/web/context/ContextPropagation.java @@ -0,0 +1,78 @@ +package dev.suprim.kit.web.context; + +import org.slf4j.MDC; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.Callable; + +/** + * Utilities for propagating {@link RequestContext} and MDC across async boundaries. + *

+ * When spawning async tasks (thread pools, CompletableFuture, etc.), the trace context + * is lost because it lives in ThreadLocal. These wrappers capture and restore it. + *

+ * Usage: + *

+ *   executor.submit(ContextPropagation.wrap(() -> {
+ *       // RequestContext and MDC are available here
+ *       String traceId = RequestContext.getTraceId().orElse("unknown");
+ *   }));
+ * 
+ */ +public final class ContextPropagation { + + private ContextPropagation() { + throw new UnsupportedOperationException("Utility class cannot be instantiated"); + } + + /** + * Wraps a Runnable to propagate the current RequestContext and MDC to the executing thread. + * + * @param task the task to wrap, must not be null + * @return a context-aware Runnable + */ + public static Runnable wrap(Runnable task) { + Objects.requireNonNull(task, "task"); + Map contextSnapshot = RequestContext.snapshot(); + Map mdcSnapshot = MDC.getCopyOfContextMap(); + return () -> { + RequestContext.restore(contextSnapshot); + setMdcContext(mdcSnapshot); + try { + task.run(); + } finally { + RequestContext.clear(); + MDC.clear(); + } + }; + } + + /** + * Wraps a Callable to propagate the current RequestContext and MDC to the executing thread. + * + * @param task the task to wrap, must not be null + * @return a context-aware Callable + */ + public static Callable wrap(Callable task) { + Objects.requireNonNull(task, "task"); + Map contextSnapshot = RequestContext.snapshot(); + Map mdcSnapshot = MDC.getCopyOfContextMap(); + return () -> { + RequestContext.restore(contextSnapshot); + setMdcContext(mdcSnapshot); + try { + return task.call(); + } finally { + RequestContext.clear(); + MDC.clear(); + } + }; + } + + private static void setMdcContext(Map mdcSnapshot) { + Optional.ofNullable(mdcSnapshot) + .ifPresentOrElse(MDC::setContextMap, MDC::clear); + } +} diff --git a/kit-java-web/src/main/java/dev/suprim/kit/web/context/RequestContext.java b/kit-java-web/src/main/java/dev/suprim/kit/web/context/RequestContext.java new file mode 100644 index 0000000..96ca697 --- /dev/null +++ b/kit-java-web/src/main/java/dev/suprim/kit/web/context/RequestContext.java @@ -0,0 +1,92 @@ +package dev.suprim.kit.web.context; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * Holds request-scoped context (trace ID, request ID, etc.) via ThreadLocal. + *

+ * Lifecycle is managed by {@link RequestContextFilter} for HTTP requests. + * For gRPC, use the corresponding interceptor in kit-java-grpc. + *

+ * Usage: + *

+ *   String traceId = RequestContext.getTraceId().orElse("unknown");
+ * 
+ */ +public final class RequestContext { + + public static final String KEY_TRACE_ID = "traceId"; + public static final String KEY_REQUEST_ID = "requestId"; + + private static final ThreadLocal> CONTEXT = ThreadLocal.withInitial(HashMap::new); + + private RequestContext() { + throw new UnsupportedOperationException("Utility class cannot be instantiated"); + } + + /** + * Sets a value in the current request context. + * + * @param key context key, must not be null + * @param value context value, must not be null + */ + public static void set(String key, String value) { + Objects.requireNonNull(key, "key"); + Objects.requireNonNull(value, "value"); + CONTEXT.get().put(key, value); + } + + /** + * Gets a value from the current request context. + */ + public static Optional get(String key) { + Objects.requireNonNull(key, "key"); + return Optional.ofNullable(CONTEXT.get().get(key)); + } + + /** + * Gets the trace ID for the current request. + */ + public static Optional getTraceId() { + return get(KEY_TRACE_ID); + } + + /** + * Gets the request ID for the current request. + */ + public static Optional getRequestId() { + return get(KEY_REQUEST_ID); + } + + /** + * Returns an unmodifiable snapshot of the current context. + * Useful for propagating context to async tasks. + */ + public static Map snapshot() { + return Collections.unmodifiableMap(new HashMap<>(CONTEXT.get())); + } + + /** + * Restores a previously captured context snapshot. + * Typically used in async task execution. + * + * @param contextSnapshot snapshot to restore, must not be null + */ + public static void restore(Map contextSnapshot) { + Objects.requireNonNull(contextSnapshot, "contextSnapshot"); + clear(); + CONTEXT.get().putAll(contextSnapshot); + } + + /** + * Clears the current request context. Must be called at the end of request processing + * to prevent memory leaks. + */ + public static void clear() { + CONTEXT.remove(); + } +} diff --git a/kit-java-web/src/main/java/dev/suprim/kit/web/context/RequestContextFilter.java b/kit-java-web/src/main/java/dev/suprim/kit/web/context/RequestContextFilter.java new file mode 100644 index 0000000..915020a --- /dev/null +++ b/kit-java-web/src/main/java/dev/suprim/kit/web/context/RequestContextFilter.java @@ -0,0 +1,91 @@ +package dev.suprim.kit.web.context; + +import dev.suprim.kit.core.UUIDUtils; +import dev.suprim.kit.web.HttpConstants; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.MDC; + +import java.io.IOException; +import java.util.Optional; + +/** + * Servlet filter that manages trace/request ID propagation for HTTP requests. + *

+ * Behavior: + *

    + *
  1. Extracts trace ID from incoming {@code X-Trace-ID} header, or generates a new UUID v7 if absent
  2. + *
  3. Extracts request ID from {@code X-Request-ID} header, or generates a new UUID v7 if absent
  4. + *
  5. Stores both in {@link RequestContext} (ThreadLocal) and SLF4J MDC
  6. + *
  7. Sets {@code X-Trace-ID} and {@code X-Request-ID} response headers
  8. + *
  9. Cleans up ThreadLocal and MDC after request completes
  10. + *
+ *

+ * Registration example (Spring Boot): + *

+ * @Bean
+ * public FilterRegistrationBean<RequestContextFilter> requestContextFilter() {
+ *     FilterRegistrationBean<RequestContextFilter> registration = new FilterRegistrationBean<>(new RequestContextFilter());
+ *     registration.setOrder(Ordered.HIGHEST_PRECEDENCE);
+ *     return registration;
+ * }
+ * 
+ */ +public class RequestContextFilter implements Filter { + + public static final String MDC_TRACE_ID = "traceId"; + public static final String MDC_REQUEST_ID = "requestId"; + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + + if (!(request instanceof HttpServletRequest httpRequest)) { + chain.doFilter(request, response); + return; + } + + if (!(response instanceof HttpServletResponse httpResponse)) { + chain.doFilter(request, response); + return; + } + + try { + String traceId = resolveHeader(httpRequest, HttpConstants.HEADER_X_TRACE_ID); + String requestId = resolveHeader(httpRequest, HttpConstants.HEADER_X_REQUEST_ID); + + // Store in ThreadLocal context + RequestContext.set(RequestContext.KEY_TRACE_ID, traceId); + RequestContext.set(RequestContext.KEY_REQUEST_ID, requestId); + + // Store in MDC for log correlation + MDC.put(MDC_TRACE_ID, traceId); + MDC.put(MDC_REQUEST_ID, requestId); + + // Set response headers so client can reference them + httpResponse.setHeader(HttpConstants.HEADER_X_TRACE_ID, traceId); + httpResponse.setHeader(HttpConstants.HEADER_X_REQUEST_ID, requestId); + + chain.doFilter(request, response); + } finally { + RequestContext.clear(); + MDC.remove(MDC_TRACE_ID); + MDC.remove(MDC_REQUEST_ID); + } + } + + /** + * Extracts header value from request, or generates a new UUID v7 if absent/blank. + */ + private String resolveHeader(HttpServletRequest request, String headerName) { + return Optional.ofNullable(request.getHeader(headerName)) + .filter(value -> !value.isBlank()) + .map(String::trim) + .orElseGet(() -> UUIDUtils.v7().toString()); + } +} diff --git a/kit-java-web/src/test/java/dev/suprim/kit/web/context/ContextPropagationTest.java b/kit-java-web/src/test/java/dev/suprim/kit/web/context/ContextPropagationTest.java new file mode 100644 index 0000000..ed80f61 --- /dev/null +++ b/kit-java-web/src/test/java/dev/suprim/kit/web/context/ContextPropagationTest.java @@ -0,0 +1,133 @@ +package dev.suprim.kit.web.context; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.slf4j.MDC; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.*; + +class ContextPropagationTest { + + @AfterEach + void tearDown() { + RequestContext.clear(); + MDC.clear(); + } + + @Test + void wrapRunnable_shouldRejectNull() { + assertThrows(NullPointerException.class, () -> ContextPropagation.wrap((Runnable) null)); + } + + @Test + void wrapCallable_shouldRejectNull() { + assertThrows(NullPointerException.class, () -> ContextPropagation.wrap((java.util.concurrent.Callable) null)); + } + + @Test + void wrapRunnable_shouldPropagateContext() throws Exception { + RequestContext.set(RequestContext.KEY_TRACE_ID, "trace-abc"); + MDC.put("traceId", "trace-abc"); + + AtomicReference capturedTraceId = new AtomicReference<>(); + AtomicReference capturedMdc = new AtomicReference<>(); + + Runnable wrapped = ContextPropagation.wrap(() -> { + capturedTraceId.set(RequestContext.getTraceId().orElse("missing")); + capturedMdc.set(MDC.get("traceId")); + }); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future future = executor.submit(wrapped); + future.get(); + executor.shutdown(); + + assertEquals("trace-abc", capturedTraceId.get()); + assertEquals("trace-abc", capturedMdc.get()); + } + + @Test + void wrapCallable_shouldPropagateContext() throws Exception { + RequestContext.set(RequestContext.KEY_TRACE_ID, "trace-xyz"); + MDC.put("traceId", "trace-xyz"); + + java.util.concurrent.Callable wrapped = ContextPropagation.wrap(() -> + RequestContext.getTraceId().orElse("missing") + ); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future future = executor.submit(wrapped); + String result = future.get(); + executor.shutdown(); + + assertEquals("trace-xyz", result); + } + + @Test + void wrapRunnable_shouldCleanupAfterExecution() throws Exception { + RequestContext.set(RequestContext.KEY_TRACE_ID, "trace-cleanup"); + MDC.put("traceId", "trace-cleanup"); + + CountDownLatch taskDone = new CountDownLatch(1); + AtomicReference afterTraceId = new AtomicReference<>(); + AtomicReference afterMdc = new AtomicReference<>(); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + + // Run wrapped task first + Future future = executor.submit(ContextPropagation.wrap(() -> { + // context available here + })); + future.get(); + + // Run plain task on same thread to verify cleanup + Future verifyFuture = executor.submit(() -> { + afterTraceId.set(RequestContext.getTraceId().orElse("empty")); + afterMdc.set(MDC.get("traceId")); + taskDone.countDown(); + }); + verifyFuture.get(); + executor.shutdown(); + + taskDone.await(); + assertEquals("empty", afterTraceId.get()); + assertNull(afterMdc.get()); + } + + @Test + void wrapRunnable_shouldHandleNullMdcSnapshot() throws Exception { + // MDC is empty (getCopyOfContextMap returns null for some implementations) + MDC.clear(); + RequestContext.set(RequestContext.KEY_TRACE_ID, "trace-no-mdc"); + + AtomicReference capturedTraceId = new AtomicReference<>(); + + Runnable wrapped = ContextPropagation.wrap(() -> + capturedTraceId.set(RequestContext.getTraceId().orElse("missing")) + ); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future future = executor.submit(wrapped); + future.get(); + executor.shutdown(); + + assertEquals("trace-no-mdc", capturedTraceId.get()); + } + + @Test + void constructor_shouldThrowUnsupportedOperationException() throws Exception { + Constructor constructor = ContextPropagation.class.getDeclaredConstructor(); + constructor.setAccessible(true); + + InvocationTargetException exception = assertThrows(InvocationTargetException.class, constructor::newInstance); + assertInstanceOf(UnsupportedOperationException.class, exception.getCause()); + } +} diff --git a/kit-java-web/src/test/java/dev/suprim/kit/web/context/RequestContextFilterTest.java b/kit-java-web/src/test/java/dev/suprim/kit/web/context/RequestContextFilterTest.java new file mode 100644 index 0000000..fd5c0e4 --- /dev/null +++ b/kit-java-web/src/test/java/dev/suprim/kit/web/context/RequestContextFilterTest.java @@ -0,0 +1,166 @@ +package dev.suprim.kit.web.context; + +import dev.suprim.kit.web.HttpConstants; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.slf4j.MDC; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +class RequestContextFilterTest { + + private RequestContextFilter filter; + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain chain; + + private AutoCloseable mocks; + + @BeforeEach + void setUp() { + mocks = MockitoAnnotations.openMocks(this); + filter = new RequestContextFilter(); + } + + @AfterEach + void tearDown() throws Exception { + RequestContext.clear(); + MDC.clear(); + mocks.close(); + } + + @Test + void shouldExtractTraceIdFromHeader() throws IOException, ServletException { + when(request.getHeader(HttpConstants.HEADER_X_TRACE_ID)).thenReturn("incoming-trace-id"); + when(request.getHeader(HttpConstants.HEADER_X_REQUEST_ID)).thenReturn("incoming-request-id"); + + doAnswer(invocation -> { + assertEquals("incoming-trace-id", RequestContext.getTraceId().orElseThrow()); + assertEquals("incoming-request-id", RequestContext.getRequestId().orElseThrow()); + assertEquals("incoming-trace-id", MDC.get(RequestContextFilter.MDC_TRACE_ID)); + assertEquals("incoming-request-id", MDC.get(RequestContextFilter.MDC_REQUEST_ID)); + return null; + }).when(chain).doFilter(request, response); + + filter.doFilter(request, response, chain); + + verify(response).setHeader(HttpConstants.HEADER_X_TRACE_ID, "incoming-trace-id"); + verify(response).setHeader(HttpConstants.HEADER_X_REQUEST_ID, "incoming-request-id"); + } + + @Test + void shouldGenerateIdsWhenHeadersAbsent() throws IOException, ServletException { + when(request.getHeader(HttpConstants.HEADER_X_TRACE_ID)).thenReturn(null); + when(request.getHeader(HttpConstants.HEADER_X_REQUEST_ID)).thenReturn(null); + + doAnswer(invocation -> { + assertTrue(RequestContext.getTraceId().isPresent()); + assertTrue(RequestContext.getRequestId().isPresent()); + assertNotEquals( + RequestContext.getTraceId().orElseThrow(), + RequestContext.getRequestId().orElseThrow() + ); + return null; + }).when(chain).doFilter(request, response); + + filter.doFilter(request, response, chain); + + verify(response).setHeader(eq(HttpConstants.HEADER_X_TRACE_ID), anyString()); + verify(response).setHeader(eq(HttpConstants.HEADER_X_REQUEST_ID), anyString()); + } + + @Test + void shouldGenerateIdsWhenHeadersBlank() throws IOException, ServletException { + when(request.getHeader(HttpConstants.HEADER_X_TRACE_ID)).thenReturn(" "); + when(request.getHeader(HttpConstants.HEADER_X_REQUEST_ID)).thenReturn(""); + + doAnswer(invocation -> { + assertTrue(RequestContext.getTraceId().isPresent()); + assertTrue(RequestContext.getRequestId().isPresent()); + return null; + }).when(chain).doFilter(request, response); + + filter.doFilter(request, response, chain); + } + + @Test + void shouldTrimHeaderValues() throws IOException, ServletException { + when(request.getHeader(HttpConstants.HEADER_X_TRACE_ID)).thenReturn(" trace-with-spaces "); + when(request.getHeader(HttpConstants.HEADER_X_REQUEST_ID)).thenReturn("req-id"); + + doAnswer(invocation -> { + assertEquals("trace-with-spaces", RequestContext.getTraceId().orElseThrow()); + return null; + }).when(chain).doFilter(request, response); + + filter.doFilter(request, response, chain); + + verify(response).setHeader(HttpConstants.HEADER_X_TRACE_ID, "trace-with-spaces"); + } + + @Test + void shouldCleanupContextAfterRequest() throws IOException, ServletException { + when(request.getHeader(HttpConstants.HEADER_X_TRACE_ID)).thenReturn("trace-id"); + when(request.getHeader(HttpConstants.HEADER_X_REQUEST_ID)).thenReturn("request-id"); + + filter.doFilter(request, response, chain); + + assertTrue(RequestContext.getTraceId().isEmpty()); + assertTrue(RequestContext.getRequestId().isEmpty()); + assertNull(MDC.get(RequestContextFilter.MDC_TRACE_ID)); + assertNull(MDC.get(RequestContextFilter.MDC_REQUEST_ID)); + } + + @Test + void shouldCleanupContextEvenOnException() throws IOException, ServletException { + when(request.getHeader(HttpConstants.HEADER_X_TRACE_ID)).thenReturn("trace-id"); + when(request.getHeader(HttpConstants.HEADER_X_REQUEST_ID)).thenReturn("request-id"); + doThrow(new ServletException("boom")).when(chain).doFilter(request, response); + + assertThrows(ServletException.class, () -> filter.doFilter(request, response, chain)); + + assertTrue(RequestContext.getTraceId().isEmpty()); + assertNull(MDC.get(RequestContextFilter.MDC_TRACE_ID)); + } + + @Test + void shouldPassThroughNonHttpRequest() throws IOException, ServletException { + ServletRequest nonHttpRequest = mock(ServletRequest.class); + ServletResponse nonHttpResponse = mock(ServletResponse.class); + + filter.doFilter(nonHttpRequest, nonHttpResponse, chain); + + verify(chain).doFilter(nonHttpRequest, nonHttpResponse); + assertTrue(RequestContext.getTraceId().isEmpty()); + assertNull(MDC.get(RequestContextFilter.MDC_TRACE_ID)); + } + + @Test + void shouldPassThroughWhenResponseNotHttp() throws IOException, ServletException { + ServletResponse nonHttpResponse = mock(ServletResponse.class); + + filter.doFilter(request, nonHttpResponse, chain); + + verify(chain).doFilter(request, nonHttpResponse); + assertTrue(RequestContext.getTraceId().isEmpty()); + assertNull(MDC.get(RequestContextFilter.MDC_TRACE_ID)); + } +} diff --git a/kit-java-web/src/test/java/dev/suprim/kit/web/context/RequestContextTest.java b/kit-java-web/src/test/java/dev/suprim/kit/web/context/RequestContextTest.java new file mode 100644 index 0000000..bbc0b6f --- /dev/null +++ b/kit-java-web/src/test/java/dev/suprim/kit/web/context/RequestContextTest.java @@ -0,0 +1,131 @@ +package dev.suprim.kit.web.context; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class RequestContextTest { + + @AfterEach + void tearDown() { + RequestContext.clear(); + } + + @Test + void set_shouldStoreValue() { + RequestContext.set("key", "value"); + + assertEquals("value", RequestContext.get("key").orElseThrow()); + } + + @Test + void set_shouldRejectNullKey() { + assertThrows(NullPointerException.class, () -> RequestContext.set(null, "value")); + } + + @Test + void set_shouldRejectNullValue() { + assertThrows(NullPointerException.class, () -> RequestContext.set("key", null)); + } + + @Test + void get_shouldReturnEmptyForMissingKey() { + assertTrue(RequestContext.get("nonexistent").isEmpty()); + } + + @Test + void get_shouldRejectNullKey() { + assertThrows(NullPointerException.class, () -> RequestContext.get(null)); + } + + @Test + void getTraceId_shouldReturnStoredTraceId() { + RequestContext.set(RequestContext.KEY_TRACE_ID, "trace-123"); + + assertEquals("trace-123", RequestContext.getTraceId().orElseThrow()); + } + + @Test + void getRequestId_shouldReturnStoredRequestId() { + RequestContext.set(RequestContext.KEY_REQUEST_ID, "req-456"); + + assertEquals("req-456", RequestContext.getRequestId().orElseThrow()); + } + + @Test + void snapshot_shouldReturnUnmodifiableCopy() { + RequestContext.set("key1", "val1"); + RequestContext.set("key2", "val2"); + + Map snapshot = RequestContext.snapshot(); + + assertEquals(2, snapshot.size()); + assertEquals("val1", snapshot.get("key1")); + assertThrows(UnsupportedOperationException.class, () -> snapshot.put("key3", "val3")); + } + + @Test + void snapshot_shouldBeIndependentFromOriginal() { + RequestContext.set("key1", "val1"); + Map snapshot = RequestContext.snapshot(); + + RequestContext.set("key2", "val2"); + + assertFalse(snapshot.containsKey("key2")); + } + + @Test + void restore_shouldReplaceCurrentContext() { + RequestContext.set("old", "value"); + + RequestContext.restore(Map.of("new", "restored")); + + assertTrue(RequestContext.get("old").isEmpty()); + assertEquals("restored", RequestContext.get("new").orElseThrow()); + } + + @Test + void restore_shouldRejectNull() { + assertThrows(NullPointerException.class, () -> RequestContext.restore(null)); + } + + @Test + void clear_shouldRemoveAllValues() { + RequestContext.set("key1", "val1"); + RequestContext.set("key2", "val2"); + + RequestContext.clear(); + + assertTrue(RequestContext.get("key1").isEmpty()); + assertTrue(RequestContext.get("key2").isEmpty()); + } + + @Test + void context_shouldBeThreadIsolated() throws InterruptedException { + RequestContext.set("main", "mainValue"); + + Thread otherThread = new Thread(() -> { + assertTrue(RequestContext.get("main").isEmpty()); + RequestContext.set("other", "otherValue"); + }); + otherThread.start(); + otherThread.join(); + + assertTrue(RequestContext.get("other").isEmpty()); + assertEquals("mainValue", RequestContext.get("main").orElseThrow()); + } + + @Test + void constructor_shouldThrowUnsupportedOperationException() throws Exception { + Constructor constructor = RequestContext.class.getDeclaredConstructor(); + constructor.setAccessible(true); + + InvocationTargetException exception = assertThrows(InvocationTargetException.class, constructor::newInstance); + assertInstanceOf(UnsupportedOperationException.class, exception.getCause()); + } +} diff --git a/pom.xml b/pom.xml index f2f9c72..79ccadc 100644 --- a/pom.xml +++ b/pom.xml @@ -30,6 +30,7 @@ UTF-8 + 2.0.16 5.1.0 2.18.2 42.7.4 @@ -122,6 +123,13 @@ ${postgresql.version}
+ + + org.slf4j + slf4j-api + ${slf4j.version} + + jakarta.servlet