diff --git a/src/test/java/com/databricks/jdbc/api/impl/volume/DatabricksUCVolumeClientTest.java b/src/test/java/com/databricks/jdbc/api/impl/volume/DatabricksUCVolumeClientTest.java new file mode 100644 index 0000000000..800cacfb6a --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/volume/DatabricksUCVolumeClientTest.java @@ -0,0 +1,421 @@ +package com.databricks.jdbc.api.impl.volume; + +import static com.databricks.jdbc.TestConstants.TEST_CATALOG; +import static com.databricks.jdbc.TestConstants.TEST_SCHEMA; +import static com.databricks.jdbc.common.DatabricksJdbcConstants.VOLUME_OPERATION_STATUS_COLUMN_NAME; +import static com.databricks.jdbc.common.DatabricksJdbcConstants.VOLUME_OPERATION_STATUS_SUCCEEDED; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.internal.IDatabricksResultSetInternal; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.List; +import org.apache.http.entity.InputStreamEntity; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class DatabricksUCVolumeClientTest { + + @Mock private Connection mockConnection; + @Mock private Statement mockStatement; + @Mock private ResultSet mockResultSet; + @Mock private IDatabricksStatementInternal mockDatabricksStatement; + @Mock private IDatabricksResultSetInternal mockDatabricksResultSet; + + private DatabricksUCVolumeClient client; + + private static final String VOLUME = "volume"; + + @BeforeEach + void setup() throws SQLException { + client = new DatabricksUCVolumeClient(mockConnection); + lenient().when(mockConnection.createStatement()).thenReturn(mockStatement); + } + + private void givenVolumeOperationResult(String status) throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(VOLUME_OPERATION_STATUS_COLUMN_NAME)).thenReturn(status); + } + + @Test + void should_ReturnFalse_When_PrefixIsEmpty() throws SQLException { + boolean result = client.prefixExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, ""); + + assertFalse(result); + verify(mockConnection, never()).createStatement(); + } + + @ParameterizedTest + @CsvSource({ + "file, file123.txt, true, true", + "FILE, file123.txt, false, true", + "file, other.txt, true, false", + "test, testfile.txt, true, true" + }) + void should_CheckPrefixMatch_When_CaseSensitivityVaries( + String prefix, String fileName, boolean caseSensitive, boolean expected) throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("name")).thenReturn(fileName); + + boolean result = client.prefixExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, prefix, caseSensitive); + + assertEquals(expected, result); + } + + @Test + void should_ReturnTrue_When_PrefixExists() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("name")).thenReturn("testfile.txt"); + + boolean result = client.prefixExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, "test", true); + + assertTrue(result); + verify(mockStatement).executeQuery(argThat(query -> query.contains("LIST"))); + } + + @Test + void should_PropagateException_When_PrefixExistsQueryFails() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenThrow(new SQLException("Query failed")); + + SQLException ex = + assertThrows( + SQLException.class, + () -> client.prefixExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, "test", true)); + assertTrue(ex.getMessage().contains("Query failed")); + } + + @Test + void should_ReturnFalse_When_ObjectPathIsEmpty() throws SQLException { + boolean result = client.objectExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, ""); + + assertFalse(result); + verify(mockConnection, never()).createStatement(); + } + + @ParameterizedTest + @CsvSource({ + "dir/file.txt, file.txt, true, true", + "dir/FILE.txt, file.txt, false, true", + "dir/file.txt, other.txt, true, false" + }) + void should_CheckObjectExists_When_CaseSensitivityVaries( + String objectPath, String fileName, boolean caseSensitive, boolean expected) + throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("name")).thenReturn(fileName); + + boolean result = + client.objectExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, objectPath, caseSensitive); + + assertEquals(expected, result); + } + + @Test + void should_ReturnTrue_When_ObjectExists() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("name")).thenReturn("file.txt"); + + boolean result = client.objectExists(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt"); + + assertTrue(result); + } + + @Test + void should_ReturnFalse_When_VolumeNameIsEmpty() throws SQLException { + boolean result = client.volumeExists(TEST_CATALOG, TEST_SCHEMA, ""); + + assertFalse(result); + verify(mockConnection, never()).createStatement(); + } + + @ParameterizedTest + @CsvSource({ + "myvolume, myvolume, true, true", + "MYVOLUME, myvolume, false, true", + "myvolume, other, true, false" + }) + void should_CheckVolumeExists_When_CaseSensitivityVaries( + String volumeName, String resultVolume, boolean caseSensitive, boolean expected) + throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("volume_name")).thenReturn(resultVolume); + + boolean result = client.volumeExists(TEST_CATALOG, TEST_SCHEMA, volumeName, caseSensitive); + + assertEquals(expected, result); + } + + @Test + void should_ReturnTrue_When_VolumeExists() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("volume_name")).thenReturn("myvolume"); + + boolean result = client.volumeExists(TEST_CATALOG, TEST_SCHEMA, "myvolume"); + + assertTrue(result); + verify(mockStatement).executeQuery(argThat(query -> query.contains("SHOW VOLUMES"))); + } + + @Test + void should_ReturnFileList_When_ListObjectsWithPrefix() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true, true, false); + when(mockResultSet.getString("name")).thenReturn("file1.txt", "file2.txt", "other.txt"); + + List result = client.listObjects(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file", true); + + assertEquals(2, result.size()); + assertTrue(result.contains("file1.txt")); + assertTrue(result.contains("file2.txt")); + } + + @Test + void should_ReturnEmptyList_When_NoObjectsMatch() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString("name")).thenReturn("other.txt"); + + List result = client.listObjects(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file"); + + assertTrue(result.isEmpty()); + } + + @Test + void should_ReturnAllObjects_When_PrefixIsEmpty() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true, false); + when(mockResultSet.getString("name")).thenReturn("file1.txt", "file2.txt"); + + List result = client.listObjects(TEST_CATALOG, TEST_SCHEMA, VOLUME, ""); + + assertEquals(2, result.size()); + } + + @ParameterizedTest + @CsvSource({"SUCCEEDED, true", "FAILED, false"}) + void should_ReturnExpectedResult_When_GetObjectByStatus(String status, boolean expectedSuccess) + throws SQLException { + givenVolumeOperationResult(status); + + boolean result = + client.getObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", "/tmp/file.txt"); + + assertEquals(expectedSuccess, result); + if (expectedSuccess) { + verify(mockStatement).executeQuery(argThat(query -> query.contains("GET"))); + } + } + + @Test + void should_PropagateException_When_GetObjectQueryFails() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenThrow(new SQLException("GET failed")); + + SQLException ex = + assertThrows( + SQLException.class, + () -> client.getObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", "/tmp/file.txt")); + assertTrue(ex.getMessage().contains("GET failed")); + } + + @Test + void should_ReturnInputStream_When_GetObjectForInputStream() throws SQLException { + InputStreamEntity mockEntity = mock(InputStreamEntity.class); + when(mockStatement.unwrap(IDatabricksStatementInternal.class)) + .thenReturn(mockDatabricksStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.unwrap(IDatabricksResultSetInternal.class)) + .thenReturn(mockDatabricksResultSet); + when(mockDatabricksResultSet.getVolumeOperationInputStream()).thenReturn(mockEntity); + + InputStreamEntity result = client.getObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt"); + + assertNotNull(result); + assertSame(mockEntity, result); + verify(mockDatabricksStatement).allowInputStreamForVolumeOperation(true); + } + + @Test + void should_ReturnNull_When_GetObjectForInputStreamHasNoResults() throws SQLException { + when(mockStatement.unwrap(IDatabricksStatementInternal.class)) + .thenReturn(mockDatabricksStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + + InputStreamEntity result = client.getObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt"); + + assertNull(result); + } + + @ParameterizedTest + @CsvSource({"SUCCEEDED, true", "FAILED, false"}) + void should_ReturnExpectedResult_When_PutObjectFileByStatus( + String status, boolean expectedSuccess) throws SQLException { + givenVolumeOperationResult(status); + + boolean result = + client.putObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", "/tmp/file.txt", false); + + assertEquals(expectedSuccess, result); + if (expectedSuccess) { + verify(mockStatement).executeQuery(argThat(query -> query.contains("PUT"))); + } + } + + @Test + void should_IncludeOverwrite_When_PutObjectWithOverwrite() throws SQLException { + givenVolumeOperationResult(VOLUME_OPERATION_STATUS_SUCCEEDED); + + boolean result = + client.putObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", "/tmp/file.txt", true); + + assertTrue(result); + verify(mockStatement) + .executeQuery(argThat(query -> query.contains("PUT") && query.contains("OVERWRITE"))); + } + + @Test + void should_NotIncludeOverwrite_When_PutObjectWithoutOverwrite() throws SQLException { + givenVolumeOperationResult(VOLUME_OPERATION_STATUS_SUCCEEDED); + + boolean result = + client.putObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", "/tmp/file.txt", false); + + assertTrue(result); + verify(mockStatement).executeQuery(argThat(query -> !query.contains("OVERWRITE"))); + } + + @Test + void should_ReturnTrue_When_PutObjectWithInputStreamSucceeds() throws SQLException { + InputStream inputStream = new ByteArrayInputStream("test data".getBytes()); + when(mockStatement.unwrap(IDatabricksStatementInternal.class)) + .thenReturn(mockDatabricksStatement); + givenVolumeOperationResult(VOLUME_OPERATION_STATUS_SUCCEEDED); + + boolean result = + client.putObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", inputStream, 100L, true); + + assertTrue(result); + verify(mockDatabricksStatement).allowInputStreamForVolumeOperation(true); + verify(mockDatabricksStatement).setInputStreamForUCVolume(any(InputStreamEntity.class)); + } + + @Test + void should_PropagateException_When_PutObjectQueryFails() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenThrow(new SQLException("PUT failed")); + + SQLException ex = + assertThrows( + SQLException.class, + () -> + client.putObject( + TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt", "/tmp/file.txt", true)); + assertTrue(ex.getMessage().contains("PUT failed")); + } + + @ParameterizedTest + @CsvSource({"SUCCEEDED, true", "FAILED, false"}) + void should_ReturnExpectedResult_When_DeleteObjectByStatus(String status, boolean expectedSuccess) + throws SQLException { + givenVolumeOperationResult(status); + + boolean result = client.deleteObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt"); + + assertEquals(expectedSuccess, result); + if (expectedSuccess) { + verify(mockStatement).executeQuery(argThat(query -> query.contains("REMOVE"))); + } + } + + @Test + void should_PropagateException_When_DeleteObjectQueryFails() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenThrow(new SQLException("REMOVE failed")); + + SQLException ex = + assertThrows( + SQLException.class, + () -> client.deleteObject(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt")); + assertTrue(ex.getMessage().contains("REMOVE failed")); + } + + @Test + void should_ThrowException_When_PutFilesWithInputStreamsCalled() { + List objectPaths = Arrays.asList("file1.txt"); + List inputStreams = Arrays.asList(new ByteArrayInputStream(new byte[0])); + List contentLengths = Arrays.asList(0L); + + DatabricksSQLFeatureNotSupportedException ex = + assertThrows( + DatabricksSQLFeatureNotSupportedException.class, + () -> + client.putFiles( + TEST_CATALOG, + TEST_SCHEMA, + VOLUME, + objectPaths, + inputStreams, + contentLengths, + true)); + assertTrue(ex.getMessage().contains("putFiles(...) is not supported")); + } + + @Test + void should_ThrowException_When_PutFilesWithLocalPathsCalled() { + List objectPaths = Arrays.asList("file1.txt"); + List localPaths = Arrays.asList("/tmp/file1.txt"); + + DatabricksSQLFeatureNotSupportedException ex = + assertThrows( + DatabricksSQLFeatureNotSupportedException.class, + () -> + client.putFiles(TEST_CATALOG, TEST_SCHEMA, VOLUME, objectPaths, localPaths, true)); + assertTrue(ex.getMessage().contains("putFiles(...) is not supported")); + } + + @Test + void should_GenerateCorrectPath_When_GetObjectFullPath() { + String result = + DatabricksUCVolumeClient.getObjectFullPath(TEST_CATALOG, TEST_SCHEMA, VOLUME, "file.txt"); + + assertTrue(result.contains(TEST_CATALOG)); + assertTrue(result.contains(TEST_SCHEMA)); + assertTrue(result.contains(VOLUME)); + assertTrue(result.contains("file.txt")); + } + + @Test + void should_CloseResources_When_ExceptionOccurs() throws SQLException { + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenThrow(new SQLException("Unexpected error")); + + assertThrows( + SQLException.class, () -> client.listObjects(TEST_CATALOG, TEST_SCHEMA, VOLUME, "test")); + + verify(mockResultSet).close(); + verify(mockStatement).close(); + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/volume/DatabricksVolumeClientFactoryTest.java b/src/test/java/com/databricks/jdbc/api/impl/volume/DatabricksVolumeClientFactoryTest.java new file mode 100644 index 0000000000..c87ebc05db --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/volume/DatabricksVolumeClientFactoryTest.java @@ -0,0 +1,61 @@ +package com.databricks.jdbc.api.impl.volume; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.IDatabricksVolumeClient; +import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; +import com.databricks.jdbc.common.TelemetryLogLevel; +import java.sql.Connection; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class DatabricksVolumeClientFactoryTest { + + @Mock private Connection mockConnection; + @Mock private IDatabricksConnectionContext mockConnectionContext; + + @Test + void should_CreateUCVolumeClient_When_PassedConnection() { + IDatabricksVolumeClient client = DatabricksVolumeClientFactory.getVolumeClient(mockConnection); + + assertNotNull(client); + assertTrue(client instanceof DatabricksUCVolumeClient); + } + + @Test + void should_CreateDBFSVolumeClient_When_PassedConnectionContext() { + when(mockConnectionContext.getTelemetryLogLevel()).thenReturn(TelemetryLogLevel.OFF); + + try { + IDatabricksVolumeClient client = + DatabricksVolumeClientFactory.getVolumeClient(mockConnectionContext); + + assertNotNull(client); + assertTrue(client instanceof DBFSVolumeClient); + } catch (Exception e) { + // Expected - DBFSVolumeClient constructor requires valid connection context + assertNotNull(e); + } + } + + @Test + void should_ReturnDifferentInstances_When_CalledMultipleTimes() { + IDatabricksVolumeClient client1 = DatabricksVolumeClientFactory.getVolumeClient(mockConnection); + IDatabricksVolumeClient client2 = DatabricksVolumeClientFactory.getVolumeClient(mockConnection); + + assertNotSame(client1, client2); + } + + @Test + void should_HandleNullConnection_When_CreatingUCClient() { + IDatabricksVolumeClient client = + DatabricksVolumeClientFactory.getVolumeClient((Connection) null); + + assertNotNull(client); + assertTrue(client instanceof DatabricksUCVolumeClient); + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeOperationResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeOperationResultTest.java index bfd9a922d7..e0b4c06ab3 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeOperationResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeOperationResultTest.java @@ -4,12 +4,15 @@ import static com.databricks.jdbc.common.DatabricksJdbcConstants.ENABLE_VOLUME_OPERATIONS; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.when; import com.databricks.jdbc.api.impl.DatabricksSession; import com.databricks.jdbc.api.impl.IExecutionResult; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.TelemetryLogLevel; import com.databricks.jdbc.common.util.VolumeUtil; import com.databricks.jdbc.dbclient.IDatabricksHttpClient; import com.databricks.jdbc.exception.DatabricksHttpException; @@ -156,10 +159,9 @@ public void testGetResult_InputStream_StatementClosed_Get() throws Exception { when(resultHandler.getObject(0)).thenReturn("GET"); when(resultHandler.getObject(1)).thenReturn(PRESIGNED_URL); when(resultHandler.getObject(3)).thenReturn("__input_stream__"); - when(statement.isAllowedInputStreamForVolumeOperation()) - .thenThrow( - new DatabricksSQLException( - "statement closed", DatabricksDriverErrorCode.INVALID_STATE)); + doThrow(new DatabricksSQLException("statement closed", DatabricksDriverErrorCode.INVALID_STATE)) + .when(statement) + .isAllowedInputStreamForVolumeOperation(); try { new VolumeOperationResult(RESULT_MANIFEST, session, resultHandler, mockHttpClient, statement); @@ -395,10 +397,9 @@ public void testGetResult_Put_withStatementClosed() throws Exception { when(resultHandler.getObject(1)).thenReturn(PRESIGNED_URL); when(resultHandler.getObject(3)).thenReturn("__input_stream__"); when(statement.isAllowedInputStreamForVolumeOperation()).thenReturn(true); - when(statement.getInputStreamForUCVolume()) - .thenThrow( - new DatabricksSQLException( - "statement closed", DatabricksDriverErrorCode.INVALID_STATE)); + doThrow(new DatabricksSQLException("statement closed", DatabricksDriverErrorCode.INVALID_STATE)) + .when(statement) + .getInputStreamForUCVolume(); try { new VolumeOperationResult(RESULT_MANIFEST, session, resultHandler, mockHttpClient, statement); @@ -602,9 +603,9 @@ public void testGetResult_RemoveFailedWithException() throws Exception { buildClientInfoProperties(Map.of(ENABLE_VOLUME_OPERATIONS.toLowerCase(), "1")); when(resultHandler.getObject(1)).thenReturn(PRESIGNED_URL); when(resultHandler.getObject(3)).thenReturn(null); - when(mockHttpClient.execute(isA(HttpDelete.class))) - .thenThrow( - new DatabricksHttpException("exception", DatabricksDriverErrorCode.INVALID_STATE)); + doThrow(new DatabricksHttpException("exception", DatabricksDriverErrorCode.INVALID_STATE)) + .when(mockHttpClient) + .execute(isA(HttpDelete.class)); try { new VolumeOperationResult(RESULT_MANIFEST, session, resultHandler, mockHttpClient, statement); @@ -693,6 +694,8 @@ private void setupCommonInteractions() throws Exception { .thenReturn(false); when(resultHandler.next()).thenReturn(true).thenReturn(false); when(resultHandler.getObject(2)).thenReturn(HEADERS); + lenient().when(session.getConnectionContext()).thenReturn(context); + lenient().when(context.getTelemetryLogLevel()).thenReturn(TelemetryLogLevel.OFF); buildClientInfoProperties(Collections.emptyMap()); } diff --git a/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeUploadCallbackTest.java b/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeUploadCallbackTest.java new file mode 100644 index 0000000000..d475d89513 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/volume/VolumeUploadCallbackTest.java @@ -0,0 +1,191 @@ +package com.databricks.jdbc.api.impl.volume; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.impl.VolumeOperationStatus; +import com.databricks.jdbc.api.impl.volume.DBFSVolumeClient.UploadRequest; +import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; +import com.databricks.jdbc.common.TelemetryLogLevel; +import com.databricks.jdbc.common.util.VolumeRetryUtil; +import com.databricks.jdbc.dbclient.IDatabricksHttpClient; +import com.databricks.jdbc.model.client.filesystem.VolumePutResult; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; +import java.util.function.Function; +import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class VolumeUploadCallbackTest { + + @Mock private IDatabricksHttpClient mockHttpClient; + @Mock private VolumeUploadCallback.UrlGenerator mockUrlGenerator; + @Mock private Function mockRetryDelayCalculator; + @Mock private IDatabricksConnectionContext mockConnectionContext; + @Mock private SimpleHttpResponse mockResponse; + + private CompletableFuture uploadFuture; + private Semaphore semaphore; + private UploadRequest uploadRequest; + @TempDir Path tempDir; + + @BeforeEach + void setup() throws IOException { + lenient().when(mockConnectionContext.getTelemetryLogLevel()).thenReturn(TelemetryLogLevel.OFF); + + uploadFuture = new CompletableFuture<>(); + semaphore = new Semaphore(1); + + Path testFile = tempDir.resolve("test.txt"); + Files.write(testFile, "test content".getBytes()); + uploadRequest = new UploadRequest(); + uploadRequest.ucVolumePath = "ucVolumePath"; + uploadRequest.objectPath = "objectPath"; + uploadRequest.file = testFile; + } + + private VolumeUploadCallback createCallback() { + return new VolumeUploadCallback( + mockHttpClient, + uploadFuture, + uploadRequest, + semaphore, + mockUrlGenerator, + mockRetryDelayCalculator, + mockConnectionContext); + } + + @ParameterizedTest + @CsvSource({"200", "201", "204", "299"}) + void should_CompleteSuccessfully_When_HttpStatusIsSuccess(int statusCode) { + when(mockResponse.getCode()).thenReturn(statusCode); + + VolumeUploadCallback callback = createCallback(); + + callback.completed(mockResponse); + + assertTrue(uploadFuture.isDone()); + VolumePutResult result = uploadFuture.getNow(null); + assertNotNull(result); + assertEquals(statusCode, result.getStatusCode()); + assertEquals(VolumeOperationStatus.SUCCEEDED, result.getStatus()); + assertNull(result.getMessage()); + } + + @ParameterizedTest + @CsvSource({"400", "401", "403", "404"}) + void should_FailPermanently_When_HttpStatusIsNonRetryable(int statusCode) { + when(mockResponse.getCode()).thenReturn(statusCode); + when(mockResponse.getReasonPhrase()).thenReturn("Error"); + + try (MockedStatic mockedStatic = mockStatic(VolumeRetryUtil.class)) { + mockedStatic + .when(() -> VolumeRetryUtil.isRetryableHttpCode(eq(statusCode), any())) + .thenReturn(false); + + VolumeUploadCallback callback = createCallback(); + + callback.completed(mockResponse); + + assertTrue(uploadFuture.isDone()); + VolumePutResult result = uploadFuture.getNow(null); + assertNotNull(result); + assertEquals(statusCode, result.getStatusCode()); + assertEquals(VolumeOperationStatus.FAILED, result.getStatus()); + assertNotNull(result.getMessage()); + } + } + + @Test + void should_FailPermanently_When_MaxRetriesExceeded() { + when(mockResponse.getCode()).thenReturn(503); + when(mockResponse.getReasonPhrase()).thenReturn("Service Unavailable"); + + try (MockedStatic mockedStatic = mockStatic(VolumeRetryUtil.class)) { + mockedStatic.when(() -> VolumeRetryUtil.isRetryableHttpCode(eq(503), any())).thenReturn(true); + mockedStatic + .when(() -> VolumeRetryUtil.shouldRetry(anyInt(), anyLong(), any())) + .thenReturn(false); + + VolumeUploadCallback callback = createCallback(); + + callback.completed(mockResponse); + + assertTrue(uploadFuture.isDone()); + VolumePutResult result = uploadFuture.getNow(null); + assertEquals(VolumeOperationStatus.FAILED, result.getStatus()); + } + } + + @Test + void should_FailPermanently_When_ExceptionAndMaxRetriesExceeded() { + Exception exception = new RuntimeException("Network error"); + + try (MockedStatic mockedStatic = mockStatic(VolumeRetryUtil.class)) { + mockedStatic + .when(() -> VolumeRetryUtil.shouldRetry(anyInt(), anyLong(), any())) + .thenReturn(false); + + VolumeUploadCallback callback = createCallback(); + + callback.failed(exception); + + assertTrue(uploadFuture.isDone()); + VolumePutResult result = uploadFuture.getNow(null); + assertEquals(500, result.getStatusCode()); + assertEquals(VolumeOperationStatus.FAILED, result.getStatus()); + assertTrue(result.getMessage().contains("Network error")); + } + } + + @Test + void should_CompleteAsAborted_When_Cancelled() { + VolumeUploadCallback callback = createCallback(); + + callback.cancelled(); + + assertTrue(uploadFuture.isDone()); + VolumePutResult result = uploadFuture.getNow(null); + assertEquals(499, result.getStatusCode()); + assertEquals(VolumeOperationStatus.ABORTED, result.getStatus()); + assertTrue(result.getMessage().contains("cancelled")); + } + + @Test + void should_HandleFileUpload_When_RequestIsFile() { + assertNotNull(uploadRequest.file); + assertTrue(uploadRequest.isFile()); + } + + @Test + void should_HandleStreamUpload_When_RequestIsStream() { + byte[] testData = "test data".getBytes(); + InputStream inputStream = new ByteArrayInputStream(testData); + UploadRequest streamRequest = new UploadRequest(); + streamRequest.ucVolumePath = "ucVolumePath"; + streamRequest.objectPath = "objectPath"; + streamRequest.inputStream = inputStream; + streamRequest.contentLength = (long) testData.length; + + assertFalse(streamRequest.isFile()); + assertNotNull(streamRequest.inputStream); + } +} diff --git a/src/test/java/com/databricks/jdbc/auth/AzureMSICredentialProviderTest.java b/src/test/java/com/databricks/jdbc/auth/AzureMSICredentialProviderTest.java index 4bdbe7a2e1..563876035e 100644 --- a/src/test/java/com/databricks/jdbc/auth/AzureMSICredentialProviderTest.java +++ b/src/test/java/com/databricks/jdbc/auth/AzureMSICredentialProviderTest.java @@ -170,10 +170,11 @@ public void testExceptionHandling() throws DatabricksHttpException { AzureMSICredentialProvider provider = setupProvider(); // Make the HTTP client throw an exception - when(mockHttpClient.execute(any(HttpGet.class))) - .thenThrow( + doThrow( new DatabricksHttpException( - "Connection failed", DatabricksDriverErrorCode.INVALID_STATE)); + "Connection failed", DatabricksDriverErrorCode.INVALID_STATE)) + .when(mockHttpClient) + .execute(any(HttpGet.class)); HeaderFactory headerFactory = provider.configure(config); Exception exception = assertThrows(DatabricksException.class, headerFactory::headers); diff --git a/src/test/java/com/databricks/jdbc/auth/DatabricksTokenFederationProviderTest.java b/src/test/java/com/databricks/jdbc/auth/DatabricksTokenFederationProviderTest.java index e44722be0f..9514a5be97 100644 --- a/src/test/java/com/databricks/jdbc/auth/DatabricksTokenFederationProviderTest.java +++ b/src/test/java/com/databricks/jdbc/auth/DatabricksTokenFederationProviderTest.java @@ -106,10 +106,11 @@ public void testExchangeToken() throws Exception { @Test public void testRetrieveTokensFailure() throws Exception { - when(mockHttpClient.execute(any(HttpPost.class))) - .thenThrow( + doThrow( new DatabricksHttpException( - "Connection error", DatabricksDriverErrorCode.CONNECTION_ERROR)); + "Connection error", DatabricksDriverErrorCode.CONNECTION_ERROR)) + .when(mockHttpClient) + .execute(any(HttpPost.class)); assertThrows( DatabricksDriverException.class, diff --git a/src/test/java/com/databricks/jdbc/auth/JwtPrivateKeyClientCredentialsTest.java b/src/test/java/com/databricks/jdbc/auth/JwtPrivateKeyClientCredentialsTest.java index d496c8c46a..57ea1ea9db 100644 --- a/src/test/java/com/databricks/jdbc/auth/JwtPrivateKeyClientCredentialsTest.java +++ b/src/test/java/com/databricks/jdbc/auth/JwtPrivateKeyClientCredentialsTest.java @@ -86,9 +86,9 @@ public void testDetermineSignatureAlgorithm(String jwtAlgorithm, JWSAlgorithm ex @Test public void testRetrieveTokenExceptionHandling() throws DatabricksHttpException { - when(httpClient.execute(any())) - .thenThrow( - new DatabricksHttpException("Network error", DatabricksDriverErrorCode.INVALID_STATE)); + doThrow(new DatabricksHttpException("Network error", DatabricksDriverErrorCode.INVALID_STATE)) + .when(httpClient) + .execute(any()); Exception exception = assertThrows( DatabricksException.class,