Skip to content

Commit 8f9cdf0

Browse files
tv-karthikeyaAmit Raj
authored andcommitted
Adding pytest for flux (#18)
Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com> Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 6d5bac4 commit 8f9cdf0

File tree

3 files changed

+746
-0
lines changed

3 files changed

+746
-0
lines changed

tests/diffusers/diffusers_utils.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
"""
9+
Common utilities for diffusion pipeline testing.
10+
Provides essential functions for MAD validation, image validation
11+
hash verification, and other testing utilities.
12+
"""
13+
14+
import os
15+
from typing import Any, Dict, Tuple, Union
16+
17+
import numpy as np
18+
import torch
19+
from PIL import Image
20+
21+
22+
class DiffusersTestUtils:
23+
"""Essential utilities for diffusion pipeline testing"""
24+
25+
@staticmethod
26+
def validate_image_generation(
27+
image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0
28+
) -> Dict[str, Any]:
29+
"""
30+
Validate generated image properties.
31+
Args:
32+
image: Generated PIL Image
33+
expected_size: Expected (width, height) tuple
34+
min_variance: Minimum pixel variance to ensure image is not blank
35+
36+
Returns:
37+
Dict containing validation results
38+
Raises:
39+
AssertionError: If image validation fails
40+
"""
41+
# Basic image validation
42+
assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}"
43+
assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}"
44+
assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}"
45+
46+
# Variance check (ensure image is not blank)
47+
img_array = np.array(image)
48+
image_variance = float(img_array.std())
49+
assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})"
50+
51+
return {
52+
"size": image.size,
53+
"mode": image.mode,
54+
"variance": image_variance,
55+
"mean_pixel_value": float(img_array.mean()),
56+
"min_pixel": int(img_array.min()),
57+
"max_pixel": int(img_array.max()),
58+
"valid": True,
59+
}
60+
61+
@staticmethod
62+
def check_file_exists(file_path: str, file_type: str = "file") -> bool:
63+
"""
64+
Check if file exists and log result.
65+
Args:
66+
file_path: Path to check
67+
file_type: Description of file type for logging
68+
Returns:
69+
bool: True if file exists
70+
"""
71+
exists = os.path.exists(file_path)
72+
status = "✅" if exists else "❌"
73+
print(f"{status} {file_type}: {file_path}")
74+
return exists
75+
76+
@staticmethod
77+
def print_test_header(title: str, config: Dict[str, Any]) -> None:
78+
"""
79+
Print formatted test header with configuration details.
80+
81+
Args:
82+
title: Test title
83+
config: Test configuration dictionary
84+
"""
85+
print(f"\n{'=' * 80}")
86+
print(f"{title}")
87+
print(f"{'=' * 80}")
88+
89+
if "model_setup" in config:
90+
setup = config["model_setup"]
91+
for k, v in setup.items():
92+
print(f"{k} : {v}")
93+
94+
if "functional_testing" in config:
95+
func = config["functional_testing"]
96+
print(f"Test Prompt: {func.get('test_prompt', 'N/A')}")
97+
print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}")
98+
print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}")
99+
100+
print(f"{'=' * 80}")
101+
102+
103+
class MADValidator:
104+
"""Specialized class for MAD validation - always enabled, always reports, always fails on exceed"""
105+
106+
def __init__(self, tolerances: Dict[str, float] = None):
107+
"""
108+
Initialize MAD validator.
109+
MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded.
110+
111+
Args:
112+
tolerances: Dictionary of module_name -> tolerance mappings
113+
"""
114+
self.tolerances = tolerances
115+
self.results = {}
116+
117+
def calculate_mad(
118+
self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray]
119+
) -> float:
120+
"""
121+
Calculate Max Absolute Deviation between two tensors.
122+
123+
Args:
124+
tensor1: First tensor (PyTorch or NumPy)
125+
tensor2: Second tensor (PyTorch or NumPy)
126+
127+
Returns:
128+
float: Maximum absolute difference between tensors
129+
"""
130+
if isinstance(tensor1, torch.Tensor):
131+
tensor1 = tensor1.detach().numpy()
132+
if isinstance(tensor2, torch.Tensor):
133+
tensor2 = tensor2.detach().numpy()
134+
135+
return float(np.max(np.abs(tensor1 - tensor2)))
136+
137+
def validate_module_mad(
138+
self,
139+
pytorch_output: Union[torch.Tensor, np.ndarray],
140+
qaic_output: Union[torch.Tensor, np.ndarray],
141+
module_name: str,
142+
step_info: str = "",
143+
) -> bool:
144+
"""
145+
Validate MAD for a specific module.
146+
Always validates, always reports, always fails if tolerance exceeded.
147+
148+
Args:
149+
pytorch_output: PyTorch reference output
150+
qaic_output: QAIC inference output
151+
module_name: Name of the module
152+
step_info: Additional step information for logging
153+
154+
Returns:
155+
bool: True if validation passed
156+
157+
Raises:
158+
AssertionError: If MAD exceeds tolerance
159+
"""
160+
mad_value = self.calculate_mad(pytorch_output, qaic_output)
161+
162+
# Always report MAD value
163+
step_str = f" {step_info}" if step_info else ""
164+
print(f"🔍 {module_name.upper()} MAD{step_str}: {mad_value:.8f}")
165+
166+
# Always validate - fail if tolerance exceeded
167+
tolerance = self.tolerances.get(module_name, 1e-2)
168+
if mad_value > tolerance:
169+
raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}")
170+
171+
# Store result
172+
if module_name not in self.results:
173+
self.results[module_name] = []
174+
self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance})
175+
return True
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
{
2+
"model_setup": {
3+
"height": 256,
4+
"width": 256,
5+
"num_transformer_layers": 2,
6+
"num_single_layers": 2,
7+
"use_onnx_subfunctions": false
8+
},
9+
"mad_validation": {
10+
"tolerances": {
11+
"clip_text_encoder": 0.1,
12+
"t5_text_encoder": 5.5,
13+
"transformer": 2.0,
14+
"vae_decoder": 1.0
15+
}
16+
},
17+
"pipeline_params": {
18+
"test_prompt": "A cat holding a sign that says hello world",
19+
"num_inference_steps": 2,
20+
"guidance_scale": 0.0,
21+
"max_sequence_length": 256,
22+
"validate_gen_img": true,
23+
"min_image_variance": 1.0,
24+
"custom_config_path": null
25+
},
26+
"validation_checks": {
27+
"image_generation": true,
28+
"onnx_export": true,
29+
"compilation": true
30+
},
31+
"modules":
32+
{
33+
"text_encoder":
34+
{
35+
"specializations":{
36+
"batch_size": 1,
37+
"seq_len": 77
38+
},
39+
"compilation":
40+
{
41+
"onnx_path": null,
42+
"compile_dir": null,
43+
"mdp_ts_num_devices": 1,
44+
"mxfp6_matmul": false,
45+
"convert_to_fp16": true,
46+
"aic_num_cores": 16
47+
},
48+
"execute":
49+
{
50+
"device_ids": null
51+
}
52+
53+
},
54+
"text_encoder_2":
55+
{
56+
"specializations":
57+
{
58+
"batch_size": 1,
59+
"seq_len": 256
60+
},
61+
"compilation":
62+
{
63+
"onnx_path": null,
64+
"compile_dir": null,
65+
"mdp_ts_num_devices": 1,
66+
"mxfp6_matmul": false,
67+
"convert_to_fp16": true,
68+
"aic_num_cores": 16
69+
},
70+
"execute":
71+
{
72+
"device_ids": null
73+
}
74+
},
75+
"transformer":
76+
{
77+
"specializations":
78+
{
79+
"batch_size": 1,
80+
"seq_len": 256,
81+
"steps": 1
82+
},
83+
"compilation":
84+
{
85+
"onnx_path": null,
86+
"compile_dir": null,
87+
"mdp_ts_num_devices": 1,
88+
"mxfp6_matmul": true,
89+
"convert_to_fp16": true,
90+
"aic_num_cores": 16,
91+
"mos": 1,
92+
"mdts-mos": 1,
93+
"aic-enable-depth-first": true
94+
},
95+
"execute":
96+
{
97+
"device_ids": null
98+
}
99+
},
100+
"vae_decoder":
101+
{
102+
"specializations":
103+
{
104+
"batch_size": 1,
105+
"channels": 16
106+
},
107+
"compilation":
108+
{
109+
"onnx_path": null,
110+
"compile_dir": null,
111+
"mdp_ts_num_devices": 1,
112+
"mxfp6_matmul": false,
113+
"convert_to_fp16": true,
114+
"aic_num_cores": 16
115+
},
116+
"execute":
117+
{
118+
"device_ids": null
119+
}
120+
}
121+
}
122+
123+
}

0 commit comments

Comments
 (0)