Skip to content
36 changes: 20 additions & 16 deletions text_to_image/tools/coco_generate_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("coco")


def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
Expand All @@ -41,6 +42,7 @@ def get_args():
args = parser.parse_args()
return args


def download_file(url: str, output_dir: Path, filename: str | None = None):
os.makedirs(str(output_dir), exist_ok=True)

Expand All @@ -65,6 +67,7 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):

return output_path


if __name__ == "__main__":
args = get_args()
dataset_dir = os.path.abspath(args.dataset_dir)
Expand All @@ -80,21 +83,22 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):
calibration_dir = Path(calibration_dir)

# Check if raw annotations file already exist
if not (dataset_dir / "raw" / "annotations" / "captions_train2014.json").exists():
# Download annotations
os.makedirs(str(dataset_dir / "raw"), exist_ok=True)
os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True)
download_file(
url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
output_dir=dataset_dir / "download_aux",
)
# Unzip file
zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip"
# Unzip file
with zipfile.ZipFile(
str(zipfile_path), "r"
) as zip_ref:
zip_ref.extractall(str(dataset_dir / "raw/"))
if not (dataset_dir / "raw" / "annotations" /
"captions_train2014.json").exists():
# Download annotations
os.makedirs(str(dataset_dir / "raw"), exist_ok=True)
os.makedirs(str(dataset_dir / "download_aux"), exist_ok=True)
download_file(
url="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
output_dir=dataset_dir / "download_aux",
)
# Unzip file
zipfile_path = dataset_dir / "download_aux" / "annotations_trainval2014.zip"
# Unzip file
with zipfile.ZipFile(
str(zipfile_path), "r"
) as zip_ref:
zip_ref.extractall(str(dataset_dir / "raw/"))

# Convert to dataframe format and extract the relevant fields
with open(dataset_dir / "raw" / "annotations" / "captions_train2014.json") as f:
Expand Down Expand Up @@ -133,4 +137,4 @@ def download_file(url: str, output_dir: Path, filename: str | None = None):
s = "\n".join([str(_) for _ in df_annotations["id"].values])
f.write(s)
# Remove Folder
shutil.rmtree(dataset_dir)
shutil.rmtree(dataset_dir)
2 changes: 1 addition & 1 deletion tools/submission/repository_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ if [ ${#FILES_GREATER_THAN_50} -gt 0 ] ||
[ ${#BAD_FOLDER_NAMES} -gt 0 ] ||
[ ${#SPACE_FILE_NAMES} -gt 0 ]
then
errors="ERRORS:\n;
errors="ERRORS:
FILES GREATER THAN 50MB:
${FILES_GREATER_THAN_50};
SYMBOLIC LINKS:
Expand Down
Loading