33#include " common_jpeg.h"
44#include " exif.h"
55
6+ #include < optional>
7+
68namespace vision {
79namespace image {
810
@@ -141,12 +143,23 @@ torch::Tensor decode_jpeg(
141143 struct jpeg_decompress_struct cinfo;
142144 struct torch_jpeg_error_mgr jerr;
143145
146+ // NOTE: libjpeg uses setjmp/longjmp for error handling. longjmp does not
147+ // unwind C++ stack frames, so destructors of objects created after setjmp
148+ // won't run. We use std::optional to declare tensors before setjmp while
149+ // deferring construction, and explicitly reset them on the error path.
150+ std::optional<torch::Tensor> tensor;
151+ std::optional<torch::Tensor> cmyk_line_tensor;
152+
144153 auto datap = data.data_ptr <uint8_t >();
145154 // Setup decompression structure
146155 cinfo.err = jpeg_std_error (&jerr.pub );
147156 jerr.pub .error_exit = torch_jpeg_error_exit;
148157 /* Establish the setjmp return context for my_error_exit to use. */
149158 if (setjmp (jerr.setjmp_buffer )) {
159+ // Release any tensors that may have been allocated after setjmp.
160+ cmyk_line_tensor.reset ();
161+ tensor.reset ();
162+
150163 /* If we get here, the JPEG code has signaled an error.
151164 * We need to clean up the JPEG object.
152165 */
@@ -209,10 +222,10 @@ torch::Tensor decode_jpeg(
209222 int width = cinfo.output_width ;
210223
211224 int stride = width * channels;
212- auto tensor =
225+ tensor =
213226 torch::empty ({int64_t (height), int64_t (width), channels}, torch::kU8 );
214- auto ptr = tensor. data_ptr <uint8_t >();
215- torch::Tensor cmyk_line_tensor;
227+ auto ptr = tensor-> data_ptr <uint8_t >();
228+
216229 if (cmyk_to_rgb_or_gray) {
217230 cmyk_line_tensor = torch::empty ({int64_t (width), 4 }, torch::kU8 );
218231 }
@@ -223,7 +236,7 @@ torch::Tensor decode_jpeg(
223236 * more than one scanline at a time if that's more convenient.
224237 */
225238 if (cmyk_to_rgb_or_gray) {
226- auto cmyk_line_ptr = cmyk_line_tensor. data_ptr <uint8_t >();
239+ auto cmyk_line_ptr = cmyk_line_tensor-> data_ptr <uint8_t >();
227240 jpeg_read_scanlines (&cinfo, &cmyk_line_ptr, 1 );
228241
229242 if (channels == 3 ) {
@@ -239,7 +252,7 @@ torch::Tensor decode_jpeg(
239252
240253 jpeg_finish_decompress (&cinfo);
241254 jpeg_destroy_decompress (&cinfo);
242- auto output = tensor. permute ({2 , 0 , 1 });
255+ auto output = tensor-> permute ({2 , 0 , 1 });
243256
244257 if (apply_exif_orientation) {
245258 return exif_orientation_transform (output, exif_orientation);
0 commit comments