diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 866d00d..216d77b 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -311,9 +311,13 @@ def post_hook(self, out): post_out = tuple(post_out) return post_out - def forward(self, *args): + def forward(self, *args, **kwargs): + signature = inspect.signature(self._module.forward) + bound_args = signature.bind(*args, **kwargs) + args = bound_args.args arg_list = self.pre_hook(*args) + if self.all_input_no_grad and not self.all_param_no_grad: placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) diff --git a/bmtrain/wrapper.py b/bmtrain/wrapper.py index 93d9877..e64fd5b 100644 --- a/bmtrain/wrapper.py +++ b/bmtrain/wrapper.py @@ -15,11 +15,15 @@ def make_distributed(model: torch.nn.Module): for kw in list(model._buffers.keys()): if model._buffers[kw] is not None: model._buffers[kw] = model._buffers[kw].cuda() - + is_module_list = isinstance(model, torch.nn.ModuleList) + pre_module = None for kw in list(model._modules.keys()): - if isinstance(model, torch.nn.ModuleList): + if is_module_list: if not isinstance(model._modules[kw], Block): model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw])) + if pre_module is not None: + model._modules[kw].set_pre_module(pre_module) + pre_module = model._modules[kw] else: model._modules[kw] = model_wrapper_dispatch(model._modules[kw]) diff --git a/setup.py b/setup.py index 737e6f1..70752ff 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ def build_extension(self, ext): ] setup( name='bmtrain', - version='1.0.0', + version='1.0.1', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models",