diff --git a/src/main/java/org/prebid/cache/handlers/cache/GetCacheHandler.java b/src/main/java/org/prebid/cache/handlers/cache/GetCacheHandler.java index 37371b7b..b4b7c795 100644 --- a/src/main/java/org/prebid/cache/handlers/cache/GetCacheHandler.java +++ b/src/main/java/org/prebid/cache/handlers/cache/GetCacheHandler.java @@ -3,6 +3,7 @@ import com.github.benmanes.caffeine.cache.Caffeine; import io.github.resilience4j.circuitbreaker.CircuitBreaker; import io.github.resilience4j.reactor.circuitbreaker.operator.CircuitBreakerOperator; +import io.netty.channel.ChannelOption; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -23,14 +24,15 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.stereotype.Component; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; import reactor.core.publisher.Mono; -import reactor.core.publisher.SynchronousSink; import reactor.core.scheduler.Schedulers; +import reactor.netty.http.client.HttpClient; import java.time.Duration; import java.util.Map; @@ -118,39 +120,51 @@ private Mono processProxyRequest(final ServerRequest request, final String idKeyParam, final String cacheUrl) { - final WebClient webClient = clientsCache.computeIfAbsent(cacheUrl, WebClient::create); + final WebClient webClient = clientsCache.computeIfAbsent(cacheUrl, this::createWebClient); return webClient.get() .uri(uriBuilder -> uriBuilder.queryParam(ID_KEY, idKeyParam).build()) .headers(httpHeaders -> httpHeaders.addAll(request.headers().asHttpHeaders())) - .exchange() + .exchangeToMono(clientResponse -> { + updateProxyMetrics(clientResponse); + return fromClientResponse(clientResponse); + }) .transform(CircuitBreakerOperator.of(circuitBreaker)) .timeout(Duration.ofMillis(config.getTimeoutMs())) .subscribeOn(Schedulers.parallel()) - .handle(this::updateProxyMetrics) - .flatMap(GetCacheHandler::fromClientResponse) .doOnError(error -> { metricsRecorder.getProxyFailure().increment(); - log.info("Failed to send request: '{}', cause: '{}'", + log.error("Failed to send request: '{}', cause: '{}'", ExceptionUtils.getMessage(error), ExceptionUtils.getMessage(error)); }); } - private void updateProxyMetrics(final ClientResponse clientResponse, - final SynchronousSink sink) { + private WebClient createWebClient(String cacheUrl) { + HttpClient httpClient = HttpClient.create() + .responseTimeout(Duration.ofMillis(config.getTimeoutMs())) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, config.getTimeoutMs()); + + return WebClient.builder() + .baseUrl(cacheUrl) + .clientConnector(new ReactorClientHttpConnector(httpClient)) + .build(); + } + + private void updateProxyMetrics(final ClientResponse clientResponse) { if (HttpStatus.OK.equals(clientResponse.statusCode())) { metricsRecorder.getProxySuccess().increment(); } else { metricsRecorder.getProxyFailure().increment(); } - - sink.next(clientResponse); } private static Mono fromClientResponse(final ClientResponse clientResponse) { - return ServerResponse.status(clientResponse.statusCode()) - .headers(headerConsumer -> clientResponse.headers().asHttpHeaders().forEach(headerConsumer::addAll)) - .body(clientResponse.bodyToMono(String.class), String.class); + // This is a workaround to handle the race condition when the response body is consumed + // https://github.com/spring-projects/spring-boot/issues/15320 + return clientResponse.bodyToMono(String.class) + .flatMap(body -> ServerResponse.status(clientResponse.statusCode()) + .headers(headers -> clientResponse.headers().asHttpHeaders().forEach(headers::addAll)) + .body(Mono.just(body), String.class)); } private Mono processRequest(final ServerRequest request, final String keyIdParam) { diff --git a/src/main/java/org/prebid/cache/handlers/cache/PostCacheHandler.java b/src/main/java/org/prebid/cache/handlers/cache/PostCacheHandler.java index 83f09aaf..e027b143 100644 --- a/src/main/java/org/prebid/cache/handlers/cache/PostCacheHandler.java +++ b/src/main/java/org/prebid/cache/handlers/cache/PostCacheHandler.java @@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableMap; import io.github.resilience4j.circuitbreaker.CircuitBreaker; import io.github.resilience4j.reactor.circuitbreaker.operator.CircuitBreakerOperator; +import io.netty.channel.ChannelOption; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -30,6 +31,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.stereotype.Component; import org.springframework.web.reactive.function.BodyExtractors; import org.springframework.web.reactive.function.client.WebClient; @@ -37,11 +39,13 @@ import org.springframework.web.reactive.function.server.ServerResponse; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.ParallelFlux; import reactor.core.publisher.SynchronousSink; import reactor.core.scheduler.Schedulers; +import reactor.netty.http.client.HttpClient; import java.io.IOException; -import java.util.ArrayList; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -80,7 +84,15 @@ public PostCacheHandler(final ReactiveRepository reposit this.repository = repository; this.config = config; if (config.getSecondaryUris() != null) { - config.getSecondaryUris().forEach(ip -> webClients.put(ip, WebClient.create(ip))); + config.getSecondaryUris().forEach(url -> { + HttpClient httpClient = HttpClient.create() + .responseTimeout(Duration.ofMillis(config.getSecondaryCacheTimeoutMs())) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, config.getSecondaryCacheTimeoutMs()); + webClients.put(url, WebClient.builder() + .baseUrl(url) + .clientConnector(new ReactorClientHttpConnector(httpClient)) + .build()); + }); } this.builder = builder; this.metricTagPrefix = "write"; @@ -202,33 +214,46 @@ private long adjustExpiry(Long expiry) { } private void sendRequestToSecondaryPrebidCacheHosts(List payloadWrappers, String secondaryCache) { - if (!"yes".equals(secondaryCache) && webClients.size() != 0) { - final List payloadTransfers = new ArrayList<>(); - for (PayloadWrapper payloadWrapper : payloadWrappers) { - payloadTransfers.add(wrapperToTransfer(payloadWrapper)); - } + if (!"yes".equals(secondaryCache) && !webClients.isEmpty()) { + Flux.fromIterable(payloadWrappers) + .map(this::wrapperToTransfer) + .collectList() + .flatMapMany(this::createSecondaryCacheRequests) + .subscribe(); + } + } + + private ParallelFlux createSecondaryCacheRequests(List payloadTransfers) { + return Flux.fromIterable(webClients.entrySet()) + .parallel() + .runOn(Schedulers.parallel()) + .flatMap(entry -> sendRequestToSecondaryCache(entry.getValue(), entry.getKey(), payloadTransfers)); + } - webClients.forEach((ip, webClient) -> webClient.post() - .uri(uriBuilder -> uriBuilder.path(config.getSecondaryCachePath()) - .queryParam("secondaryCache", "yes").build()) - .contentType(MediaType.APPLICATION_JSON) - .headers(enrichWithSecurityHeader()) - .bodyValue(RequestObject.of(payloadTransfers)) - .exchange() - .transform(CircuitBreakerOperator.of(circuitBreaker)) - .doOnError(throwable -> { + private Mono sendRequestToSecondaryCache(WebClient webClient, + String url, + List payloadTransfers) { + return webClient.post() + .uri(uriBuilder -> uriBuilder.path(config.getSecondaryCachePath()) + .queryParam("secondaryCache", "yes").build()) + .contentType(MediaType.APPLICATION_JSON) + .headers(enrichWithSecurityHeader()) + .bodyValue(RequestObject.of(payloadTransfers)) + .exchangeToMono(clientResponse -> { + if (clientResponse.statusCode() != HttpStatus.OK) { metricsRecorder.getSecondaryCacheWriteError().increment(); - log.info("Failed to send request: '{}', cause: '{}'", - ExceptionUtils.getMessage(throwable), ExceptionUtils.getMessage(throwable)); - }) - .subscribe(clientResponse -> { - if (clientResponse.statusCode() != HttpStatus.OK) { - metricsRecorder.getSecondaryCacheWriteError().increment(); - log.debug(clientResponse.statusCode().toString()); - log.info("Failed to write to remote address : {}", ip); - } - })); - } + log.debug(clientResponse.statusCode().toString()); + log.error("Failed to write to remote address: {}", url); + } + return clientResponse.releaseBody(); + }) + .transform(CircuitBreakerOperator.of(circuitBreaker)) + .doOnError(throwable -> { + metricsRecorder.getSecondaryCacheWriteError().increment(); + log.error("Failed to send request: '{}', cause: '{}'", + ExceptionUtils.getMessage(throwable), ExceptionUtils.getMessage(throwable)); + }) + .then(); } private Consumer enrichWithSecurityHeader() { diff --git a/src/main/java/org/prebid/cache/repository/CacheConfig.java b/src/main/java/org/prebid/cache/repository/CacheConfig.java index 7282c3e7..5bd16f2d 100644 --- a/src/main/java/org/prebid/cache/repository/CacheConfig.java +++ b/src/main/java/org/prebid/cache/repository/CacheConfig.java @@ -22,9 +22,9 @@ public class CacheConfig { private boolean allowExternalUUID; private List secondaryUris; private String secondaryCachePath; + private int secondaryCacheTimeoutMs; private int clientsCacheDuration; private int clientsCacheSize; private String allowedProxyHost; private String hostParamProtocol; } - diff --git a/src/test/java/org/prebid/cache/handlers/PostCacheHandlerTests.java b/src/test/java/org/prebid/cache/handlers/PostCacheHandlerTests.java index b32f292c..0503cb7e 100644 --- a/src/test/java/org/prebid/cache/handlers/PostCacheHandlerTests.java +++ b/src/test/java/org/prebid/cache/handlers/PostCacheHandlerTests.java @@ -51,13 +51,13 @@ @ExtendWith(SpringExtension.class) @ContextConfiguration(classes = { - PostCacheHandler.class, - PrebidServerResponseBuilder.class, - CacheConfig.class, - MetricsRecorderTest.class, - MetricsRecorder.class, - ApiConfig.class, - CircuitBreakerPropertyConfiguration.class + PostCacheHandler.class, + PrebidServerResponseBuilder.class, + CacheConfig.class, + MetricsRecorderTest.class, + MetricsRecorder.class, + ApiConfig.class, + CircuitBreakerPropertyConfiguration.class }) @EnableConfigurationProperties @SpringBootTest @@ -176,10 +176,19 @@ void testSecondaryCacheSuccess() { @Test void testExternalUUIDInvalid() { //given - final var cacheConfigLocal = new CacheConfig(cacheConfig.getPrefix(), cacheConfig.getExpirySec(), + final var cacheConfigLocal = new CacheConfig(cacheConfig.getPrefix(), + cacheConfig.getExpirySec(), cacheConfig.getTimeoutMs(), - cacheConfig.getMinExpiry(), cacheConfig.getMaxExpiry(), - false, Collections.emptyList(), cacheConfig.getSecondaryCachePath(), 100, 100, "example.com", "http"); + cacheConfig.getMinExpiry(), + cacheConfig.getMaxExpiry(), + false, + Collections.emptyList(), + cacheConfig.getSecondaryCachePath(), + 100, + 100, + 100, + "example.com", + "http"); final var handler = new PostCacheHandler(repository, cacheConfigLocal, metricsRecorder, builder, webClientCircuitBreaker, samplingRate, apiConfig); @@ -207,10 +216,19 @@ void testUUIDDuplication() { .willReturn(Mono.just(PAYLOAD_WRAPPER)) .willReturn(Mono.error(new DuplicateKeyException(""))); - final CacheConfig cacheConfigLocal = new CacheConfig(cacheConfig.getPrefix(), cacheConfig.getExpirySec(), + final CacheConfig cacheConfigLocal = new CacheConfig(cacheConfig.getPrefix(), + cacheConfig.getExpirySec(), cacheConfig.getTimeoutMs(), - 5, cacheConfig.getMaxExpiry(), cacheConfig.isAllowExternalUUID(), - Collections.emptyList(), cacheConfig.getSecondaryCachePath(), 100, 100, "example.com", "http"); + 5, + cacheConfig.getMaxExpiry(), + cacheConfig.isAllowExternalUUID(), + Collections.emptyList(), + cacheConfig.getSecondaryCachePath(), + 100, + 100, + 100, + "example.com", + "http"); final PostCacheHandler handler = new PostCacheHandler(repository, cacheConfigLocal, metricsRecorder, builder, webClientCircuitBreaker, samplingRate, apiConfig); diff --git a/src/test/kotlin/org/prebid/cache/functional/testcontainers/client/WebCacheContainerClient.kt b/src/test/kotlin/org/prebid/cache/functional/testcontainers/client/WebCacheContainerClient.kt index e33e1a6d..fba295e9 100644 --- a/src/test/kotlin/org/prebid/cache/functional/testcontainers/client/WebCacheContainerClient.kt +++ b/src/test/kotlin/org/prebid/cache/functional/testcontainers/client/WebCacheContainerClient.kt @@ -40,8 +40,11 @@ class WebCacheContainerClient(mockServerHost: String, mockServerPort: Int) { .withBody(body, mediaType) ) - fun getSecondaryCacheRecordedRequests(uuidKey: String): Array? = - mockServerClient.retrieveRecordedRequests(getSecondaryCacheRequest(uuidKey)) + fun getSecondaryCacheRecordedRequests(uuidKey: String): Array? { + val secondaryCacheRequest = getSecondaryCacheRequest(uuidKey) + waitUntil({ mockServerClient.retrieveRecordedRequests(secondaryCacheRequest)!!.isNotEmpty() }) + return mockServerClient.retrieveRecordedRequests(secondaryCacheRequest) + } fun initSecondaryCacheResponse(): Array? = mockServerClient.`when`(getSecondaryCacheRequest()) @@ -59,4 +62,15 @@ class WebCacheContainerClient(mockServerHost: String, mockServerPort: Int) { request().withMethod(POST.name()) .withPath("/$WEB_CACHE_PATH") .withBody(jsonPath("\$.puts[?(@.key == '$uuidKey')]")) + + private fun waitUntil(closure: () -> Boolean, timeoutMs: Long = 5000, pollInterval: Long = 100) { + val startTime = System.currentTimeMillis() + while (System.currentTimeMillis() - startTime <= timeoutMs) { + if (closure()) { + return + } + Thread.sleep(pollInterval) + } + throw IllegalStateException("Condition was not fulfilled within $timeoutMs ms.") + } }