33#include " common_jpeg.h"
44#include " exif.h"
55
6+ #include < optional>
7+
68namespace vision {
79namespace image {
810
@@ -141,12 +143,23 @@ torch::Tensor decode_jpeg(
141143 struct jpeg_decompress_struct cinfo;
142144 struct torch_jpeg_error_mgr jerr;
143145
146+ // NOTE: libjpeg uses setjmp/longjmp for error handling. longjmp does not
147+ // unwind C++ stack frames, so destructors of objects created after setjmp
148+ // won't run. We use std::optional to declare tensors before setjmp while
149+ // deferring construction, and explicitly reset them on the error path.
150+ std::optional<torch::Tensor> tensor;
151+ std::optional<torch::Tensor> cmyk_line_tensor;
152+
144153 auto datap = data.data_ptr <uint8_t >();
145154 // Setup decompression structure
146155 cinfo.err = jpeg_std_error (&jerr.pub );
147156 jerr.pub .error_exit = torch_jpeg_error_exit;
148157 /* Establish the setjmp return context for my_error_exit to use. */
149158 if (setjmp (jerr.setjmp_buffer )) {
159+ // Release any tensors that may have been allocated after setjmp.
160+ cmyk_line_tensor.reset ();
161+ tensor.reset ();
162+
150163 /* If we get here, the JPEG code has signaled an error.
151164 * We need to clean up the JPEG object.
152165 */
@@ -210,10 +223,10 @@ torch::Tensor decode_jpeg(
210223 int width = cinfo.output_width ;
211224
212225 int stride = width * channels;
213- auto tensor =
226+ tensor =
214227 torch::empty ({int64_t (height), int64_t (width), channels}, torch::kU8 );
215- auto ptr = tensor. data_ptr <uint8_t >();
216- torch::Tensor cmyk_line_tensor;
228+ auto ptr = tensor-> data_ptr <uint8_t >();
229+
217230 if (cmyk_to_rgb_or_gray) {
218231 cmyk_line_tensor = torch::empty ({int64_t (width), 4 }, torch::kU8 );
219232 }
@@ -224,7 +237,7 @@ torch::Tensor decode_jpeg(
224237 * more than one scanline at a time if that's more convenient.
225238 */
226239 if (cmyk_to_rgb_or_gray) {
227- auto cmyk_line_ptr = cmyk_line_tensor. data_ptr <uint8_t >();
240+ auto cmyk_line_ptr = cmyk_line_tensor-> data_ptr <uint8_t >();
228241 jpeg_read_scanlines (&cinfo, &cmyk_line_ptr, 1 );
229242
230243 if (channels == 3 ) {
@@ -240,7 +253,7 @@ torch::Tensor decode_jpeg(
240253
241254 jpeg_finish_decompress (&cinfo);
242255 jpeg_destroy_decompress (&cinfo);
243- auto output = tensor. permute ({2 , 0 , 1 });
256+ auto output = tensor-> permute ({2 , 0 , 1 });
244257
245258 if (apply_exif_orientation) {
246259 return exif_orientation_transform (output, exif_orientation);
0 commit comments