From 1df166b8a96dfbd6dc204fce527d6f0647d55298 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Feb 2026 05:06:38 -0800 Subject: [PATCH] refactor: GcsServiceArtifact becomes async Adding a bunch of tests for GcsServiceArtifact PiperOrigin-RevId: 872341848 --- .../adk/artifacts/GcsArtifactService.java | 222 ++++++++++-------- .../adk/artifacts/GcsArtifactServiceTest.java | 145 ++++++++++++ 2 files changed, 270 insertions(+), 97 deletions(-) diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index b9bc49a02..e31d50327 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -28,12 +28,12 @@ import com.google.common.base.Splitter; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import com.google.genai.types.FileData; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -135,22 +135,25 @@ public Maybe loadArtifact( .flatMapMaybe( versions -> versions.isEmpty() ? Maybe.empty() : Maybe.just(max(versions)))) .flatMap( - versionToLoad -> { - String blobName = getBlobName(appName, userId, sessionId, filename, versionToLoad); - BlobId blobId = BlobId.of(bucketName, blobName); + versionToLoad -> + Maybe.fromCallable( + () -> { + String blobName = + getBlobName(appName, userId, sessionId, filename, versionToLoad); + BlobId blobId = BlobId.of(bucketName, blobName); - try { - Blob blob = storageClient.get(blobId); - if (blob == null || !blob.exists()) { - return Maybe.empty(); - } - byte[] data = blob.getContent(); - String mimeType = blob.getContentType(); - return Maybe.just(Part.fromBytes(data, mimeType)); - } catch (StorageException e) { - return Maybe.empty(); - } - }); + try { + Blob blob = storageClient.get(blobId); + if (blob == null || !blob.exists()) { + return null; + } + byte[] data = blob.getContent(); + String mimeType = blob.getContentType(); + return Part.fromBytes(data, mimeType); + } catch (StorageException e) { + return null; + } + })); } /** @@ -164,34 +167,38 @@ public Maybe loadArtifact( @Override public Single listArtifactKeys( String appName, String userId, String sessionId) { - Set filenames = new HashSet<>(); + return Single.fromCallable( + () -> { + Set filenames = new HashSet<>(); - // List session-specific files - String sessionPrefix = String.format("%s/%s/%s/", appName, userId, sessionId); - try { - for (Blob blob : - storageClient.list(bucketName, BlobListOption.prefix(sessionPrefix)).iterateAll()) { - List parts = Splitter.on('/').splitToList(blob.getName()); - filenames.add(parts.get(3)); // appName/userId/sessionId/filename/version - } - } catch (StorageException e) { - throw new VerifyException("Failed to list session artifacts from GCS", e); - } + // List session-specific files + String sessionPrefix = String.format("%s/%s/%s/", appName, userId, sessionId); + try { + for (Blob blob : + storageClient.list(bucketName, BlobListOption.prefix(sessionPrefix)).iterateAll()) { + List parts = Splitter.on('/').splitToList(blob.getName()); + filenames.add(parts.get(3)); // appName/userId/sessionId/filename/version + } + } catch (StorageException e) { + throw new VerifyException("Failed to list session artifacts from GCS", e); + } - // List user-namespace files - String userPrefix = String.format("%s/%s/user/", appName, userId); - try { - for (Blob blob : - storageClient.list(bucketName, BlobListOption.prefix(userPrefix)).iterateAll()) { - List parts = Splitter.on('/').splitToList(blob.getName()); - filenames.add(parts.get(3)); // appName/userId/user/filename/version - } - } catch (StorageException e) { - throw new VerifyException("Failed to list user artifacts from GCS", e); - } + // List user-namespace files + String userPrefix = String.format("%s/%s/user/", appName, userId); + try { + for (Blob blob : + storageClient.list(bucketName, BlobListOption.prefix(userPrefix)).iterateAll()) { + List parts = Splitter.on('/').splitToList(blob.getName()); + filenames.add(parts.get(3)); // appName/userId/user/filename/version + } + } catch (StorageException e) { + throw new VerifyException("Failed to list user artifacts from GCS", e); + } - return Single.just( - ListArtifactsResponse.builder().filenames(ImmutableList.sortedCopyOf(filenames)).build()); + return ListArtifactsResponse.builder() + .filenames(ImmutableList.sortedCopyOf(filenames)) + .build(); + }); } /** @@ -206,22 +213,30 @@ public Single listArtifactKeys( @Override public Completable deleteArtifact( String appName, String userId, String sessionId, String filename) { - ImmutableList versions = - listVersions(appName, userId, sessionId, filename).blockingGet(); - List blobIdsToDelete = new ArrayList<>(); - for (int version : versions) { - String blobName = getBlobName(appName, userId, sessionId, filename, version); - blobIdsToDelete.add(BlobId.of(bucketName, blobName)); - } + return listVersions(appName, userId, sessionId, filename) + .flatMapCompletable( + versions -> { + if (versions.isEmpty()) { + return Completable.complete(); + } + ImmutableList blobIdsToDelete = + versions.stream() + .map( + version -> + BlobId.of( + bucketName, + getBlobName(appName, userId, sessionId, filename, version))) + .collect(ImmutableList.toImmutableList()); - if (!blobIdsToDelete.isEmpty()) { - try { - var unused = storageClient.delete(blobIdsToDelete); - } catch (StorageException e) { - throw new VerifyException("Failed to delete artifact versions from GCS", e); - } - } - return Completable.complete(); + return Completable.fromAction( + () -> { + try { + var unused = storageClient.delete(blobIdsToDelete); + } catch (StorageException e) { + throw new VerifyException("Failed to delete artifact versions from GCS", e); + } + }); + }); } /** @@ -236,20 +251,29 @@ public Completable deleteArtifact( @Override public Single> listVersions( String appName, String userId, String sessionId, String filename) { - String prefix = getBlobPrefix(appName, userId, sessionId, filename); - List versions = new ArrayList<>(); - try { - for (Blob blob : storageClient.list(bucketName, BlobListOption.prefix(prefix)).iterateAll()) { - String name = blob.getName(); - int versionDelimiterIndex = name.lastIndexOf('/'); // immediately before the version number - if (versionDelimiterIndex != -1 && versionDelimiterIndex < name.length() - 1) { - versions.add(Integer.parseInt(name.substring(versionDelimiterIndex + 1))); - } - } - return Single.just(ImmutableList.sortedCopyOf(versions)); - } catch (StorageException e) { - return Single.just(ImmutableList.of()); - } + return Single.fromCallable( + () -> { + String prefix = getBlobPrefix(appName, userId, sessionId, filename); + try { + return Streams.stream( + storageClient.list(bucketName, BlobListOption.prefix(prefix)).iterateAll()) + .map(Blob::getName) + .map( + name -> { + int versionDelimiterIndex = name.lastIndexOf('/'); + return versionDelimiterIndex != -1 + && versionDelimiterIndex < name.length() - 1 + ? Optional.of(name.substring(versionDelimiterIndex + 1)) + : Optional.empty(); + }) + .flatMap(Optional::stream) + .map(Integer::parseInt) + .sorted() + .collect(ImmutableList.toImmutableList()); + } catch (StorageException e) { + return ImmutableList.of(); + } + }); } @Override @@ -291,35 +315,39 @@ private Single saveArtifactAndReturnBlob( String appName, String userId, String sessionId, String filename, Part artifact) { return listVersions(appName, userId, sessionId, filename) .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) - .map( - nextVersion -> { - if (artifact.inlineData().isEmpty()) { - throw new IllegalArgumentException("Saveable artifact must have inline data."); - } + .flatMap( + nextVersion -> + Single.fromCallable( + () -> { + if (artifact.inlineData().isEmpty()) { + throw new IllegalArgumentException( + "Saveable artifact must have inline data."); + } - String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); - BlobId blobId = BlobId.of(bucketName, blobName); + String blobName = + getBlobName(appName, userId, sessionId, filename, nextVersion); + BlobId blobId = BlobId.of(bucketName, blobName); - BlobInfo blobInfo = - BlobInfo.newBuilder(blobId) - .setContentType(artifact.inlineData().get().mimeType().orElse(null)) - .build(); + BlobInfo blobInfo = + BlobInfo.newBuilder(blobId) + .setContentType(artifact.inlineData().get().mimeType().orElse(null)) + .build(); - try { - byte[] dataToSave = - artifact - .inlineData() - .get() - .data() - .orElseThrow( - () -> - new IllegalArgumentException( - "Saveable artifact data must be non-empty.")); - Blob blob = storageClient.create(blobInfo, dataToSave); - return SaveResult.create(blob, nextVersion); - } catch (StorageException e) { - throw new VerifyException("Failed to save artifact to GCS", e); - } - }); + try { + byte[] dataToSave = + artifact + .inlineData() + .get() + .data() + .orElseThrow( + () -> + new IllegalArgumentException( + "Saveable artifact data must be non-empty.")); + Blob blob = storageClient.create(blobInfo, dataToSave); + return SaveResult.create(blob, nextVersion); + } catch (StorageException e) { + throw new VerifyException("Failed to save artifact to GCS", e); + } + })); } } diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 40493bf3a..88abd60c4 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -16,6 +16,7 @@ package com.google.adk.artifacts; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -28,6 +29,8 @@ import com.google.cloud.storage.BlobInfo; import com.google.cloud.storage.Storage; import com.google.cloud.storage.Storage.BlobListOption; +import com.google.cloud.storage.StorageException; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; @@ -41,6 +44,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -233,8 +237,10 @@ public void list_noFiles_returnsEmpty() { String sessionPrefix = String.format("%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID); String userPrefix = String.format("%s/%s/user/", APP_NAME, USER_ID); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockSessionPage = mock(Page.class); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockUserPage = mock(Page.class); when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) @@ -262,8 +268,10 @@ public void list_withFiles_returnsCorrectFilenames() { Blob blobS2V0 = mockBlob(sessionPrefix + sessionFile2 + "/0", "text/log", new byte[0]); Blob blobU1V0 = mockBlob(userPrefix + userFile1 + "/0", "app/json", new byte[0]); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockSessionPage = mock(Page.class); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockUserPage = mock(Page.class); when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) @@ -363,6 +371,143 @@ public void saveAndReloadArtifact_savesAndReturnsFileData() { verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3})); } + @Test + public void save_noInlineData_throwsException() { + Part artifact = Part.builder().build(); // No inline data + assertThrows( + IllegalArgumentException.class, + () -> + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet()); + } + + @Test + public void save_storageException_throwsVerifyException() { + Part artifact = Part.fromBytes(new byte[] {1}, "text/plain"); + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + when(mockStorage.create(any(BlobInfo.class), any(byte[].class))) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet()); + } + + @Test + public void load_storageException_returnsEmpty() { + String blobNameV0 = String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + BlobId blobIdV0 = BlobId.of(BUCKET_NAME, blobNameV0); + when(mockStorage.get(blobIdV0)).thenThrow(new StorageException(500, "Induced error")); + + Optional loadedArtifact = + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.of(0))); + + assertThat(loadedArtifact).isEmpty(); + } + + @Test + public void list_sessionStorageException_throwsVerifyException() { + String sessionPrefix = String.format("%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID); + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> service.listArtifactKeys(APP_NAME, USER_ID, SESSION_ID).blockingGet()); + } + + @Test + public void list_userStorageException_throwsVerifyException() { + String sessionPrefix = String.format("%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID); + String userPrefix = String.format("%s/%s/user/", APP_NAME, USER_ID); + + // Mocking generic Page class requires unchecked suppression. + @SuppressWarnings("unchecked") + Page mockSessionPage = mock(Page.class); + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) + .thenReturn(mockSessionPage); + when(mockSessionPage.iterateAll()).thenReturn(ImmutableList.of()); + + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(userPrefix))) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> service.listArtifactKeys(APP_NAME, USER_ID, SESSION_ID).blockingGet()); + } + + @Test + public void delete_storageException_throwsVerifyException() { + String blobNameV0 = String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1}); + + when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0)); + when(mockStorage.delete(ArgumentMatchers.>any())) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> service.deleteArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME).blockingAwait()); + } + + @Test + public void listVersions_storageException_returnsEmptyList() { + String prefix = String.format("%s/%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID, FILENAME); + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(prefix))) + .thenThrow(new StorageException(500, "Induced error")); + + ImmutableList versions = + service.listVersions(APP_NAME, USER_ID, SESSION_ID, FILENAME).blockingGet(); + + assertThat(versions).isEmpty(); + } + + @Test + public void saveAndReload_noContentTypeAnywhere_defaultsToOctetStream() { + // Artifact with no mime type + Part artifact = + Part.builder() + .inlineData(com.google.genai.types.Blob.builder().data(new byte[] {1}).build()) + .build(); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mock(Blob.class); + when(savedBlob.getName()).thenReturn(expectedBlobName); + when(savedBlob.getBucket()).thenReturn(BUCKET_NAME); + when(savedBlob.getContentType()).thenReturn(null); + when(mockStorage.create(any(BlobInfo.class), any(byte[].class))).thenReturn(savedBlob); + + Part result = + service + .saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact) + .blockingGet(); + + assertThat(result.fileData().get().mimeType()).hasValue("application/octet-stream"); + } + + @Test + public void saveAndReload_blobMissingContentType_usesArtifactContentType() { + Part artifact = Part.fromBytes(new byte[] {1}, "application/pdf"); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mock(Blob.class); + when(savedBlob.getName()).thenReturn(expectedBlobName); + when(savedBlob.getBucket()).thenReturn(BUCKET_NAME); + when(savedBlob.getContentType()).thenReturn(null); + when(mockStorage.create(any(BlobInfo.class), any(byte[].class))).thenReturn(savedBlob); + + Part result = + service + .saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact) + .blockingGet(); + + assertThat(result.fileData().get().mimeType()).hasValue("application/pdf"); + } + private static Optional asOptional(Maybe maybe) { return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); }