diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-f9f830e.json b/.changes/next-release/bugfix-AWSSDKforJavav2-f9f830e.json new file mode 100644 index 000000000000..eab06274e421 --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-f9f830e.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "ApplyUserAgentStage will not overwrite the custom User-Agent" +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStage.java index 583634c7288c..692973f528ec 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStage.java @@ -25,6 +25,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.ApiName; @@ -66,10 +67,38 @@ public ApplyUserAgentStage(HttpClientDependencies dependencies) { @Override public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) throws Exception { + + if (hasUserAgentInAdditionalHeaders() || hasUserAgentInRequestConfig(context)) { + return request; + } String headerValue = finalizeUserAgent(context); return request.putHeader(HEADER_USER_AGENT, headerValue); } + /** + * Checks if User-Agent header is present in ADDITIONAL_HTTP_HEADERS configuration. + * We skip adding user-agent in the ApplyUserAgentStage if user has set "User-Agent" header in additional header of client + */ + private boolean hasUserAgentInAdditionalHeaders() { + Map> additionalHeaders = clientConfig.option(SdkClientOption.ADDITIONAL_HTTP_HEADERS); + if (additionalHeaders == null) { + return false; + } + return additionalHeaders.containsKey(HEADER_USER_AGENT); + } + + /** + * Checks if User-Agent header is present in request override configs. + * We skip adding user-agent in the ApplyUserAgentStage if user has set "User-Agent" header at request level + */ + private boolean hasUserAgentInRequestConfig(RequestExecutionContext context) { + Map> requestHeaders = context.requestConfig().headers(); + if (requestHeaders == null) { + return false; + } + return requestHeaders.containsKey(HEADER_USER_AGENT); + } + /** * The final value sent in the user agent header consists of *
    diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStageTest.java index 4db0103b7e3c..cb7b1328322d 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStageTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApplyUserAgentStageTest.java @@ -25,7 +25,9 @@ import static software.amazon.awssdk.core.internal.useragent.UserAgentConstant.SPACE; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import org.junit.Test; import org.junit.runner.RunWith; @@ -152,6 +154,94 @@ public void when_identityContainsProvider_authSourceIsPresent() throws Exception assertThat(userAgentHeaders.get(0)).contains("m/w"); } + @Test + public void when_userAgentHeaderAlreadyPresent_AndSdkOptionAdditionalHeaderNotPresent_doesNotOverwrite() throws Exception { + ApplyUserAgentStage stage = new ApplyUserAgentStage(dependencies(clientUserAgent())); + + String existingUserAgent = "CustomUserAgent/1.0"; + SdkHttpFullRequest.Builder requestWithExistingHeader = SdkHttpFullRequest.builder() + .putHeader(HEADER_USER_AGENT, existingUserAgent); + + RequestExecutionContext ctx = requestExecutionContext(executionAttributes(IDENTITY_WITHOUT_SOURCE), noOpRequest()); + SdkHttpFullRequest.Builder result = stage.execute(requestWithExistingHeader, ctx); + + List userAgentHeaders = result.headers().get(HEADER_USER_AGENT); + assertThat(userAgentHeaders).isNotNull().hasSize(1); + assertThat(userAgentHeaders.get(0)).startsWith("aws-sdk-java"); + } + + @Test + public void when_userAgentHeaderPresentButEmpty_sdkAddsUserAgent() throws Exception { + ApplyUserAgentStage stage = new ApplyUserAgentStage(dependencies(clientUserAgent())); + + SdkHttpFullRequest.Builder requestWithEmptyHeader = SdkHttpFullRequest.builder() + .putHeader(HEADER_USER_AGENT, ""); + + RequestExecutionContext ctx = requestExecutionContext(executionAttributes(IDENTITY_WITHOUT_SOURCE), noOpRequest()); + SdkHttpFullRequest.Builder result = stage.execute(requestWithEmptyHeader, ctx); + + List userAgentHeaders = result.headers().get(HEADER_USER_AGENT); + assertThat(userAgentHeaders).isNotNull().hasSize(1); + assertThat(userAgentHeaders.get(0)).startsWith("aws-sdk-java"); + } + + @Test + public void when_userAgentHeaderPresentButNull_sdkAddsHeader() throws Exception { + ApplyUserAgentStage stage = new ApplyUserAgentStage(dependencies(clientUserAgent())); + String headerValue = null; + SdkHttpFullRequest.Builder requestWithNullHeader = SdkHttpFullRequest.builder() + .putHeader(HEADER_USER_AGENT, headerValue); + + RequestExecutionContext ctx = requestExecutionContext(executionAttributes(IDENTITY_WITHOUT_SOURCE), noOpRequest()); + SdkHttpFullRequest.Builder result = stage.execute(requestWithNullHeader, ctx); + + List userAgentHeaders = result.headers().get(HEADER_USER_AGENT); + assertThat(userAgentHeaders).isNotNull().hasSize(1); + assertThat(userAgentHeaders.get(0)).startsWith("aws-sdk-java"); + } + + @Test + public void when_userAgentHeaderAbsent_sdkAddsHeader() throws Exception { + ApplyUserAgentStage stage = new ApplyUserAgentStage(dependencies(clientUserAgent())); + + SdkHttpFullRequest.Builder requestWithoutHeader = SdkHttpFullRequest.builder(); + + RequestExecutionContext ctx = requestExecutionContext(executionAttributes(IDENTITY_WITHOUT_SOURCE), noOpRequest()); + SdkHttpFullRequest.Builder result = stage.execute(requestWithoutHeader, ctx); + + List userAgentHeaders = result.headers().get(HEADER_USER_AGENT); + assertThat(userAgentHeaders).isNotNull().hasSize(1); + assertThat(userAgentHeaders.get(0)).startsWith("aws-sdk-java"); + } + + @Test + public void when_userAgentInAdditionalHeaders_doesNotOverwriteUserAgent() throws Exception { + Map> headerMap = new LinkedHashMap<>(); + headerMap.put(HEADER_USER_AGENT, Arrays.asList("CustomAgent/1.0", "AnotherAgent/2.0")); + + SdkClientConfiguration clientConfiguration = + SdkClientConfiguration.builder() + .option(SdkClientOption.CLIENT_USER_AGENT, clientUserAgent()) + .option(SdkClientOption.ADDITIONAL_HTTP_HEADERS, headerMap) + .build(); + HttpClientDependencies httpClientDependencies = HttpClientDependencies.builder() + .clientConfiguration(clientConfiguration) + .build(); + + ApplyUserAgentStage stage = new ApplyUserAgentStage(httpClientDependencies); + + SdkHttpFullRequest.Builder request = SdkHttpFullRequest.builder(); + + RequestExecutionContext ctx = requestExecutionContext(executionAttributes(IDENTITY_WITHOUT_SOURCE), noOpRequest()); + SdkHttpFullRequest.Builder result = stage.execute(request, ctx); + + // ApplyUserAgentStage should skip adding User-Agent since it's in ADDITIONAL_HTTP_HEADERS + // The actual merging happens in MergeCustomHeadersStage + List userAgentHeaders = result.headers().get(HEADER_USER_AGENT); + assertThat(userAgentHeaders).isNull(); + } + + private static HttpClientDependencies dependencies(String clientUserAgent) { return dependencies(clientUserAgent, null, null); } @@ -219,6 +309,5 @@ private RequestExecutionContext requestExecutionContext(ExecutionAttributes exec return RequestExecutionContext.builder() .executionContext(executionContext) .originalRequest(request).build(); - } } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/useragent/CustomUserAgentHeaderTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/useragent/CustomUserAgentHeaderTest.java new file mode 100644 index 000000000000..4711b9906ce3 --- /dev/null +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/useragent/CustomUserAgentHeaderTest.java @@ -0,0 +1,350 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.useragent; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.restjsonendpointproviders.RestJsonEndpointProvidersClient; +import software.amazon.awssdk.services.restjsonendpointproviders.RestJsonEndpointProvidersClientBuilder; + +/** + * Functional tests verifying custom User-Agent header preservation. + * + *

    Tests ensure that User-Agent headers provided via + * {@link software.amazon.awssdk.core.client.config.ClientOverrideConfiguration.Builder#putHeader(String, String)} are preserved + * and not overwritten by SDK's default User-Agent generation logic. + */ +class CustomUserAgentHeaderTest { + + private static final String USER_AGENT_HEADER = "User-Agent"; + private static final String SDK_USER_AGENT_PREFIX = "aws-sdk-java"; + private static final String TEST_API_NAME = "TestApiName"; + private static final String TEST_API_VERSION = "1.0"; + private static final String INTERCEPTOR_STOP_MESSAGE = "stop"; + + private CapturingInterceptor interceptor; + + private static Stream customUserAgentValues() { + return Stream.of( + Arguments.of("CustomUserAgentHeaderValue"), + Arguments.of("MyApplication/1.0.0"), + Arguments.of("CustomClient/2.0 (Linux; x86_64)") + ); + } + + private static Stream customUserAgentListValues() { + return Stream.of( + Arguments.of(Arrays.asList("Agent1")), + Arguments.of(Arrays.asList("Agent1", "Agent2")), + Arguments.of(Arrays.asList("CustomClient/1.0", "MyApp/2.0")) + ); + } + + @BeforeEach + void setUp() { + interceptor = new CapturingInterceptor(); + } + + // ========== Default Behavior Tests ========== + + @Test + void executeRequest_withoutCustomUserAgent_shouldAddSdkDefaultUserAgent() { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + executeRequestExpectingInterception(client); + + assertUserAgentContains(SDK_USER_AGENT_PREFIX); + } + + // ========== Custom User-Agent Preservation Tests ========== + + @ParameterizedTest(name = "Custom User-Agent ''{0}'' should be preserved without SDK prefix") + @MethodSource("customUserAgentValues") + void executeRequest_withCustomUserAgent_shouldPreserveAndNotOverwrite(String customUserAgent) { + RestJsonEndpointProvidersClient client = clientWithCustomUserAgent(customUserAgent); + executeRequestExpectingInterception(client); + + String userAgent = getCapturedUserAgent(); + assertThat(userAgent) + .isEqualTo(customUserAgent) + .doesNotContain(SDK_USER_AGENT_PREFIX); + } + + @ParameterizedTest(name = "Custom User-Agent list {0} should be preserved") + @MethodSource("customUserAgentListValues") + void executeRequest_withCustomUserAgentList_shouldPreserveAllValues(List customUserAgentList) { + RestJsonEndpointProvidersClient client = clientWithCustomUserAgentList(customUserAgentList); + executeRequestExpectingInterception(client); + + List userAgentList = getCapturedUserAgentList(); + assertThat(userAgentList).isEqualTo(customUserAgentList); + } + + // ========== Request-Level User-Agent Tests ========== + + @ParameterizedTest(name = "Request-level User-Agent ''{0}'' should be preserved") + @MethodSource("customUserAgentValues") + void executeRequest_withRequestLevelCustomUserAgent_shouldPreserveAndNotOverwrite(String customUserAgent) { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o.putHeader(USER_AGENT_HEADER, customUserAgent)))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + + String userAgent = getCapturedUserAgent(); + assertThat(userAgent) + .isEqualTo(customUserAgent) + .doesNotContain(SDK_USER_AGENT_PREFIX); + } + + @ParameterizedTest(name = "Request-level User-Agent list {0} should be preserved") + @MethodSource("customUserAgentListValues") + void executeRequest_withRequestLevelCustomUserAgentList_shouldPreserveAllValues(List customUserAgentList) { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o.putHeader(USER_AGENT_HEADER, customUserAgentList)))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + + List userAgentList = getCapturedUserAgentList(); + assertThat(userAgentList).isEqualTo(customUserAgentList); + } + + @Test + void executeRequest_withRequestLevelCustomUserAgentAndApiName_shouldNotAppendApiName() { + String customUserAgent = "CustomUserAgentHeaderValue"; + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o + .addApiName(api -> api.name(TEST_API_NAME).version(TEST_API_VERSION)) + .putHeader(USER_AGENT_HEADER, customUserAgent)))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + + String userAgent = getCapturedUserAgent(); + assertThat(userAgent) + .isEqualTo(customUserAgent) + .doesNotContain(TEST_API_NAME); + } + + @ParameterizedTest(name = "Request-level User-Agent list {0} with API name should not append API name") + @MethodSource("customUserAgentListValues") + void executeRequest_withRequestLevelCustomUserAgentListAndApiName_shouldNotAppendApiName(List customUserAgentList) { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o + .addApiName(api -> api.name(TEST_API_NAME).version(TEST_API_VERSION)) + .putHeader(USER_AGENT_HEADER, customUserAgentList)))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + + List userAgentList = getCapturedUserAgentList(); + assertThat(userAgentList).isEqualTo(customUserAgentList); + assertThat(String.join(" ", userAgentList)).doesNotContain(TEST_API_NAME); + } + + // ========== Header via Interceptors ========== + + @Test + void executeRequest_withInterceptorAddingUserAgent_shouldAddSdkDefaultUserAgent() { + RestJsonEndpointProvidersClient client = + defaultClientBuilder().overrideConfiguration(o -> o + .addExecutionInterceptor(interceptor) + .addExecutionInterceptor(new ExecutionInterceptor() { + @Override + public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, + ExecutionAttributes executionAttributes) { + return context.httpRequest().toBuilder() + .putHeader(USER_AGENT_HEADER, "custom-agent") + .build(); + } + })).build(); + + executeRequestExpectingInterception(client); + assertUserAgentContains(SDK_USER_AGENT_PREFIX); + } + + // ========== API Name Handling Tests ========== + + @Test + void executeRequest_withCustomUserAgentAndApiName_shouldNotAppendApiName() { + String customUserAgent = "CustomUserAgentHeaderValue"; + RestJsonEndpointProvidersClient client = clientWithCustomUserAgent(customUserAgent); + executeRequestWithApiName(client); + + String userAgent = getCapturedUserAgent(); + assertThat(userAgent) + .isEqualTo(customUserAgent) + .doesNotContain(TEST_API_NAME); + } + + @Test + void executeRequest_withoutCustomUserAgentAndWithApiName_shouldAppendApiName() { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + executeRequestWithApiName(client); + + assertUserAgentContains(TEST_API_NAME + "/" + TEST_API_VERSION); + } + + // ========== Edge Case Tests ========== + + @Test + void buildClient_withNullListUserAgent_shouldThrowNullPointerException() { + assertThatThrownBy(() -> clientWithCustomUserAgentList(null)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("values must not be null"); + } + + @Test + void executeRequest_withEmptyListUserAgent_shouldResultInSdkUserAgentHeader() { + RestJsonEndpointProvidersClient client = clientWithCustomUserAgentList(Collections.emptyList()); + executeRequestExpectingInterception(client); + + List userAgentList = getCapturedUserAgentList(); + assertThat(userAgentList).isNull(); + } + + @Test + void executeRequest_withEmptyCustomUserAgent_shouldStoreSdkUserAgent() { + RestJsonEndpointProvidersClient client = clientWithCustomUserAgent(""); + executeRequestExpectingInterception(client); + + assertUserAgentContains(""); + } + + @Test + void executeRequest_withNullStringUserAgent_shouldStoreAsSdkUserAgent() { + RestJsonEndpointProvidersClient client = clientWithCustomUserAgent(null); + executeRequestExpectingInterception(client); + + List userAgentList = getCapturedUserAgentList(); + assertThat(userAgentList) + .hasSize(1) + .allSatisfy(ua -> { + assertThat(ua).isNull(); + }); + } + + @Test + void executeRequest_withRequestLevelEmptyCustomUserAgent_shouldStoreEmptyUserAgent() { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o.putHeader(USER_AGENT_HEADER, "")))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + + assertUserAgentContains(""); + } + + @Test + void executeRequest_withRequestLevelEmptyListUserAgent_shouldResultInNoUserAgent() { + RestJsonEndpointProvidersClient client = defaultClientBuilder().build(); + + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o.putHeader(USER_AGENT_HEADER, Collections.emptyList())))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + + List userAgentList = getCapturedUserAgentList(); + assertThat(userAgentList).isNull(); + } + + // ========== Helper Methods ========== + + private void assertUserAgentContains(String expected) { + assertThat(getCapturedUserAgent()).contains(expected); + } + + private void executeRequestExpectingInterception(RestJsonEndpointProvidersClient client) { + assertThatThrownBy(() -> client.allTypes(r -> {})) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + } + + private void executeRequestWithApiName(RestJsonEndpointProvidersClient client) { + assertThatThrownBy(() -> client.allTypes(r -> r + .overrideConfiguration(o -> o.addApiName(api -> api + .name(TEST_API_NAME) + .version(TEST_API_VERSION))))) + .hasMessageContaining(INTERCEPTOR_STOP_MESSAGE); + } + + private String getCapturedUserAgent() { + Map> headers = interceptor.context.httpRequest().headers(); + assertThat(headers).containsKey(USER_AGENT_HEADER); + return headers.get(USER_AGENT_HEADER).get(0); + } + + private List getCapturedUserAgentList() { + Map> headers = interceptor.context.httpRequest().headers(); + return headers.get(USER_AGENT_HEADER); + } + + private RestJsonEndpointProvidersClientBuilder defaultClientBuilder() { + return RestJsonEndpointProvidersClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("akid", "skid"))) + .overrideConfiguration(c -> c.addExecutionInterceptor(interceptor)); + } + + private RestJsonEndpointProvidersClient clientWithCustomUserAgent(String customUserAgent) { + return RestJsonEndpointProvidersClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("akid", "skid"))) + .overrideConfiguration(c -> c + .addExecutionInterceptor(interceptor) + .putHeader(USER_AGENT_HEADER, customUserAgent)) + .build(); + } + + private RestJsonEndpointProvidersClient clientWithCustomUserAgentList(List customUserAgentList) { + return RestJsonEndpointProvidersClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("akid", "skid"))) + .overrideConfiguration(c -> c + .addExecutionInterceptor(interceptor) + .putHeader(USER_AGENT_HEADER, customUserAgentList)) + .build(); + } + + private static class CapturingInterceptor implements ExecutionInterceptor { + private Context.BeforeTransmission context; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.context = context; + throw new RuntimeException(INTERCEPTOR_STOP_MESSAGE); + } + } +}