-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
123 lines (99 loc) · 3.74 KB
/
main.py
File metadata and controls
123 lines (99 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Imports
import os
from pathlib import Path
from PIL import Image
import torch
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
from paddleocr import PPStructureV3
import time
start_time = time.time()
# Initializations
INPUT_IMG = "Path_to_the_image"
OUT_DIR = "output_directory_path_to_save_the_results"
DET_MODEL = "microsoft/table-transformer-detection"
DET_THRESH = 0.7
EXPAND_CM = 0.30
DEFAULT_DPI = 300.0 # depends on image res
USE_GPU = torch.cuda.is_available()
os.makedirs(OUT_DIR, exist_ok=True)
# This function is used to crop the image / box from the input
def crop_pil(img: Image.Image, box):
"""box = (y1, x1, y2, x2)"""
y1, x1, y2, x2 = box
return img.crop((x1, y1, x2, y2))
# This function is used to extract the DPI from the img metadata or else the default value will be used
def infer_image_dpi(pil_img: Image.Image, default: float = DEFAULT_DPI) -> float:
dpi = pil_img.info.get("dpi")
if isinstance(dpi, tuple) and len(dpi) >= 1:
try:
xdpi = float(dpi[0])
if xdpi > 0:
return xdpi
except Exception:
pass
return float(default)
# helper function to convert cm to pixels
def cm_to_px(cm: float, dpi: float) -> int:
return max(1, int(round(cm * dpi / 2.54)))
# In this function we expand the box by 0.3 cm so that the
# table has some extra space on it
def expand_box_cm_on_page(box, page_w, page_h, cm_each_side, dpi):
y1, x1, y2, x2 = map(int, box)
pad = cm_to_px(cm_each_side, dpi)
y1e = max(0, y1 - pad)
x1e = max(0, x1 - pad)
y2e = min(page_h, y2 + pad)
x2e = min(page_w, x2 + pad)
if y2e - y1e < 2:
y2e = min(page_h, y1e + 2)
if x2e - x1e < 2:
x2e = min(page_w, x1e + 2)
return (y1e, x1e, y2e, x2e)
# ---- Load detection model ----
device = torch.device("cuda" if USE_GPU else "cpu")
det_processor = AutoImageProcessor.from_pretrained(DET_MODEL)
det_model = TableTransformerForObjectDetection.from_pretrained(DET_MODEL).to(device).eval()
# ---- Load PaddleOCR structure model ----
pipeline = PPStructureV3(device="gpu" if USE_GPU else "cpu")
# ---- Detect tables ----
page_img = Image.open(INPUT_IMG).convert("RGB")
W, H = page_img.size
dpi = infer_image_dpi(page_img, default=DEFAULT_DPI)
inputs = det_processor(images=page_img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = det_model(**inputs)
detections = det_processor.post_process_object_detection(
outputs, threshold=DET_THRESH, target_sizes=[(H, W)])[0]
# for understanding
print("DPI OF THE IMAGE :",dpi)
print("NOW i will be printing the detections , boxes and scores")
print("Detections")
print(detections)
boxes = detections["boxes"].cpu().numpy().tolist()
scores = detections["scores"].cpu().numpy().tolist()
print("BOXES")
print(boxes)
print("Scores")
print(scores)
if not boxes:
print("No tables detected.")
else:
for i, (b, sc) in enumerate(zip(boxes, scores)):
print(f"Box {i}: {b}, score: {sc:.3f}")
x1, y1, x2, y2 = map(int, b)
box_yx = (y1, x1, y2, x2)
y1e, x1e, y2e, x2e = expand_box_cm_on_page(box_yx, W, H, EXPAND_CM, dpi)
crop_img = crop_pil(page_img, (y1e, x1e, y2e, x2e))
crop_path = os.path.join(OUT_DIR, f"table_{i:02d}.png")
crop_img.save(crop_path)
print(f"Table {i} cropped: {crop_path} (score={sc:.3f})")
# ---- Parse with PaddleOCR ----
output = pipeline.predict(crop_path)
for j, res in enumerate(output):
res.print()
res.save_to_xlsx(save_path=OUT_DIR)
print(f"Parsed results saved in {OUT_DIR}")
end_time = time.time()
elapsed = end_time - start_time
elapsed= elapsed/60.0
print(f"Execution time: {elapsed:.2f} mins")