Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion landmarkdiff/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,14 @@ def generate(
else:
composited = mask_composite(raw_output, image_512, mask)

# confidence scoring and breakdown
confidence_data = self._calculate_confidence(
face=face,
identity_check=identity_check,
mask=mask,
mode=self.mode,
)

return {
"output": composited,
"output_raw": raw_output,
Expand All @@ -378,6 +386,56 @@ def generate(
"ip_adapter_active": self._ip_adapter_loaded,
"identity_check": identity_check,
"restore_used": restore_used,
"confidence": confidence_data["confidence"],
"confidence_breakdown": confidence_data["breakdown"],
}

def _calculate_confidence(
self,
face: Optional[FaceLandmarks],
identity_check: Optional[dict],
mask: np.ndarray,
mode: str,
) -> dict:
"""Aggregate multiple metrics into a unified confidence score (0-1)."""
breakdown = {
"face_detection": 1.0 if face is not None else 0.0,
"identity_preservation": 1.0,
"landmark_accuracy": 0.95, # baseline for MediaPipe/TPS
"mask_coverage": 1.0,
}

# identity score from ArcFace (if available)
if identity_check and identity_check.get("similarity", -1) >= 0:
breakdown["identity_preservation"] = float(identity_check["similarity"])
elif mode == "tps":
breakdown["identity_preservation"] = 1.0 # TPS is identity-perfect
else:
# without verification, we assume a lower bound for diffusion
breakdown["identity_preservation"] = 0.85 if mode == "controlnet_ip" else 0.7

# mask coverage (ensure mask isn't empty or oversized)
mask_area = np.mean(mask)
if mask_area < 0.01: # too small
breakdown["mask_coverage"] = 0.5
elif mask_area > 0.6: # too large (covers most of face)
breakdown["mask_coverage"] = 0.8
else:
breakdown["mask_coverage"] = 1.0

# overall confidence (weighted average)
weights = {
"face_detection": 0.2,
"identity_preservation": 0.5,
"landmark_accuracy": 0.2,
"mask_coverage": 0.1,
}

confidence = sum(breakdown[k] * weights[k] for k in weights)

return {
"confidence": round(confidence, 2),
"breakdown": {k: round(v, 2) for k, v in breakdown.items()},
}

def _generate_controlnet(
Expand Down Expand Up @@ -471,6 +529,7 @@ def run_inference(
seed: int = 42,
mode: str = "img2img",
ip_adapter_scale: float = 0.6,
explain: bool = False,
) -> None:
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -500,6 +559,13 @@ def run_inference(
if view.get("warning"):
print(f"WARNING: {view['warning']}")
print(f"Face view: {view.get('view', 'unknown')} (yaw={view.get('yaw', 0)})")

print(f"Confidence: {result['confidence']:.2f}")
if explain:
print("\nConfidence Breakdown:")
for k, v in result["confidence_breakdown"].items():
print(f" - {k.replace('_', ' ').capitalize()}: {v:.2f}")

print(f"Results saved to {out}/")


Expand All @@ -517,9 +583,10 @@ def run_inference(
choices=["img2img", "controlnet", "controlnet_ip", "tps"],
)
parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
parser.add_argument("--explain", action="store_true", help="Show confidence breakdown")
args = parser.parse_args()

run_inference(
args.image, args.procedure, args.intensity, args.output,
args.seed, args.mode, args.ip_adapter_scale,
args.seed, args.mode, args.ip_adapter_scale, args.explain,
)