-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
139 lines (106 loc) · 4.02 KB
/
main.py
File metadata and controls
139 lines (106 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent
SRC_DIR = PROJECT_ROOT / "src"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
from warpless_docs.shadow_removal import DocShadowONNXRemover
SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="WarpLess Docs - ML document shadow removal")
parser.add_argument(
"--input",
default=None,
help="Single input image path. If empty, framed outputs are processed first.",
)
parser.add_argument(
"--samples-dir",
default="outputs/framed",
help="Preferred directory containing frame-corrected document photos.",
)
parser.add_argument(
"--fallback-samples-dir",
default="input/samples",
help="Fallback directory containing raw document photos if outputs/framed is empty.",
)
parser.add_argument(
"--output-dir",
default="outputs/deshadowed",
help="Directory to save shadow-removed images.",
)
parser.add_argument(
"--model",
default="models/docshadow_sd7k.onnx",
help="Path to DocShadow ONNX model.",
)
parser.add_argument(
"--size",
nargs=2,
type=int,
default=[768, 768],
metavar=("WIDTH", "HEIGHT"),
help="Inference resize size. Try 1024 1024 for higher quality.",
)
parser.add_argument("--cpu", action="store_true", help="Force CPU provider.")
return parser.parse_args()
def is_supported_image(path: Path) -> bool:
return path.is_file() and path.suffix.lower() in SUPPORTED_EXTENSIONS
def find_images(directory: Path) -> list[Path]:
if not directory.exists():
return []
return sorted(path for path in directory.rglob("*") if is_supported_image(path))
def resolve_inputs(args: argparse.Namespace) -> list[Path]:
if args.input:
path = Path(args.input)
if not path.exists():
raise FileNotFoundError(f"Input image not found: {path}")
return [path]
framed = find_images(Path(args.samples_dir))
if framed:
print(f"Using frame-corrected images from: {args.samples_dir}")
return framed
fallback = find_images(Path(args.fallback_samples_dir))
if fallback:
print(f"No framed outputs found. Falling back to raw samples from: {args.fallback_samples_dir}")
return fallback
return []
def build_output_path(input_path: Path, output_dir: Path) -> Path:
return output_dir / f"{input_path.stem}_deshadowed.png"
def main() -> None:
args = parse_args()
model_path = Path(args.model)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
remover = DocShadowONNXRemover(
model_path=model_path,
input_size=(args.size[0], args.size[1]),
prefer_gpu=not args.cpu,
)
input_paths = resolve_inputs(args)
if not input_paths:
print("No input images found.")
print("Run frame correction first with: python scan_document_frame.py")
return
print("WarpLess Docs shadow removal")
print(f"Model : {model_path}")
print(f"Output dir: {output_dir}")
print(f"Providers : {remover.providers}")
print(f"Images : {len(input_paths)}")
print("-" * 68)
for index, input_path in enumerate(input_paths, start=1):
if not input_path.exists():
print(f"[SKIP] Input not found: {input_path}")
continue
output_path = build_output_path(input_path, output_dir)
print(f"[{index}/{len(input_paths)}] Processing: {input_path}")
try:
saved_path = remover.remove_shadow_from_path(input_path=input_path, output_path=output_path)
print(f" saved: {saved_path}")
except Exception as exc:
print(f" failed: {input_path}")
print(f" reason: {exc}")
print("-" * 68)
print("Done.")
if __name__ == "__main__":
main()