Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,9 @@ public interface Authenticator {
* @return a {@link CompletableFuture} that will complete with the access token
*/
CompletableFuture<String> asyncToken();

/**
* Returns the authentication scheme to be used in the Authorization header.
*/
String scheme();
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
* }</pre>
*/
public class CP4DAuthenticator implements Authenticator {
private static final String BEARER_SCHEME = "Bearer";
private static final String ZEN_API_SCHEME = "ZenApiKey";

private final URI baseUrl;
private final String username;
private final String password;
Expand Down Expand Up @@ -98,6 +101,11 @@ public CompletableFuture<String> asyncToken() {
});
}

@Override
public String scheme() {
return authMode == AuthMode.ZEN_API_KEY ? ZEN_API_SCHEME : BEARER_SCHEME;
}

/**
* Checks if the current authentication mode matches the provided {@link AuthMode}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
* }</pre>
*/
public class IBMCloudAuthenticator implements Authenticator {
private static final String SCHEME = "Bearer";

private final URI baseUrl;
private final String apiKey;
private final String grantType;
Expand Down Expand Up @@ -83,6 +85,11 @@ public CompletableFuture<String> asyncToken() {
});
}

@Override
public String scheme() {
return SCHEME;
}

/**
* Returns a new {@link Builder} instance.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import com.ibm.watsonx.ai.core.auth.Authenticator;
import com.ibm.watsonx.ai.core.http.AsyncHttpClient;
import com.ibm.watsonx.ai.core.http.SyncHttpClient;
import com.ibm.watsonx.ai.core.http.interceptors.BearerInterceptor;
import com.ibm.watsonx.ai.core.http.interceptors.AuthenticationInterceptor;
import com.ibm.watsonx.ai.core.http.interceptors.LoggerInterceptor;
import com.ibm.watsonx.ai.core.http.interceptors.LoggerInterceptor.LogMode;
import com.ibm.watsonx.ai.core.http.interceptors.RetryInterceptor;
Expand All @@ -22,7 +22,7 @@
* <ul>
* <li>{@link RetryInterceptor#ON_TOKEN_EXPIRED} – retry on expired authentication tokens</li>
* <li>{@link RetryInterceptor#ON_RETRYABLE_STATUS_CODES} – retry on retryable status codes (5xx, etc.)</li>
* <li>{@link BearerInterceptor} – attach an IAM or custom {@link Authenticator}</li>
* <li>{@link AuthenticationInterceptor} – attach an IAM or custom {@link Authenticator}</li>
* <li>{@link LoggerInterceptor} – optional request/response logging</li>
* </ul>
*/
Expand All @@ -45,7 +45,7 @@ public static SyncHttpClient createSync(Authenticator authenticator, HttpClient
builder.interceptor(RetryInterceptor.ON_TOKEN_EXPIRED);

if (nonNull(authenticator)) {
builder.interceptor(new BearerInterceptor(authenticator));
builder.interceptor(new AuthenticationInterceptor(authenticator));
}

builder.interceptor(RetryInterceptor.ON_RETRYABLE_STATUS_CODES);
Expand Down Expand Up @@ -77,7 +77,7 @@ public static AsyncHttpClient createAsync(Authenticator authenticator, HttpClien
builder.interceptor(RetryInterceptor.ON_TOKEN_EXPIRED);

if (nonNull(authenticator)) {
builder.interceptor(new BearerInterceptor(authenticator));
builder.interceptor(new AuthenticationInterceptor(authenticator));
}

builder.interceptor(RetryInterceptor.ON_RETRYABLE_STATUS_CODES);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,42 @@
import java.net.http.HttpResponse.BodyHandler;
import java.util.concurrent.CompletableFuture;
import com.ibm.watsonx.ai.core.auth.Authenticator;
import com.ibm.watsonx.ai.core.auth.cp4d.AuthMode;
import com.ibm.watsonx.ai.core.auth.cp4d.CP4DAuthenticator;
import com.ibm.watsonx.ai.core.exception.WatsonxException;
import com.ibm.watsonx.ai.core.http.AsyncHttpInterceptor;
import com.ibm.watsonx.ai.core.http.SyncHttpInterceptor;

/**
* Interceptor that adds a Bearer token to outgoing requests.
* Interceptor that adds an Authorization header to outgoing HTTP requests.
*/
public final class BearerInterceptor implements SyncHttpInterceptor, AsyncHttpInterceptor {
public final class AuthenticationInterceptor implements SyncHttpInterceptor, AsyncHttpInterceptor {

private final Authenticator authenticator;

/**
* Constructs a new BearerInterceptor with the given authenticator.
* Constructs a new AuthenticationInterceptor with the given authenticator.
*
* @param authenticator the authenticator used to retrieve bearer tokens
* @param authenticator the authenticator used to retrieve authentication tokens
*/
public BearerInterceptor(Authenticator authenticator) {
public AuthenticationInterceptor(Authenticator authenticator) {
this.authenticator = authenticator;
}

@Override
public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index, AsyncChain chain) {
return authenticator.asyncToken()
.thenCompose(token -> chain.proceed(requestWithBearer(request, token), bodyHandler));
.thenCompose(token -> chain.proceed(requestWithAuthHeader(request, token), bodyHandler));
}

@Override
public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index, Chain chain)
throws WatsonxException, IOException, InterruptedException {
var token = authenticator.token();
return chain.proceed(requestWithBearer(request, token), bodyHandler);
return chain.proceed(requestWithAuthHeader(request, token), bodyHandler);
}

// Creates a copy of the given request with the Authorization header set to use the Bearer token.
private HttpRequest requestWithBearer(HttpRequest request, String token) {
var authorization = authenticator instanceof CP4DAuthenticator auth && auth.isAuthMode(AuthMode.ZEN_API_KEY)
? "ZenApiKey %s".formatted(token)
: "Bearer %s".formatted(token);
// Creates a copy of the given request with the appropriate Authorization header.
private HttpRequest requestWithAuthHeader(HttpRequest request, String token) {
var authorization = "%s %s".formatted(authenticator.scheme(), token);
return HttpRequest.newBuilder(request, (key, value) -> true)
.header("Authorization", authorization)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import com.ibm.watsonx.ai.core.auth.cp4d.AuthMode;
import com.ibm.watsonx.ai.core.auth.cp4d.CP4DAuthenticator;
import com.ibm.watsonx.ai.core.http.AsyncHttpClient;
import com.ibm.watsonx.ai.core.http.AsyncHttpInterceptor;
import com.ibm.watsonx.ai.core.http.SyncHttpClient;
import com.ibm.watsonx.ai.core.http.interceptors.BearerInterceptor;
import com.ibm.watsonx.ai.core.http.interceptors.AuthenticationInterceptor;
import com.ibm.watsonx.ai.core.provider.ExecutorProvider;

@ExtendWith(MockitoExtension.class)
public class BearerInterceptorTest extends AbstractWatsonxTest {
@MockitoSettings(strictness = Strictness.LENIENT)
public class AuthenticationInterceptorTest extends AbstractWatsonxTest {

@Nested
class Sync {
Expand All @@ -46,6 +49,7 @@ class Sync {
void should_send_request_with_bearer_token() throws Exception {

when(mockAuthenticator.token()).thenReturn("my_super_token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");
when(mockHttpResponse.statusCode()).thenReturn(200);

withWatsonxServiceMock(() -> {
Expand All @@ -54,7 +58,7 @@ void should_send_request_with_bearer_token() throws Exception {

var client = SyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(mockAuthenticator))
.interceptor(new AuthenticationInterceptor(mockAuthenticator))
.build();

try {
Expand All @@ -77,6 +81,7 @@ void should_send_request_with_zen_api_key() throws Exception {
var cp4dAuthenticatorMock = mock(CP4DAuthenticator.class);
when(cp4dAuthenticatorMock.token()).thenReturn("#1234");
when(cp4dAuthenticatorMock.isAuthMode(AuthMode.ZEN_API_KEY)).thenReturn(true);
when(cp4dAuthenticatorMock.scheme()).thenReturn("ZenApiKey");
when(mockHttpResponse.statusCode()).thenReturn(200);

withWatsonxServiceMock(() -> {
Expand All @@ -85,7 +90,7 @@ void should_send_request_with_zen_api_key() throws Exception {

var client = SyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(cp4dAuthenticatorMock))
.interceptor(new AuthenticationInterceptor(cp4dAuthenticatorMock))
.build();

try {
Expand All @@ -111,7 +116,7 @@ void should_throw_exception_when_bearer_token_is_invalid() throws Exception {

var client = SyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(mockAuthenticator))
.interceptor(new AuthenticationInterceptor(mockAuthenticator))
.build();

try {
Expand All @@ -136,13 +141,14 @@ class Async {
void should_send_request_with_bearer_token() throws Exception {

when(mockAuthenticator.asyncToken()).thenReturn(completedFuture("my_super_token"));
when(mockAuthenticator.scheme()).thenReturn("Bearer");

withWatsonxServiceMock(() -> {
mockHttpClientAsyncSend(mockHttpRequest.capture(), any());

var client = AsyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(mockAuthenticator))
.interceptor(new AuthenticationInterceptor(mockAuthenticator))
.build();

var fakeRequest = HttpRequest.newBuilder(URI.create("http://test"))
Expand All @@ -159,6 +165,7 @@ void should_send_request_with_zen_api_key() throws Exception {

var cp4dAuthenticatorMock = mock(CP4DAuthenticator.class);
when(cp4dAuthenticatorMock.asyncToken()).thenReturn(completedFuture("#1234"));
when(cp4dAuthenticatorMock.scheme()).thenReturn("ZenApiKey");
when(cp4dAuthenticatorMock.isAuthMode(AuthMode.ZEN_API_KEY)).thenReturn(true);

withWatsonxServiceMock(() -> {
Expand All @@ -167,7 +174,7 @@ void should_send_request_with_zen_api_key() throws Exception {

var client = AsyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(cp4dAuthenticatorMock))
.interceptor(new AuthenticationInterceptor(cp4dAuthenticatorMock))
.build();

var fakeRequest = HttpRequest.newBuilder(URI.create("http://test"))
Expand All @@ -188,7 +195,7 @@ void should_throw_exception_when_bearer_token_is_invalid() {

var client = AsyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(mockAuthenticator))
.interceptor(new AuthenticationInterceptor(mockAuthenticator))
.build();

var fakeRequest = HttpRequest.newBuilder(URI.create("http://test"))
Expand Down Expand Up @@ -222,7 +229,7 @@ void should_execute_with_custom_executor() throws Exception {

var client = AsyncHttpClient.builder()
.httpClient(mockSecureHttpClient)
.interceptor(new BearerInterceptor(mockAuthenticator))
.interceptor(new AuthenticationInterceptor(mockAuthenticator))
.interceptor(new AsyncHttpInterceptor() {
@Override
public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,4 +720,36 @@ void should_use_the_correct_executors() throws Exception {
});
}
}

@Test
void should_return_the_correct_scheme() {

CP4DAuthenticator authenticator = CP4DAuthenticator.builder()
.username("username")
.baseUrl(URI.create("http://my-url"))
.apiKey("api_key")
.authMode(AuthMode.ZEN_API_KEY)
.build();

assertEquals("ZenApiKey", authenticator.scheme());

authenticator = CP4DAuthenticator.builder()
.baseUrl("http://my-url/")
.username("username")
.password("password")
.authMode(AuthMode.IAM)
.build();

assertEquals("Bearer", authenticator.scheme());

authenticator = CP4DAuthenticator.builder()
.username("username")
.baseUrl(URI.create("http://my-url"))
.apiKey("my_super_api_key")
.timeout(Duration.ofSeconds(10))
.authMode(AuthMode.LEGACY)
.build();

assertEquals("Bearer", authenticator.scheme());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,10 @@ void should_use_the_correct_executors() throws Exception {
});
}
}

@Test
void should_return_the_correct_scheme() {
var authenticator = IBMCloudAuthenticator.withKey("api-key");
assertEquals("Bearer", authenticator.scheme());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import com.ibm.watsonx.ai.core.auth.Authenticator;
import com.ibm.watsonx.ai.core.provider.HttpClientProvider;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public abstract class AbstractWatsonxTest {

protected static final String ML_API_PATH = "/ml/v1";
Expand Down Expand Up @@ -53,6 +56,7 @@ public abstract class AbstractWatsonxTest {

@BeforeEach
void setUp() {
when(mockAuthenticator.scheme()).thenReturn("Bearer");
resetHttpClient();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public class BatchServiceTest extends AbstractWatsonxTest {
@BeforeEach
void setUp() {
when(mockAuthenticator.token()).thenReturn("token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");
BASE_URL = "http://localhost:%s".formatted(wireMock.getPort());

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public class DeploymentServiceTest extends AbstractWatsonxTest {
@BeforeEach
void setup() {
when(mockAuthenticator.token()).thenReturn("token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");
when(mockAuthenticator.asyncToken()).thenReturn(CompletableFuture.completedFuture("token"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void should_build_request_with_correct_parameters() throws Exception {
void should_detect_pii_and_hap_entities_request() {

when(mockAuthenticator.token()).thenReturn("token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");

wireMock.stubFor(post("/ml/v1/text/detection?version=%s".formatted(API_VERSION))
.withRequestBody(equalToJson("""
Expand Down Expand Up @@ -269,6 +270,7 @@ void should_catch_watsonx_exception() throws Exception {
}""";

when(mockAuthenticator.token()).thenReturn("token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");

wireMock.stubFor(post("/ml/v1/text/detection?version=%s".formatted(API_VERSION))
.withRequestBody(equalToJson("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class FileServiceTest extends AbstractWatsonxTest {
@BeforeEach
void setUp() {
when(mockAuthenticator.token()).thenReturn("token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");
BASE_URL = "http://localhost:%s".formatted(wireMock.getPort());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ void should_stream_text_generation_correctly() throws Exception {


when(mockAuthenticator.asyncToken()).thenReturn(completedFuture("my-super-token"));
when(mockAuthenticator.scheme()).thenReturn("Bearer");

var textGenerationService = TextGenerationService.builder()
.authenticator(mockAuthenticator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class ToolServiceTest extends AbstractWatsonxTest {
@BeforeEach
void setup() {
when(mockAuthenticator.token()).thenReturn("token");
when(mockAuthenticator.scheme()).thenReturn("Bearer");
}

@Test
Expand Down
Loading