From 4a4ddcb3eba36aa5c196743cc299b7e3839617d3 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 22 Apr 2025 14:48:08 +0800 Subject: [PATCH 1/2] Make block can accept kwargs --- bmtrain/block_layer.py | 6 +++++- bmtrain/wrapper.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 866d00d..216d77b 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -311,9 +311,13 @@ def post_hook(self, out): post_out = tuple(post_out) return post_out - def forward(self, *args): + def forward(self, *args, **kwargs): + signature = inspect.signature(self._module.forward) + bound_args = signature.bind(*args, **kwargs) + args = bound_args.args arg_list = self.pre_hook(*args) + if self.all_input_no_grad and not self.all_param_no_grad: placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) diff --git a/bmtrain/wrapper.py b/bmtrain/wrapper.py index 93d9877..e64fd5b 100644 --- a/bmtrain/wrapper.py +++ b/bmtrain/wrapper.py @@ -15,11 +15,15 @@ def make_distributed(model: torch.nn.Module): for kw in list(model._buffers.keys()): if model._buffers[kw] is not None: model._buffers[kw] = model._buffers[kw].cuda() - + is_module_list = isinstance(model, torch.nn.ModuleList) + pre_module = None for kw in list(model._modules.keys()): - if isinstance(model, torch.nn.ModuleList): + if is_module_list: if not isinstance(model._modules[kw], Block): model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw])) + if pre_module is not None: + model._modules[kw].set_pre_module(pre_module) + pre_module = model._modules[kw] else: model._modules[kw] = model_wrapper_dispatch(model._modules[kw]) From 249772129221135a6ddc5c077d877384993de85c Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 13 May 2025 14:42:51 +0800 Subject: [PATCH 2/2] update version 1.0.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 737e6f1..70752ff 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ def build_extension(self, ext): ] setup( name='bmtrain', - version='1.0.0', + version='1.0.1', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models",