Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions kit-java-grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
<scope>provided</scope>
</dependency>

<!-- SLF4J API (for MDC) -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>

<!-- gRPC -->
<dependency>
<groupId>io.grpc</groupId>
Expand Down Expand Up @@ -62,6 +68,17 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-subclass</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.5.12</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-testing</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package dev.suprim.kit.grpc;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;

import java.util.Optional;

/**
* gRPC client interceptor that forwards trace/request IDs from the current gRPC {@link io.grpc.Context}
* to outgoing call metadata.
* <p>
* This ensures trace context is propagated across service-to-service gRPC calls.
* <p>
* Registration example:
* <pre>
* ManagedChannel channel = ManagedChannelBuilder.forTarget("localhost:9090")
* .intercept(new ContextForwardingInterceptor())
* .build();
* </pre>
*/
public class ContextForwardingInterceptor implements ClientInterceptor {

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {

return new ForwardingClientCall.SimpleForwardingClientCall<>(next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
propagateContextToHeaders(headers);
super.start(responseListener, headers);
}
};
}

private void propagateContextToHeaders(Metadata headers) {
Optional.ofNullable(GrpcContext.getTraceId())
.ifPresent(value -> headers.put(MetadataUtils.TRACE_ID, value));

Optional.ofNullable(GrpcContext.getRequestId())
.ifPresent(value -> headers.put(MetadataUtils.REQUEST_ID, value));

Optional.ofNullable(GrpcContext.getUserId())
.ifPresent(value -> headers.put(MetadataUtils.USER_ID, value));

Optional.ofNullable(GrpcContext.getTenantId())
.ifPresent(value -> headers.put(MetadataUtils.TENANT_ID, value));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package dev.suprim.kit.grpc;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import org.slf4j.MDC;

import java.util.Objects;
import java.util.Optional;
import java.util.UUID;

/**
* gRPC server interceptor that extracts trace/request IDs from incoming metadata,
* attaches them to the gRPC {@link Context}, and sets SLF4J MDC for log correlation.
* <p>
* If trace ID or request ID is absent in metadata, a new UUID v4 is generated.
* <p>
* Registration example:
* <pre>
* Server server = ServerBuilder.forPort(9090)
* .addService(new MyServiceImpl())
* .intercept(new ContextPropagationInterceptor())
* .intercept(new ExceptionInterceptor())
* .build();
* </pre>
*/
public class ContextPropagationInterceptor implements ServerInterceptor {

public static final String MDC_TRACE_ID = "traceId";
public static final String MDC_REQUEST_ID = "requestId";

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {

Objects.requireNonNull(headers, "headers");

String traceId = resolveFromMetadata(headers, MetadataUtils.TRACE_ID);
String requestId = resolveFromMetadata(headers, MetadataUtils.REQUEST_ID);

Context context = Context.current()
.withValue(GrpcContext.TRACE_ID, traceId)
.withValue(GrpcContext.REQUEST_ID, requestId);

// Also extract user/tenant if present
context = attachIfPresent(context, headers, MetadataUtils.USER_ID, GrpcContext.USER_ID);
context = attachIfPresent(context, headers, MetadataUtils.TENANT_ID, GrpcContext.TENANT_ID);

// Set MDC for log correlation within this call
MDC.put(MDC_TRACE_ID, traceId);
MDC.put(MDC_REQUEST_ID, requestId);

return new MdcCleanupListener<>(Contexts.interceptCall(context, call, headers, next));
}

private String resolveFromMetadata(Metadata headers, Metadata.Key<String> key) {
return Optional.ofNullable(MetadataUtils.getString(headers, key))
.filter(value -> !value.isBlank())
.orElseGet(() -> UUID.randomUUID().toString());
}

private Context attachIfPresent(Context context, Metadata headers,
Metadata.Key<String> metadataKey, Context.Key<String> contextKey) {
return Optional.ofNullable(MetadataUtils.getString(headers, metadataKey))
.map(value -> context.withValue(contextKey, value))
.orElse(context);
}

/**
* Listener wrapper that cleans up MDC when the call completes or is cancelled.
*/
private static class MdcCleanupListener<ReqT>
extends io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT> {

MdcCleanupListener(ServerCall.Listener<ReqT> delegate) {
super(delegate);
}

@Override
public void onComplete() {
try {
super.onComplete();
} finally {
clearMdc();
}
}

@Override
public void onCancel() {
try {
super.onCancel();
} finally {
clearMdc();
}
}

private static void clearMdc() {
MDC.remove(MDC_TRACE_ID);
MDC.remove(MDC_REQUEST_ID);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package dev.suprim.kit.grpc;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;

import static org.junit.jupiter.api.Assertions.*;

class ContextForwardingInterceptorTest {

private ContextForwardingInterceptor interceptor;
private MethodDescriptor<String, String> methodDescriptor;
private CapturingChannel channel;

@BeforeEach
void setUp() {
interceptor = new ContextForwardingInterceptor();
methodDescriptor = MethodDescriptor.<String, String>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("test/method")
.setRequestMarshaller(new StringMarshaller())
.setResponseMarshaller(new StringMarshaller())
.build();
channel = new CapturingChannel();
}

@Test
void shouldForwardTraceIdFromContext() {
Context context = Context.current()
.withValue(GrpcContext.TRACE_ID, "forwarded-trace")
.withValue(GrpcContext.REQUEST_ID, "forwarded-req");

Context previous = context.attach();
try {
ClientCall<String, String> call = interceptor.interceptCall(
methodDescriptor, CallOptions.DEFAULT, channel);

Metadata headers = new Metadata();
call.start(new NoopListener<>(), headers);

assertEquals("forwarded-trace", headers.get(MetadataUtils.TRACE_ID));
assertEquals("forwarded-req", headers.get(MetadataUtils.REQUEST_ID));
} finally {
context.detach(previous);
}
}

@Test
void shouldForwardUserAndTenantFromContext() {
Context context = Context.current()
.withValue(GrpcContext.TRACE_ID, "trace")
.withValue(GrpcContext.REQUEST_ID, "req")
.withValue(GrpcContext.USER_ID, "user-123")
.withValue(GrpcContext.TENANT_ID, "tenant-456");

Context previous = context.attach();
try {
ClientCall<String, String> call = interceptor.interceptCall(
methodDescriptor, CallOptions.DEFAULT, channel);

Metadata headers = new Metadata();
call.start(new NoopListener<>(), headers);

assertEquals("user-123", headers.get(MetadataUtils.USER_ID));
assertEquals("tenant-456", headers.get(MetadataUtils.TENANT_ID));
} finally {
context.detach(previous);
}
}

@Test
void shouldNotSetHeadersWhenContextEmpty() {
ClientCall<String, String> call = interceptor.interceptCall(
methodDescriptor, CallOptions.DEFAULT, channel);

Metadata headers = new Metadata();
call.start(new NoopListener<>(), headers);

assertNull(headers.get(MetadataUtils.TRACE_ID));
assertNull(headers.get(MetadataUtils.REQUEST_ID));
assertNull(headers.get(MetadataUtils.USER_ID));
assertNull(headers.get(MetadataUtils.TENANT_ID));
}

private static class CapturingChannel extends Channel {
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return new NoopClientCall<>();
}

@Override
public String authority() {
return "test-authority";
}
}

private static class NoopClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
// no-op
}

@Override
public void request(int numMessages) {
// no-op
}

@Override
public void cancel(String message, Throwable cause) {
// no-op
}

@Override
public void halfClose() {
// no-op
}

@Override
public void sendMessage(ReqT message) {
// no-op
}
}

private static class NoopListener<T> extends ClientCall.Listener<T> {
// no-op
}

private static class StringMarshaller implements MethodDescriptor.Marshaller<String> {
@Override
public InputStream stream(String value) {
return new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8));
}

@Override
public String parse(InputStream stream) {
try {
return new String(stream.readAllBytes(), StandardCharsets.UTF_8);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
Loading
Loading