-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpreprocessing.py
More file actions
137 lines (105 loc) · 4.28 KB
/
preprocessing.py
File metadata and controls
137 lines (105 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import pandas as pd
import argparse
import yaml
from deepface import DeepFace
from PIL import Image
from torchvision.transforms import v2
from tqdm import tqdm
from utils import configure_logging
logger = configure_logging(__name__)
EXTS = [".jpg", ".jpeg", ".png"]
def load_yaml_params():
with open("params.yaml", "r") as f:
return yaml.safe_load(f)["prepare"]
def parse_args():
yaml_params = load_yaml_params()
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, default=yaml_params["input_dir"])
parser.add_argument(
"--input_metadata_file", type=str, default=yaml_params["input_metadata_file"]
)
parser.add_argument("--output_dir", type=str, default=yaml_params["output_dir"])
parser.add_argument(
"--output_metadata_file", type=str, default=yaml_params["output_metadata_file"]
)
parser.add_argument(
"--img_size",
type=lambda s: tuple(map(int, s.strip("()").split(","))),
default=tuple(yaml_params["img_size"]),
)
return parser.parse_args()
def main():
args = parse_args()
input_dir = args.input_dir
input_metadata_file = args.input_metadata_file
output_dir = args.output_dir
output_metadata_file = args.output_metadata_file
img_size = args.img_size
if not os.path.exists(input_dir):
raise ValueError("Input directory does not exist!")
if not os.path.exists(input_metadata_file):
raise ValueError("Input metadata file does not exist!")
logger.info("Reading input metadata from %s", input_metadata_file)
input_metadata = pd.read_csv(input_metadata_file)
os.makedirs(output_dir, exist_ok=True)
transforms = v2.Compose([v2.PILToTensor(), v2.Resize(img_size), v2.ToPILImage()])
output_data = []
imgs_to_remove = []
logger.info("Processing images...")
for idx, row in tqdm(
input_metadata.iterrows(),
total=len(input_metadata),
desc="Processing",
unit="img",
):
ext = os.path.splitext(row["filename"])[-1].lower()
if ext not in EXTS:
logger.warning("Skipping unsupported file: %s", row["filename"])
continue
img_path = os.path.join(input_dir, row["filename"])
try:
objs = DeepFace.analyze(
img_path=img_path, actions=["gender", "race"], enforce_detection=False
)
obj = objs[0] if isinstance(objs, list) else objs
except Exception as e:
logger.error("DeepFace failed for %s: %s", row["filename"], str(e))
continue
try:
pil_img = Image.open(img_path).convert("RGB")
processed_img = transforms(pil_img)
label = "real" if row["target"] == 1 else "fake"
output_dir_path = os.path.join(output_dir, label)
os.makedirs(output_dir_path, exist_ok=True)
output_img_path = os.path.join(output_dir_path, row["filename"])
processed_img.save(output_img_path)
imgs_to_remove.append(img_path)
output_data.append(
{
"filename": row["filename"],
"img_path": os.path.join(label, row["filename"]),
"target": row["target"],
"gender": obj.get("gender", {}).get("dominant_gender", "unknown"),
"race": obj.get("race", {}).get("dominant_race", "unknown"),
}
)
except Exception as e:
logger.error("Failed to process image %s: %s", row["filename"], str(e))
if output_data:
new_df = pd.DataFrame(output_data)
if os.path.exists(output_metadata_file):
logger.info("Appending to existing metadata file: %s", output_metadata_file)
existing_df = pd.read_csv(output_metadata_file)
updated_df = pd.concat([existing_df, new_df], ignore_index=True)
else:
logger.info("Creating new metadata file: %s", output_metadata_file)
updated_df = new_df
updated_df.to_csv(output_metadata_file, index=False)
logger.info("Metadata saved to: %s", output_metadata_file)
for img_path in imgs_to_remove:
os.remove(img_path)
else:
logger.warning("No data to save!")
if __name__ == "__main__":
main()