diff --git a/animation/modules/unet.py b/animation/modules/unet.py index 52a8725..f335a3e 100644 --- a/animation/modules/unet.py +++ b/animation/modules/unet.py @@ -10,7 +10,7 @@ from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging -from animation.modules.unet_3d_blocks import get_down_block, UNetMidBlockSpatioTemporal, get_up_block +from .unet_3d_blocks import get_down_block, UNetMidBlockSpatioTemporal, get_up_block # from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal diff --git a/animation/modules/unet_3d_blocks.py b/animation/modules/unet_3d_blocks.py index 7ab636b..bbaec55 100644 --- a/animation/modules/unet_3d_blocks.py +++ b/animation/modules/unet_3d_blocks.py @@ -32,7 +32,7 @@ # TransformerSpatioTemporalModel, # TransformerTemporalModel, # ) -from animation.modules.transformer_temporal import TransformerTemporalModel, TransformerSpatioTemporalModel +from .transformer_temporal import TransformerTemporalModel, TransformerSpatioTemporalModel from diffusers.models.unets.unet_motion_model import ( CrossAttnDownBlockMotion, diff --git a/animation/pipelines/inference_pipeline_animation.py b/animation/pipelines/inference_pipeline_animation.py index 7d6c412..7500edc 100644 --- a/animation/pipelines/inference_pipeline_animation.py +++ b/animation/pipelines/inference_pipeline_animation.py @@ -14,7 +14,7 @@ from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import is_compiled_module, randn_tensor -from animation.modules.attention_processor import AnimationAttnProcessor, AnimationIDAttnProcessor +from ..modules.attention_processor import AnimationAttnProcessor, AnimationIDAttnProcessor from einops import rearrange logger = logging.get_logger(__name__) # pylint: disable=invalid-name