Skip to content

Commit 220aad2

Browse files
authored
Merge pull request #1596 from CMSgov/QPPA-11251-memory-leak-mitigation
QPPA-11251: Adds closure to s3 connections to mitigate memory leak issue
2 parents 8000248 + f512682 commit 220aad2

2 files changed

Lines changed: 70 additions & 21 deletions

File tree

rest-api/src/main/java/gov/cms/qpp/conversion/api/services/internal/StorageServiceImpl.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import gov.cms.qpp.conversion.api.model.Constants;
2121
import gov.cms.qpp.conversion.api.services.StorageService;
2222

23+
import java.io.FilterInputStream;
24+
import java.io.IOException;
2325
import java.io.InputStream;
2426
import java.util.Objects;
2527
import java.util.concurrent.CompletableFuture;
@@ -99,12 +101,11 @@ public InputStream getFileByLocationId(String fileLocationId) {
99101
API_LOG.info("Retrieving file {} from bucket {}", fileLocationId, bucketName);
100102

101103
GetObjectRequest getObjectRequest = new GetObjectRequest(bucketName, fileLocationId);
102-
103-
S3Object s3Object = amazonS3.getObject(getObjectRequest);
104+
InputStream objectContent = retrieveManagedContentStream(getObjectRequest);
104105

105106
API_LOG.info("Successfully retrieved file {} from S3 bucket {}", getObjectRequest.getKey(), getObjectRequest.getBucketName());
106107

107-
return s3Object.getObjectContent();
108+
return objectContent;
108109
}
109110

110111
/**
@@ -123,9 +124,7 @@ public InputStream getCpcPlusValidationFile() {
123124
API_LOG.info("Retrieving CPC+ validation file from bucket {}", bucketName);
124125

125126
GetObjectRequest getObjectRequest = new GetObjectRequest(bucketName, key);
126-
S3Object s3Object = amazonS3.getObject(getObjectRequest);
127-
128-
return s3Object.getObjectContent();
127+
return retrieveManagedContentStream(getObjectRequest);
129128
}
130129

131130
/**
@@ -143,9 +142,15 @@ public InputStream getApmValidationFile(String fileName) {
143142
API_LOG.info("Retrieving APM validation file from bucket {}", bucketName);
144143

145144
GetObjectRequest getObjectRequest = new GetObjectRequest(bucketName, fileName);
146-
S3Object s3Object = amazonS3.getObject(getObjectRequest);
145+
return retrieveManagedContentStream(getObjectRequest);
146+
}
147147

148-
return s3Object.getObjectContent();
148+
/**
149+
* Provides an {@link InputStream} whose close also closes the backing {@link S3Object}.
150+
*/
151+
private InputStream retrieveManagedContentStream(GetObjectRequest getObjectRequest) {
152+
S3Object s3Object = amazonS3.getObject(getObjectRequest);
153+
return new ManagedS3InputStream(s3Object);
149154
}
150155

