diff --git a/data/pretrain_product_dataset.py b/data/pretrain_product_dataset.py index 2fcfce9..338ee09 100644 --- a/data/pretrain_product_dataset.py +++ b/data/pretrain_product_dataset.py @@ -43,7 +43,7 @@ def __len__(self): def __getitem__(self, index): item_id, caption, cate_name = self.data_list[index] - image_path = "{}/{}.png".format(self.image_path, item_id) + image_path = "{}/{}.jpg".format(self.image_path, item_id) try: image = Image.open(image_path).convert('RGB') except: