Skip to content

Commit 1cc5693

Browse files
MPSFuzzpytorchstuNicolasHug
authored
Fix CPU decode_jpeg error-path leak on malformed JPEGs (setjmp/longjmp) (#9423)
Co-authored-by: MPSFuzz <2286770808@qq.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent 48956e0 commit 1cc5693

4 files changed

Lines changed: 54 additions & 13 deletions

File tree

torchvision/csrc/io/image/cpu/decode_jpeg.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "common_jpeg.h"
44
#include "exif.h"
55

6+
#include <optional>
7+
68
namespace vision {
79
namespace image {
810

@@ -141,12 +143,23 @@ torch::Tensor decode_jpeg(
141143
struct jpeg_decompress_struct cinfo;
142144
struct torch_jpeg_error_mgr jerr;
143145

146+
// NOTE: libjpeg uses setjmp/longjmp for error handling. longjmp does not
147+
// unwind C++ stack frames, so destructors of objects created after setjmp
148+
// won't run. We use std::optional to declare tensors before setjmp while
149+
// deferring construction, and explicitly reset them on the error path.
150+
std::optional<torch::Tensor> tensor;
151+
std::optional<torch::Tensor> cmyk_line_tensor;
152+
144153
auto datap = data.data_ptr<uint8_t>();
145154
// Setup decompression structure
146155
cinfo.err = jpeg_std_error(&jerr.pub);
147156
jerr.pub.error_exit = torch_jpeg_error_exit;
148157
/* Establish the setjmp return context for my_error_exit to use. */
149158
if (setjmp(jerr.setjmp_buffer)) {
159+
// Release any tensors that may have been allocated after setjmp.
160+
cmyk_line_tensor.reset();
161+
tensor.reset();
162+
150163
/* If we get here, the JPEG code has signaled an error.
151164
* We need to clean up the JPEG object.
152165
*/
@@ -210,10 +223,10 @@ torch::Tensor decode_jpeg(
210223
int width = cinfo.output_width;
211224

212225
int stride = width * channels;
213-
auto tensor =
226+
tensor =
214227
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
215-
auto ptr = tensor.data_ptr<uint8_t>();
216-
torch::Tensor cmyk_line_tensor;
228+
auto ptr = tensor->data_ptr<uint8_t>();
229+
217230
if (cmyk_to_rgb_or_gray) {
218231
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
219232
}
@@ -224,7 +237,7 @@ torch::Tensor decode_jpeg(
224237
* more than one scanline at a time if that's more convenient.
225238
*/
226239
if (cmyk_to_rgb_or_gray) {
227-
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
240+
auto cmyk_line_ptr = cmyk_line_tensor->data_ptr<uint8_t>();
228241
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
229242

230243
if (channels == 3) {
@@ -240,7 +253,7 @@ torch::Tensor decode_jpeg(
240253

241254
jpeg_finish_decompress(&cinfo);
242255
jpeg_destroy_decompress(&cinfo);
243-
auto output = tensor.permute({2, 0, 1});
256+
auto output = tensor->permute({2, 0, 1});
244257

245258
if (apply_exif_orientation) {
246259
return exif_orientation_transform(output, exif_orientation);

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "common_png.h"
44
#include "exif.h"
55

6+
#include <optional>
7+
68
namespace vision {
79
namespace image {
810

@@ -45,7 +47,14 @@ torch::Tensor decode_png(
4547
auto datap = accessor.data();
4648
auto datap_len = accessor.size(0);
4749

50+
// NOTE: libpng uses setjmp/longjmp for error handling. longjmp does not
51+
// unwind C++ stack frames, so destructors of objects created after setjmp
52+
// won't run. We use std::optional to declare tensors before setjmp while
53+
// deferring construction, and explicitly reset them on the error path.
54+
std::optional<torch::Tensor> tensor;
55+
4856
if (setjmp(png_jmpbuf(png_ptr)) != 0) {
57+
tensor.reset();
4958
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
5059
STD_TORCH_CHECK(false, "Internal error.");
5160
}
@@ -197,19 +206,19 @@ torch::Tensor decode_png(
197206

198207
auto num_pixels_per_row = width * channels;
199208
auto is_16_bits = bit_depth == 16;
200-
auto tensor = torch::empty(
209+
tensor = torch::empty(
201210
{int64_t(height), int64_t(width), channels},
202211
is_16_bits ? at::kUInt16 : torch::kU8);
203212
if (is_little_endian()) {
204213
png_set_swap(png_ptr);
205214
}
206-
auto t_ptr = (uint8_t*)tensor.data_ptr();
215+
auto t_ptr = (uint8_t*)tensor->data_ptr();
207216
for (int pass = 0; pass < number_of_passes; pass++) {
208217
for (png_uint_32 i = 0; i < height; ++i) {
209218
png_read_row(png_ptr, t_ptr, nullptr);
210219
t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1);
211220
}
212-
t_ptr = (uint8_t*)tensor.data_ptr();
221+
t_ptr = (uint8_t*)tensor->data_ptr();
213222
}
214223

215224
int exif_orientation = -1;
@@ -219,7 +228,7 @@ torch::Tensor decode_png(
219228

220229
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
221230

222-
auto output = tensor.permute({2, 0, 1});
231+
auto output = tensor->permute({2, 0, 1});
223232
if (apply_exif_orientation) {
224233
return exif_orientation_transform(output, exif_orientation);
225234
}

torchvision/csrc/io/image/cpu/encode_jpeg.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <torch/headeronly/util/Exception.h>
44

5+
#include <optional>
6+
57
#include "common_jpeg.h"
68

79
namespace vision {
@@ -37,6 +39,12 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
3739
JpegSizeType jpegSize = 0;
3840
uint8_t* jpegBuf = nullptr;
3941

42+
// NOTE: libjpeg uses setjmp/longjmp for error handling. longjmp does not
43+
// unwind C++ stack frames, so destructors of objects created after setjmp
44+
// won't run. We use std::optional to declare tensors before setjmp while
45+
// deferring construction, and explicitly reset them on the error path.
46+
std::optional<torch::Tensor> input;
47+
4048
cinfo.err = jpeg_std_error(&jerr.pub);
4149
jerr.pub.error_exit = torch_jpeg_error_exit;
4250

@@ -45,6 +53,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
4553
/* If we get here, the JPEG code has signaled an error.
4654
* We need to clean up the JPEG object and the buffer.
4755
*/
56+
input.reset();
4857
jpeg_destroy_compress(&cinfo);
4958
if (jpegBuf != nullptr) {
5059
free(jpegBuf);
@@ -69,7 +78,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
6978
int channels = data.size(0);
7079
int height = data.size(1);
7180
int width = data.size(2);
72-
auto input = data.permute({1, 2, 0}).contiguous();
81+
input = data.permute({1, 2, 0}).contiguous();
7382

7483
STD_TORCH_CHECK(
7584
channels == 1 || channels == 3,
@@ -95,7 +104,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
95104
jpeg_start_compress(&cinfo, TRUE);
96105

97106
auto stride = width * channels;
98-
auto ptr = input.data_ptr<uint8_t>();
107+
auto ptr = input->data_ptr<uint8_t>();
99108

100109
// Encode JPEG file
101110
while (cinfo.next_scanline < cinfo.image_height) {

torchvision/csrc/io/image/cpu/encode_png.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <torch/headeronly/util/Exception.h>
44

5+
#include <optional>
6+
57
#include "common_png.h"
68

79
namespace vision {
@@ -78,11 +80,19 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
7880
buf_info.buffer = nullptr;
7981
buf_info.size = 0;
8082

83+
// NOTE: libpng uses setjmp/longjmp for error handling. longjmp does not
84+
// unwind C++ stack frames, so destructors of objects created after setjmp
85+
// won't run. We use std::optional to declare tensors before setjmp while
86+
// deferring construction, and explicitly reset them on the error path.
87+
std::optional<torch::Tensor> input;
88+
8189
/* Establish the setjmp return context for my_error_exit to use. */
8290
if (setjmp(err_ptr.setjmp_buffer)) {
8391
/* If we get here, the PNG code has signaled an error.
8492
* We need to clean up the PNG object and the buffer.
8593
*/
94+
input.reset();
95+
8696
if (info_ptr != nullptr) {
8797
png_destroy_info_struct(png_write, &info_ptr);
8898
}
@@ -119,7 +129,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
119129
int channels = data.size(0);
120130
int height = data.size(1);
121131
int width = data.size(2);
122-
auto input = data.permute({1, 2, 0}).contiguous();
132+
input = data.permute({1, 2, 0}).contiguous();
123133

124134
STD_TORCH_CHECK(
125135
channels == 1 || channels == 3,
@@ -155,7 +165,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
155165
png_write_info(png_write, info_ptr);
156166

157167
auto stride = width * channels;
158-
auto ptr = input.data_ptr<uint8_t>();
168+
auto ptr = input->data_ptr<uint8_t>();
159169

160170
// Encode PNG file
161171
for (int y = 0; y < height; ++y) {

0 commit comments

Comments
 (0)