151156
/**
@@ -176,4 +181,22 @@ protected String asynchronousAction(Supplier<PutObjectRequest> objectToActOn) {
176181
protected String getActionName() {
177182
return "Write to Storage";
178183
}
184+
185+
private static final class ManagedS3InputStream extends FilterInputStream {
186+
private final S3Object s3Object;
187+
188+
ManagedS3InputStream(S3Object s3Object) {
189+
super(s3Object.getObjectContent());
190+
this.s3Object = s3Object;
191+
}
192+
193+
@Override
194+
public void close() throws IOException {
195+
try {
196+
super.close();
197+
} finally {
198+
s3Object.close();
199+
}
200+
}
201+
}
179202
}

rest-api/src/test/java/gov/cms/qpp/conversion/api/services/internal/StorageServiceImplTest.java

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import static org.mockito.Mockito.when;
1313

1414
import java.io.ByteArrayInputStream;
15+
import java.io.ByteArrayOutputStream;
16+
import java.io.IOException;
1517
import java.io.InputStream;
18+
import java.nio.charset.StandardCharsets;
1619
import java.util.concurrent.CompletableFuture;
1720
import java.util.concurrent.CompletionException;
1821

@@ -144,7 +147,8 @@ void noBucket() {
144147
@Test
145148
void envVariablesPresent() {
146149
S3Object s3ObjectMock = mock(S3Object.class);
147-
s3ObjectMock.setObjectContent(new ByteArrayInputStream("1234".getBytes()));
150+
S3ObjectInputStream objectContent = new S3ObjectInputStream(new ByteArrayInputStream("1234".getBytes(StandardCharsets.UTF_8)), null);
151+
when(s3ObjectMock.getObjectContent()).thenReturn(objectContent);
148152
Mockito.when(amazonS3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3ObjectMock);
149153
Mockito.when(environment.getProperty(Constants.BUCKET_NAME_ENV_VARIABLE)).thenReturn("meep");
150154
underTest.getFileByLocationId("meep");
@@ -197,31 +201,53 @@ void test_getCpcPlusValidationFile_NPE() {
197201
}
198202

199203
@Test
200-
void test_getCpcPlusValidationFile() {
201-
S3ObjectInputStream expected = new S3ObjectInputStream(null, null);
204+
void test_getCpcPlusValidationFile() throws IOException {
205+
byte[] expectedBytes = "Mock Contents".getBytes(StandardCharsets.UTF_8);
206+
S3ObjectInputStream expectedStream = new S3ObjectInputStream(new ByteArrayInputStream(expectedBytes), null);
202207
S3Object mockS3Obj = mock(S3Object.class);
203-
Mockito.when(mockS3Obj.getObjectContent()).thenReturn(expected);
208+
Mockito.when(mockS3Obj.getObjectContent()).thenReturn(expectedStream);
204209

205210
Mockito.when(environment.getProperty(Constants.CPC_PLUS_BUCKET_NAME_VARIABLE)).thenReturn("Mock_Bucket");
206211
Mockito.when(environment.getProperty(Constants.CPC_PLUS_FILENAME_VARIABLE)).thenReturn("Mock_Filename");
207-
Mockito.when(amazonS3Client.getObject( any(GetObjectRequest.class) )).thenReturn(mockS3Obj);
212+
Mockito.when(amazonS3Client.getObject(any(GetObjectRequest.class))).thenReturn(mockS3Obj);
208213

209-
InputStream actual = underTest.getCpcPlusValidationFile();
214+
byte[] actualBytes;
215+
try (InputStream actual = underTest.getCpcPlusValidationFile()) {
216+
assertThat(actual).isNotNull();
217+
actualBytes = toByteArray(actual);
218+
}
210219

211-
assertThat(actual).isEqualTo(expected);
220+
assertThat(actualBytes).isEqualTo(expectedBytes);
221+
verify(mockS3Obj, times(1)).close();
212222
}
213223

214224
@Test
215-
void test_getApmValidationFile() {
216-
S3ObjectInputStream expected = new S3ObjectInputStream(null, null);
225+
void test_getApmValidationFile() throws IOException {
226+
byte[] expectedBytes = "APM".getBytes(StandardCharsets.UTF_8);
227+
S3ObjectInputStream expectedStream = new S3ObjectInputStream(new ByteArrayInputStream(expectedBytes), null);
217228
S3Object mockS3Obj = mock(S3Object.class);
218-
Mockito.when(mockS3Obj.getObjectContent()).thenReturn(expected);
229+
Mockito.when(mockS3Obj.getObjectContent()).thenReturn(expectedStream);
219230

220231
Mockito.when(environment.getProperty(Constants.BUCKET_NAME_ENV_VARIABLE)).thenReturn("Mock_Bucket");
221-
Mockito.when(amazonS3Client.getObject( any(GetObjectRequest.class) )).thenReturn(mockS3Obj);
232+
Mockito.when(amazonS3Client.getObject(any(GetObjectRequest.class))).thenReturn(mockS3Obj);
222233

223-
InputStream actual = underTest.getApmValidationFile(Constants.CPC_PLUS_APM_FILE_NAME_KEY);
234+
byte[] actualBytes;
235+
try (InputStream actual = underTest.getApmValidationFile(Constants.CPC_PLUS_APM_FILE_NAME_KEY)) {
236+
assertThat(actual).isNotNull();
237+
actualBytes = toByteArray(actual);
238+
}
224239

225-
assertThat(actual).isEqualTo(expected);
240+
assertThat(actualBytes).isEqualTo(expectedBytes);
241+
verify(mockS3Obj, times(1)).close();
242+
}
243+
244+
private byte[] toByteArray(InputStream inputStream) throws IOException {
245+
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
246+
byte[] data = new byte[1024];
247+
int bytesRead;
248+
while ((bytesRead = inputStream.read(data)) != -1) {
249+
buffer.write(data, 0, bytesRead);
250+
}
251+
return buffer.toByteArray();
226252
}
227253
}

0 commit comments

Comments
 (0)