-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_loader.py
More file actions
29 lines (21 loc) · 840 Bytes
/
model_loader.py
File metadata and controls
29 lines (21 loc) · 840 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion
import model_converter
def preload_models_from_standard_weights(ckpt_path, device):
state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
encoder = VAE_Encoder().to(device)
encoder.load_state_dict(state_dict['encoder'],strict=True)
decoder = VAE_Decoder().to(device)
decoder.load_state_dict(state_dict['decoder'], strict = True)
diffusion = Diffusion().to(device)
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
clip = CLIP().to(device)
clip.load_state_dict(state_dict[clip], strict=True)
return {
'clip' : clip,
'encoder' : encoder,
'decoder' : decoder,
'diffusion' : diffusion,
}