forked from ace-step/ACE-Step
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
114 lines (102 loc) · 3.48 KB
/
infer.py
File metadata and controls
114 lines (102 loc) · 3.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import click
import os
from acestep.pipeline_ace_step import ACEStepPipeline
from acestep.data_sampler import DataSampler
def sample_data(json_data):
return (
json_data["audio_duration"],
json_data["prompt"],
json_data["lyrics"],
json_data["infer_step"],
json_data["guidance_scale"],
json_data["scheduler_type"],
json_data["cfg_type"],
json_data["omega_scale"],
", ".join(map(str, json_data["actual_seeds"])),
json_data["guidance_interval"],
json_data["guidance_interval_decay"],
json_data["min_guidance_scale"],
json_data["use_erg_tag"],
json_data["use_erg_lyric"],
json_data["use_erg_diffusion"],
", ".join(map(str, json_data["oss_steps"])),
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
(
json_data["guidance_scale_lyric"]
if "guidance_scale_lyric" in json_data
else 0.0
),
)
@click.command()
@click.option(
"--checkpoint_path", type=str, default="", help="Path to the checkpoint directory"
)
@click.option("--bf16", type=bool, default=True, help="Whether to use bfloat16")
@click.option(
"--torch_compile", type=bool, default=False, help="Whether to use torch compile"
)
@click.option(
"--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)"
)
@click.option(
"--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)"
)
@click.option("--device_id", type=int, default=0, help="Device ID to use")
@click.option("--output_path", type=str, default=None, help="Path to save the output")
def main(checkpoint_path, bf16, torch_compile, cpu_offload, overlapped_decode, device_id, output_path):
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
model_demo = ACEStepPipeline(
checkpoint_dir=checkpoint_path,
dtype="bfloat16" if bf16 else "float32",
torch_compile=torch_compile,
cpu_offload=cpu_offload,
overlapped_decode=overlapped_decode
)
print(model_demo)
data_sampler = DataSampler()
json_data = data_sampler.sample()
json_data = sample_data(json_data)
print(json_data)
(
audio_duration,
prompt,
lyrics,
infer_step,
guidance_scale,
scheduler_type,
cfg_type,
omega_scale,
manual_seeds,
guidance_interval,
guidance_interval_decay,
min_guidance_scale,
use_erg_tag,
use_erg_lyric,
use_erg_diffusion,
oss_steps,
guidance_scale_text,
guidance_scale_lyric,
) = json_data
model_demo(
audio_duration=audio_duration,
prompt=prompt,
lyrics=lyrics,
infer_step=infer_step,
guidance_scale=guidance_scale,
scheduler_type=scheduler_type,
cfg_type=cfg_type,
omega_scale=omega_scale,
manual_seeds=manual_seeds,
guidance_interval=guidance_interval,
guidance_interval_decay=guidance_interval_decay,
min_guidance_scale=min_guidance_scale,
use_erg_tag=use_erg_tag,
use_erg_lyric=use_erg_lyric,
use_erg_diffusion=use_erg_diffusion,
oss_steps=oss_steps,
guidance_scale_text=guidance_scale_text,
guidance_scale_lyric=guidance_scale_lyric,
save_path=output_path,
)
if __name__ == "__main__":
main()