Skip to content

Commit c9ffdfe

Browse files
MPSFuzzpytorchstuNicolasHug
committed
Fix CPU decode_jpeg error-path leak on malformed JPEGs (setjmp/longjmp) (pytorch#9423)
Co-authored-by: MPSFuzz <2286770808@qq.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent 4fe736f commit c9ffdfe

4 files changed

Lines changed: 53 additions & 13 deletions

File tree

torchvision/csrc/io/image/cpu/decode_jpeg.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "common_jpeg.h"
44
#include "exif.h"
55

6+
#include <optional>
7+
68
namespace vision {
79
namespace image {
810

@@ -141,12 +143,23 @@ torch::Tensor decode_jpeg(
141143
struct jpeg_decompress_struct cinfo;
142144
struct torch_jpeg_error_mgr jerr;
143145

146+
// NOTE: libjpeg uses setjmp/longjmp for error handling. longjmp does not
147+
// unwind C++ stack frames, so destructors of objects created after setjmp
148+
// won't run. We use std::optional to declare tensors before setjmp while
149+
// deferring construction, and explicitly reset them on the error path.
150+
std::optional<torch::Tensor> tensor;
151+
std::optional<torch::Tensor> cmyk_line_tensor;
152+
144153
auto datap = data.data_ptr<uint8_t>();
145154
// Setup decompression structure
146155
cinfo.err = jpeg_std_error(&jerr.pub);
147156
jerr.pub.error_exit = torch_jpeg_error_exit;
148157
/* Establish the setjmp return context for my_error_exit to use. */
149158
if (setjmp(jerr.setjmp_buffer)) {
159+
// Release any tensors that may have been allocated after setjmp.
160+
cmyk_line_tensor.reset();
161+
tensor.reset();
162+
150163
/* If we get here, the JPEG code has signaled an error.
151164
* We need to clean up the JPEG object.
152165
*/
@@ -209,10 +222,10 @@ torch::Tensor decode_jpeg(
209222
int width = cinfo.output_width;
210223

211224
int stride = width * channels;
212-
auto tensor =
225+
tensor =
213226
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
214-
auto ptr = tensor.data_ptr<uint8_t>();
215-
torch::Tensor cmyk_line_tensor;
227+
auto ptr = tensor->data_ptr<uint8_t>();
228+
216229
if (cmyk_to_rgb_or_gray) {
217230
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
218231
}
@@ -223,7 +236,7 @@ torch::Tensor decode_jpeg(
223236
* more than one scanline at a time if that's more convenient.
224237
*/
225238
if (cmyk_to_rgb_or_gray) {
226-
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
239+
auto cmyk_line_ptr = cmyk_line_tensor->data_ptr<uint8_t>();
227240
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
228241

229242
if (channels == 3) {
@@ -239,7 +252,7 @@ torch::Tensor decode_jpeg(
239252

240253
jpeg_finish_decompress(&cinfo);
241254
jpeg_destroy_decompress(&cinfo);
242-
auto output = tensor.permute({2, 0, 1});
255+
auto output = tensor->permute({2, 0, 1});
243256

244257
if (apply_exif_orientation) {
245258
return exif_orientation_transform(output, exif_orientation);

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "common_png.h"
44
#include "exif.h"
55

6+
#include <optional>
7+
68
namespace vision {
79
namespace image {
810

@@ -45,7 +47,14 @@ torch::Tensor decode_png(
4547
auto datap = accessor.data();
4648
auto datap_len = accessor.size(0);
4749

50+
// NOTE: libpng uses setjmp/longjmp for error handling. longjmp does not
51+
// unwind C++ stack frames, so destructors of objects created after setjmp
52+
// won't run. We use std::optional to declare tensors before setjmp while
53+
// deferring construction, and explicitly reset them on the error path.
54+
std::optional<torch::Tensor> tensor;
55+
4856
if (setjmp(png_jmpbuf(png_ptr)) != 0) {
57+
tensor.reset();
4958
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
5059
TORCH_CHECK(false, "Internal error.");
5160
}
@@ -196,19 +205,19 @@ torch::Tensor decode_png(
196205

197206
auto num_pixels_per_row = width * channels;
198207
auto is_16_bits = bit_depth == 16;
199-
auto tensor = torch::empty(
208+
tensor = torch::empty(
200209
{int64_t(height), int64_t(width), channels},
201210
is_16_bits ? at::kUInt16 : torch::kU8);
202211
if (is_little_endian()) {
203212
png_set_swap(png_ptr);
204213
}
205-
auto t_ptr = (uint8_t*)tensor.data_ptr();
214+
auto t_ptr = (uint8_t*)tensor->data_ptr();
206215
for (int pass = 0; pass < number_of_passes; pass++) {
207216
for (png_uint_32 i = 0; i < height; ++i) {
208217
png_read_row(png_ptr, t_ptr, nullptr);
209218
t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1);
210219
}
211-
t_ptr = (uint8_t*)tensor.data_ptr();
220+
t_ptr = (uint8_t*)tensor->data_ptr();
212221
}
213222

214223
int exif_orientation = -1;
@@ -218,7 +227,7 @@ torch::Tensor decode_png(
218227

219228
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
220229

221-
auto output = tensor.permute({2, 0, 1});
230+
auto output = tensor->permute({2, 0, 1});
222231
if (apply_exif_orientation) {
223232
return exif_orientation_transform(output, exif_orientation);
224233
}

torchvision/csrc/io/image/cpu/encode_jpeg.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "encode_jpeg.h"
22

3+
#include <optional>
34
#include "common_jpeg.h"
45

56
namespace vision {
@@ -35,6 +36,12 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
3536
JpegSizeType jpegSize = 0;
3637
uint8_t* jpegBuf = nullptr;
3738

39+
// NOTE: libjpeg uses setjmp/longjmp for error handling. longjmp does not
40+
// unwind C++ stack frames, so destructors of objects created after setjmp
41+
// won't run. We use std::optional to declare tensors before setjmp while
42+
// deferring construction, and explicitly reset them on the error path.
43+
std::optional<torch::Tensor> input;
44+
3845
cinfo.err = jpeg_std_error(&jerr.pub);
3946
jerr.pub.error_exit = torch_jpeg_error_exit;
4047

@@ -43,6 +50,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
4350
/* If we get here, the JPEG code has signaled an error.
4451
* We need to clean up the JPEG object and the buffer.
4552
*/
53+
input.reset();
4654
jpeg_destroy_compress(&cinfo);
4755
if (jpegBuf != nullptr) {
4856
free(jpegBuf);
@@ -64,7 +72,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
6472
int channels = data.size(0);
6573
int height = data.size(1);
6674
int width = data.size(2);
67-
auto input = data.permute({1, 2, 0}).contiguous();
75+
input = data.permute({1, 2, 0}).contiguous();
6876

6977
TORCH_CHECK(
7078
channels == 1 || channels == 3,
@@ -90,7 +98,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
9098
jpeg_start_compress(&cinfo, TRUE);
9199

92100
auto stride = width * channels;
93-
auto ptr = input.data_ptr<uint8_t>();
101+
auto ptr = input->data_ptr<uint8_t>();
94102

95103
// Encode JPEG file
96104
while (cinfo.next_scanline < cinfo.image_height) {

torchvision/csrc/io/image/cpu/encode_png.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "encode_jpeg.h"
22

3+
#include <optional>
4+
35
#include "common_png.h"
46

57
namespace vision {
@@ -76,11 +78,19 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
7678
buf_info.buffer = nullptr;
7779
buf_info.size = 0;
7880

81+
// NOTE: libpng uses setjmp/longjmp for error handling. longjmp does not
82+
// unwind C++ stack frames, so destructors of objects created after setjmp
83+
// won't run. We use std::optional to declare tensors before setjmp while
84+
// deferring construction, and explicitly reset them on the error path.
85+
std::optional<torch::Tensor> input;
86+
7987
/* Establish the setjmp return context for my_error_exit to use. */
8088
if (setjmp(err_ptr.setjmp_buffer)) {
8189
/* If we get here, the PNG code has signaled an error.
8290
* We need to clean up the PNG object and the buffer.
8391
*/
92+
input.reset();
93+
8494
if (info_ptr != nullptr) {
8595
png_destroy_info_struct(png_write, &info_ptr);
8696
}
@@ -114,7 +124,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
114124
int channels = data.size(0);
115125
int height = data.size(1);
116126
int width = data.size(2);
117-
auto input = data.permute({1, 2, 0}).contiguous();
127+
input = data.permute({1, 2, 0}).contiguous();
118128

119129
TORCH_CHECK(
120130
channels == 1 || channels == 3,
@@ -150,7 +160,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
150160
png_write_info(png_write, info_ptr);
151161

152162
auto stride = width * channels;
153-
auto ptr = input.data_ptr<uint8_t>();
163+
auto ptr = input->data_ptr<uint8_t>();
154164

155165
// Encode PNG file
156166
for (int y = 0; y < height; ++y) {

0 commit comments

Comments
 (0)