From 9fb40bd6df17b70120f87162559f6ecc87157001 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 1 Sep 2023 21:00:33 +0800 Subject: [PATCH 01/43] rename pipe env variable name --- bmtrain/distributed/ops.py | 23 ++++- bmtrain/init.py | 18 ++-- bmtrain/inspect/tensor.py | 22 ++--- bmtrain/nn/parallel_cross_entropy_func.py | 2 +- bmtrain/pipe/comm.py | 69 +++++++++++++ bmtrain/pipe/debug.py | 18 ++++ bmtrain/pipe/schedule.py | 115 ++++++++++++++++++++++ bmtrain/pipe/topo.py | 85 ++++++++++++++++ bmtrain/pipe_layer.py | 96 +++++++++--------- example/layers/embedding.py | 18 +++- example/models/gpt.py | 11 ++- example/train.py | 3 +- tests/test_send_recv.py | 2 +- 13 files changed, 408 insertions(+), 74 deletions(-) create mode 100644 bmtrain/pipe/comm.py create mode 100644 bmtrain/pipe/debug.py create mode 100644 bmtrain/pipe/schedule.py create mode 100644 bmtrain/pipe/topo.py diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index ef69659a..a7a1e626 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -5,7 +5,7 @@ from ..nccl import broadcast as ncclBroadcast from ..nccl import send as ncclSend from ..nccl import recv as ncclRecv -from ..nccl import commCount,commRank,NCCLCommunicator +from ..nccl import commCount,commRank,NCCLCommunicator,groupStart,groupEnd DTYPE_LIST = [ torch.float64, torch.float32, @@ -17,7 +17,28 @@ torch.bfloat16, torch.bool ] +def send_activations_list(hidden_state_list, next_rank, comm): + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + length[0] = len(hidden_state_list) + ncclSend(length.storage(), next_rank, comm) + groupStart() + for i in range(length): + send_activations(hidden_state_list[i], next_rank, comm) + groupEnd() + +def recv_activations_list(prev_rank, comm): + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + ncclRecv(length.storage(), prev_rank, comm) + hidden_state_list = [] + groupStart() + for i in range(length[0].item()): + hidden_state_list.append(recv_activations(prev_rank, comm)) + groupEnd() + return hidden_state_list + + def send_activations(hidden_state, next_rank, comm): + hidden_state = hidden_state.contiguous() send_meta(hidden_state, next_rank, comm) ncclSend(hidden_state.storage(), next_rank, comm) diff --git a/bmtrain/init.py b/bmtrain/init.py index a6214d78..209d64c5 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -114,11 +114,11 @@ def init_distributed( if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - if topo.stage_id == 0: + if topo.pipe_rank == 0: unique_id = nccl.getUniqueId() store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) - config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) + config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.pipe_rank) if topo.tp_id == 0: unique_id = nccl.getUniqueId() @@ -165,21 +165,21 @@ def __init__(self,config): config['zero_size'] = world_size // pp_size topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') topo=topo.view(pp_size,dp_size*tp_size) - self.stages = config['pipe_size'] + self.pipe_size = config['pipe_size'] stage_size = world_size // pp_size for i in range(world_size): self.pipe_idx = self.rank % stage_size - self.stage_id = self.rank // stage_size + self.pipe_rank = self.rank // stage_size self.tp_id = self.rank % tp_size self.tp_idx = self.rank // tp_size - self.zero_idx = self.stage_id + self.zero_idx = self.pipe_rank self.zero_id = self.pipe_idx - self.tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.tp_zero_idx = self.pipe_rank * tp_size + self.tp_id self.tp_zero_id = self.pipe_idx // tp_size - self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 - self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 + self.next_rank = self.pipe_rank+1 if self.pipe_rank < config['pipe_size'] - 1 else -1 + self.prev_rank = self.pipe_rank-1 if self.pipe_rank > 0 else -1 def get_group_id(self,group_name): @@ -194,7 +194,7 @@ def get_group_id(self,group_name): def get_group_rank(self,group_name): if group_name == "pipe": - return self.stage_id + return self.pipe_rank elif group_name == "zero": return self.zero_id elif group_name == "tp_zero": diff --git a/bmtrain/inspect/tensor.py b/bmtrain/inspect/tensor.py index 2c45fdac..bcd81407 100644 --- a/bmtrain/inspect/tensor.py +++ b/bmtrain/inspect/tensor.py @@ -39,7 +39,7 @@ def _set_summary(self, summary): kw = f'{item["prefix"]}{item["name"]}' assert item["inside_pipe"] is not None - stage_id = item["inside_pipe"]["stage_id"] + pipe_rank = item["inside_pipe"]["pipe_rank"] stages = item["inside_pipe"]["stages"] st = item["inside_pipe"]["st"] ed = item["inside_pipe"]["ed"] @@ -53,7 +53,7 @@ def _set_summary(self, summary): break for stage in range(stages): - if stage_id == stage: + if pipe_rank == stage: broadcast_object(pipe_cnt, config["pipe_comm"], src = stage) for k in range(i, j): item = summary[k] @@ -76,7 +76,7 @@ def _set_summary(self, summary): "tensor": tensor, "grad": grad, "requires_grad": item["requires_grad"], - "inside_pipe": {"stage_id": stage}, + "inside_pipe": {"pipe_rank": stage}, }) kw_cnt[kw] += 1 else: @@ -99,7 +99,7 @@ def _set_summary(self, summary): "tensor": None, "grad": None, "requires_grad": None, - "inside_pipe": {"stage_id": stage}, + "inside_pipe": {"pipe_rank": stage}, }) kw_cnt[kw] += 1 @@ -114,23 +114,23 @@ def _set_summary(self, summary): "requires_grad": it["requires_grad"], "has_grad": has_grad, } - broadcast_object(info, config["pipe_comm"], src = it["inside_pipe"]["stage_id"]) + broadcast_object(info, config["pipe_comm"], src = it["inside_pipe"]["pipe_rank"]) tensor = it["tensor"] - tensor = broadcast(tensor, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + tensor = broadcast(tensor, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) grad = it["grad"] else: - info = broadcast_object({}, config["pipe_comm"], src = it["inside_pipe"]["stage_id"]) + info = broadcast_object({}, config["pipe_comm"], src = it["inside_pipe"]["pipe_rank"]) has_grad = info.pop("has_grad") it.update(info) tensor = torch.empty(it["shape"]).cuda().requires_grad_() - tensor = broadcast(tensor, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + tensor = broadcast(tensor, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) if has_grad: grad = torch.empty(it["shape"]).cuda() - tensor = tensor.chunk(stages, dim=0)[stage_id].clone() + tensor = tensor.chunk(stages, dim=0)[pipe_rank].clone() it["tensor"] = tensor if has_grad: - grad = broadcast(grad, it["inside_pipe"]["stage_id"], config["pipe_comm"]) - grad = grad.chunk(stages, dim=0)[stage_id].clone() + grad = broadcast(grad, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) + grad = grad.chunk(stages, dim=0)[pipe_rank].clone() tensor.grad = grad it["shape"] = (it["shape"][0]//config["pipe_size"],) + it["shape"][1:] diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py index cd1f63bf..8e4548ae 100644 --- a/bmtrain/nn/parallel_cross_entropy_func.py +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -103,7 +103,7 @@ def backward(ctx, grad_output): else: grad_2d[arange_1d, masked_target_1d] -= softmax_update - grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) + grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1)) return grad_input, None, None diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py new file mode 100644 index 00000000..cf5d73d6 --- /dev/null +++ b/bmtrain/pipe/comm.py @@ -0,0 +1,69 @@ +import torch +from ..distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations +from ..global_var import config +class PipeCommander: + def __init__(self, topo, data_iterator, num_micros, num_warmup, forward_only, interleaving_size) -> None: + self.topo = topo + self.data_iterator = data_iterator + self.num_micros = num_micros + self.num_warmup = num_warmup + self.forward_only = forward_only + self.interleaving_size = interleaving_size + + def send_next(self, tensors): + if not self.is_last_stage(): + if not isinstance(tensors, list): + tensors = [tensors] + send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"]) + + def send_prev(self, tensors): + if not self.is_first_stage(): + if not isinstance(tensors, list): + tensors = [tensors] + send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"]) + + def recv_prev(self, need_data=False): + if not self.is_first_stage(): + return recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) + else: + if need_data: + return next(self.data_iterator) + else: + return None + + def recv_next(self): + if not self.is_last_stage(): + return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) + else: + return None + + def allocate_tensor(self, shape, dtype): + return torch.empty(shape, dtype=dtype, device="cuda") + + def is_first_stage(self): + return self.topo.pipe_rank == 0 + + def is_last_stage(self): + return self.topo.pipe_rank == self.topo.pipe_size - 1 + + def send_forward_recv_backward(self, forward_state): + if not self.is_last_stage(): + self.send_next(forward_state) + backward_grad = self.recv_next() + else: + backward_grad = None + return backward_grad + + def send_backward_recv_forward(self, backward_grad, need_data=False): + if not self.is_first_stage(): + self.send_pre(backward_grad) + forward_state = self.recv_pre() + else: + if need_data: + forward_state = next(self.data_iterator) + else: + forward_state = None + return forward_state + + + \ No newline at end of file diff --git a/bmtrain/pipe/debug.py b/bmtrain/pipe/debug.py new file mode 100644 index 00000000..6f7063c2 --- /dev/null +++ b/bmtrain/pipe/debug.py @@ -0,0 +1,18 @@ +import logging + +logger = logging.getLogger('pipeline') +logger.setLevel(logging.DEBUG) + +fh = logging.FileHandler('pipe.log') +fh.setLevel(logging.DEBUG) + +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) + +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + +fh.setFormatter(formatter) +ch.setFormatter(formatter) + +logger.addHandler(fh) +logger.addHandler(ch) diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py new file mode 100644 index 00000000..fa7a63f5 --- /dev/null +++ b/bmtrain/pipe/schedule.py @@ -0,0 +1,115 @@ +from ..global_var import config +from .comm import PipeCommander +import torch + +def backward_step(input_tensor, output_tensor, output_tensor_grad): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + #TODO scale the grad + # if output_tensor_grad[0] is None and config.grad_scale_func is not None: + # output_tensor[0] = config.grad_scale_func(output_tensor[0]) + + torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + + input_tensor_grad = [None] + if input_tensor is not None: + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) + + return input_tensor_grad + +def pipeline_forward_backward(models, inputs, data_iterator, global_batch_size, interleaving_size=1): + """Forward and backward the pipeline model. + + Args: + models (TransformerBlocklist): The list of models. + data_iterator (iterator): The iterator of the dataset. + micro_batch_size (int): The micro batch size. + + Returns: + torch.Tensor: The loss of the model. + """ + + # forwrad unpack + inp, *args = data_iterator + micro_batch_size = inp.shape[0] + assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" + num_micro_batches = global_batch_size // micro_batch_size + assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" + topo = config["topology"] + # construct Pipe Commander + forward_only = torch.is_grad_enabled() + if forward_only: + num_warmup = num_micro_batches + else: + num_warmup = topo.pipe_size - topo.pipe_rank - 1 + + commander = PipeCommander(topo, num_micros=num_micro_batches,\ + num_warmup=num_warmup, forward_only=False, \ + interleaving_size=interleaving_size, \ + data_iterator=data_iterator) + inps = [] + outputs = [] + + for micro in range(num_warmup): + inp = commander.recv_peer(need_data=True) + output = models(*inp) + # send activations + commander.send_peer(output) + if not forward_only: + inps.append(inp) + outputs.append(output) + remain_batch = num_micro_batches - num_warmup + + if remain_batch > 0: + inp = commander.recv_peer(need_data=True) + + for micro in range(num_micro_batches - num_warmup): + output = models(*inp) + grad_output = commander.send_forward_recv_backward(output) + inp_grad = backward_step(inp, output, grad_output) + if micro == remain_batch - 1: + input_tensor = None + commander.send_prev(inp_grad) + else: + input_tensor = commander.send_backward_recv_forward(inp_grad) + for i in range(num_warmup): + + # if i == num_warmup - 1: + # grad sync + # if config.grad_sync_func is None or rank == 0: + # enable_grad_sync() + + input_tensor = inp.pop(0) + output_tensor = output.pop(0) + + output_tensor_grad = commander.recv_next() + + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, + ) + + commander.send_prev(input_tensor_grad) + + + \ No newline at end of file diff --git a/bmtrain/pipe/topo.py b/bmtrain/pipe/topo.py new file mode 100644 index 00000000..f0641f48 --- /dev/null +++ b/bmtrain/pipe/topo.py @@ -0,0 +1,85 @@ +class topology: + def __init__(self,**config): + # pipe_idx is the idx of the pipeline in the group + self.rank = config['rank'] + pp_size = config["pipe_size"] + tp_size = config["tp_size"] + world_size = config["world_size"] + assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" + + dp_size = world_size // (pp_size * tp_size) + config['tp_zero_size'] = dp_size + config['zero_size'] = world_size // pp_size + self.pipe_size = config['pipe_size'] + + stage_size = world_size // pp_size + for i in range(world_size): + self.pipe_idx = self.rank % stage_size + self.pipe_rank = self.rank // stage_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + #pp->zero + self.pp_zero_idx = self.pipe_rank + self.pp_zero_id = self.pipe_idx + #tp->zero + self.tp_zero_idx = self.tp_id + self.tp_zero_id = self.tp_idx + #pp->tp->zero + self.pp_tp_zero_idx = self.pipe_rank * tp_size + self.tp_id + self.pp_tp_zero_id = self.pipe_idx // tp_size + #only zero + self.zero_idx = 0 + self.zero_id = self.rank + + + def get_group_id(self,group_name): + if group_name == "pipe": + return self.pipe_idx + elif group_name == "zero": + return self.zero_idx + elif group_name == "tp_zero": + return self.tp_zero_idx + elif group_name == "tp": + return self.tp_idx + + def get_group_rank(self,group_name): + if group_name == "pipe": + return self.pipe_rank + elif group_name == "zero": + return self.zero_id + elif group_name == "tp_zero": + return self.tp_zero_id + elif group_name == "tp": + return self.tp_id + + def get_peer(self, group_name, next_prev): + if group_name == "pipe": + if next_prev == "next": + return self.pipe_rank+1 if self.pipe_rank < self.pipe_size - 1 else -1 + elif next_prev == "prev": + return self.pipe_rank-1 if self.pipe_rank > 0 else -1 + elif group_name == "zero": + if next_prev == "next": + return self.zero_id+1 if self.zero_id < self.pipe_size - 1 else -1 + elif next_prev == "prev": + return self.zero_id-1 if self.zero_id > 0 else -1 + elif group_name == "tp_zero": + if next_prev == "next": + return self.tp_zero_id+1 if self.tp_zero_id < self.pipe_size - 1 else -1 + elif next_prev == "prev": + return self.tp_zero_id-1 if self.tp_zero_id > 0 else -1 + elif group_name == "tp": + if next_prev == "next": + return self.tp_id+1 if self.tp_id < self.pipe_size - 1 else -1 + elif next_prev == "prev": + return self.tp_id-1 if self.tp_id > 0 else -1 + return -1 + + +if __name__ == "__main__": + topology1 = topology(**{"rank":0,"pipe_size":4,"tp_size":8,"world_size":32}) + topology2 = topology(**{"rank":8,"pipe_size":4,"tp_size":8,"world_size":32}) + topology3 = topology(**{"rank":16,"pipe_size":4,"tp_size":8,"world_size":32}) + topology4 = topology(**{"rank":24,"pipe_size":4,"tp_size":8,"world_size":32}) + from IPython import embed;embed() + \ No newline at end of file diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index e3913b6c..a694a9a3 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -64,16 +64,16 @@ def backward(ctx, grads, arg_grads): if requires_grad: grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0) grad = all_reduce(grad, "sum", config["pipe_comm"]) - split_size = topo.stages if ctx.batch_related[idx] else num_micros + split_size = topo.pipe_size if ctx.batch_related[idx] else num_micros grad = grad.chunk(split_size) if ctx.batch_related[idx]: - arg_grads.append(grad[topo.stage_id]) + arg_grads.append(grad[topo.pipe_rank]) else: arg_grads.append(grad[0]) else: arg_grads.append(None) arg_grads.append(None) #for append(batch_related) - return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads + return grads.chunk(topo.pipe_size, dim=0)[topo.pipe_rank], *arg_grads class PipePostFunction(torch.autograd.Function): @staticmethod @@ -81,24 +81,24 @@ def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, bac topo = config['topology'] ctx.return_hidden_states = return_hidden_states last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(topo.stages, dim=0) - output = last_hidden[topo.stage_id] + last_hidden = last_hidden.chunk(topo.pipe_size, dim=0) + output = last_hidden[topo.pipe_rank] output.requires_grad_() if return_hidden_states: - ctx.stage_id = topo.stage_id - ctx.stages = topo.stages + ctx.pipe_rank = topo.pipe_rank + ctx.pipe_size = topo.pipe_size ctx.backward_stage_ranges = backward_stage_ranges middle_hiddens = [] - for stage_id in range(ctx.stages): - if ctx.stage_id == stage_id: + for pipe_rank in range(ctx.pipe_size): + if ctx.pipe_rank == pipe_rank: middle_hidden = hidden_states else: - middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape + middle_shape = (forward_stage_ranges[pipe_rank],) + last_hidden_shape middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype) - middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) - middle_hidden = middle_hidden.chunk(ctx.stages, dim=1) - middle_hidden = middle_hidden[ctx.stage_id].clone() + middle_hidden = broadcast(middle_hidden, pipe_rank, config["pipe_comm"]) + middle_hidden = middle_hidden.chunk(ctx.pipe_size, dim=1) + middle_hidden = middle_hidden[ctx.pipe_rank].clone() middle_hiddens.append(middle_hidden) middle_hiddens = torch.cat(middle_hiddens, dim=0) middle_hiddens.requires_grad_() @@ -112,12 +112,12 @@ def backward(ctx, grads, grad_middle=None): grad_list = grad_list.flatten(start_dim=0, end_dim=1) if ctx.return_hidden_states: - for stage_id in range(ctx.stages): - layer_range = ctx.backward_stage_ranges[stage_id] + for pipe_rank in range(ctx.pipe_size): + layer_range = ctx.backward_stage_ranges[pipe_rank] grad_middle_state = grad_middle[layer_range] grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]) grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1) - if ctx.stage_id == stage_id: + if ctx.pipe_rank == pipe_rank: grad_hidden_state_list = grad_middle_state return grad_list, grad_hidden_state_list, None, None, None, None else: @@ -125,12 +125,12 @@ def backward(ctx, grads, grad_middle=None): class StagePreFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, stage_id): - ctx.stage_id = stage_id - ctx.is_first_stage = stage_id == 0 - ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + def forward(ctx, input, pipe_rank): + ctx.pipe_rank = pipe_rank + ctx.is_first_stage = pipe_rank == 0 + ctx.is_last_stage = pipe_rank == config['pipe_size'] - 1 if not ctx.is_first_stage: - input = recv_activations(stage_id - 1, config['pipe_comm']) + input = recv_activations(pipe_rank - 1, config['pipe_comm']) input.requires_grad_() return input return input @@ -143,28 +143,28 @@ def backward(ctx, grad_outputs): with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) send_data.record_stream(current_stream) - send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + send_activations(send_data, ctx.pipe_rank - 1, config['pipe_comm']) return grad_outputs, None class StagePostFunction(torch.autograd.Function): @staticmethod - def forward(ctx, outputs, stage_id): - ctx.stage_id = stage_id - ctx.is_first_stage = stage_id == 0 - ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + def forward(ctx, outputs, pipe_rank): + ctx.pipe_rank = pipe_rank + ctx.is_first_stage = pipe_rank == 0 + ctx.is_last_stage = pipe_rank == config['pipe_size'] - 1 if not ctx.is_last_stage: send_data = outputs[0] if isinstance(outputs, tuple) else outputs current_stream = torch.cuda.current_stream() with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) send_data.record_stream(current_stream) - send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + send_activations(send_data.detach(), pipe_rank + 1, config['pipe_comm']) return outputs @staticmethod def backward(ctx, grad_outputs): if not ctx.is_last_stage: - pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm']) + pre_grad_inputs = recv_activations(ctx.pipe_rank + 1, config['pipe_comm']) return pre_grad_inputs, None return grad_outputs, None @@ -197,20 +197,20 @@ def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: rank = config['rank'] topo = config['topology'] self.layer_ids = [] - self.stages = topo.stages - self.stage_id = topo.stage_id + self.pipe_size = topo.pipe_size + self.pipe_rank = topo.pipe_rank self.pipe_idx = topo.pipe_idx for idx, module in enumerate(modules): if not isinstance(module, Block): module = Block(module) module._mode = "PIPE" - module.stage_id = self.stage_id - module.stages = self.stages + module.pipe_rank = self.pipe_rank + module.pipe_size = self.pipe_size self._modules[str(idx)] = module - self.layer_ids = self.get_range_by_stage_id(self.stage_id) + self.layer_ids = self.get_range_by_pipe_rank(self.pipe_rank) pre_module = None for i,layer_id in enumerate(self.layer_ids): @@ -218,8 +218,8 @@ def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: module.set_pre_module(pre_module) pre_module = module - module._is_first_stage = True if self.stage_id == 0 else False - module._is_last_stage = True if self.stage_id == self.stages-1 else False + module._is_first_stage = True if self.pipe_rank == 0 else False + module._is_last_stage = True if self.pipe_rank == self.pipe_size-1 else False module._is_first_layer = False module._is_last_layer = False self._modules[str(self.layer_ids[0])]._is_first_layer = True @@ -251,14 +251,14 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): micro_hidden_states = [] - hidden_state = StagePreFunction.apply(hidden_state, self.stage_id) + hidden_state = StagePreFunction.apply(hidden_state, self.pipe_rank) for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx if return_hidden_states: micro_hidden_states.append(hidden_state) hidden_state = self._modules[str(layer_id)](hidden_state, *arg) - hidden_state = StagePostFunction.apply(hidden_state, self.stage_id) + hidden_state = StagePostFunction.apply(hidden_state, self.pipe_rank) outputs.append(hidden_state) if return_hidden_states: @@ -271,27 +271,27 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa hidden_states = torch.cat(hidden_states, dim=1) forward_stage_ranges = [] backward_stage_ranges = [] - for stage_id in range(self.stages): - forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id)) - backward_stage_ranges.append(self.get_range_by_stage_id(stage_id)) + for pipe_rank in range(self.pipe_size): + forward_stage_ranges.append(self.get_part_len_by_pipe_rank(pipe_rank)) + backward_stage_ranges.append(self.get_range_by_pipe_rank(pipe_rank)) outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states) return outputs, hidden_states else: outputs = PipePostFunction.apply(last_hidden) return outputs - def get_range_by_stage_id(self, stage_id : int) -> List[int]: - part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] - start = sum(part_lens[:stage_id+1]) - end = start + part_lens[stage_id+1] + def get_range_by_pipe_rank(self, pipe_rank : int) -> List[int]: + part_lens = [0]+[self.get_part_len_by_pipe_rank(i) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] return range(start, end) - def get_part_len_by_stage_id(self, stage_id : int) -> int: - return len(self) // self.stages + (stage_id < (len(self) % self.stages)) + def get_part_len_by_pipe_rank(self, pipe_rank : int) -> int: + return len(self) // self.pipe_size + (pipe_rank < (len(self) % self.pipe_size)) def get_stage_by_layer_id(self, layer_id : int) -> int: - part_len = len(self) // self.stages - rest = len(self) % self.stages + part_len = len(self) // self.pipe_size + rest = len(self) % self.pipe_size if layer_id // (part_len + 1) < rest: return layer_id // (part_len + 1) else: diff --git a/example/layers/embedding.py b/example/layers/embedding.py index f62151c4..8a3bbd62 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -3,7 +3,22 @@ import torch import torch.nn.functional as F import bmtrain as bmt - +import inspect +def router(func): + params_kw = list(inspect.signature(func).parameters.keys()) + def wrapper(self,*args,**kwargs): + assert len(args) == 0, "In pipeline module , you have to pass variable in key=value manner" + sub_kwargs = {} + for key in kwargs: + if key in params_kw: + sub_kwargs[key] = kwargs[key] + next_module = self.next_module() + next_module.set_input() + return func(**sub_kwargs) + if bmt.config["pipe_size"] > 1: + return wrapper + else: + return func class Embedding(bmt.DistributedModule): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, @@ -75,6 +90,7 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, embedding.weight.requires_grad = not freeze return embedding + @router def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: out = F.embedding( diff --git a/example/models/gpt.py b/example/models/gpt.py index 456dbcc9..23b571db 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -2,7 +2,16 @@ import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder from bmtrain.global_var import config - +class BlockInput(bmt.DistributedModule): + def __init__(self,word_emb, pos_emb, dtype = None): + self.embed_layers = embedding_layers + + def forward(self, input, pos): + output = self.embed_layers[0](input) + for layer in self.embed_layers[1:]: + output = layer(output) + return input + class GPT(bmt.DistributedModule): def __init__(self, num_layers : int, vocab_size : int, diff --git a/example/train.py b/example/train.py index 8aaf65e4..2386a31a 100644 --- a/example/train.py +++ b/example/train.py @@ -9,7 +9,8 @@ def main(): bmt.init_distributed( seed=0, - tp_size=2, + tp_size=1, + pipe_size=4, ) model = GPT( diff --git a/tests/test_send_recv.py b/tests/test_send_recv.py index 95c9c1e5..83a7225c 100644 --- a/tests/test_send_recv.py +++ b/tests/test_send_recv.py @@ -5,7 +5,7 @@ from bmtrain.global_var import config def test_send_recv(): - if config["topology"].stage_id == 0: + if config["topology"].pipe_rank == 0: a = torch.ones((2,1)) * (config["zero_rank"]+1) a = a.cuda() print(f"send {a}") From c5ac256eb4d67255cb6d1ecd885ec80f611915ee Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 4 Sep 2023 15:35:26 +0800 Subject: [PATCH 02/43] pipe example test --- bmtrain/__init__.py | 2 +- bmtrain/block_layer.py | 14 +++++ bmtrain/nn/__init__.py | 3 +- bmtrain/nn/pipe_embedding.py | 110 +++++++++++++++++++++++++++++++++++ bmtrain/pipe/__init__.py | 0 bmtrain/pipe/comm.py | 20 ++++--- bmtrain/pipe/debug.py | 31 +++++----- bmtrain/pipe/example.py | 16 +++++ bmtrain/pipe/run.sh | 1 + bmtrain/pipe/salloc.sh | 1 + bmtrain/pipe/schedule.py | 75 ++++++++++++++++++------ example/models/gpt.py | 10 +--- 12 files changed, 228 insertions(+), 55 deletions(-) create mode 100644 bmtrain/nn/pipe_embedding.py create mode 100644 bmtrain/pipe/__init__.py create mode 100644 bmtrain/pipe/example.py create mode 100644 bmtrain/pipe/run.sh create mode 100644 bmtrain/pipe/salloc.sh diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index f4ac3642..05476ec8 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -10,7 +10,7 @@ from .layer import DistributedModule from .param_init import init_parameters, grouped_parameters from .synchronize import synchronize, sum_loss, wait_loader, gather_result -from .block_layer import Block, TransformerBlockList +from .block_layer import Block, TransformerBlockList,PipeDreamBlockList from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 61b335cf..b11112a5 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -603,3 +603,17 @@ def forward(self, *args, return_hidden_states = False): return outputs + tuple(hidden_states) else: return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] + +class PipeDreamBlockList(TransformerBlockList): + def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: + modules,s,e = self.partition(modules) + print(s,"->",e) + super().__init__(modules, num_hidden, sqrt) + + def partition(self,modules): + pipe_size = config["topology"].pipe_size + pipe_rank = config["topology"].pipe_rank + part_lens = [0]+[len(modules) // pipe_size + (i < (len(modules) % pipe_size)) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] + return modules[start:end],start,end \ No newline at end of file diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index e22d8c55..3b24b3ca 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -3,4 +3,5 @@ from .row_parallel_linear import RowParallelLinear from .parallel_embedding import ParallelEmbedding from .parallel_cross_entropy_func import parallel_cross_entropy_func -from .parallel_linear_func import OpParallelLinear \ No newline at end of file +from .parallel_linear_func import OpParallelLinear +from .pipe_embedding import PipeEmbedding \ No newline at end of file diff --git a/bmtrain/nn/pipe_embedding.py b/bmtrain/nn/pipe_embedding.py new file mode 100644 index 00000000..faf691a1 --- /dev/null +++ b/bmtrain/nn/pipe_embedding.py @@ -0,0 +1,110 @@ +import math +from typing import Optional +import torch +import torch.nn.functional as F +import bmtrain as bmt +import inspect +def router(func): + def wrapper(self,*args,**kwargs): + if bmt.config["topology"].pipe_rank == 0: + return func(self,*args,**kwargs) + else: + return args,kwargs + return wrapper + +class PipeEmbedding(bmt.DistributedModule): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[torch.Tensor] = None, + dtype=None): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if bmt.config['topology'].pipe_rank == 0: + if _weight is None: + self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_) + else: + self.weight = bmt.DistributedParameter(_weight) + self.sparse = sparse + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Creates Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + embedding.weight.requires_grad = not freeze + return embedding + + @router + def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: + if not projection: + out = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + return out + else: + out = F.linear(input, self.weight) + return out + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.max_norm is not None: + s += ', max_norm={max_norm}' + if self.norm_type != 2: + s += ', norm_type={norm_type}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + diff --git a/bmtrain/pipe/__init__.py b/bmtrain/pipe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index cf5d73d6..f69a3d17 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -1,6 +1,6 @@ import torch -from ..distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations -from ..global_var import config +from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations +from bmtrain.global_var import config class PipeCommander: def __init__(self, topo, data_iterator, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo @@ -14,26 +14,28 @@ def send_next(self, tensors): if not self.is_last_stage(): if not isinstance(tensors, list): tensors = [tensors] - send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"]) + # send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"]) def send_prev(self, tensors): if not self.is_first_stage(): if not isinstance(tensors, list): tensors = [tensors] - send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"]) + # send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"]) def recv_prev(self, need_data=False): if not self.is_first_stage(): - return recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) + return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] + # return recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) else: if need_data: - return next(self.data_iterator) + return list(next(self.data_iterator)) else: return None def recv_next(self): if not self.is_last_stage(): - return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) + # return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) + return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] else: return None @@ -56,8 +58,8 @@ def send_forward_recv_backward(self, forward_state): def send_backward_recv_forward(self, backward_grad, need_data=False): if not self.is_first_stage(): - self.send_pre(backward_grad) - forward_state = self.recv_pre() + self.send_prev(backward_grad) + forward_state = self.send_prev() else: if need_data: forward_state = next(self.data_iterator) diff --git a/bmtrain/pipe/debug.py b/bmtrain/pipe/debug.py index 6f7063c2..830b26b8 100644 --- a/bmtrain/pipe/debug.py +++ b/bmtrain/pipe/debug.py @@ -1,18 +1,17 @@ +from bmtrain.global_var import config import logging -logger = logging.getLogger('pipeline') -logger.setLevel(logging.DEBUG) - -fh = logging.FileHandler('pipe.log') -fh.setLevel(logging.DEBUG) - -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) - -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - -fh.setFormatter(formatter) -ch.setFormatter(formatter) - -logger.addHandler(fh) -logger.addHandler(ch) +def get_logger(rank): + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger = logging.getLogger('pipeline') + logger.setLevel(logging.DEBUG) + if rank == 0: + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch.setFormatter(formatter) + logger.addHandler(ch) + fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + return logger diff --git a/bmtrain/pipe/example.py b/bmtrain/pipe/example.py new file mode 100644 index 00000000..e710c757 --- /dev/null +++ b/bmtrain/pipe/example.py @@ -0,0 +1,16 @@ +from schedule import pipeline_forward_backward +import torch +import bmtrain as bmt +def generate(iters): + for i in range(iters): + yield (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) + +data_loader = iter(generate(100)) +bmt.init_distributed(pipe_size=4) +# print(bmt.config['rank']) +models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] +for i in range(11): + models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) +models = bmt.PipeDreamBlockList(models) + +pipeline_forward_backward(models, data_loader, 48) \ No newline at end of file diff --git a/bmtrain/pipe/run.sh b/bmtrain/pipe/run.sh new file mode 100644 index 00000000..68ea9bbd --- /dev/null +++ b/bmtrain/pipe/run.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost example.py diff --git a/bmtrain/pipe/salloc.sh b/bmtrain/pipe/salloc.sh new file mode 100644 index 00000000..c17843da --- /dev/null +++ b/bmtrain/pipe/salloc.sh @@ -0,0 +1 @@ +salloc --partition=gpu3 --nodelist=$1 diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index fa7a63f5..5ec37c50 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -1,6 +1,9 @@ -from ..global_var import config -from .comm import PipeCommander +import sys +from bmtrain.global_var import config +from bmtrain.loss import FusedCrossEntropy +from comm import PipeCommander import torch +from debug import get_logger def backward_step(input_tensor, output_tensor, output_tensor_grad): """Backward step through passed-in output tensor. @@ -11,12 +14,11 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad): Returns gradient of loss with respect to input tensor (None if first stage).""" - if not isinstance(input_tensor, list): + if not isinstance(input_tensor, list) : input_tensor = [input_tensor] for x in input_tensor: - if x is not None: + if x is not None and x.requires_grad: x.retain_grad() - if not isinstance(output_tensor, list): output_tensor = [output_tensor] if not isinstance(output_tensor_grad, list): @@ -24,21 +26,35 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad): #TODO scale the grad # if output_tensor_grad[0] is None and config.grad_scale_func is not None: # output_tensor[0] = config.grad_scale_func(output_tensor[0]) - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) input_tensor_grad = [None] if input_tensor is not None: input_tensor_grad = [] for x in input_tensor: - if x is None: + if x is None or not x.requires_grad: input_tensor_grad.append(None) else: input_tensor_grad.append(x.grad) return input_tensor_grad -def pipeline_forward_backward(models, inputs, data_iterator, global_batch_size, interleaving_size=1): +def forward_func(model, inp): + if not isinstance(inp, list): + inp = [inp] + if config["topology"].pipe_rank == config["topology"].pipe_size - 1: + inp = model(*inp) + config['logger'].debug("inp shape: {}".format(inp[0].shape)) + loss = inp.mean() + config['logger'].debug("loss shape: {}".format(loss.shape)) + return loss + else: + hidden_state = model(*inp) + if not isinstance(hidden_state, list): + hidden_state = [hidden_state] + return hidden_state + +def pipeline_forward_backward(model, data_iterator, global_batch_size, interleaving_size=1): """Forward and backward the pipeline model. Args: @@ -51,14 +67,22 @@ def pipeline_forward_backward(models, inputs, data_iterator, global_batch_size, """ # forwrad unpack - inp, *args = data_iterator + inp, *args = next(data_iterator) micro_batch_size = inp.shape[0] assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" num_micro_batches = global_batch_size // micro_batch_size assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" topo = config["topology"] + logger = get_logger(config['rank']) + config['logger'] = logger + logger.debug("topo: {}".format(topo)) + logger.debug("num_micro_batches: {}".format(num_micro_batches)) + logger.debug("micro_batch_size: {}".format(micro_batch_size)) + logger.debug("global_batch_size: {}".format(global_batch_size)) + logger.debug("interleaving_size: {}".format(interleaving_size)) # construct Pipe Commander - forward_only = torch.is_grad_enabled() + forward_only = False + logger.debug("forward_only: {}".format(forward_only)) if forward_only: num_warmup = num_micro_batches else: @@ -70,31 +94,43 @@ def pipeline_forward_backward(models, inputs, data_iterator, global_batch_size, data_iterator=data_iterator) inps = [] outputs = [] - + logger.debug("num_warmup: {}".format(num_warmup)) for micro in range(num_warmup): - inp = commander.recv_peer(need_data=True) - output = models(*inp) + logger.debug("{} recv micro {}th from prev neighbour".format(config['rank'], micro)) + inp = commander.recv_prev(need_data=True) + output = forward_func(model, inp) # send activations - commander.send_peer(output) + commander.send_next(output) + logger.debug("{} send micro {}th to next neighbour".format(config['rank'], micro)) if not forward_only: inps.append(inp) outputs.append(output) remain_batch = num_micro_batches - num_warmup - + logger.debug("remain_batch: {}".format(remain_batch)) if remain_batch > 0: - inp = commander.recv_peer(need_data=True) + inp = commander.recv_prev(need_data=True) for micro in range(num_micro_batches - num_warmup): - output = models(*inp) + output = forward_func(model, inp) + logger.debug("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) grad_output = commander.send_forward_recv_backward(output) + logger.debug("inp shape: {}".format(inp[0].shape)) + logger.debug("output shape: {}".format(output[0].shape)) + if grad_output[0] is not None: + logger.debug("grad_output shape: {}".format(grad_output[0].shape)) inp_grad = backward_step(inp, output, grad_output) if micro == remain_batch - 1: input_tensor = None commander.send_prev(inp_grad) else: + logger.debug("{} send micro grad {}th to prev neighbour and recv micro hidden state {} from recv neighbour".format(config['rank'], micro, micro + num_warmup)) input_tensor = commander.send_backward_recv_forward(inp_grad) - for i in range(num_warmup): + + if not forward_only: + logger.debug("cooling stage") + for i in range(num_warmup): + logger.debug("{} recv micro grad {}th from next neighbour".format(config['rank'], i)) # if i == num_warmup - 1: # grad sync # if config.grad_sync_func is None or rank == 0: @@ -108,8 +144,9 @@ def pipeline_forward_backward(models, inputs, data_iterator, global_batch_size, input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, ) + logger.send("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) - commander.send_prev(input_tensor_grad) + commander.send_prev(input_tensor_grad) \ No newline at end of file diff --git a/example/models/gpt.py b/example/models/gpt.py index 23b571db..de76e1fd 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -2,15 +2,7 @@ import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder from bmtrain.global_var import config -class BlockInput(bmt.DistributedModule): - def __init__(self,word_emb, pos_emb, dtype = None): - self.embed_layers = embedding_layers - - def forward(self, input, pos): - output = self.embed_layers[0](input) - for layer in self.embed_layers[1:]: - output = layer(output) - return input + class GPT(bmt.DistributedModule): def __init__(self, From 22e9bc028b284dcc942966bf27ca87fb2912c145 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 4 Sep 2023 21:25:47 +0800 Subject: [PATCH 03/43] fix 1f1b stuck --- bmtrain/block_layer.py | 21 ++++++++++--------- bmtrain/distributed/__init__.py | 2 +- bmtrain/distributed/ops.py | 36 ++++++++++++++++++++++++--------- bmtrain/init.py | 6 +++--- bmtrain/inspect/tensor.py | 8 ++++---- bmtrain/nn/pipe_embedding.py | 9 ++++----- bmtrain/pipe/comm.py | 28 ++++++++++++++----------- bmtrain/pipe/schedule.py | 33 ++++++++++++++++++++---------- bmtrain/pipe/test_send_recv.py | 24 ++++++++++++++++++++++ docs/UPDATE_0.2.0.md | 2 +- tests/test_send_recv.py | 2 +- 11 files changed, 114 insertions(+), 57 deletions(-) create mode 100644 bmtrain/pipe/test_send_recv.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 648bd61f..c85c3ed8 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -106,13 +106,12 @@ def init_param_storage(self): storage_type = storage_type_cuda(param.storage_type()) kw_name = _get_param_kw(param) - if kw_name not in self._storage_info: if self._mode == "PIPE" and param._tp_mode: zero_comm = config["pp_tp_zero_comm"] elif self._mode != "PIPE" and param._tp_mode: zero_comm = config["tp_zero_comm"] - elif self._mode == "PIPE" and not param._tp_mode: + elif (self._mode == "PIPE" or self._mode == "1F1B") and not param._tp_mode: zero_comm = config["pp_zero_comm"] else: zero_comm = config["zero_comm"] @@ -531,7 +530,7 @@ def _block_wrapper(module, module_dict:dict, mode="BLOCK"): in_block = id(module) in module_dict new_module = Block(module, initialized=in_block, mode=mode) if in_block: - new_module.reference(modules[id(module)]) + new_module.reference(module_dict[id(module)]) else: module_dict[id(module)] = new_module else: @@ -565,15 +564,14 @@ class TransformerBlockList(torch.nn.Module): """ _modules: Dict[str, Block] - def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: + def __init__(self, modules: Iterable[Block], num_hidden=1, mode="BLOCK") -> None: super().__init__() self._modules = {} pre_module = None module_dict = {} - module_dict = {} for i, module in enumerate(modules): - module = _block_wrapper(module, module_dict) + module = _block_wrapper(module, module_dict, mode=mode) module.set_pre_module(pre_module) pre_module = module module._is_first_layer = False @@ -620,10 +618,13 @@ def forward(self, *args, return_hidden_states = False): class PipeDreamBlockList(TransformerBlockList): def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: - modules,s,e = self.partition(modules) - print(s,"->",e) - super().__init__(modules, num_hidden, sqrt) - + module_dict = {} + mode = "1F1B" + for idx in range(len(modules)): + modules[idx] = _block_wrapper(modules[idx], module_dict,mode=mode) + modules,s,e = self.partition(modules) + super().__init__(modules, num_hidden, "1F1B") + def partition(self,modules): pipe_size = config["topology"].pipe_size pipe_rank = config["topology"].pipe_rank diff --git a/bmtrain/distributed/__init__.py b/bmtrain/distributed/__init__.py index 9dc64bb8..8671c4aa 100644 --- a/bmtrain/distributed/__init__.py +++ b/bmtrain/distributed/__init__.py @@ -1 +1 @@ -from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations +from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations,groupcall \ No newline at end of file diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index a7a1e626..182d7899 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -6,6 +6,7 @@ from ..nccl import send as ncclSend from ..nccl import recv as ncclRecv from ..nccl import commCount,commRank,NCCLCommunicator,groupStart,groupEnd +import contextlib DTYPE_LIST = [ torch.float64, torch.float32, @@ -17,23 +18,39 @@ torch.bfloat16, torch.bool ] -def send_activations_list(hidden_state_list, next_rank, comm): - length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - length[0] = len(hidden_state_list) - ncclSend(length.storage(), next_rank, comm) +@contextlib.contextmanager +def groupcall(): groupStart() - for i in range(length): - send_activations(hidden_state_list[i], next_rank, comm) + yield groupEnd() +def send_activations_list(hidden_state_list, next_rank, comm, async_op=True): + if async_op: + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config["pp_comm_stream"]): + config["pp_comm_stream"].wait_stream(current_stream) + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + length[0] = len([h for h in hidden_state_list if h is not None]) + ncclSend(length.storage(), next_rank, comm) + for i in range(len(hidden_state_list)): + if hidden_state_list[i] is None: + continue + hidden_state_list[i].record_stream(config["pp_comm_stream"]) + send_activations(hidden_state_list[i], next_rank, comm) + else: + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + length[0] = len(hidden_state_list) + ncclSend(length.storage(), next_rank, comm) + for i in range(length): + send_activations(hidden_state_list[i], next_rank, comm) + + def recv_activations_list(prev_rank, comm): length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - ncclRecv(length.storage(), prev_rank, comm) hidden_state_list = [] - groupStart() + ncclRecv(length.storage(), prev_rank, comm) for i in range(length[0].item()): hidden_state_list.append(recv_activations(prev_rank, comm)) - groupEnd() return hidden_state_list @@ -62,6 +79,7 @@ def recv_meta(prev_rank, comm): n_dims = meta_data[0].item() dtype = DTYPE_LIST[meta_data[1].item()] shape = meta_data[2:n_dims+2].tolist() + return dtype,shape class OpBroadcast(torch.autograd.Function): diff --git a/bmtrain/init.py b/bmtrain/init.py index fdce0c25..3a4420ad 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -171,7 +171,7 @@ def __init__(self,config): dp_size = world_size // (pp_size * tp_size) config['tp_zero_size'] = dp_size config['zero_size'] = world_size // pp_size - self.stages = config['pipe_size'] + self.pipe_size = config['pipe_size'] stage_size = world_size // pp_size for i in range(world_size): @@ -180,13 +180,13 @@ def __init__(self,config): self.tp_id = self.rank % tp_size self.tp_idx = self.rank // tp_size #pp->zero - self.pp_zero_idx = self.stage_id + self.pp_zero_idx = self.pipe_rank self.pp_zero_id = self.pipe_idx #tp->zero self.tp_zero_idx = self.tp_id self.tp_zero_id = self.tp_idx #pp->tp->zero - self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.pp_tp_zero_idx = self.pipe_rank * tp_size + self.tp_id self.pp_tp_zero_id = self.pipe_idx // tp_size #only zero self.zero_idx = 0 diff --git a/bmtrain/inspect/tensor.py b/bmtrain/inspect/tensor.py index bcd81407..0b5a9ae4 100644 --- a/bmtrain/inspect/tensor.py +++ b/bmtrain/inspect/tensor.py @@ -40,7 +40,7 @@ def _set_summary(self, summary): assert item["inside_pipe"] is not None pipe_rank = item["inside_pipe"]["pipe_rank"] - stages = item["inside_pipe"]["stages"] + pipe_size = item["inside_pipe"]["pipe_size"] st = item["inside_pipe"]["st"] ed = item["inside_pipe"]["ed"] @@ -52,7 +52,7 @@ def _set_summary(self, summary): if ed: break - for stage in range(stages): + for stage in range(pipe_size): if pipe_rank == stage: broadcast_object(pipe_cnt, config["pipe_comm"], src = stage) for k in range(i, j): @@ -126,11 +126,11 @@ def _set_summary(self, summary): tensor = broadcast(tensor, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) if has_grad: grad = torch.empty(it["shape"]).cuda() - tensor = tensor.chunk(stages, dim=0)[pipe_rank].clone() + tensor = tensor.chunk(pipe_size, dim=0)[pipe_rank].clone() it["tensor"] = tensor if has_grad: grad = broadcast(grad, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) - grad = grad.chunk(stages, dim=0)[pipe_rank].clone() + grad = grad.chunk(pipe_size, dim=0)[pipe_rank].clone() tensor.grad = grad it["shape"] = (it["shape"][0]//config["pipe_size"],) + it["shape"][1:] diff --git a/bmtrain/nn/pipe_embedding.py b/bmtrain/nn/pipe_embedding.py index faf691a1..5893dac7 100644 --- a/bmtrain/nn/pipe_embedding.py +++ b/bmtrain/nn/pipe_embedding.py @@ -31,11 +31,10 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq - if bmt.config['topology'].pipe_rank == 0: - if _weight is None: - self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_) - else: - self.weight = bmt.DistributedParameter(_weight) + if _weight is None: + self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_) + else: + self.weight = bmt.DistributedParameter(_weight) self.sparse = sparse @classmethod diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index f69a3d17..d47bb0fa 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -1,5 +1,5 @@ import torch -from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations +from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations, groupcall from bmtrain.global_var import config class PipeCommander: def __init__(self, topo, data_iterator, num_micros, num_warmup, forward_only, interleaving_size) -> None: @@ -14,30 +14,34 @@ def send_next(self, tensors): if not self.is_last_stage(): if not isinstance(tensors, list): tensors = [tensors] - # send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"]) + send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"]) def send_prev(self, tensors): if not self.is_first_stage(): if not isinstance(tensors, list): tensors = [tensors] - # send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"]) + send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"]) def recv_prev(self, need_data=False): if not self.is_first_stage(): - return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] - # return recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) + # return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] + res = recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) + for t in res: + t.requires_grad_() + + return res else: if need_data: return list(next(self.data_iterator)) else: - return None + return [None] def recv_next(self): if not self.is_last_stage(): - # return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) - return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] + return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) + # return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] else: - return None + return [None] def allocate_tensor(self, shape, dtype): return torch.empty(shape, dtype=dtype, device="cuda") @@ -53,18 +57,18 @@ def send_forward_recv_backward(self, forward_state): self.send_next(forward_state) backward_grad = self.recv_next() else: - backward_grad = None + backward_grad = [None] return backward_grad def send_backward_recv_forward(self, backward_grad, need_data=False): if not self.is_first_stage(): self.send_prev(backward_grad) - forward_state = self.send_prev() + forward_state = self.recv_prev() else: if need_data: forward_state = next(self.data_iterator) else: - forward_state = None + forward_state = [None] return forward_state diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 5ec37c50..f012dbe3 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -1,6 +1,7 @@ import sys from bmtrain.global_var import config from bmtrain.loss import FusedCrossEntropy +import bmtrain as bmt from comm import PipeCommander import torch from debug import get_logger @@ -68,6 +69,7 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav # forwrad unpack inp, *args = next(data_iterator) + optimizer = bmt.optim.AdamOptimizer(model.parameters(), lr=0.001) micro_batch_size = inp.shape[0] assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" num_micro_batches = global_batch_size // micro_batch_size @@ -96,9 +98,10 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav outputs = [] logger.debug("num_warmup: {}".format(num_warmup)) for micro in range(num_warmup): - logger.debug("{} recv micro {}th from prev neighbour".format(config['rank'], micro)) inp = commander.recv_prev(need_data=True) + logger.debug("{} recv micro {}th from prev neighbour".format(config['rank'], micro)) output = forward_func(model, inp) + logger.debug("{} micro forward".format(micro)) # send activations commander.send_next(output) logger.debug("{} send micro {}th to next neighbour".format(config['rank'], micro)) @@ -112,41 +115,49 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav for micro in range(num_micro_batches - num_warmup): output = forward_func(model, inp) - logger.debug("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) + logger.debug("{} micro forward".format(micro+num_warmup)) grad_output = commander.send_forward_recv_backward(output) + print(len(grad_output)) + logger.debug("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) logger.debug("inp shape: {}".format(inp[0].shape)) - logger.debug("output shape: {}".format(output[0].shape)) - if grad_output[0] is not None: + if not commander.is_last_stage(): + logger.debug("output shape: {}".format(output[0].shape)) + if grad_output[0] is not None : logger.debug("grad_output shape: {}".format(grad_output[0].shape)) inp_grad = backward_step(inp, output, grad_output) + logger.debug("{} micro backward".format(micro+num_warmup)) if micro == remain_batch - 1: input_tensor = None commander.send_prev(inp_grad) + logger.debug("{} send micro grad {}th to prev neighbour".format(config['rank'], micro + num_warmup)) else: - logger.debug("{} send micro grad {}th to prev neighbour and recv micro hidden state {} from recv neighbour".format(config['rank'], micro, micro + num_warmup)) + logger.debug("{} send micro grad {}th to prev neighbour and recv micro hidden state {} from recv neighbour".format(config['rank'], micro, micro + num_warmup + 1)) + logger.debug("inp_grad shape: {}".format(inp_grad[0].shape)) input_tensor = commander.send_backward_recv_forward(inp_grad) + inps.append(input_tensor) + outputs.append(output) if not forward_only: logger.debug("cooling stage") for i in range(num_warmup): - logger.debug("{} recv micro grad {}th from next neighbour".format(config['rank'], i)) + logger.debug("{} recv micro grad {}th from next neighbour".format(config['rank'], num_micro_batches - num_warmup + i)) # if i == num_warmup - 1: # grad sync # if config.grad_sync_func is None or rank == 0: # enable_grad_sync() - input_tensor = inp.pop(0) - output_tensor = output.pop(0) + input_tensor = inps.pop(0) + output_tensor = outputs.pop(0) output_tensor_grad = commander.recv_next() - + logger.debug("{} micro backward".format(num_micro_batches - num_warmup + i)) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, ) - logger.send("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) + logger.debug("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) commander.send_prev(input_tensor_grad) - + optimizer.step() \ No newline at end of file diff --git a/bmtrain/pipe/test_send_recv.py b/bmtrain/pipe/test_send_recv.py new file mode 100644 index 00000000..c8d42f53 --- /dev/null +++ b/bmtrain/pipe/test_send_recv.py @@ -0,0 +1,24 @@ +from schedule import pipeline_forward_backward +import torch +import bmtrain as bmt +from comm import PipeCommander,groupcall +def generate(iters): + for i in range(iters): + yield (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) + +bmt.init_distributed(pipe_size=4) + +topo = bmt.config["topology"] +num_micro_batches = 48 +num_warmup = 3 +interleaving_size = 1 +data_iterator = iter(generate(100)) +commander = PipeCommander(topo, num_micros=num_micro_batches,\ + num_warmup=num_warmup, forward_only=False, \ + interleaving_size=interleaving_size, \ + data_iterator=data_iterator) +# with groupcall(): +commander.send_prev([torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()]) +recv = commander.recv_next() +if recv[0] is not None: + print(recv[0].shape) \ No newline at end of file diff --git a/docs/UPDATE_0.2.0.md b/docs/UPDATE_0.2.0.md index 92819afd..5ee04639 100644 --- a/docs/UPDATE_0.2.0.md +++ b/docs/UPDATE_0.2.0.md @@ -70,7 +70,7 @@ layers = bmt.PipelineTransformerBlockList([ ``` Replacing TransformerBlockList with PipelineTransformerBlockList allows the parallel algorithm to switch from ZeRO to pipeline parallelism. -The number of stages in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed. +The number of pipe_size in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed. ### 3. Others diff --git a/tests/test_send_recv.py b/tests/test_send_recv.py index f933b0c2..2ec406b4 100644 --- a/tests/test_send_recv.py +++ b/tests/test_send_recv.py @@ -5,7 +5,7 @@ from bmtrain.global_var import config def test_send_recv(): - if config["topology"].stage_id == 0: + if config["topology"].pipe_rank == 0: a = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) a = a.cuda() print(f"send {a}") From 7fcbf0f3113f669f1024befb932d70bb380d9590 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 7 Sep 2023 13:15:39 +0800 Subject: [PATCH 04/43] 1F1B Pipeline compitabe with zero --- bmtrain/block_layer.py | 40 ++++++----- bmtrain/nn/pipe_embedding.py | 1 - bmtrain/param_init.py | 26 ++++---- bmtrain/pipe/comm.py | 16 +++-- bmtrain/pipe/debug.py | 8 +-- bmtrain/pipe/example.py | 55 ++++++++++++---- bmtrain/pipe/run.sh | 7 +- bmtrain/pipe/schedule.py | 124 +++++++++++++++++++---------------- 8 files changed, 171 insertions(+), 106 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index c85c3ed8..96146e48 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -61,7 +61,7 @@ class Block(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialized=False, mode="BLOCK"): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialize_param=True, mode="BLOCK"): super().__init__() self._module = inner_module self._inputs = None @@ -84,7 +84,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev self.all_input_no_grad = False self.all_param_no_grad = False self._zero_level = zero_level - if not initialized: + if initialize_param: self.init_param_storage() def reference(self, block): @@ -95,7 +95,7 @@ def reference(self, block): self._initialized = True self._need_release = False - def init_param_storage(self): + def init_param_storage(self, throw=False): # sort parameters by name ordered_parameters = list(self._module.named_parameters()) @@ -191,7 +191,9 @@ def init_param_storage(self): # make parameter contiguous in storage with torch.no_grad(): contiguous_param = OpAllGather.apply(param) - + if throw: + del contiguous_param + return if not (param_st >= storage_end or param_end <= storage_st): # copy offset in parameter storage offset_st = max(storage_st - param_st, 0) @@ -525,17 +527,18 @@ def eval(self): def __repr__(self): return self._module.__repr__() -def _block_wrapper(module, module_dict:dict, mode="BLOCK"): +def _block_wrapper(module, module_dict:dict, mode="BLOCK", **kwargs): if not isinstance(module, Block): - in_block = id(module) in module_dict - new_module = Block(module, initialized=in_block, mode=mode) - if in_block: - new_module.reference(module_dict[id(module)]) + if mode == "BLOCK": + in_block = id(module) in module_dict + new_module = Block(module, initialize_param=not in_block, mode=mode, **kwargs) + if in_block: + new_module.reference(module_dict[id(module)]) + elif mode == "PIPE" or mode == "1F1B": + new_module = Block(module, initialize_param=False, mode=mode, **kwargs) else: module_dict[id(module)] = new_module else: - if mode == "PIPE" and module._mode != "PIPE": - assert False, "You must be set mode=\"PIPE\" in bmt.Block when use PipelineTransformerBlockList!" if id(module._module) in module_dict: assert False, "Duplicate bmt.Block not supported in same block list!" else: @@ -621,9 +624,16 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: module_dict = {} mode = "1F1B" for idx in range(len(modules)): - modules[idx] = _block_wrapper(modules[idx], module_dict,mode=mode) - modules,s,e = self.partition(modules) - super().__init__(modules, num_hidden, "1F1B") + modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2) + s,e = self.partition(modules) + partition_module = [] + for idx,m in enumerate(modules): + if idx>=s and idx torch.Tensor: if not projection: out = F.embedding( diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index a46c7845..7d188a28 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -1,4 +1,4 @@ -from typing import Generator, Iterable, List, Tuple +from typing import Generator, Iterable, List, Tuple, Union import torch from .block_layer import Block from .parameter import DistributedParameter @@ -42,20 +42,22 @@ def iterate_parameters(model : torch.nn.Module): return [] yield val -def init_parameters(model : torch.nn.Module): +def init_parameters(models : Union[List[torch.nn.Module], torch.nn.Module]): """ Initialize the parameters of the model by calling the init_method of the distributed parameters. """ - - modules = model.named_modules() - for module_prefix, module in modules: - if isinstance(module, Block): - module.init_parameters() - else: - init_distributed_parameter( iterate_parameters(module) ) - - current_stream = torch.cuda.current_stream() - config['load_stream'].wait_stream(current_stream) + if not isinstance(models, list): + models = [models] + for model in models: + modules = model.named_modules() + for module_prefix, module in modules: + if isinstance(module, Block): + module.init_parameters() + else: + init_distributed_parameter( iterate_parameters(module) ) + + current_stream = torch.cuda.current_stream() + config['load_stream'].wait_stream(current_stream) def grouped_parameters(model : torch.nn.Module) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]: """ diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index d47bb0fa..5a02b937 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -1,6 +1,7 @@ import torch from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations, groupcall from bmtrain.global_var import config +from collections.abc import Iterable class PipeCommander: def __init__(self, topo, data_iterator, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo @@ -9,6 +10,11 @@ def __init__(self, topo, data_iterator, num_micros, num_warmup, forward_only, in self.num_warmup = num_warmup self.forward_only = forward_only self.interleaving_size = interleaving_size + + def get_data(self): + micro_batch = next(self.data_iterator) + assert isinstance(micro_batch, Iterable) + return list(micro_batch) def send_next(self, tensors): if not self.is_last_stage(): @@ -32,7 +38,7 @@ def recv_prev(self, need_data=False): return res else: if need_data: - return list(next(self.data_iterator)) + return self.get_data() else: return [None] @@ -54,7 +60,8 @@ def is_last_stage(self): def send_forward_recv_backward(self, forward_state): if not self.is_last_stage(): - self.send_next(forward_state) + if forward_state[0] is not None: + self.send_next(forward_state) backward_grad = self.recv_next() else: backward_grad = [None] @@ -62,11 +69,12 @@ def send_forward_recv_backward(self, forward_state): def send_backward_recv_forward(self, backward_grad, need_data=False): if not self.is_first_stage(): - self.send_prev(backward_grad) + if backward_grad[0] is not None: + self.send_prev(backward_grad) forward_state = self.recv_prev() else: if need_data: - forward_state = next(self.data_iterator) + forward_state = self.get_data() else: forward_state = [None] return forward_state diff --git a/bmtrain/pipe/debug.py b/bmtrain/pipe/debug.py index 830b26b8..e1de6a14 100644 --- a/bmtrain/pipe/debug.py +++ b/bmtrain/pipe/debug.py @@ -1,17 +1,17 @@ from bmtrain.global_var import config import logging -def get_logger(rank): +def get_logger(rank, level): formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('pipeline') - logger.setLevel(logging.DEBUG) + logger.setLevel(level) if rank == 0: ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) + ch.setLevel(level) ch.setFormatter(formatter) logger.addHandler(ch) fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") - fh.setLevel(logging.DEBUG) + fh.setLevel(level) fh.setFormatter(formatter) logger.addHandler(fh) return logger diff --git a/bmtrain/pipe/example.py b/bmtrain/pipe/example.py index e710c757..fcea3306 100644 --- a/bmtrain/pipe/example.py +++ b/bmtrain/pipe/example.py @@ -1,16 +1,49 @@ from schedule import pipeline_forward_backward import torch import bmtrain as bmt +import time +import sys def generate(iters): + torch.manual_seed(42) for i in range(iters): - yield (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) - -data_loader = iter(generate(100)) -bmt.init_distributed(pipe_size=4) -# print(bmt.config['rank']) -models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] -for i in range(11): - models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) -models = bmt.PipeDreamBlockList(models) - -pipeline_forward_backward(models, data_loader, 48) \ No newline at end of file + inp = (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) + print(inp[0][0]) + yield inp +data_loader = iter(generate(100*16)) +def test_pipe(): + bmt.init_distributed(seed=42, pipe_size=4) + models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] + for i in range(11): + models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) + # print(models[0].weight) + bmt.init_parameters(models) + models = bmt.PipeDreamBlockList(models) + start = time.time() + for i in range(10): + pipeline_forward_backward(models, data_loader, 12*16) + if bmt.config['topology'].pipe_rank == 0: + print(models['0'].weight.grad) + t = time.time() - start + print(t) +def test_dp(): + bmt.init_distributed(seed=42, pipe_size=1) + models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] + for i in range(11): + models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) + bmt.init_parameters(models) + models = bmt.TransformerBlockList(models) + loss = 0 + for i in range(16): + loss_tmp = models(*next(data_loader)) + loss_tmp = loss_tmp.mean() + print(loss_tmp.item()) + loss += loss_tmp + print(loss) + loss.backward() + print(models['0'].weight.grad) +if __name__ == "__main__": + if sys.argv[1] == "dp": + print("dp") + test_dp() + else: + test_pipe() \ No newline at end of file diff --git a/bmtrain/pipe/run.sh b/bmtrain/pipe/run.sh index 68ea9bbd..1bcc5c6c 100644 --- a/bmtrain/pipe/run.sh +++ b/bmtrain/pipe/run.sh @@ -1 +1,6 @@ -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost example.py +if [ "$1" = "dp" ]; then + nproc=1 +else + nproc=4 +fi +torchrun --nnodes=1 --nproc_per_node=$nproc --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost example.py $1 diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index f012dbe3..d7cba26a 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -5,8 +5,9 @@ from comm import PipeCommander import torch from debug import get_logger +import logging -def backward_step(input_tensor, output_tensor, output_tensor_grad): +def backward_step(inp, output, grad_output): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -15,39 +16,40 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad): Returns gradient of loss with respect to input tensor (None if first stage).""" - if not isinstance(input_tensor, list) : - input_tensor = [input_tensor] - for x in input_tensor: + if not isinstance(inp, list) : + inp = [inp] + for x in inp: if x is not None and x.requires_grad: x.retain_grad() - if not isinstance(output_tensor, list): - output_tensor = [output_tensor] - if not isinstance(output_tensor_grad, list): - output_tensor_grad = [output_tensor_grad] + if not isinstance(output, list): + output = [output] + if not isinstance(grad_output, list): + grad_output = [grad_output] #TODO scale the grad # if output_tensor_grad[0] is None and config.grad_scale_func is not None: # output_tensor[0] = config.grad_scale_func(output_tensor[0]) - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + torch.autograd.backward(output[0], grad_tensors=grad_output[0]) - input_tensor_grad = [None] - if input_tensor is not None: - input_tensor_grad = [] - for x in input_tensor: + input_grad = [None] + if inp is not None: + input_grad = [] + for x in inp: if x is None or not x.requires_grad: - input_tensor_grad.append(None) + input_grad.append(None) else: - input_tensor_grad.append(x.grad) + input_grad.append(x.grad) - return input_tensor_grad + return input_grad -def forward_func(model, inp): +def forward_func(model, inp, micro_idx): if not isinstance(inp, list): inp = [inp] if config["topology"].pipe_rank == config["topology"].pipe_size - 1: - inp = model(*inp) - config['logger'].debug("inp shape: {}".format(inp[0].shape)) - loss = inp.mean() - config['logger'].debug("loss shape: {}".format(loss.shape)) + output = model(*inp) + config['logger'].info("inp shape: {}".format(output.shape)) + loss = output.mean() + config['logger'].info("loss: {}".format(loss.item())) + return loss else: hidden_state = model(*inp) @@ -68,23 +70,23 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav """ # forwrad unpack - inp, *args = next(data_iterator) optimizer = bmt.optim.AdamOptimizer(model.parameters(), lr=0.001) - micro_batch_size = inp.shape[0] + micro_batch_size = 12 assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" num_micro_batches = global_batch_size // micro_batch_size assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" + config["micros"] = num_micro_batches topo = config["topology"] - logger = get_logger(config['rank']) + logger = get_logger(config['rank'], logging.DEBUG) config['logger'] = logger - logger.debug("topo: {}".format(topo)) - logger.debug("num_micro_batches: {}".format(num_micro_batches)) - logger.debug("micro_batch_size: {}".format(micro_batch_size)) - logger.debug("global_batch_size: {}".format(global_batch_size)) - logger.debug("interleaving_size: {}".format(interleaving_size)) + logger.info("topo: {}".format(topo)) + logger.info("num_micro_batches: {}".format(num_micro_batches)) + logger.info("micro_batch_size: {}".format(micro_batch_size)) + logger.info("global_batch_size: {}".format(global_batch_size)) + logger.info("interleaving_size: {}".format(interleaving_size)) # construct Pipe Commander forward_only = False - logger.debug("forward_only: {}".format(forward_only)) + logger.info("forward_only: {}".format(forward_only)) if forward_only: num_warmup = num_micro_batches else: @@ -96,68 +98,74 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav data_iterator=data_iterator) inps = [] outputs = [] - logger.debug("num_warmup: {}".format(num_warmup)) + logger.info("num_warmup: {}".format(num_warmup)) for micro in range(num_warmup): inp = commander.recv_prev(need_data=True) - logger.debug("{} recv micro {}th from prev neighbour".format(config['rank'], micro)) - output = forward_func(model, inp) - logger.debug("{} micro forward".format(micro)) + logger.info("{} recv micro {}th from prev neighbour".format(config['rank'], micro)) + output = forward_func(model, inp, micro) + logger.info("{} micro forward".format(micro)) # send activations commander.send_next(output) - logger.debug("{} send micro {}th to next neighbour".format(config['rank'], micro)) + logger.info("{} send micro {}th to next neighbour".format(config['rank'], micro)) if not forward_only: inps.append(inp) outputs.append(output) remain_batch = num_micro_batches - num_warmup - logger.debug("remain_batch: {}".format(remain_batch)) + logger.info("remain_batch: {}".format(remain_batch)) if remain_batch > 0: inp = commander.recv_prev(need_data=True) for micro in range(num_micro_batches - num_warmup): - output = forward_func(model, inp) - logger.debug("{} micro forward".format(micro+num_warmup)) + output = forward_func(model, inp, micro + num_warmup) + logger.info("{} micro forward".format(micro+num_warmup)) grad_output = commander.send_forward_recv_backward(output) - print(len(grad_output)) - logger.debug("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) + inps.append(inp) + outputs.append(output) + logger.info("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) logger.debug("inp shape: {}".format(inp[0].shape)) if not commander.is_last_stage(): logger.debug("output shape: {}".format(output[0].shape)) if grad_output[0] is not None : logger.debug("grad_output shape: {}".format(grad_output[0].shape)) + inp = inps.pop(0) + output = outputs.pop(0) + for x in inp: + logger.info("inp requires_grad: {}".format(x.requires_grad)) inp_grad = backward_step(inp, output, grad_output) - logger.debug("{} micro backward".format(micro+num_warmup)) + logger.info("{} micro backward".format(micro+num_warmup)) if micro == remain_batch - 1: - input_tensor = None + inp = None commander.send_prev(inp_grad) - logger.debug("{} send micro grad {}th to prev neighbour".format(config['rank'], micro + num_warmup)) + logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], micro + num_warmup)) else: - logger.debug("{} send micro grad {}th to prev neighbour and recv micro hidden state {} from recv neighbour".format(config['rank'], micro, micro + num_warmup + 1)) - logger.debug("inp_grad shape: {}".format(inp_grad[0].shape)) - input_tensor = commander.send_backward_recv_forward(inp_grad) - inps.append(input_tensor) - outputs.append(output) + if inp_grad[0] is not None: + logger.debug("inp_grad shape: {}".format(inp_grad[0].shape)) + inp = commander.send_backward_recv_forward(inp_grad, need_data=True) + logger.debug("inp type: {}".format(type(inp))) + logger.debug("inp shape: {}".format(inp[0].shape)) + logger.info("{} send micro grad {}th to prev neighbour and recv micro hidden state {} from prev neighbour".format(config['rank'], micro, micro + num_warmup + 1)) if not forward_only: - logger.debug("cooling stage") + logger.info("cooling stage") for i in range(num_warmup): - logger.debug("{} recv micro grad {}th from next neighbour".format(config['rank'], num_micro_batches - num_warmup + i)) + logger.info("{} recv micro grad {}th from next neighbour".format(config['rank'], num_micro_batches - num_warmup + i)) # if i == num_warmup - 1: # grad sync # if config.grad_sync_func is None or rank == 0: # enable_grad_sync() - input_tensor = inps.pop(0) - output_tensor = outputs.pop(0) + inp = inps.pop(0) + output = outputs.pop(0) - output_tensor_grad = commander.recv_next() - logger.debug("{} micro backward".format(num_micro_batches - num_warmup + i)) - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, + grad_output = commander.recv_next() + logger.info("{} micro backward".format(num_micro_batches - num_warmup + i)) + input_grad = backward_step( + inp, output , grad_output, ) - logger.debug("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) + logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) - commander.send_prev(input_tensor_grad) + commander.send_prev(input_grad) optimizer.step() \ No newline at end of file From 3c235a4ffe020bc1db037b325e8ffbf0c5bc1d76 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 7 Sep 2023 16:20:28 +0800 Subject: [PATCH 05/43] fix pipe embedding --- bmtrain/block_layer.py | 12 ++++++++++-- bmtrain/nn/pipe_embedding.py | 11 ++--------- bmtrain/pipe/comm.py | 9 ++++----- bmtrain/pipe/example.py | 2 ++ bmtrain/pipe/schedule.py | 17 ++++++++++++++--- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 96146e48..2155cdfd 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -64,6 +64,7 @@ class Block(torch.nn.Module): def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialize_param=True, mode="BLOCK"): super().__init__() self._module = inner_module + self._module._in_block = True self._inputs = None self._layer_dict = {} self._forward_block_ctx = None @@ -276,8 +277,8 @@ def post_hook(self, out): return post_out def forward(self, *args): + arg_list = self.pre_hook(*args) - if self.all_input_no_grad and not self.all_param_no_grad: placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) @@ -620,6 +621,7 @@ def forward(self, *args, return_hidden_states = False): return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] class PipeDreamBlockList(TransformerBlockList): + def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: module_dict = {} mode = "1F1B" @@ -641,4 +643,10 @@ def partition(self,modules): part_lens = [0]+[len(modules) // pipe_size + (i < (len(modules) % pipe_size)) for i in range(pipe_rank+1)] start = sum(part_lens[:pipe_rank+1]) end = start + part_lens[pipe_rank+1] - return start,end \ No newline at end of file + return start,end + + def get_embedding(self): + assert config["topology"].pipe_rank == 0 + return self._modules[str(0)] + + \ No newline at end of file diff --git a/bmtrain/nn/pipe_embedding.py b/bmtrain/nn/pipe_embedding.py index b3082e75..fc6d92b7 100644 --- a/bmtrain/nn/pipe_embedding.py +++ b/bmtrain/nn/pipe_embedding.py @@ -3,14 +3,7 @@ import torch import torch.nn.functional as F import bmtrain as bmt -import inspect -def router(func): - def wrapper(self,*args,**kwargs): - if bmt.config["topology"].pipe_rank == 0: - return func(self,*args,**kwargs) - else: - return args,kwargs - return wrapper + class PipeEmbedding(bmt.DistributedModule): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, @@ -80,7 +73,7 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, sparse=sparse) embedding.weight.requires_grad = not freeze return embedding - + def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: out = F.embedding( diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 5a02b937..70d524c5 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -3,16 +3,17 @@ from bmtrain.global_var import config from collections.abc import Iterable class PipeCommander: - def __init__(self, topo, data_iterator, num_micros, num_warmup, forward_only, interleaving_size) -> None: + def __init__(self, topo, input_generator, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo - self.data_iterator = data_iterator + self.input_generator = input_generator self.num_micros = num_micros self.num_warmup = num_warmup self.forward_only = forward_only self.interleaving_size = interleaving_size def get_data(self): - micro_batch = next(self.data_iterator) + assert config["topology"].pipe_rank == 0 + micro_batch = next(self.input_generator) assert isinstance(micro_batch, Iterable) return list(micro_batch) @@ -30,7 +31,6 @@ def send_prev(self, tensors): def recv_prev(self, need_data=False): if not self.is_first_stage(): - # return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] res = recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) for t in res: t.requires_grad_() @@ -45,7 +45,6 @@ def recv_prev(self, need_data=False): def recv_next(self): if not self.is_last_stage(): return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) - # return [torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()] else: return [None] diff --git a/bmtrain/pipe/example.py b/bmtrain/pipe/example.py index fcea3306..05888ef8 100644 --- a/bmtrain/pipe/example.py +++ b/bmtrain/pipe/example.py @@ -10,6 +10,7 @@ def generate(iters): print(inp[0][0]) yield inp data_loader = iter(generate(100*16)) + def test_pipe(): bmt.init_distributed(seed=42, pipe_size=4) models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] @@ -25,6 +26,7 @@ def test_pipe(): print(models['0'].weight.grad) t = time.time() - start print(t) + def test_dp(): bmt.init_distributed(seed=42, pipe_size=1) models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index d7cba26a..04a8c29e 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -6,7 +6,7 @@ import torch from debug import get_logger import logging - +from typing import Iterable def backward_step(inp, output, grad_output): """Backward step through passed-in output tensor. @@ -57,6 +57,17 @@ def forward_func(model, inp, micro_idx): hidden_state = [hidden_state] return hidden_state +def preprocess_func(model, data_iter): + + while True: + try: + inp = next(data_iter) + except StopIteration: + break + input_ids = inp[0] + embed = model.get_embedding() + yield embed(input_ids), inp[1:] + def pipeline_forward_backward(model, data_iterator, global_batch_size, interleaving_size=1): """Forward and backward the pipeline model. @@ -92,10 +103,10 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav else: num_warmup = topo.pipe_size - topo.pipe_rank - 1 - commander = PipeCommander(topo, num_micros=num_micro_batches,\ + commander = PipeCommander(topo,input_generator=preprocess_func(model, data_iterator), num_micros=num_micro_batches,\ num_warmup=num_warmup, forward_only=False, \ interleaving_size=interleaving_size, \ - data_iterator=data_iterator) + ) inps = [] outputs = [] logger.info("num_warmup: {}".format(num_warmup)) From 872285bf366c3c408666f88eccc0921069fcf6ac Mon Sep 17 00:00:00 2001 From: maydomine <1583143678@qq.com> Date: Thu, 7 Sep 2023 17:37:00 +0800 Subject: [PATCH 06/43] fix param init --- bmtrain/block_layer.py | 29 ++++++++++++++++++++++++++--- bmtrain/pipe/example.py | 20 ++++++++++---------- bmtrain/pipe/schedule.py | 5 +++-- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 2155cdfd..b8ca5246 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -193,8 +193,7 @@ def init_param_storage(self, throw=False): with torch.no_grad(): contiguous_param = OpAllGather.apply(param) if throw: - del contiguous_param - return + continue if not (param_st >= storage_end or param_end <= storage_st): # copy offset in parameter storage offset_st = max(storage_st - param_st, 0) @@ -219,7 +218,8 @@ def init_param_storage(self, throw=False): param.data = torch.tensor([], dtype=param.dtype, device=param.device) # clear parameter data, but keep the dtype and device setattr(param, "_in_block", True) - + if throw: + return for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] @@ -637,6 +637,29 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: m.init_param_storage(throw=True) super().__init__(partition_module, num_hidden, mode=mode) + def forward(self, *args, return_hidden_states = False): + self.return_hidden_states = return_hidden_states + hidden_states = [] + for i in range(1, len(self)): + if return_hidden_states: + for hidden_state in args[:self.num_hidden]: + hidden_states.append(hidden_state) + outputs = self._modules[str(i)]._call_impl(*args) + if not isinstance(outputs, tuple): + outputs = (outputs, ) + args = outputs + args[self.num_hidden:] + + if return_hidden_states: + hidden_states = [ + torch.stack(hidden_states[i::self.num_hidden], dim=0) + for i in range(self.num_hidden) + ] + + if return_hidden_states: + return outputs + tuple(hidden_states) + else: + return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] + def partition(self,modules): pipe_size = config["topology"].pipe_size pipe_rank = config["topology"].pipe_rank diff --git a/bmtrain/pipe/example.py b/bmtrain/pipe/example.py index 05888ef8..a3ef9d93 100644 --- a/bmtrain/pipe/example.py +++ b/bmtrain/pipe/example.py @@ -7,7 +7,6 @@ def generate(iters): torch.manual_seed(42) for i in range(iters): inp = (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) - print(inp[0][0]) yield inp data_loader = iter(generate(100*16)) @@ -20,7 +19,7 @@ def test_pipe(): bmt.init_parameters(models) models = bmt.PipeDreamBlockList(models) start = time.time() - for i in range(10): + for i in range(1): pipeline_forward_backward(models, data_loader, 12*16) if bmt.config['topology'].pipe_rank == 0: print(models['0'].weight.grad) @@ -34,14 +33,15 @@ def test_dp(): models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) bmt.init_parameters(models) models = bmt.TransformerBlockList(models) - loss = 0 - for i in range(16): - loss_tmp = models(*next(data_loader)) - loss_tmp = loss_tmp.mean() - print(loss_tmp.item()) - loss += loss_tmp - print(loss) - loss.backward() + for iter in range(1): + loss = 0 + for i in range(16): + loss_tmp = models(*next(data_loader)) + loss_tmp = loss_tmp.mean() + print(loss_tmp.item()) + loss += loss_tmp + print(loss) + loss.backward() print(models['0'].weight.grad) if __name__ == "__main__": if sys.argv[1] == "dp": diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 04a8c29e..375ac58e 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -62,11 +62,12 @@ def preprocess_func(model, data_iter): while True: try: inp = next(data_iter) + print(type(inp)) except StopIteration: break input_ids = inp[0] embed = model.get_embedding() - yield embed(input_ids), inp[1:] + yield embed(input_ids), *inp[1:] def pipeline_forward_backward(model, data_iterator, global_batch_size, interleaving_size=1): """Forward and backward the pipeline model. @@ -88,7 +89,7 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" config["micros"] = num_micro_batches topo = config["topology"] - logger = get_logger(config['rank'], logging.DEBUG) + logger = get_logger(config['rank'], logging.INFO) config['logger'] = logger logger.info("topo: {}".format(topo)) logger.info("num_micro_batches: {}".format(num_micro_batches)) From 45e79d7230ce0992a0fa1d998035bdf0eb7945c4 Mon Sep 17 00:00:00 2001 From: maydomine <1583143678@qq.com> Date: Mon, 11 Sep 2023 20:59:16 +0800 Subject: [PATCH 07/43] fix multi step validation and trying to fix embedding tied and dataloader --- bmtrain/block_layer.py | 19 +++--- bmtrain/distributed/ops.py | 11 +++- bmtrain/init.py | 7 +++ bmtrain/pipe/__init__.py | 1 + bmtrain/pipe/comm.py | 23 +++++--- bmtrain/pipe/example.py | 37 ++++++------ bmtrain/pipe/schedule.py | 45 +++++--------- example/layers/embedding.py | 16 ----- example/models/__init__.py | 3 +- example/models/pipe_gpt.py | 85 ++++++++++++++++++++++++++ example/pipe.sh | 1 + example/pipe_train.py | 115 ++++++++++++++++++++++++++++++++++++ 12 files changed, 279 insertions(+), 84 deletions(-) create mode 100644 example/models/pipe_gpt.py create mode 100644 example/pipe.sh create mode 100644 example/pipe_train.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index b8ca5246..84a511a6 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -96,7 +96,7 @@ def reference(self, block): self._initialized = True self._need_release = False - def init_param_storage(self, throw=False): + def init_param_storage(self): # sort parameters by name ordered_parameters = list(self._module.named_parameters()) @@ -192,8 +192,6 @@ def init_param_storage(self, throw=False): # make parameter contiguous in storage with torch.no_grad(): contiguous_param = OpAllGather.apply(param) - if throw: - continue if not (param_st >= storage_end or param_end <= storage_st): # copy offset in parameter storage offset_st = max(storage_st - param_st, 0) @@ -218,8 +216,6 @@ def init_param_storage(self, throw=False): param.data = torch.tensor([], dtype=param.dtype, device=param.device) # clear parameter data, but keep the dtype and device setattr(param, "_in_block", True) - if throw: - return for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] @@ -631,16 +627,24 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: partition_module = [] for idx,m in enumerate(modules): if idx>=s and idx 1 else outputs[0] + def partition(self,modules): pipe_size = config["topology"].pipe_size pipe_rank = config["topology"].pipe_rank diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 182d7899..dc175b88 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -24,7 +24,7 @@ def groupcall(): yield groupEnd() -def send_activations_list(hidden_state_list, next_rank, comm, async_op=True): +def send_activations_list(hidden_state_list, next_rank, comm, async_op=False): if async_op: current_stream = torch.cuda.current_stream() with torch.cuda.stream(config["pp_comm_stream"]): @@ -34,7 +34,7 @@ def send_activations_list(hidden_state_list, next_rank, comm, async_op=True): ncclSend(length.storage(), next_rank, comm) for i in range(len(hidden_state_list)): if hidden_state_list[i] is None: - continue + hidden_state_list[i] = torch.tensor([12306],dtype=torch.int,device="cuda") hidden_state_list[i].record_stream(config["pp_comm_stream"]) send_activations(hidden_state_list[i], next_rank, comm) else: @@ -50,7 +50,12 @@ def recv_activations_list(prev_rank, comm): hidden_state_list = [] ncclRecv(length.storage(), prev_rank, comm) for i in range(length[0].item()): - hidden_state_list.append(recv_activations(prev_rank, comm)) + recv = recv_activations(prev_rank, comm) + if len(recv.shape) == 1 and recv.shape[0] == 1 and recv.item() == 12306: + hidden_state_list.append(None) + else: + hidden_state_list.append(recv) + return hidden_state_list diff --git a/bmtrain/init.py b/bmtrain/init.py index 3a4420ad..0e0672ca 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -119,6 +119,13 @@ def init_distributed( store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.pipe_rank) + if topo.pipe_rank == topo.pipe_size - 1 or topo.pipe_rank == 0: + if topo.pipe_rank == 0: + unique_tied_id = nccl.getUniqueId() + store.set(f"PIPE_TIED_UNIQUE_ID{topo.pipe_idx}", unique_tied_id.hex()) + unique_tied_id = bytes.fromhex(store.get(f"PIPE_TIED_UNIQUE_ID{topo.pipe_idx}").decode()) + rank = 0 if topo.pipe_rank == 0 else 1 + config['pipe_tied_comm'] = nccl.commInitRank(unique_tied_id, 2, rank) if topo.tp_id == 0: unique_id = nccl.getUniqueId() diff --git a/bmtrain/pipe/__init__.py b/bmtrain/pipe/__init__.py index e69de29b..f697a75d 100644 --- a/bmtrain/pipe/__init__.py +++ b/bmtrain/pipe/__init__.py @@ -0,0 +1 @@ +from .schedule import pipeline_forward_backward \ No newline at end of file diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 70d524c5..24b8794c 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -1,5 +1,5 @@ import torch -from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations, groupcall +from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations, groupcall,all_reduce from bmtrain.global_var import config from collections.abc import Iterable class PipeCommander: @@ -10,7 +10,12 @@ def __init__(self, topo, input_generator, num_micros, num_warmup, forward_only, self.num_warmup = num_warmup self.forward_only = forward_only self.interleaving_size = interleaving_size - + + def param_reduce(self, module): + for name, param in module.named_parameters(): + p = all_reduce(param, "sum", config["pipe_tied_comm"]) + param.data = p + def get_data(self): assert config["topology"].pipe_rank == 0 micro_batch = next(self.input_generator) @@ -19,22 +24,22 @@ def get_data(self): def send_next(self, tensors): if not self.is_last_stage(): - if not isinstance(tensors, list): + if not isinstance(tensors, Iterable): tensors = [tensors] - send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"]) + send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"], async_op=True) def send_prev(self, tensors): if not self.is_first_stage(): - if not isinstance(tensors, list): + if not isinstance(tensors, Iterable): tensors = [tensors] - send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"]) + send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"], async_op=True) def recv_prev(self, need_data=False): if not self.is_first_stage(): res = recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) - for t in res: - t.requires_grad_() - + for idx,tensor in enumerate(res): + if idx == 0: + tensor.requires_grad_() return res else: if need_data: diff --git a/bmtrain/pipe/example.py b/bmtrain/pipe/example.py index a3ef9d93..2a64e4b0 100644 --- a/bmtrain/pipe/example.py +++ b/bmtrain/pipe/example.py @@ -9,40 +9,43 @@ def generate(iters): inp = (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) yield inp data_loader = iter(generate(100*16)) +iters=10 +dtype=torch.half def test_pipe(): bmt.init_distributed(seed=42, pipe_size=4) - models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] + models = [bmt.nn.PipeEmbedding(1024,128,dtype=dtype)] for i in range(11): - models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) - # print(models[0].weight) + models.append(bmt.nn.Linear(128,128,dtype=dtype)) bmt.init_parameters(models) models = bmt.PipeDreamBlockList(models) + optimizer = bmt.optim.AdamOptimizer(models.parameters(), lr=0.001) start = time.time() - for i in range(1): + for i in range(iters): pipeline_forward_backward(models, data_loader, 12*16) - if bmt.config['topology'].pipe_rank == 0: - print(models['0'].weight.grad) + if bmt.config['topology'].pipe_rank == 0: + print(models['0'].weight.grad) + optimizer.step() t = time.time() - start - print(t) def test_dp(): bmt.init_distributed(seed=42, pipe_size=1) - models = [bmt.nn.PipeEmbedding(1024,128,dtype=torch.float16)] + models = [bmt.nn.PipeEmbedding(1024,128,dtype=dtype)] for i in range(11): - models.append(bmt.nn.Linear(128,128,dtype=torch.float16)) + models.append(bmt.nn.Linear(128,128,dtype=dtype)) bmt.init_parameters(models) models = bmt.TransformerBlockList(models) - for iter in range(1): - loss = 0 + optimizer = bmt.optim.AdamOptimizer(models.parameters(), lr=0.001) + for it in range(iters): for i in range(16): - loss_tmp = models(*next(data_loader)) + inp = next(data_loader) + loss_tmp = models(*inp) loss_tmp = loss_tmp.mean() - print(loss_tmp.item()) - loss += loss_tmp - print(loss) - loss.backward() - print(models['0'].weight.grad) + loss_tmp.backward() + + print(models['0'].weight.grad) + optimizer.step() + if __name__ == "__main__": if sys.argv[1] == "dp": print("dp") diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 375ac58e..e07960b8 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -2,9 +2,9 @@ from bmtrain.global_var import config from bmtrain.loss import FusedCrossEntropy import bmtrain as bmt -from comm import PipeCommander +from .debug import get_logger +from .comm import PipeCommander import torch -from debug import get_logger import logging from typing import Iterable def backward_step(inp, output, grad_output): @@ -42,33 +42,18 @@ def backward_step(inp, output, grad_output): return input_grad def forward_func(model, inp, micro_idx): - if not isinstance(inp, list): - inp = [inp] if config["topology"].pipe_rank == config["topology"].pipe_size - 1: - output = model(*inp) - config['logger'].info("inp shape: {}".format(output.shape)) - loss = output.mean() + loss = model(*inp) config['logger'].info("loss: {}".format(loss.item())) return loss else: hidden_state = model(*inp) - if not isinstance(hidden_state, list): + config['logger'].info("inp shape: {}".format(hidden_state[0].shape)) + if not isinstance(hidden_state, Iterable): hidden_state = [hidden_state] return hidden_state -def preprocess_func(model, data_iter): - - while True: - try: - inp = next(data_iter) - print(type(inp)) - except StopIteration: - break - input_ids = inp[0] - embed = model.get_embedding() - yield embed(input_ids), *inp[1:] - def pipeline_forward_backward(model, data_iterator, global_batch_size, interleaving_size=1): """Forward and backward the pipeline model. @@ -82,14 +67,13 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav """ # forwrad unpack - optimizer = bmt.optim.AdamOptimizer(model.parameters(), lr=0.001) - micro_batch_size = 12 + micro_batch_size = 2 assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" num_micro_batches = global_batch_size // micro_batch_size assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" config["micros"] = num_micro_batches topo = config["topology"] - logger = get_logger(config['rank'], logging.INFO) + logger = get_logger(config['rank'], logging.DEBUG) config['logger'] = logger logger.info("topo: {}".format(topo)) logger.info("num_micro_batches: {}".format(num_micro_batches)) @@ -103,11 +87,15 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav num_warmup = num_micro_batches else: num_warmup = topo.pipe_size - topo.pipe_rank - 1 - - commander = PipeCommander(topo,input_generator=preprocess_func(model, data_iterator), num_micros=num_micro_batches,\ + def generator(data_iterator): + yield model.preprocess_func(next(data_iterator)) + commander = PipeCommander(topo,input_generator=generator(data_iterator), num_micros=num_micro_batches,\ num_warmup=num_warmup, forward_only=False, \ interleaving_size=interleaving_size, \ ) + # if commander.is_first_stage() or commander.is_last_stage(): + # module = model.head_layer() if commander.is_first_stage() else model.tail_layer() + # commander.param_reduce(module) inps = [] outputs = [] logger.info("num_warmup: {}".format(num_warmup)) @@ -162,11 +150,6 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav logger.info("cooling stage") for i in range(num_warmup): logger.info("{} recv micro grad {}th from next neighbour".format(config['rank'], num_micro_batches - num_warmup + i)) - # if i == num_warmup - 1: - # grad sync - # if config.grad_sync_func is None or rank == 0: - # enable_grad_sync() - inp = inps.pop(0) output = outputs.pop(0) @@ -178,6 +161,6 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) commander.send_prev(input_grad) - optimizer.step() + bmt.synchronize() \ No newline at end of file diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 8a3bbd62..b78ca88d 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -4,21 +4,6 @@ import torch.nn.functional as F import bmtrain as bmt import inspect -def router(func): - params_kw = list(inspect.signature(func).parameters.keys()) - def wrapper(self,*args,**kwargs): - assert len(args) == 0, "In pipeline module , you have to pass variable in key=value manner" - sub_kwargs = {} - for key in kwargs: - if key in params_kw: - sub_kwargs[key] = kwargs[key] - next_module = self.next_module() - next_module.set_input() - return func(**sub_kwargs) - if bmt.config["pipe_size"] > 1: - return wrapper - else: - return func class Embedding(bmt.DistributedModule): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, @@ -90,7 +75,6 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, embedding.weight.requires_grad = not freeze return embedding - @router def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: out = F.embedding( diff --git a/example/models/__init__.py b/example/models/__init__.py index e7d1dcc9..a17709b8 100644 --- a/example/models/__init__.py +++ b/example/models/__init__.py @@ -1 +1,2 @@ -from .gpt import GPT \ No newline at end of file +from .gpt import GPT +from .pipe_gpt import GPTPipe \ No newline at end of file diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py new file mode 100644 index 00000000..63f74a8e --- /dev/null +++ b/example/models/pipe_gpt.py @@ -0,0 +1,85 @@ +import torch +import bmtrain as bmt +from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config +class InputWrapper(bmt.DistributedModule): + def __init__(self, module_list): + super().__init__() + + self._module = {} + for i in range(len(module_list)): + self._module[str(i)] = module_list[i] + + def forward(self, *args): + output_list = [] + for idx,i in enumerate(args): + output_list.append(self._module[str(idx)](i)) + return sum(output_list) + + + + +class GPTPipe(bmt.DistributedModule): + def __init__(self, + num_layers : int, vocab_size : int, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + max_distance : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.max_distance = max_distance + + if config['tp_size'] > 1: + word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + else: + word_emb = bmt.nn.PipeEmbedding(vocab_size, dim_model, dtype=dtype) + pos_emb = Embedding(max_distance, dim_model, dtype=dtype) + # self.inp_emb = InputWrapper([word_emb, pos_emb]) + blocklist = [word_emb] + blocklist += [ + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + for _ in range(num_layers)] + layernorm = Layernorm(dim_model, dtype=dtype) + self.transformers = bmt.PipeDreamBlockList( + blocklist, + ) + if config['topology'].pipe_rank == config['topology'].pipe_size - 1: + self.word_emb = word_emb + + if config['tp_size'] > 1: + self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + def forward(self, + input : torch.LongTensor, # (batch, seq_len) + pos : torch.LongTensor, # (batch, seq_len) + mask : torch.BoolTensor, # (batch, seq_len) + target: torch.LongTensor, + ) -> torch.Tensor: + mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) + mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) + + + # for layer in self.transformers: + out = self.transformers(input, mask_2d, None) + if config['topology'].pipe_rank == config['topology'].pipe_size - 1: + if config['tp_size'] > 1: + logits = self.word_emb.projection(out) + else: + logits = self.word_emb(out, projection=True) + + return self.loss_func(logits, target) + else: + return out, pos, mask, target + + def preprocess_func(self, inp): + if config['topology'].pipe_rank == 0: + inp_pos = inp[:1] + return self.transformers['0'](*inp_pos), *inp[1:] + else: + return None + + \ No newline at end of file diff --git a/example/pipe.sh b/example/pipe.sh new file mode 100644 index 00000000..ea07e82a --- /dev/null +++ b/example/pipe.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost pipe_train.py diff --git a/example/pipe_train.py b/example/pipe_train.py new file mode 100644 index 00000000..64974195 --- /dev/null +++ b/example/pipe_train.py @@ -0,0 +1,115 @@ +import torch +import bmtrain as bmt +from models import GPTPipe +import time +from bmtrain import optim +from bmtrain.global_var import config +from bmtrain import inspect +from bmtrain.pipe import pipeline_forward_backward + +def main(): + bmt.init_distributed( + seed=0, + tp_size=1, + pipe_size=4, + ) + + model = GPTPipe( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half + ) + + bmt.init_parameters(model) + + bmt.print_rank("Model memory") + bmt.print_rank(torch.cuda.memory_summary()) + bmt.synchronize() + + # data + # generate dummy data for each rank + torch.manual_seed(1234) + + batch_size = 2 * 4 + seq_len = 512 + def data_loader(): + for i in range(1000): + micro = 2 + sent = torch.randint(0, 10240, (micro, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (micro,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + yield enc_input, pos, pos 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) + + optim_manager = optim.OptimManager(loss_scale=2**20) + optim_manager.add_optimizer(optimizer, lr_scheduler) + + bmt.synchronize() + + avg_time_recorder = bmt.utils.AverageRecorder() + avg_loss_recorder = bmt.utils.AverageRecorder() + bmt.init_parameters(model) + for iteration in range(1000): + # load data + st = time.time() + global_loss = pipeline_forward_backward(model, data_loader(), batch_size) + + # print inspected tensors in the forward & backward pass + # print parameters of the model + # if iteration % 100 == 0: + # bmt.print_rank( + # inspect.format_summary( + # inspector.get_summary() + # ) + # ) + # bmt.print_rank( + # inspect.format_summary( + # inspect.inspect_model(model, "*") + # ) + # ) + + optim_manager.step() + + # record time and loss + iteration_time = time.time() - st + + avg_time_recorder.record(iteration_time) + avg_loss_recorder.record(global_loss) + + # print time and loss + bmt.print_rank( + "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + iteration, + global_loss, + avg_loss_recorder.value, + lr_scheduler.current_lr, + optim_manager.loss_scale, + avg_time_recorder.value + ) + ) + + + +if __name__ == '__main__': + main() From 9c05206ccf5884e96dbf3c88f7da33b5d81301cd Mon Sep 17 00:00:00 2001 From: maydomine <1583143678@qq.com> Date: Thu, 14 Sep 2023 10:54:54 +0800 Subject: [PATCH 08/43] 1f1b example --- bmtrain/__init__.py | 2 +- bmtrain/block_layer.py | 42 ++++++++++++++++++++++++++------ bmtrain/distributed/ops.py | 35 ++++++++++++++++---------- bmtrain/init.py | 28 ++++++++++++++++++++- bmtrain/pipe/comm.py | 18 ++++++++++---- bmtrain/pipe/schedule.py | 50 ++++++++++++++++++++++++++++++++------ bmtrain/utils.py | 6 ++++- example/models/pipe_gpt.py | 25 +++++++++++++------ example/pipe_train.py | 32 ++++++++++++------------ 9 files changed, 179 insertions(+), 59 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 05476ec8..459bfcc6 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -1,4 +1,4 @@ -from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi +from .utils import print_block, print_dict, print_rank, print_rank_pp, see_memory, load_nccl_pypi try: from . import nccl except: diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 84a511a6..298c7a14 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -581,7 +581,7 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, mode="BLOCK") -> None self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True - + self.module_dict = module_dict self.num_hidden = num_hidden def __len__(self) -> int: @@ -624,10 +624,11 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: for idx in range(len(modules)): modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2) s,e = self.partition(modules) + self.head_idx = 0 + self.tail_idx = e-s-1 partition_module = [] for idx,m in enumerate(modules): if idx>=s and idx None: def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states hidden_states = [] - for i in range(len(self)): + for i in range(self.head_idx, self.tail_idx+1): if i == 0 and self.no_emb_forward: continue if return_hidden_states: @@ -665,7 +666,7 @@ def forward(self, *args, return_hidden_states = False): return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] - def partition(self,modules): + def partition(self, modules): pipe_size = config["topology"].pipe_size pipe_rank = config["topology"].pipe_rank part_lens = [0]+[len(modules) // pipe_size + (i < (len(modules) % pipe_size)) for i in range(pipe_rank+1)] @@ -673,8 +674,33 @@ def partition(self,modules): end = start + part_lens[pipe_rank+1] return start,end - def get_embedding(self): - assert config["topology"].pipe_rank == 0 - return self._modules[str(0)] - + def add_head(self, module): + module = _block_wrapper(module, self.module_dict, mode="1F1B") + module.init_param_storage() + if config['topology'].pipe_rank != 0: + return + lens = len(self) + self.head_idx += 1 + self.tail_idx += 1 + for i in range(len(self), 0, -1): + self._modules[str(i)] = self._modules[str(i-1)] + self.add_module(str(i), self._modules[str(i-1)]) + self._modules['0'] = module + self.add_module('0', module) + self._modules['1'].set_pre_module(module) + self._modules['0'].set_pre_module(None) + self._modules['1']._is_first_layer = False + self._modules['0']._is_first_layer = True + + def add_tail(self, module): + module = _block_wrapper(module, self.module_dict, mode="1F1B") + module.init_param_storage() + if config['topology'].pipe_rank != config['topology'].pipe_size - 1: + return + lens = len(self) + self._modules[str(lens)] = module + self.add_module(str(lens), module) + self._modules[str(lens)].set_pre_module(self._modules[str(lens-1)]) + self._modules[str(lens-1)]._is_last_layer = False + self._modules[str(lens)]._is_last_layer = True \ No newline at end of file diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index dc175b88..3f2eca00 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -24,19 +24,27 @@ def groupcall(): yield groupEnd() +class handler: + def __init__(self, stream): + self.stream = stream + + def wait(self): + torch.cuda.current_stream().wait_stream(self.stream) + def send_activations_list(hidden_state_list, next_rank, comm, async_op=False): if async_op: current_stream = torch.cuda.current_stream() with torch.cuda.stream(config["pp_comm_stream"]): config["pp_comm_stream"].wait_stream(current_stream) length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - length[0] = len([h for h in hidden_state_list if h is not None]) + length[0] = len([h for h in hidden_state_list ]) ncclSend(length.storage(), next_rank, comm) for i in range(len(hidden_state_list)): if hidden_state_list[i] is None: hidden_state_list[i] = torch.tensor([12306],dtype=torch.int,device="cuda") hidden_state_list[i].record_stream(config["pp_comm_stream"]) send_activations(hidden_state_list[i], next_rank, comm) + return handler(config["pp_comm_stream"]) else: length = torch.tensor(data=[0], device="cuda", dtype=torch.int) length[0] = len(hidden_state_list) @@ -45,18 +53,19 @@ def send_activations_list(hidden_state_list, next_rank, comm, async_op=False): send_activations(hidden_state_list[i], next_rank, comm) -def recv_activations_list(prev_rank, comm): - length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - hidden_state_list = [] - ncclRecv(length.storage(), prev_rank, comm) - for i in range(length[0].item()): - recv = recv_activations(prev_rank, comm) - if len(recv.shape) == 1 and recv.shape[0] == 1 and recv.item() == 12306: - hidden_state_list.append(None) - else: - hidden_state_list.append(recv) - - return hidden_state_list +def recv_activations_list(prev_rank, comm, async_op = True): + if async_op: + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + hidden_state_list = [] + ncclRecv(length.storage(), prev_rank, comm) + for i in range(length[0].item()): + recv = recv_activations(prev_rank, comm) + if len(recv.shape) == 1 and recv.shape[0] == 1 and recv.item() == 12306: + hidden_state_list.append(None) + else: + hidden_state_list.append(recv) + + return hidden_state_list def send_activations(hidden_state, next_rank, comm): diff --git a/bmtrain/init.py b/bmtrain/init.py index 0e0672ca..c1af8342 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -3,6 +3,7 @@ import random import torch.distributed as dist import os +import logging from .utils import print_dict import ctypes from .global_var import config @@ -16,6 +17,7 @@ def init_distributed( pipe_size: int = -1, num_micro_batches: int = None, tp_size : int = 1, + debug_level="DEBUG", ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -82,6 +84,7 @@ def init_distributed( config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") + config["logger"] = get_logger(rank, debug_level) cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) @@ -179,7 +182,8 @@ def __init__(self,config): config['tp_zero_size'] = dp_size config['zero_size'] = world_size // pp_size self.pipe_size = config['pipe_size'] - + self.dp_size = dp_size + self.tp_size = tp_size stage_size = world_size // pp_size for i in range(world_size): self.pipe_idx = self.rank % stage_size @@ -219,7 +223,29 @@ def get_group_rank(self,group_name): return self.tp_zero_id elif group_name == "tp": return self.tp_id + + def is_last_rank(self, group_name): + if group_name == "pipe": + return self.pipe_rank == self.pipe_size - 1 + elif group_name == "zero": + return self.zero_id == self.dp_size - 1 + elif group_name == "tp": + return self.tp_id == self.tp_size - 1 def is_initialized() -> bool: return config["initialized"] +def get_logger(rank, level): + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger = logging.getLogger('pipeline') + logger.setLevel(level) + if rank == 0: + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") + fh.setLevel(level) + fh.setFormatter(formatter) + logger.addHandler(fh) + return logger \ No newline at end of file diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 24b8794c..41eb3aa6 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -19,20 +19,27 @@ def param_reduce(self, module): def get_data(self): assert config["topology"].pipe_rank == 0 micro_batch = next(self.input_generator) + config['logger'].debug("Input id ",micro_batch) assert isinstance(micro_batch, Iterable) return list(micro_batch) def send_next(self, tensors): + handle = [] if not self.is_last_stage(): if not isinstance(tensors, Iterable): tensors = [tensors] - send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"], async_op=True) - + handle.append(send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"], async_op=True)) + for h in handle: + h.wait() + def send_prev(self, tensors): + handle = [] if not self.is_first_stage(): if not isinstance(tensors, Iterable): tensors = [tensors] - send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"], async_op=True) + handle.append(send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"], async_op=True)) + for h in handle: + h.wait() def recv_prev(self, need_data=False): if not self.is_first_stage(): @@ -61,7 +68,8 @@ def is_first_stage(self): def is_last_stage(self): return self.topo.pipe_rank == self.topo.pipe_size - 1 - + def is_even_rank(self): + return self.topo.pipe_rank % 2 == 0 def send_forward_recv_backward(self, forward_state): if not self.is_last_stage(): if forward_state[0] is not None: @@ -73,9 +81,9 @@ def send_forward_recv_backward(self, forward_state): def send_backward_recv_forward(self, backward_grad, need_data=False): if not self.is_first_stage(): + forward_state = self.recv_prev() if backward_grad[0] is not None: self.send_prev(backward_grad) - forward_state = self.recv_prev() else: if need_data: forward_state = self.get_data() diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index e07960b8..8d7350b7 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -7,6 +7,21 @@ import torch import logging from typing import Iterable +def obj_str(objs): + string = "" + for o in objs: + if isinstance(o, torch.Tensor): + string += repr(o.shape) + elif o is None: + string += "None" + else: + string += repr(o) + + string += " ," + return string.rstrip(",") + + + def backward_step(inp, output, grad_output): """Backward step through passed-in output tensor. @@ -21,13 +36,15 @@ def backward_step(inp, output, grad_output): for x in inp: if x is not None and x.requires_grad: x.retain_grad() - if not isinstance(output, list): + if not isinstance(output, Iterable): output = [output] - if not isinstance(grad_output, list): + if not isinstance(grad_output, Iterable): grad_output = [grad_output] #TODO scale the grad # if output_tensor_grad[0] is None and config.grad_scale_func is not None: # output_tensor[0] = config.grad_scale_func(output_tensor[0]) + config['logger'].debug(obj_str(grad_output)) + # config['logger'].debug(obj_str(output)) torch.autograd.backward(output[0], grad_tensors=grad_output[0]) input_grad = [None] @@ -44,16 +61,18 @@ def backward_step(inp, output, grad_output): def forward_func(model, inp, micro_idx): if config["topology"].pipe_rank == config["topology"].pipe_size - 1: loss = model(*inp) + global global_loss + global_loss += loss config['logger'].info("loss: {}".format(loss.item())) - return loss + return [loss] else: hidden_state = model(*inp) config['logger'].info("inp shape: {}".format(hidden_state[0].shape)) if not isinstance(hidden_state, Iterable): hidden_state = [hidden_state] return hidden_state - +global_loss = 0 def pipeline_forward_backward(model, data_iterator, global_batch_size, interleaving_size=1): """Forward and backward the pipeline model. @@ -67,14 +86,20 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav """ # forwrad unpack + global global_loss + global_loss = 0 micro_batch_size = 2 assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" num_micro_batches = global_batch_size // micro_batch_size assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" config["micros"] = num_micro_batches topo = config["topology"] - logger = get_logger(config['rank'], logging.DEBUG) - config['logger'] = logger + + logger = config['logger'] + + logger.info("model arch {}".format(model)) + logger.info("model partition s: {} , e: {} ".format(model.transformers.head_idx,model.transformers.tail_idx)) + logger.info("model length:{}".format(str(len(model.transformers)))) logger.info("topo: {}".format(topo)) logger.info("num_micro_batches: {}".format(num_micro_batches)) logger.info("micro_batch_size: {}".format(micro_batch_size)) @@ -88,7 +113,12 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav else: num_warmup = topo.pipe_size - topo.pipe_rank - 1 def generator(data_iterator): - yield model.preprocess_func(next(data_iterator)) + while True: + try: + yield model.preprocess_func(next(data_iterator)) + except StopIteration: + break + commander = PipeCommander(topo,input_generator=generator(data_iterator), num_micros=num_micro_batches,\ num_warmup=num_warmup, forward_only=False, \ interleaving_size=interleaving_size, \ @@ -117,10 +147,13 @@ def generator(data_iterator): for micro in range(num_micro_batches - num_warmup): output = forward_func(model, inp, micro + num_warmup) + logger.debug("output :{}".format(obj_str(output))) logger.info("{} micro forward".format(micro+num_warmup)) + logger.debug("send forward and recv backward") grad_output = commander.send_forward_recv_backward(output) inps.append(inp) outputs.append(output) + logger.debug("grad output :{}".format(obj_str(grad_output))) logger.info("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) logger.debug("inp shape: {}".format(inp[0].shape)) if not commander.is_last_stage(): @@ -138,8 +171,10 @@ def generator(data_iterator): commander.send_prev(inp_grad) logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], micro + num_warmup)) else: + logger.debug("grad inp :{}".format(obj_str(inp_grad))) if inp_grad[0] is not None: logger.debug("inp_grad shape: {}".format(inp_grad[0].shape)) + logger.debug("send backward and recv forward") inp = commander.send_backward_recv_forward(inp_grad, need_data=True) logger.debug("inp type: {}".format(type(inp))) logger.debug("inp shape: {}".format(inp[0].shape)) @@ -162,5 +197,6 @@ def generator(data_iterator): commander.send_prev(input_grad) bmt.synchronize() + return global_loss \ No newline at end of file diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8cb87808..bca462e2 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -66,7 +66,11 @@ def print_block(title : str, content : Optional[str] = None, file=sys.stdout): print("=" * left_title + " " + title + " " + "=" * right_title, file=file) if content is not None: print(content, file=file) - + +def print_rank_pp(*args, pipe_rank=0, **kwargs): + if config['topology'].pipe_rank == pipe_rank: + print(*args, **kwargs) + def print_rank(*args, rank=0, **kwargs): """ Prints the message only on the `rank` of the process. diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index 63f74a8e..148818ca 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -36,7 +36,7 @@ def __init__(self, word_emb = bmt.nn.PipeEmbedding(vocab_size, dim_model, dtype=dtype) pos_emb = Embedding(max_distance, dim_model, dtype=dtype) # self.inp_emb = InputWrapper([word_emb, pos_emb]) - blocklist = [word_emb] + blocklist = [] blocklist += [ TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype @@ -46,9 +46,15 @@ def __init__(self, self.transformers = bmt.PipeDreamBlockList( blocklist, ) - if config['topology'].pipe_rank == config['topology'].pipe_size - 1: - self.word_emb = word_emb + self.transformers.add_tail(layernorm) + self.transformers.add_tail(word_emb) + self.transformers.add_head(word_emb) + self.transformers.add_head(pos_emb) + if config['topology'].pipe_rank == config['topology'].pipe_size - 1: + self.word_emb = self.transformers[str(len(self.transformers) - 1)] + print(self.word_emb._module.__class__.__name__) + print(self.transformers.head_idx, self.transformers.tail_idx, config['topology'].pipe_rank) if config['tp_size'] > 1: self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: @@ -69,16 +75,19 @@ def forward(self, if config['tp_size'] > 1: logits = self.word_emb.projection(out) else: - logits = self.word_emb(out, projection=True) - + logits = self.word_emb(out, True) + logits = logits.float().view(-1, logits.shape[-1]) + target = target.view(-1) return self.loss_func(logits, target) else: return out, pos, mask, target - def preprocess_func(self, inp): if config['topology'].pipe_rank == 0: - inp_pos = inp[:1] - return self.transformers['0'](*inp_pos), *inp[1:] + inp_id = inp[0] + pos = inp[1] + output =torch.randn((2,512,2560),dtype=torch.float16,device="cuda") + # return self.transformers['0'](inp_id)+self.transformers['1'](pos), *inp[1:] + return output, *inp[1:] else: return None diff --git a/example/pipe_train.py b/example/pipe_train.py index 64974195..5046844d 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -70,7 +70,7 @@ def data_loader(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() bmt.init_parameters(model) - for iteration in range(1000): + for iteration in range(1): # load data st = time.time() global_loss = pipeline_forward_backward(model, data_loader(), batch_size) @@ -93,21 +93,23 @@ def data_loader(): # record time and loss iteration_time = time.time() - st - - avg_time_recorder.record(iteration_time) - avg_loss_recorder.record(global_loss) - + # avg_time_recorder.record(iteration_time) + # avg_loss_recorder.record(global_loss) + if global_loss is not None: + print(global_loss) + # print("hello") # print time and loss - bmt.print_rank( - "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( - iteration, - global_loss, - avg_loss_recorder.value, - lr_scheduler.current_lr, - optim_manager.loss_scale, - avg_time_recorder.value - ) - ) + # if config['topology'].pipe_rank == config['topology'].pipe_size - 1: + # bmt.print_rank_pp( + # "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + # iteration, + # global_loss, + # avg_loss_recorder.value, + # lr_scheduler.current_lr, + # optim_manager.loss_scale, + # avg_time_recorder.value + # ), pipe_rank=config['pipe_size'] - 1 + # ) From 9590493a7407cfa1106e507f370cf5eb358e6bab Mon Sep 17 00:00:00 2001 From: maydomine <1583143678@qq.com> Date: Thu, 14 Sep 2023 13:25:13 +0800 Subject: [PATCH 09/43] add debug logger in bmt.init --- bmtrain/init.py | 5 +++-- bmtrain/pipe/comm.py | 1 - bmtrain/pipe/schedule.py | 4 +++- example/layers/embedding.py | 10 ++++------ example/pipe_train.py | 1 + 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index c1af8342..131c95ed 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -17,7 +17,7 @@ def init_distributed( pipe_size: int = -1, num_micro_batches: int = None, tp_size : int = 1, - debug_level="DEBUG", + debug=False, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -84,7 +84,8 @@ def init_distributed( config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") - config["logger"] = get_logger(rank, debug_level) + if debug: + config["logger"] = get_logger(rank, "DEBUG") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 41eb3aa6..208e66c0 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -19,7 +19,6 @@ def param_reduce(self, module): def get_data(self): assert config["topology"].pipe_rank == 0 micro_batch = next(self.input_generator) - config['logger'].debug("Input id ",micro_batch) assert isinstance(micro_batch, Iterable) return list(micro_batch) diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 8d7350b7..f66289e0 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -115,7 +115,9 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav def generator(data_iterator): while True: try: - yield model.preprocess_func(next(data_iterator)) + inp = next(data_iterator) + logger.debug("Input id {}".format(inp)) + yield model.preprocess_func(inp) except StopIteration: break diff --git a/example/layers/embedding.py b/example/layers/embedding.py index b78ca88d..13c47384 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F import bmtrain as bmt -import inspect + class Embedding(bmt.DistributedModule): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, @@ -77,13 +77,11 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - out = F.embedding( + return F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) - return out else: - out = F.linear(input, self.weight) - return out + return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -99,4 +97,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - + \ No newline at end of file diff --git a/example/pipe_train.py b/example/pipe_train.py index 5046844d..71f000d3 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -12,6 +12,7 @@ def main(): seed=0, tp_size=1, pipe_size=4, + debug=True ) model = GPTPipe( From 8dec50de7ce7ad60ecb3b125467cc63fcbd646f4 Mon Sep 17 00:00:00 2001 From: maydomine <1583143678@qq.com> Date: Mon, 18 Sep 2023 10:19:57 +0800 Subject: [PATCH 10/43] 1f1b inspect and tied embedding --- bmtrain/block_layer.py | 83 +++++++++++++++++++++++-------- bmtrain/distributed/ops.py | 9 ++++ bmtrain/init.py | 3 +- bmtrain/pipe/schedule.py | 18 +++---- example/convert_ckpt_pipe.py | 95 ++++++++++++++++++++++++++++++++++++ example/inspect.py | 32 ++++++++++++ example/models/pipe_gpt.py | 33 +++++++------ example/pipe_train.py | 48 +++++------------- example/train.py | 14 +++--- 9 files changed, 243 insertions(+), 92 deletions(-) create mode 100644 example/convert_ckpt_pipe.py create mode 100644 example/inspect.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 298c7a14..1e2e09f0 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -9,6 +9,7 @@ from . import hook_func import inspect from torch.utils.checkpoint import checkpoint +from .distributed.ops import send_activations_inplace, recv_activations_inplace def storage_type_cuda(storage_type): STORAGE_MAP = { @@ -578,7 +579,6 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, mode="BLOCK") -> None module._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) - self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True self.module_dict = module_dict @@ -626,26 +626,22 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: s,e = self.partition(modules) self.head_idx = 0 self.tail_idx = e-s-1 - partition_module = [] + self.tied_modules = [] + partition_modules = [] for idx,m in enumerate(modules): if idx>=s and idx= start and layer_idx < end): + if rank == 0: + if key == "word_emb.weight": + return "transformers.0.weight" + elif key == "pos_emb.weight": + return "transformers.1.weight" + if key.startswith("layernorm"): + return None + else: + if layer_idx is not None: + return re.sub(r"\d+", str(layer_idx + 2), key) + elif rank == pipe_size - 1: + if key == "word_emb.weight": + return "transformers.3.weight" + if key.startswith("layernorm"): + postfix = key.split(".")[-1] + return f"transformers.2.{postfix}" + elif key == "pos_emb.weight": + return None + else: + print(key) + if layer_idx is not None: + return re.sub(r"\d+", str(layer_idx - start), key) + else: + if layer_idx is not None: + return re.sub(r"\d+", str(layer_idx - start), key) + else: + return None + + + +def init_model(): + model = GPT( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half + ) + return model + +def get_len_modules(state): + max_len = 0 + for key in state: + s = re.search("\.(\d+)\.", key) + if s is not None: + res = int(s.group(1)) + if res>max_len: + max_len = res + return max_len+1 + + +if __name__ == "__main__": + bmt.init_distributed() + model = init_model() + bmt.load(model, "ckpt-0.pt") + pipe_size = 4 + state = model.state_dict() + + for rank in range(pipe_size): + print(rank) + dic = OrderedDict() + len_modules = get_len_modules(state) + s,e = partition(rank, pipe_size, len_modules) + print(s," ",e) + for i in state.keys(): + k = key_process(i, pipe_size, rank, s, e) + if k is not None: + dic[k] = state[i] + print(dic.keys()) + torch.save(dic, f"pipe_{rank}.ckpt") + + \ No newline at end of file diff --git a/example/inspect.py b/example/inspect.py new file mode 100644 index 00000000..b4b17f5d --- /dev/null +++ b/example/inspect.py @@ -0,0 +1,32 @@ +from contextlib import contextmanager +from bmtrain import CheckpointBlock +import sys + +@contextmanager +def custom_redirection(fileobj): + old = sys.stdout + sys.stdout = fileobj + try: + yield fileobj + finally: + sys.stdout = old + +def look_var(layer, _, output): + try: + print(f"{layer.__name__}: {output.min()}") + except: + print(f"{layer.__name__}:{output[0].min()}") + +def lookup_output(model,layers=set()): + + for key,layer in model.named_modules(): + layer.__name__ = key + if layer not in layers: + layers.add(layer) + else: + continue + if len(layer._modules) !=0: + layer.register_forward_hook(look_var) + lookup_output(layer,layers) + else: + layer.register_forward_hook(look_var) diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index 148818ca..33dcfedf 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -16,9 +16,6 @@ def forward(self, *args): output_list.append(self._module[str(idx)](i)) return sum(output_list) - - - class GPTPipe(bmt.DistributedModule): def __init__(self, num_layers : int, vocab_size : int, @@ -33,7 +30,7 @@ def __init__(self, if config['tp_size'] > 1: word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) else: - word_emb = bmt.nn.PipeEmbedding(vocab_size, dim_model, dtype=dtype) + word_emb = Embedding(vocab_size, dim_model, dtype=dtype) pos_emb = Embedding(max_distance, dim_model, dtype=dtype) # self.inp_emb = InputWrapper([word_emb, pos_emb]) blocklist = [] @@ -46,19 +43,20 @@ def __init__(self, self.transformers = bmt.PipeDreamBlockList( blocklist, ) - self.transformers.add_tail(layernorm) - self.transformers.add_tail(word_emb) - self.transformers.add_head(word_emb) self.transformers.add_head(pos_emb) + self.transformers.add_tail(layernorm) + self.transformers.add_head_tail(word_emb) - if config['topology'].pipe_rank == config['topology'].pipe_size - 1: - self.word_emb = self.transformers[str(len(self.transformers) - 1)] - print(self.word_emb._module.__class__.__name__) - print(self.transformers.head_idx, self.transformers.tail_idx, config['topology'].pipe_rank) + if config['topology'].pipe_rank == config['topology'].pipe_size - 1 : + self.word_emb = self.transformers.get_last_layer + if config['topology'].pipe_rank == 0: + self.word_emb = self.transformers.get_first_layer + if config['tp_size'] > 1: self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + def forward(self, input : torch.LongTensor, # (batch, seq_len) pos : torch.LongTensor, # (batch, seq_len) @@ -73,21 +71,24 @@ def forward(self, out = self.transformers(input, mask_2d, None) if config['topology'].pipe_rank == config['topology'].pipe_size - 1: if config['tp_size'] > 1: - logits = self.word_emb.projection(out) + logits = self.word_emb().projection(out) else: - logits = self.word_emb(out, True) + logits = self.word_emb()(out, True) logits = logits.float().view(-1, logits.shape[-1]) target = target.view(-1) + config["logger"].debug("logits:{}".format(logits)) return self.loss_func(logits, target) else: return out, pos, mask, target + def preprocess_func(self, inp): if config['topology'].pipe_rank == 0: inp_id = inp[0] pos = inp[1] - output =torch.randn((2,512,2560),dtype=torch.float16,device="cuda") - # return self.transformers['0'](inp_id)+self.transformers['1'](pos), *inp[1:] - return output, *inp[1:] + # output =torch.randn((2,512,2560),dtype=torch.float16,device="cuda") + config['logger'].debug("preprocess emb type{}".format(self.transformers['0']._module.__class__.__name__)) + return self.transformers['0'](inp_id)+self.transformers['1'](pos), *inp[1:] + # return output, *inp[1:] else: return None diff --git a/example/pipe_train.py b/example/pipe_train.py index 71f000d3..85f8f369 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -6,6 +6,7 @@ from bmtrain.global_var import config from bmtrain import inspect from bmtrain.pipe import pipeline_forward_backward +from inspect import custom_redirection, lookup_output def main(): bmt.init_distributed( @@ -27,7 +28,6 @@ def main(): dtype=torch.half ) - bmt.init_parameters(model) bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) @@ -65,53 +65,27 @@ def data_loader(): optim_manager = optim.OptimManager(loss_scale=2**20) optim_manager.add_optimizer(optimizer, lr_scheduler) - + pipe_rank = bmt.config["topology"].pipe_rank + model.load_state_dict(torch.load(f"pipe_{pipe_rank}.ckpt")) bmt.synchronize() - avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - bmt.init_parameters(model) - for iteration in range(1): + model.transformers.sync_tied_module() + for iteration in range(10): # load data st = time.time() - global_loss = pipeline_forward_backward(model, data_loader(), batch_size) + rank = bmt.config["topology"].pipe_rank + with custom_redirection(open(f"pp_output_{rank}","w")): + lookup_output(model) + global_loss = pipeline_forward_backward(model, data_loader(), batch_size) - # print inspected tensors in the forward & backward pass - # print parameters of the model - # if iteration % 100 == 0: - # bmt.print_rank( - # inspect.format_summary( - # inspector.get_summary() - # ) - # ) - # bmt.print_rank( - # inspect.format_summary( - # inspect.inspect_model(model, "*") - # ) - # ) + optim_manager.step() # record time and loss iteration_time = time.time() - st - # avg_time_recorder.record(iteration_time) - # avg_loss_recorder.record(global_loss) - if global_loss is not None: - print(global_loss) - # print("hello") - # print time and loss - # if config['topology'].pipe_rank == config['topology'].pipe_size - 1: - # bmt.print_rank_pp( - # "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( - # iteration, - # global_loss, - # avg_loss_recorder.value, - # lr_scheduler.current_lr, - # optim_manager.loss_scale, - # avg_time_recorder.value - # ), pipe_rank=config['pipe_size'] - 1 - # ) - + if __name__ == '__main__': diff --git a/example/train.py b/example/train.py index 2386a31a..76782d38 100644 --- a/example/train.py +++ b/example/train.py @@ -10,7 +10,6 @@ def main(): bmt.init_distributed( seed=0, tp_size=1, - pipe_size=4, ) model = GPT( @@ -25,8 +24,7 @@ def main(): dtype=torch.half ) - bmt.init_parameters(model) - + bmt.load(model, "./ckpt-0.pt") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() @@ -52,7 +50,7 @@ def main(): if i == bmt.rank(): break - + print(enc_input) if config['tp_size'] > 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: @@ -68,7 +66,6 @@ def main(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - for iteration in range(1000): # load data st = time.time() @@ -80,6 +77,7 @@ def main(): pos, pos < enc_length[:, None] ) + print(logits) batch, seq_len, vocab_out_size = logits.size() if config['tp_size'] > 1: @@ -98,7 +96,7 @@ def main(): if iteration % 100 == 0: bmt.print_rank( inspect.format_summary( - inspector.get_summary() + inspector.get_summary( ) ) bmt.print_rank( @@ -106,8 +104,8 @@ def main(): inspect.inspect_model(model, "*") ) ) - - optim_manager.step() + if (iteration + 1) % 4 == 0: + optim_manager.step() # record time and loss iteration_time = time.time() - st From 3384c5f99f657f3aa463c3f865e54831775899fa Mon Sep 17 00:00:00 2001 From: maydomine <1583143678@qq.com> Date: Wed, 20 Sep 2023 16:50:14 +0800 Subject: [PATCH 11/43] 1f1b stable version --- bmtrain/block_layer.py | 85 ++++++++++++++--------------- bmtrain/init.py | 23 +++++--- bmtrain/optim/optim_manager.py | 2 +- bmtrain/pipe/debug.py | 17 ------ bmtrain/pipe/example.py | 54 ------------------- bmtrain/pipe/run.sh | 6 --- bmtrain/pipe/salloc.sh | 1 - bmtrain/pipe/schedule.py | 85 ++++++++++------------------- bmtrain/pipe/test_send_recv.py | 24 --------- bmtrain/pipe/topo.py | 85 ----------------------------- example/convert_ckpt_pipe.py | 95 --------------------------------- example/inspect.py | 32 ----------- example/inspect_tools.py | 54 +++++++++++++++++++ example/models/gpt.py | 5 +- example/models/pipe_gpt.py | 23 +++----- example/pipe_train.py | 42 +++++++++++---- example/train.py | 97 +++++++++++++++++++--------------- 17 files changed, 235 insertions(+), 495 deletions(-) delete mode 100644 bmtrain/pipe/debug.py delete mode 100644 bmtrain/pipe/example.py delete mode 100644 bmtrain/pipe/run.sh delete mode 100644 bmtrain/pipe/salloc.sh delete mode 100644 bmtrain/pipe/test_send_recv.py delete mode 100644 bmtrain/pipe/topo.py delete mode 100644 example/convert_ckpt_pipe.py delete mode 100644 example/inspect.py create mode 100644 example/inspect_tools.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 1e2e09f0..53ffe608 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -616,17 +616,22 @@ def forward(self, *args, return_hidden_states = False): else: return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] +def DummyForward(*args, **kwargs): + """ + Only useful for embedding and layernorm layer + """ + return args[0] + class PipeDreamBlockList(TransformerBlockList): def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: module_dict = {} mode = "1F1B" for idx in range(len(modules)): - modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2) + modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2, use_checkpoint=False) s,e = self.partition(modules) - self.head_idx = 0 - self.tail_idx = e-s-1 - self.tied_modules = [] + self.head_idx = s + self.tail_idx = e partition_modules = [] for idx,m in enumerate(modules): if idx>=s and idx None: else: m.init_param_storage() del m - super().__init__(partition_modules, num_hidden, mode=mode) + self.fisrt_module = (self._modules['0'],) + self.last_module = (self._modules[str(len(self._modules) - 1)],) + self.tied_modules = [] def forward(self, *args, return_hidden_states = False): + self.return_hidden_states = return_hidden_states hidden_states = [] - for i in range(self.head_idx, self.tail_idx+1): + for i in range(len(self)): if return_hidden_states: for hidden_state in args[:self.num_hidden]: hidden_states.append(hidden_state) @@ -671,37 +679,30 @@ def partition(self, modules): return start,end def _add_head(self, module): - self.head_idx += 1 - self.tail_idx += 1 - for i in range(len(self), 0, -1): - self._modules[str(i)] = self._modules[str(i-1)] - self.add_module(str(i), self._modules[str(i-1)]) - self._modules['0'] = module - self.add_module('0', module) - self._modules['1'].set_pre_module(module) - self._modules['0'].set_pre_module(None) - self._modules['1']._is_first_layer = False - self._modules['0']._is_first_layer = True + self.fisrt_module[0]._is_first_layer = False + module._is_first_layer = True + self.fisrt_module[0].set_pre_module(module) + self.fisrt_module = (module,) def add_head(self, module): module = _block_wrapper(module, self.module_dict, mode="1F1B") module.init_param_storage() if config['topology'].pipe_rank != 0: - return + return DummyForward self._add_head(module) + return module def get_first_layer(self): return self._modules['0'] def get_last_layer(self): return self._modules[str(len(self)-1)] - def add_head_tail(self, module): module = _block_wrapper(module, self.module_dict, mode="1F1B") module.init_param_storage() if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): - return + return DummyForward else: if config['topology'].pipe_rank == 0: module._tied = "head" @@ -710,42 +711,38 @@ def add_head_tail(self, module): module._tied = "tail" self._add_tail(module) self.tied_modules.append(module) + return module - def sync_tied_module(self): - if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): - return - else: - for tied_m in self.tied_modules: - for name, param in tied_m.named_parameters(): - if config['topology'].pipe_rank == 0: - send_activations_inplace(param, 1, config["pipe_tied_comm"]) - elif config['topology'].is_last_rank(): - recv_activations_inplace(param, 0, config["pipe_tied_comm"]) - + def reduce_tied_module(self): if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): return else: for tied_m in self.tied_modules: for name, param in tied_m.named_parameters(): - if config['topology'].pipe_rank == 0: - send_activations_inplace(param.gard, 1, config["pipe_tied_comm"]) - elif config['topology'].is_last_rank(): + if config['topology'].pipe_rank == 0 and param.grad is not None: + with torch.no_grad(): + grad = torch.empty_like(param) + param.grad += recv_activations_inplace(grad, 1, config["pipe_tied_comm"]) + send_activations_inplace(param.grad, 1, config["pipe_tied_comm"]) + elif config['topology'].pipe_rank == 0 and param.grad is None: grad = torch.empty_like(param) - param.grad += recv_activations_inplace(grad, 0, config["pipe_tied_comm"]) + param.grad = recv_activations_inplace(grad, 1, config["pipe_tied_comm"]) + elif config['topology'].is_last_rank() and param.grad is not None: + send_activations_inplace(param.grad, 0, config["pipe_tied_comm"]) + param.grad = recv_activations_inplace(param.grad, 0, config["pipe_tied_comm"]) def _add_tail(self, module): - lens = len(self) - self._modules[str(lens)] = module - self.add_module(str(lens), module) - self._modules[str(lens)].set_pre_module(self._modules[str(lens-1)]) - self._modules[str(lens-1)]._is_last_layer = False - self._modules[str(lens)]._is_last_layer = True - + self.last_module[0]._is_last_layer = False + module._is_last_layer = True + self.last_module[0].set_pre_module(module) + self.last_module = (module,) + def add_tail(self, module): module = _block_wrapper(module, self.module_dict, mode="1F1B") module.init_param_storage() if config['topology'].pipe_rank != config['topology'].pipe_size - 1: - return + return DummyForward else: - self._add_tail(module) \ No newline at end of file + self._add_tail(module) + return module \ No newline at end of file diff --git a/bmtrain/init.py b/bmtrain/init.py index aba4fc87..6125cd16 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -223,7 +223,15 @@ def get_group_rank(self,group_name): return self.tp_zero_id elif group_name == "tp": return self.tp_id - + + def is_first_rank(self, group_name="pipe"): + if group_name == "pipe": + return self.pipe_rank == 0 + elif group_name == "zero": + return self.zero_id == 0 + elif group_name == "tp": + return self.tp_id == 0 + def is_last_rank(self, group_name="pipe"): if group_name == "pipe": return self.pipe_rank == self.pipe_size - 1 @@ -235,15 +243,16 @@ def is_last_rank(self, group_name="pipe"): def is_initialized() -> bool: return config["initialized"] -def get_logger(rank, level): +def get_logger(rank, level, print_to_screen=False): formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('pipeline') logger.setLevel(level) - if rank == 0: - ch = logging.StreamHandler() - ch.setLevel(level) - ch.setFormatter(formatter) - logger.addHandler(ch) + if print_to_screen: + if rank == 0: + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") fh.setLevel(level) fh.setFormatter(formatter) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 9b7a3120..f69c98e8 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -85,7 +85,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * (self.loss_scale / config['world_size']) # loss scale + return loss * (self.loss_scale / config['world_size'] * config['pipe_size']) # loss scale def backward(self, loss : torch.Tensor): """ diff --git a/bmtrain/pipe/debug.py b/bmtrain/pipe/debug.py deleted file mode 100644 index e1de6a14..00000000 --- a/bmtrain/pipe/debug.py +++ /dev/null @@ -1,17 +0,0 @@ -from bmtrain.global_var import config -import logging - -def get_logger(rank, level): - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - logger = logging.getLogger('pipeline') - logger.setLevel(level) - if rank == 0: - ch = logging.StreamHandler() - ch.setLevel(level) - ch.setFormatter(formatter) - logger.addHandler(ch) - fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") - fh.setLevel(level) - fh.setFormatter(formatter) - logger.addHandler(fh) - return logger diff --git a/bmtrain/pipe/example.py b/bmtrain/pipe/example.py deleted file mode 100644 index 2a64e4b0..00000000 --- a/bmtrain/pipe/example.py +++ /dev/null @@ -1,54 +0,0 @@ -from schedule import pipeline_forward_backward -import torch -import bmtrain as bmt -import time -import sys -def generate(iters): - torch.manual_seed(42) - for i in range(iters): - inp = (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) - yield inp -data_loader = iter(generate(100*16)) -iters=10 -dtype=torch.half - -def test_pipe(): - bmt.init_distributed(seed=42, pipe_size=4) - models = [bmt.nn.PipeEmbedding(1024,128,dtype=dtype)] - for i in range(11): - models.append(bmt.nn.Linear(128,128,dtype=dtype)) - bmt.init_parameters(models) - models = bmt.PipeDreamBlockList(models) - optimizer = bmt.optim.AdamOptimizer(models.parameters(), lr=0.001) - start = time.time() - for i in range(iters): - pipeline_forward_backward(models, data_loader, 12*16) - if bmt.config['topology'].pipe_rank == 0: - print(models['0'].weight.grad) - optimizer.step() - t = time.time() - start - -def test_dp(): - bmt.init_distributed(seed=42, pipe_size=1) - models = [bmt.nn.PipeEmbedding(1024,128,dtype=dtype)] - for i in range(11): - models.append(bmt.nn.Linear(128,128,dtype=dtype)) - bmt.init_parameters(models) - models = bmt.TransformerBlockList(models) - optimizer = bmt.optim.AdamOptimizer(models.parameters(), lr=0.001) - for it in range(iters): - for i in range(16): - inp = next(data_loader) - loss_tmp = models(*inp) - loss_tmp = loss_tmp.mean() - loss_tmp.backward() - - print(models['0'].weight.grad) - optimizer.step() - -if __name__ == "__main__": - if sys.argv[1] == "dp": - print("dp") - test_dp() - else: - test_pipe() \ No newline at end of file diff --git a/bmtrain/pipe/run.sh b/bmtrain/pipe/run.sh deleted file mode 100644 index 1bcc5c6c..00000000 --- a/bmtrain/pipe/run.sh +++ /dev/null @@ -1,6 +0,0 @@ -if [ "$1" = "dp" ]; then - nproc=1 -else - nproc=4 -fi -torchrun --nnodes=1 --nproc_per_node=$nproc --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost example.py $1 diff --git a/bmtrain/pipe/salloc.sh b/bmtrain/pipe/salloc.sh deleted file mode 100644 index c17843da..00000000 --- a/bmtrain/pipe/salloc.sh +++ /dev/null @@ -1 +0,0 @@ -salloc --partition=gpu3 --nodelist=$1 diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index b83bf64f..95d9432e 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -1,28 +1,12 @@ import sys from bmtrain.global_var import config -from bmtrain.loss import FusedCrossEntropy import bmtrain as bmt -from .debug import get_logger from .comm import PipeCommander import torch -import logging from typing import Iterable -def obj_str(objs): - string = "" - for o in objs: - if isinstance(o, torch.Tensor): - string += repr(o.shape) - elif o is None: - string += "None" - else: - string += repr(o) - string += " ," - return string.rstrip(",") - - -def backward_step(inp, output, grad_output): +def backward_step(inp, output, grad_output, optim_manager=None): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -43,10 +27,13 @@ def backward_step(inp, output, grad_output): #TODO scale the grad # if output_tensor_grad[0] is None and config.grad_scale_func is not None: # output_tensor[0] = config.grad_scale_func(output_tensor[0]) - config['logger'].debug(obj_str(grad_output)) - # config['logger'].debug(obj_str(output)) - torch.autograd.backward(output[0], grad_tensors=grad_output[0]) - + if optim_manager is not None and config["topology"].is_last_rank(): + output = optim_manager.scale_loss(output[0]) + else: + output = output[0] + torch.autograd.backward(output, grad_tensors=grad_output[0]) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config['load_stream']) input_grad = [None] if inp is not None: input_grad = [] @@ -58,10 +45,9 @@ def backward_step(inp, output, grad_output): return input_grad -def forward_func(model, inp, micro_idx): +def forward_func(model, inp, micro_idx, is_last_micro=False): if config["topology"].pipe_rank == config["topology"].pipe_size - 1: loss = model(*inp) - config['logger'].info("loss: {}".format(loss.item())) return [loss] else: @@ -70,7 +56,8 @@ def forward_func(model, inp, micro_idx): if not isinstance(hidden_state, Iterable): hidden_state = [hidden_state] return hidden_state -def pipeline_forward_backward(model, data_iterator, global_batch_size, interleaving_size=1): + +def pipeline_forward_backward(model, data_iterator, global_batch_size, optim_manager, interleaving_size=1): """Forward and backward the pipeline model. Args: @@ -83,22 +70,15 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, interleav """ # forwrad unpack + loss = None + optim_manager.zero_grad() micro_batch_size = 2 assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" num_micro_batches = global_batch_size // micro_batch_size assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" config["micros"] = num_micro_batches topo = config["topology"] - logger = config['logger'] - - logger.debug("model arch {}".format(model)) - logger.debug("model partition s: {} , e: {} ".format(model.transformers.head_idx,model.transformers.tail_idx)) - logger.debug("model length:{}".format(str(len(model.transformers)))) - if config["topology"].pipe_rank == 0 or config["topology"].is_last_rank(): - lens = len(model.transformers) - emb = model.transformers['0'] if config["topology"].pipe_rank == 0 else model.transformers[str(lens - 1)] - logger.debug("embedding weight param {}".format(emb.weight)) logger.info("topo: {}".format(topo)) logger.info("num_micro_batches: {}".format(num_micro_batches)) logger.info("micro_batch_size: {}".format(micro_batch_size)) @@ -115,7 +95,6 @@ def generator(data_iterator): while True: try: inp = next(data_iterator) - logger.debug("Input id {}".format(inp)) yield model.preprocess_func(inp) except StopIteration: break @@ -124,9 +103,6 @@ def generator(data_iterator): num_warmup=num_warmup, forward_only=False, \ interleaving_size=interleaving_size, \ ) - # if commander.is_first_stage() or commander.is_last_stage(): - # module = model.head_layer() if commander.is_first_stage() else model.tail_layer() - # commander.param_reduce(module) inps = [] outputs = [] logger.info("num_warmup: {}".format(num_warmup)) @@ -147,56 +123,49 @@ def generator(data_iterator): inp = commander.recv_prev(need_data=True) for micro in range(num_micro_batches - num_warmup): - output = forward_func(model, inp, micro + num_warmup) - logger.debug("output :{}".format(obj_str(output))) + is_last_micro = micro == num_micro_batches - num_warmup - 1 + output = forward_func(model, inp, micro + num_warmup, is_last_micro) + if commander.is_last_stage(): + loss = output[0] logger.info("{} micro forward".format(micro+num_warmup)) - logger.debug("send forward and recv backward") grad_output = commander.send_forward_recv_backward(output) + inps.append(inp) outputs.append(output) - logger.debug("grad output :{}".format(obj_str(grad_output))) + logger.info("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) - logger.debug("inp shape: {}".format(inp[0].shape)) - if not commander.is_last_stage(): - logger.debug("output shape: {}".format(output[0].shape)) - if grad_output[0] is not None : - logger.debug("grad_output shape: {}".format(grad_output[0].shape)) + inp = inps.pop(0) output = outputs.pop(0) + for x in inp: logger.info("inp requires_grad: {}".format(x.requires_grad)) - inp_grad = backward_step(inp, output, grad_output) + inp_grad = backward_step(inp, output, grad_output, optim_manager) logger.info("{} micro backward".format(micro+num_warmup)) if micro == remain_batch - 1: inp = None commander.send_prev(inp_grad) logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], micro + num_warmup)) else: - logger.debug("grad inp :{}".format(obj_str(inp_grad))) - if inp_grad[0] is not None: - logger.debug("inp_grad shape: {}".format(inp_grad[0].shape)) - logger.debug("send backward and recv forward") + logger.info("send backward and recv forward") inp = commander.send_backward_recv_forward(inp_grad, need_data=True) - logger.debug("inp type: {}".format(type(inp))) - logger.debug("inp shape: {}".format(inp[0].shape)) - logger.info("{} send micro grad {}th to prev neighbour and recv micro hidden state {} from prev neighbour".format(config['rank'], micro, micro + num_warmup + 1)) - - if not forward_only: logger.info("cooling stage") for i in range(num_warmup): logger.info("{} recv micro grad {}th from next neighbour".format(config['rank'], num_micro_batches - num_warmup + i)) inp = inps.pop(0) output = outputs.pop(0) - grad_output = commander.recv_next() logger.info("{} micro backward".format(num_micro_batches - num_warmup + i)) input_grad = backward_step( - inp, output , grad_output, + inp, output , grad_output, ) logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) commander.send_prev(input_grad) + model.transformers.reduce_tied_module() + optim_manager.step() bmt.synchronize() + return loss \ No newline at end of file diff --git a/bmtrain/pipe/test_send_recv.py b/bmtrain/pipe/test_send_recv.py deleted file mode 100644 index c8d42f53..00000000 --- a/bmtrain/pipe/test_send_recv.py +++ /dev/null @@ -1,24 +0,0 @@ -from schedule import pipeline_forward_backward -import torch -import bmtrain as bmt -from comm import PipeCommander,groupcall -def generate(iters): - for i in range(iters): - yield (torch.randint(0,1024,size=(12,1024),device="cuda", dtype=torch.int32),) - -bmt.init_distributed(pipe_size=4) - -topo = bmt.config["topology"] -num_micro_batches = 48 -num_warmup = 3 -interleaving_size = 1 -data_iterator = iter(generate(100)) -commander = PipeCommander(topo, num_micros=num_micro_batches,\ - num_warmup=num_warmup, forward_only=False, \ - interleaving_size=interleaving_size, \ - data_iterator=data_iterator) -# with groupcall(): -commander.send_prev([torch.randn((12,1024,128),device="cuda", dtype=torch.float16).requires_grad_()]) -recv = commander.recv_next() -if recv[0] is not None: - print(recv[0].shape) \ No newline at end of file diff --git a/bmtrain/pipe/topo.py b/bmtrain/pipe/topo.py deleted file mode 100644 index f0641f48..00000000 --- a/bmtrain/pipe/topo.py +++ /dev/null @@ -1,85 +0,0 @@ -class topology: - def __init__(self,**config): - # pipe_idx is the idx of the pipeline in the group - self.rank = config['rank'] - pp_size = config["pipe_size"] - tp_size = config["tp_size"] - world_size = config["world_size"] - assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" - - dp_size = world_size // (pp_size * tp_size) - config['tp_zero_size'] = dp_size - config['zero_size'] = world_size // pp_size - self.pipe_size = config['pipe_size'] - - stage_size = world_size // pp_size - for i in range(world_size): - self.pipe_idx = self.rank % stage_size - self.pipe_rank = self.rank // stage_size - self.tp_id = self.rank % tp_size - self.tp_idx = self.rank // tp_size - #pp->zero - self.pp_zero_idx = self.pipe_rank - self.pp_zero_id = self.pipe_idx - #tp->zero - self.tp_zero_idx = self.tp_id - self.tp_zero_id = self.tp_idx - #pp->tp->zero - self.pp_tp_zero_idx = self.pipe_rank * tp_size + self.tp_id - self.pp_tp_zero_id = self.pipe_idx // tp_size - #only zero - self.zero_idx = 0 - self.zero_id = self.rank - - - def get_group_id(self,group_name): - if group_name == "pipe": - return self.pipe_idx - elif group_name == "zero": - return self.zero_idx - elif group_name == "tp_zero": - return self.tp_zero_idx - elif group_name == "tp": - return self.tp_idx - - def get_group_rank(self,group_name): - if group_name == "pipe": - return self.pipe_rank - elif group_name == "zero": - return self.zero_id - elif group_name == "tp_zero": - return self.tp_zero_id - elif group_name == "tp": - return self.tp_id - - def get_peer(self, group_name, next_prev): - if group_name == "pipe": - if next_prev == "next": - return self.pipe_rank+1 if self.pipe_rank < self.pipe_size - 1 else -1 - elif next_prev == "prev": - return self.pipe_rank-1 if self.pipe_rank > 0 else -1 - elif group_name == "zero": - if next_prev == "next": - return self.zero_id+1 if self.zero_id < self.pipe_size - 1 else -1 - elif next_prev == "prev": - return self.zero_id-1 if self.zero_id > 0 else -1 - elif group_name == "tp_zero": - if next_prev == "next": - return self.tp_zero_id+1 if self.tp_zero_id < self.pipe_size - 1 else -1 - elif next_prev == "prev": - return self.tp_zero_id-1 if self.tp_zero_id > 0 else -1 - elif group_name == "tp": - if next_prev == "next": - return self.tp_id+1 if self.tp_id < self.pipe_size - 1 else -1 - elif next_prev == "prev": - return self.tp_id-1 if self.tp_id > 0 else -1 - return -1 - - -if __name__ == "__main__": - topology1 = topology(**{"rank":0,"pipe_size":4,"tp_size":8,"world_size":32}) - topology2 = topology(**{"rank":8,"pipe_size":4,"tp_size":8,"world_size":32}) - topology3 = topology(**{"rank":16,"pipe_size":4,"tp_size":8,"world_size":32}) - topology4 = topology(**{"rank":24,"pipe_size":4,"tp_size":8,"world_size":32}) - from IPython import embed;embed() - \ No newline at end of file diff --git a/example/convert_ckpt_pipe.py b/example/convert_ckpt_pipe.py deleted file mode 100644 index 90cb8027..00000000 --- a/example/convert_ckpt_pipe.py +++ /dev/null @@ -1,95 +0,0 @@ -import bmtrain as bmt -import torch -from models import GPT, GPTPipe -import re -from collections import OrderedDict - -def partition(pipe_rank,pipe_size,len_modules): - part_lens = [0]+[(len_modules // pipe_size + (i < (len_modules % pipe_size))) for i in range(pipe_rank+1)] - start = sum(part_lens[:pipe_rank+1]) - end = start + part_lens[pipe_rank+1] - return start,end - -def key_process(key, pipe_size , rank, start, end): - res = re.search("\.(\d+)\.", key) - if res is not None: - layer_idx = int(res.group(1)) - else: - layer_idx = None - if layer_idx is None or (layer_idx >= start and layer_idx < end): - if rank == 0: - if key == "word_emb.weight": - return "transformers.0.weight" - elif key == "pos_emb.weight": - return "transformers.1.weight" - if key.startswith("layernorm"): - return None - else: - if layer_idx is not None: - return re.sub(r"\d+", str(layer_idx + 2), key) - elif rank == pipe_size - 1: - if key == "word_emb.weight": - return "transformers.3.weight" - if key.startswith("layernorm"): - postfix = key.split(".")[-1] - return f"transformers.2.{postfix}" - elif key == "pos_emb.weight": - return None - else: - print(key) - if layer_idx is not None: - return re.sub(r"\d+", str(layer_idx - start), key) - else: - if layer_idx is not None: - return re.sub(r"\d+", str(layer_idx - start), key) - else: - return None - - - -def init_model(): - model = GPT( - num_layers=8, - vocab_size=10240, - dim_model=2560, - dim_head=80, - num_heads=32, - dim_ff=8192, - max_distance=1024, - bias=True, - dtype=torch.half - ) - return model - -def get_len_modules(state): - max_len = 0 - for key in state: - s = re.search("\.(\d+)\.", key) - if s is not None: - res = int(s.group(1)) - if res>max_len: - max_len = res - return max_len+1 - - -if __name__ == "__main__": - bmt.init_distributed() - model = init_model() - bmt.load(model, "ckpt-0.pt") - pipe_size = 4 - state = model.state_dict() - - for rank in range(pipe_size): - print(rank) - dic = OrderedDict() - len_modules = get_len_modules(state) - s,e = partition(rank, pipe_size, len_modules) - print(s," ",e) - for i in state.keys(): - k = key_process(i, pipe_size, rank, s, e) - if k is not None: - dic[k] = state[i] - print(dic.keys()) - torch.save(dic, f"pipe_{rank}.ckpt") - - \ No newline at end of file diff --git a/example/inspect.py b/example/inspect.py deleted file mode 100644 index b4b17f5d..00000000 --- a/example/inspect.py +++ /dev/null @@ -1,32 +0,0 @@ -from contextlib import contextmanager -from bmtrain import CheckpointBlock -import sys - -@contextmanager -def custom_redirection(fileobj): - old = sys.stdout - sys.stdout = fileobj - try: - yield fileobj - finally: - sys.stdout = old - -def look_var(layer, _, output): - try: - print(f"{layer.__name__}: {output.min()}") - except: - print(f"{layer.__name__}:{output[0].min()}") - -def lookup_output(model,layers=set()): - - for key,layer in model.named_modules(): - layer.__name__ = key - if layer not in layers: - layers.add(layer) - else: - continue - if len(layer._modules) !=0: - layer.register_forward_hook(look_var) - lookup_output(layer,layers) - else: - layer.register_forward_hook(look_var) diff --git a/example/inspect_tools.py b/example/inspect_tools.py new file mode 100644 index 00000000..c170c74c --- /dev/null +++ b/example/inspect_tools.py @@ -0,0 +1,54 @@ +from contextlib import contextmanager +from bmtrain import CheckpointBlock +import sys +log_file = set() +@contextmanager +def custom_redirection(fileobj): + if isinstance(fileobj, str): + if fileobj not in log_file: + ftmp = open(fileobj,"w") + ftmp.close() + log_file.add(fileobj) + file_handle = open(fileobj,"a") + else: + file_handle = fileobj + old = sys.stdout + sys.stdout = file_handle + try: + yield file_handle + finally: + sys.stdout = old + file_handle.close() + +def look_var(layer, _, output): + try: + print(f"{layer.__name__}: {output.min()}") + except: + print(f"{layer.__name__}: {output[0].min()}") + + +def look_inp_weight(look_inp,look_weight): + def look_inp_func(layer, inp): + if look_inp: + try: + print(f"{layer.__name__}: {inp.min()}") + except: + print(f"{layer.__name__}: {inp[0].min()}") + if look_weight: + print(f"{layer.__name__} weight: {layer._parameters}") + return look_inp_func + +def lookup_output(model,layers=set(), look_input=False, look_weight=False): + for key,layer in model.named_modules(): + layer.__name__ = key + if layer not in layers: + layers.add(layer) + else: + continue + if len(layer._modules) !=0: + layer.register_forward_hook(look_var) + lookup_output(layer,layers,look_input=look_input,look_weight=look_weight) + layer.register_forward_pre_hook(look_inp_weight(look_input,look_weight)) + else: + layer.register_forward_hook(look_var) + layer.register_forward_pre_hook(look_inp_weight(look_input,look_weight)) \ No newline at end of file diff --git a/example/models/gpt.py b/example/models/gpt.py index 4895c081..e83ca8c6 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -20,14 +20,13 @@ def __init__(self, else: self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - if config['pipe_size'] > 1: self.transformers = bmt.PipelineTransformerBlockList([ bmt.Block( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype ) - , mode="PIPE" + , mode="PIPE",use_checkpoint=False ) for _ in range(num_layers) ]) @@ -36,7 +35,7 @@ def __init__(self, bmt.Block( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ) + ),use_checkpoint=False ) for _ in range(num_layers) ]) diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index 33dcfedf..119f9538 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -43,15 +43,9 @@ def __init__(self, self.transformers = bmt.PipeDreamBlockList( blocklist, ) - self.transformers.add_head(pos_emb) - self.transformers.add_tail(layernorm) - self.transformers.add_head_tail(word_emb) - - if config['topology'].pipe_rank == config['topology'].pipe_size - 1 : - self.word_emb = self.transformers.get_last_layer - if config['topology'].pipe_rank == 0: - self.word_emb = self.transformers.get_first_layer - + self.pos_emb = self.transformers.add_head(pos_emb) + self.layernorm = self.transformers.add_tail(layernorm) + self.word_emb = self.transformers.add_head_tail(word_emb) if config['tp_size'] > 1: self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: @@ -69,14 +63,14 @@ def forward(self, # for layer in self.transformers: out = self.transformers(input, mask_2d, None) + out = self.layernorm(out) if config['topology'].pipe_rank == config['topology'].pipe_size - 1: if config['tp_size'] > 1: - logits = self.word_emb().projection(out) + logits = self.word_emb.projection(out) else: - logits = self.word_emb()(out, True) + logits = self.word_emb(out, True) logits = logits.float().view(-1, logits.shape[-1]) target = target.view(-1) - config["logger"].debug("logits:{}".format(logits)) return self.loss_func(logits, target) else: return out, pos, mask, target @@ -85,10 +79,7 @@ def preprocess_func(self, inp): if config['topology'].pipe_rank == 0: inp_id = inp[0] pos = inp[1] - # output =torch.randn((2,512,2560),dtype=torch.float16,device="cuda") - config['logger'].debug("preprocess emb type{}".format(self.transformers['0']._module.__class__.__name__)) - return self.transformers['0'](inp_id)+self.transformers['1'](pos), *inp[1:] - # return output, *inp[1:] + return self.pos_emb(pos) + self.word_emb(inp_id) , *inp[1:] else: return None diff --git a/example/pipe_train.py b/example/pipe_train.py index 85f8f369..27cdf44a 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -6,7 +6,7 @@ from bmtrain.global_var import config from bmtrain import inspect from bmtrain.pipe import pipeline_forward_backward -from inspect import custom_redirection, lookup_output +from inspect_tools import custom_redirection, lookup_output def main(): bmt.init_distributed( @@ -25,7 +25,7 @@ def main(): dim_ff=8192, max_distance=1024, bias=True, - dtype=torch.half + dtype=torch.float32 ) @@ -37,7 +37,7 @@ def main(): # generate dummy data for each rank torch.manual_seed(1234) - batch_size = 2 * 4 + batch_size = 2 * 16 seq_len = 512 def data_loader(): for i in range(1000): @@ -70,22 +70,44 @@ def data_loader(): bmt.synchronize() avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - model.transformers.sync_tied_module() + + # lookup_output(model) for iteration in range(10): # load data st = time.time() rank = bmt.config["topology"].pipe_rank - with custom_redirection(open(f"pp_output_{rank}","w")): - lookup_output(model) - global_loss = pipeline_forward_backward(model, data_loader(), batch_size) + if iteration == 4: + lookup_output(model, look_weight=False) + if iteration >= 4: + with custom_redirection(f"pp_output_{rank}"): + global_loss = pipeline_forward_backward(model, data_loader(), batch_size, optim_manager) + else: + global_loss = pipeline_forward_backward(model, data_loader(), batch_size, optim_manager) - - - optim_manager.step() + bmt.synchronize() + if iteration == 4: + bmt.save(model, f"pipe_{rank}_iter4.ckpt") + if bmt.config["topology"].is_last_rank() or bmt.config["topology"].pipe_rank == 0: + is_head = bmt.config["topology"].pipe_rank == 0 + torch.save(model.word_emb.weight.grad, f"word_emb.ckpt_{int(is_head)}") # record time and loss iteration_time = time.time() - st + if bmt.config["topology"].is_last_rank(): + avg_time_recorder.record(iteration_time) + avg_loss_recorder.record(global_loss) + print( + "| Iter: {:6d} | loss: {:.10f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + iteration, + global_loss, + avg_loss_recorder.value, + lr_scheduler.current_lr, + optim_manager.loss_scale, + avg_time_recorder.value + ) + ) + if __name__ == '__main__': diff --git a/example/train.py b/example/train.py index 76782d38..cc8b3f61 100644 --- a/example/train.py +++ b/example/train.py @@ -5,7 +5,7 @@ from bmtrain import optim from bmtrain.global_var import config from bmtrain import inspect - +from inspect_tools import lookup_output, custom_redirection def main(): bmt.init_distributed( seed=0, @@ -23,34 +23,31 @@ def main(): bias=True, dtype=torch.half ) - bmt.load(model, "./ckpt-0.pt") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() - # data # generate dummy data for each rank torch.manual_seed(1234) - batch_size = 2 seq_len = 512 - - for i in range(bmt.world_size()): - sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) - enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() - enc_input = sent[:, :-1].long().cuda() - targets = sent[:, 1:].long().cuda() - mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] - targets = torch.where( - mask, - targets, - torch.full_like(targets, -100, dtype=torch.long) - ) - - if i == bmt.rank(): - break - print(enc_input) + global_batch = 2 * 16 + + # for i in range(bmt.world_size()): + # sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + # enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + # enc_input = sent[:, :-1].long().cuda() + # targets = sent[:, 1:].long().cuda() + # mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + # targets = torch.where( + # mask, + # targets, + # torch.full_like(targets, -100, dtype=torch.long) + # ) + + # if i == bmt.rank(): + # break if config['tp_size'] > 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: @@ -66,47 +63,63 @@ def main(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - for iteration in range(1000): + for iteration in range(10): # load data st = time.time() + # if iteration == 1: + # lookup_output(model, look_weight=True) + for micro in range(global_batch // batch_size): + # for i in range(bmt.world_size()): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) - with inspect.inspect_tensor() as inspector: + # if i == bmt.rank(): + # break + + # with inspect.inspect_tensor() as inspector: pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) - logits = model( - enc_input, - pos, - pos < enc_length[:, None] - ) - print(logits) + # if iteration == 4: + # lookup_output(model) + if iteration >= 4: + with custom_redirection("dp_ref.output"): + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) + else: + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) batch, seq_len, vocab_out_size = logits.size() if config['tp_size'] > 1: loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) else: loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) - - global_loss = bmt.sum_loss(loss).item() - + global_loss = loss.item() optim_manager.zero_grad() - optim_manager.backward(loss) - + # print inspected tensors in the forward & backward pass # print parameters of the model if iteration % 100 == 0: - bmt.print_rank( - inspect.format_summary( - inspector.get_summary( - ) - ) bmt.print_rank( inspect.format_summary( inspect.inspect_model(model, "*") ) ) - if (iteration + 1) % 4 == 0: - optim_manager.step() - + optim_manager.step() # record time and loss iteration_time = time.time() - st @@ -115,7 +128,7 @@ def main(): # print time and loss bmt.print_rank( - "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + "| Iter: {:6d} | loss: {:.10f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( iteration, global_loss, avg_loss_recorder.value, From e35ff763257d051f0df65d40dcda2edaf1d1d4b5 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 12 Oct 2023 11:32:09 +0800 Subject: [PATCH 12/43] WIP: 1f1b training adaption --- bmtrain/block_layer.py | 6 ++-- bmtrain/distributed/__init__.py | 2 +- bmtrain/distributed/ops.py | 58 +++++++++++++++++++++++++++------ bmtrain/init.py | 4 +++ bmtrain/pipe/comm.py | 13 +++++--- bmtrain/pipe/schedule.py | 47 ++++++++++++++------------ bmtrain/synchronize.py | 6 ++-- 7 files changed, 95 insertions(+), 41 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 53ffe608..985b5348 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -109,9 +109,9 @@ def init_param_storage(self): storage_type = storage_type_cuda(param.storage_type()) kw_name = _get_param_kw(param) if kw_name not in self._storage_info: - if self._mode == "PIPE" and param._tp_mode: + if (self._mode == "PIPE" or self._mode == "1F1B") and param._tp_mode: zero_comm = config["pp_tp_zero_comm"] - elif self._mode != "PIPE" and param._tp_mode: + elif (self._mode != "PIPE" and self._mode != "1F1B") and param._tp_mode: zero_comm = config["tp_zero_comm"] elif (self._mode == "PIPE" or self._mode == "1F1B") and not param._tp_mode: zero_comm = config["pp_zero_comm"] @@ -735,7 +735,7 @@ def reduce_tied_module(self): def _add_tail(self, module): self.last_module[0]._is_last_layer = False module._is_last_layer = True - self.last_module[0].set_pre_module(module) + module.set_pre_module(self.last_module[0]) self.last_module = (module,) def add_tail(self, module): diff --git a/bmtrain/distributed/__init__.py b/bmtrain/distributed/__init__.py index 8671c4aa..2d247d7b 100644 --- a/bmtrain/distributed/__init__.py +++ b/bmtrain/distributed/__init__.py @@ -1 +1 @@ -from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations,groupcall \ No newline at end of file +from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations,groupcall,send_object, recv_object \ No newline at end of file diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 903e2ae7..87e87315 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -1,5 +1,6 @@ import torch -from ..global_var import config +import bmtrain as bmt +from ..global_var import config, rank from ..nccl import allGather as ncclAllGather, recv from ..nccl import allReduce as ncclAllReduce from ..nccl import broadcast as ncclBroadcast @@ -7,6 +8,7 @@ from ..nccl import recv as ncclRecv from ..nccl import commCount,commRank,NCCLCommunicator,groupStart,groupEnd import contextlib +import pickle DTYPE_LIST = [ torch.float64, torch.float32, @@ -31,19 +33,47 @@ def __init__(self, stream): def wait(self): torch.cuda.current_stream().wait_stream(self.stream) +def send_object(obj, next_rank, comm): + data_bytes: bytes = pickle.dumps(obj) + data_length: int = len(data_bytes) + + gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long) + ncclSend(gpu_data_length.storage(), next_rank, comm) + byte_storage = torch.ByteStorage.from_buffer(data_bytes).cuda() + ncclSend(byte_storage, next_rank, comm) + +def recv_object(prev_rank, comm): + data_length = torch.tensor([0], device="cuda", dtype=torch.long) + ncclRecv(data_length.storage(), prev_rank, comm) + data_bytes_stor = torch.cuda.ByteStorage(data_length.item()) + ncclRecv(data_bytes_stor, prev_rank, comm) + tensor = torch.ByteTensor(data_bytes_stor.cpu()) + data = pickle.loads(tensor.numpy().tobytes()) + return data + def send_activations_list(hidden_state_list, next_rank, comm, async_op=False): if async_op: current_stream = torch.cuda.current_stream() with torch.cuda.stream(config["pp_comm_stream"]): config["pp_comm_stream"].wait_stream(current_stream) - length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - length[0] = len([h for h in hidden_state_list ]) + length = torch.tensor(data=[len([h for h in hidden_state_list ])], device="cuda", dtype=torch.int) ncclSend(length.storage(), next_rank, comm) + flags = torch.tensor(data=[0 for _ in range(len(hidden_state_list))], device="cuda",dtype=torch.int) for i in range(len(hidden_state_list)): if hidden_state_list[i] is None: - hidden_state_list[i] = torch.tensor([12306],dtype=torch.int,device="cuda") - hidden_state_list[i].record_stream(config["pp_comm_stream"]) - send_activations(hidden_state_list[i], next_rank, comm) + flag = -1 + elif torch.is_tensor(hidden_state_list[i]): + flag = 0 + else: + flag = 1 + flags[i] = flag + ncclSend(flags.contiguous().storage(), next_rank, comm) + for i in range(len(hidden_state_list)): + if flags[i] == 0: + hidden_state_list[i].record_stream(config["pp_comm_stream"]) + send_activations(hidden_state_list[i], next_rank, comm) + elif flags[i] == 1: + send_object(hidden_state_list[i], next_rank, comm) return handler(config["pp_comm_stream"]) else: length = torch.tensor(data=[0], device="cuda", dtype=torch.int) @@ -58,16 +88,24 @@ def recv_activations_list(prev_rank, comm, async_op = True): length = torch.tensor(data=[0], device="cuda", dtype=torch.int) hidden_state_list = [] ncclRecv(length.storage(), prev_rank, comm) + flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int) + ncclRecv(flags.storage(), prev_rank, comm) + bmt.synchronize(bmt.config["pp_zero_comm"]) for i in range(length[0].item()): - recv = recv_activations(prev_rank, comm) - if len(recv.shape) == 1 and recv.shape[0] == 1 and recv.item() == 12306: + flag = flags[i].item() + if flag == -1: hidden_state_list.append(None) - else: + elif flag == 0: + recv = recv_activations(prev_rank, comm) hidden_state_list.append(recv) - + elif flag == 1: + recv = recv_object(prev_rank, comm) + hidden_state_list.append(recv) + return hidden_state_list + def send_activations(hidden_state, next_rank, comm): hidden_state = hidden_state.contiguous() send_meta(hidden_state, next_rank, comm) diff --git a/bmtrain/init.py b/bmtrain/init.py index 6125cd16..39cae9e0 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -83,6 +83,7 @@ def init_distributed( config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") + config["pipe_rank"] = config['topology'].get_group_rank("pipe") if debug: config["logger"] = get_logger(rank, "DEBUG") cpus_this_worker = None @@ -203,6 +204,9 @@ def __init__(self,config): self.zero_idx = 0 self.zero_id = self.rank + def get_comm(self, group_name): + if group_name == "pipe": + return config["pipe_comm"] def get_group_id(self,group_name): if group_name == "pipe": diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 208e66c0..3bd031c4 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -5,6 +5,7 @@ class PipeCommander: def __init__(self, topo, input_generator, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo + self.comm = self.topo.get_comm("pipe") self.input_generator = input_generator self.num_micros = num_micros self.num_warmup = num_warmup @@ -27,7 +28,9 @@ def send_next(self, tensors): if not self.is_last_stage(): if not isinstance(tensors, Iterable): tensors = [tensors] - handle.append(send_activations_list(tensors, self.topo.pipe_rank + 1, config["pipe_comm"], async_op=True)) + elif not isinstance(tensors, list): + tensors = list(tensors) + handle.append(send_activations_list(tensors, self.topo.pipe_rank + 1, self.comm, async_op=True)) for h in handle: h.wait() @@ -36,13 +39,15 @@ def send_prev(self, tensors): if not self.is_first_stage(): if not isinstance(tensors, Iterable): tensors = [tensors] - handle.append(send_activations_list(tensors, self.topo.pipe_rank - 1, config["pipe_comm"], async_op=True)) + elif not isinstance(tensors, list): + tensors = list(tensors) + handle.append(send_activations_list(tensors, self.topo.pipe_rank - 1, self.comm, async_op=True)) for h in handle: h.wait() def recv_prev(self, need_data=False): if not self.is_first_stage(): - res = recv_activations_list(self.topo.pipe_rank - 1, config["pipe_comm"]) + res = recv_activations_list(self.topo.pipe_rank - 1, self.comm) for idx,tensor in enumerate(res): if idx == 0: tensor.requires_grad_() @@ -55,7 +60,7 @@ def recv_prev(self, need_data=False): def recv_next(self): if not self.is_last_stage(): - return recv_activations_list(self.topo.pipe_rank + 1, config["pipe_comm"]) + return recv_activations_list(self.topo.pipe_rank + 1, self.comm) else: return [None] diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 95d9432e..d14d5c68 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -18,7 +18,7 @@ def backward_step(inp, output, grad_output, optim_manager=None): if not isinstance(inp, list) : inp = [inp] for x in inp: - if x is not None and x.requires_grad: + if x is not None and (torch.is_tensor(x) and x.requires_grad): x.retain_grad() if not isinstance(output, Iterable): output = [output] @@ -28,7 +28,11 @@ def backward_step(inp, output, grad_output, optim_manager=None): # if output_tensor_grad[0] is None and config.grad_scale_func is not None: # output_tensor[0] = config.grad_scale_func(output_tensor[0]) if optim_manager is not None and config["topology"].is_last_rank(): - output = optim_manager.scale_loss(output[0]) + if isinstance(output[0], Iterable): + output = optim_manager.scale_loss(output[0][0]) + else: + output = optim_manager.scale_loss(output[0]) + else: output = output[0] torch.autograd.backward(output, grad_tensors=grad_output[0]) @@ -38,7 +42,7 @@ def backward_step(inp, output, grad_output, optim_manager=None): if inp is not None: input_grad = [] for x in inp: - if x is None or not x.requires_grad: + if x is None or (not torch.is_tensor(x)) or (not x.requires_grad): input_grad.append(None) else: input_grad.append(x.grad) @@ -51,13 +55,14 @@ def forward_func(model, inp, micro_idx, is_last_micro=False): return [loss] else: + config['logger'].info("inp shape: {}".format(inp[0].shape)) hidden_state = model(*inp) config['logger'].info("inp shape: {}".format(hidden_state[0].shape)) if not isinstance(hidden_state, Iterable): hidden_state = [hidden_state] return hidden_state -def pipeline_forward_backward(model, data_iterator, global_batch_size, optim_manager, interleaving_size=1): +def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros, optim_manager, clip_grad=1.0): """Forward and backward the pipeline model. Args: @@ -72,9 +77,9 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, optim_man # forwrad unpack loss = None optim_manager.zero_grad() - micro_batch_size = 2 - assert global_batch_size % micro_batch_size == 0, "The global batch size must be divisible by the micro batch size" - num_micro_batches = global_batch_size // micro_batch_size + micro_batch_size = micro_batch_size + num_micro_batches = num_micros + global_batch_size = micro_batch_size * num_micro_batches assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" config["micros"] = num_micro_batches topo = config["topology"] @@ -83,7 +88,6 @@ def pipeline_forward_backward(model, data_iterator, global_batch_size, optim_man logger.info("num_micro_batches: {}".format(num_micro_batches)) logger.info("micro_batch_size: {}".format(micro_batch_size)) logger.info("global_batch_size: {}".format(global_batch_size)) - logger.info("interleaving_size: {}".format(interleaving_size)) # construct Pipe Commander forward_only = False logger.info("forward_only: {}".format(forward_only)) @@ -98,7 +102,7 @@ def generator(data_iterator): yield model.preprocess_func(inp) except StopIteration: break - + interleaving_size = 1 commander = PipeCommander(topo,input_generator=generator(data_iterator), num_micros=num_micro_batches,\ num_warmup=num_warmup, forward_only=False, \ interleaving_size=interleaving_size, \ @@ -108,12 +112,12 @@ def generator(data_iterator): logger.info("num_warmup: {}".format(num_warmup)) for micro in range(num_warmup): inp = commander.recv_prev(need_data=True) - logger.info("{} recv micro {}th from prev neighbour".format(config['rank'], micro)) + logger.info("{} recv micro {}th from prev neighbour".format(bmt.config["topology"].pipe_rank, micro)) output = forward_func(model, inp, micro) logger.info("{} micro forward".format(micro)) # send activations commander.send_next(output) - logger.info("{} send micro {}th to next neighbour".format(config['rank'], micro)) + logger.info("{} send micro {}th to next neighbour".format(bmt.config["topology"].pipe_rank, micro)) if not forward_only: inps.append(inp) outputs.append(output) @@ -121,38 +125,38 @@ def generator(data_iterator): logger.info("remain_batch: {}".format(remain_batch)) if remain_batch > 0: inp = commander.recv_prev(need_data=True) - + logger.info("recv micro from prev neighbour") + loss_items = [] for micro in range(num_micro_batches - num_warmup): is_last_micro = micro == num_micro_batches - num_warmup - 1 output = forward_func(model, inp, micro + num_warmup, is_last_micro) if commander.is_last_stage(): loss = output[0] + loss_items.append(loss) logger.info("{} micro forward".format(micro+num_warmup)) grad_output = commander.send_forward_recv_backward(output) inps.append(inp) outputs.append(output) - logger.info("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(config['rank'], micro + num_warmup, micro)) + logger.info("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(bmt.config["topology"].pipe_rank, micro + num_warmup, micro)) inp = inps.pop(0) output = outputs.pop(0) - for x in inp: - logger.info("inp requires_grad: {}".format(x.requires_grad)) inp_grad = backward_step(inp, output, grad_output, optim_manager) logger.info("{} micro backward".format(micro+num_warmup)) if micro == remain_batch - 1: inp = None commander.send_prev(inp_grad) - logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], micro + num_warmup)) + logger.info("{} send micro grad {}th to prev neighbour".format(bmt.config["topology"].pipe_rank, micro + num_warmup)) else: logger.info("send backward and recv forward") inp = commander.send_backward_recv_forward(inp_grad, need_data=True) if not forward_only: logger.info("cooling stage") for i in range(num_warmup): - logger.info("{} recv micro grad {}th from next neighbour".format(config['rank'], num_micro_batches - num_warmup + i)) + logger.info("{} recv micro grad {}th from next neighbour".format(bmt.config["topology"].pipe_rank, num_micro_batches - num_warmup + i)) inp = inps.pop(0) output = outputs.pop(0) grad_output = commander.recv_next() @@ -160,12 +164,15 @@ def generator(data_iterator): input_grad = backward_step( inp, output , grad_output, ) - logger.info("{} send micro grad {}th to prev neighbour".format(config['rank'], i)) + logger.info("{} send micro grad {}th to prev neighbour".format(bmt.config["topology"].pipe_rank, i)) commander.send_prev(input_grad) - model.transformers.reduce_tied_module() + blocklist = model.get_blocklist() + blocklist.reduce_tied_module() + grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, clip_grad, norm_type=2) optim_manager.step() + bmt.synchronize() - return loss + return loss_items, grad_norm \ No newline at end of file diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index d562cc21..6be916df 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -3,16 +3,16 @@ from .global_var import config import warnings -def synchronize(): +def synchronize(comm=None): """ Synchronize all the workers across all nodes. (both CPU and GPU are synchronized) """ if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") - + comm = config['comm'] if comm is None else comm with torch.cuda.stream(config['barrier_stream']): barrier = torch.cuda.FloatTensor([1]) - nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm']) + nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', comm) config['barrier_stream'].synchronize() def wait_loader(): From 36830073d592918fd0e575ceec24e64e9e36a023 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 18 Oct 2023 15:50:04 +0800 Subject: [PATCH 13/43] 1f1b training adaption and fix example --- bmtrain/distributed/ops.py | 38 ++++++++++++++++++++------------------ bmtrain/pipe/comm.py | 34 +++++++++++++++++++++++----------- bmtrain/pipe/schedule.py | 17 +++++------------ example/models/pipe_gpt.py | 32 +++++++++----------------------- example/pipe_train.py | 22 ++++++++-------------- 5 files changed, 65 insertions(+), 78 deletions(-) diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 87e87315..c1558b7e 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -85,24 +85,26 @@ def send_activations_list(hidden_state_list, next_rank, comm, async_op=False): def recv_activations_list(prev_rank, comm, async_op = True): if async_op: - length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - hidden_state_list = [] - ncclRecv(length.storage(), prev_rank, comm) - flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int) - ncclRecv(flags.storage(), prev_rank, comm) - bmt.synchronize(bmt.config["pp_zero_comm"]) - for i in range(length[0].item()): - flag = flags[i].item() - if flag == -1: - hidden_state_list.append(None) - elif flag == 0: - recv = recv_activations(prev_rank, comm) - hidden_state_list.append(recv) - elif flag == 1: - recv = recv_object(prev_rank, comm) - hidden_state_list.append(recv) - - return hidden_state_list + with torch.cuda.stream(config["pp_comm_stream"]): + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + hidden_state_list = [] + ncclRecv(length.storage(), prev_rank, comm) + flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int) + ncclRecv(flags.storage(), prev_rank, comm) + for i in range(length[0].item()): + flag = flags[i].item() + if flag == -1: + hidden_state_list.append(None) + elif flag == 0: + recv = recv_activations(prev_rank, comm) + hidden_state_list.append(recv) + elif flag == 1: + recv = recv_object(prev_rank, comm) + hidden_state_list.append(recv) + for hidden_state in hidden_state_list: + if torch.is_tensor(hidden_state): + hidden_state.record_stream(torch.cuda.current_stream()) + return hidden_state_list, handler(config["pp_comm_stream"]) diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 3bd031c4..79bb5ee2 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -2,15 +2,26 @@ from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations, groupcall,all_reduce from bmtrain.global_var import config from collections.abc import Iterable +from bmtrain.synchronize import synchronize class PipeCommander: - def __init__(self, topo, input_generator, num_micros, num_warmup, forward_only, interleaving_size) -> None: + def __init__(self, topo, model, data_iter, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo self.comm = self.topo.get_comm("pipe") - self.input_generator = input_generator + self.input_generator = self.generator(data_iter) self.num_micros = num_micros self.num_warmup = num_warmup self.forward_only = forward_only self.interleaving_size = interleaving_size + self.model = model + self.send_handle = {"next":[], "prev":[]} + self.recv_handle = {"next":[], "prev":[]} + def generator(self, data_iterator): + while True: + try: + inp = next(data_iterator) + yield self.model.preprocess_func(inp) + except StopIteration: + break def param_reduce(self, module): for name, param in module.named_parameters(): @@ -18,7 +29,6 @@ def param_reduce(self, module): param.data = p def get_data(self): - assert config["topology"].pipe_rank == 0 micro_batch = next(self.input_generator) assert isinstance(micro_batch, Iterable) return list(micro_batch) @@ -31,23 +41,21 @@ def send_next(self, tensors): elif not isinstance(tensors, list): tensors = list(tensors) handle.append(send_activations_list(tensors, self.topo.pipe_rank + 1, self.comm, async_op=True)) - for h in handle: - h.wait() + self.send_handle["next"] = handle def send_prev(self, tensors): - handle = [] if not self.is_first_stage(): if not isinstance(tensors, Iterable): tensors = [tensors] elif not isinstance(tensors, list): tensors = list(tensors) - handle.append(send_activations_list(tensors, self.topo.pipe_rank - 1, self.comm, async_op=True)) - for h in handle: - h.wait() + self.send_handle["prev"].append(send_activations_list(tensors, self.topo.pipe_rank - 1, self.comm, async_op=True)) def recv_prev(self, need_data=False): if not self.is_first_stage(): - res = recv_activations_list(self.topo.pipe_rank - 1, self.comm) + res, h = recv_activations_list(self.topo.pipe_rank - 1, self.comm) + self.recv_handle["prev"].append(h) + synchronize(config["pp_zero_comm"]) for idx,tensor in enumerate(res): if idx == 0: tensor.requires_grad_() @@ -60,7 +68,9 @@ def recv_prev(self, need_data=False): def recv_next(self): if not self.is_last_stage(): - return recv_activations_list(self.topo.pipe_rank + 1, self.comm) + res, h = recv_activations_list(self.topo.pipe_rank + 1, self.comm) + self.recv_handle["next"].append(h) + return res else: return [None] @@ -72,8 +82,10 @@ def is_first_stage(self): def is_last_stage(self): return self.topo.pipe_rank == self.topo.pipe_size - 1 + def is_even_rank(self): return self.topo.pipe_rank % 2 == 0 + def send_forward_recv_backward(self, forward_state): if not self.is_last_stage(): if forward_state[0] is not None: diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index d14d5c68..97f73b3e 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -28,9 +28,9 @@ def backward_step(inp, output, grad_output, optim_manager=None): # if output_tensor_grad[0] is None and config.grad_scale_func is not None: # output_tensor[0] = config.grad_scale_func(output_tensor[0]) if optim_manager is not None and config["topology"].is_last_rank(): - if isinstance(output[0], Iterable): + if not torch.is_tensor(output[0]) and isinstance(output[0], Iterable): output = optim_manager.scale_loss(output[0][0]) - else: + elif torch.is_tensor(output[0]): output = optim_manager.scale_loss(output[0]) else: @@ -95,17 +95,10 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros num_warmup = num_micro_batches else: num_warmup = topo.pipe_size - topo.pipe_rank - 1 - def generator(data_iterator): - while True: - try: - inp = next(data_iterator) - yield model.preprocess_func(inp) - except StopIteration: - break interleaving_size = 1 - commander = PipeCommander(topo,input_generator=generator(data_iterator), num_micros=num_micro_batches,\ + commander = PipeCommander(topo,model=model, data_iter=data_iterator, num_micros=num_micro_batches,\ num_warmup=num_warmup, forward_only=False, \ - interleaving_size=interleaving_size, \ + interleaving_size=interleaving_size \ ) inps = [] outputs = [] @@ -168,7 +161,7 @@ def generator(data_iterator): commander.send_prev(input_grad) blocklist = model.get_blocklist() - blocklist.reduce_tied_module() + # blocklist.reduce_tied_module() grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, clip_grad, norm_type=2) optim_manager.step() diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index 119f9538..d416c941 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -2,19 +2,6 @@ import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder from bmtrain.global_var import config -class InputWrapper(bmt.DistributedModule): - def __init__(self, module_list): - super().__init__() - - self._module = {} - for i in range(len(module_list)): - self._module[str(i)] = module_list[i] - - def forward(self, *args): - output_list = [] - for idx,i in enumerate(args): - output_list.append(self._module[str(idx)](i)) - return sum(output_list) class GPTPipe(bmt.DistributedModule): def __init__(self, @@ -27,12 +14,11 @@ def __init__(self, self.max_distance = max_distance - if config['tp_size'] > 1: - word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) - else: - word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + # if config['tp_size'] > 1: + # word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + # else: + word_emb = Embedding(vocab_size, dim_model, dtype=dtype) pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - # self.inp_emb = InputWrapper([word_emb, pos_emb]) blocklist = [] blocklist += [ TransformerEncoder( @@ -47,10 +33,13 @@ def __init__(self, self.layernorm = self.transformers.add_tail(layernorm) self.word_emb = self.transformers.add_head_tail(word_emb) if config['tp_size'] > 1: - self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) else: self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + def get_blocklist(self): + return self.transformers + def forward(self, input : torch.LongTensor, # (batch, seq_len) pos : torch.LongTensor, # (batch, seq_len) @@ -65,10 +54,7 @@ def forward(self, out = self.transformers(input, mask_2d, None) out = self.layernorm(out) if config['topology'].pipe_rank == config['topology'].pipe_size - 1: - if config['tp_size'] > 1: - logits = self.word_emb.projection(out) - else: - logits = self.word_emb(out, True) + logits = self.word_emb(out, True) logits = logits.float().view(-1, logits.shape[-1]) target = target.view(-1) return self.loss_func(logits, target) diff --git a/example/pipe_train.py b/example/pipe_train.py index 27cdf44a..511e31e1 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -11,8 +11,8 @@ def main(): bmt.init_distributed( seed=0, - tp_size=1, pipe_size=4, + tp_size=1, debug=True ) @@ -25,7 +25,7 @@ def main(): dim_ff=8192, max_distance=1024, bias=True, - dtype=torch.float32 + dtype=torch.float16 ) @@ -36,8 +36,9 @@ def main(): # data # generate dummy data for each rank torch.manual_seed(1234) - - batch_size = 2 * 16 + micro = 2 + num_micros = 16 + batch_size = micro * num_micros seq_len = 512 def data_loader(): for i in range(1000): @@ -80,21 +81,14 @@ def data_loader(): lookup_output(model, look_weight=False) if iteration >= 4: with custom_redirection(f"pp_output_{rank}"): - global_loss = pipeline_forward_backward(model, data_loader(), batch_size, optim_manager) + global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) else: - global_loss = pipeline_forward_backward(model, data_loader(), batch_size, optim_manager) - - - bmt.synchronize() - if iteration == 4: - bmt.save(model, f"pipe_{rank}_iter4.ckpt") - if bmt.config["topology"].is_last_rank() or bmt.config["topology"].pipe_rank == 0: - is_head = bmt.config["topology"].pipe_rank == 0 - torch.save(model.word_emb.weight.grad, f"word_emb.ckpt_{int(is_head)}") + global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) # record time and loss iteration_time = time.time() - st if bmt.config["topology"].is_last_rank(): + global_loss = sum(list(global_loss))/len(global_loss) avg_time_recorder.record(iteration_time) avg_loss_recorder.record(global_loss) print( From 05bb1b3be37c155f546739c5782071c36593a766 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 20 Oct 2023 14:41:52 +0800 Subject: [PATCH 14/43] fix data loader for 1f1b --- bmtrain/pipe/comm.py | 38 ++++++++++++++++++++++++++------------ bmtrain/pipe/schedule.py | 2 +- example/models/pipe_gpt.py | 2 +- example/pipe_train.py | 5 +---- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 79bb5ee2..dcfa6c50 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -15,11 +15,27 @@ def __init__(self, topo, model, data_iter, num_micros, num_warmup, forward_only, self.model = model self.send_handle = {"next":[], "prev":[]} self.recv_handle = {"next":[], "prev":[]} + + def is_first_stage(self): + if self.interleaving_size == 1: + return self.topo.is_first_rank("pipe") + else: + raise ValueError("Now only supoort interleaving_size == 1") + + def is_last_stage(self): + if self.interleaving_size == 1: + return self.topo.is_last_rank("pipe") + else: + raise ValueError("Now only supoort interleaving_size == 1") + def generator(self, data_iterator): while True: try: inp = next(data_iterator) - yield self.model.preprocess_func(inp) + if self.is_first_stage(): + yield self.model.preprocess_func(inp) + else: + yield inp except StopIteration: break @@ -51,15 +67,19 @@ def send_prev(self, tensors): tensors = list(tensors) self.send_handle["prev"].append(send_activations_list(tensors, self.topo.pipe_rank - 1, self.comm, async_op=True)) + def wait(self): + torch.cuda.current_stream().wait_stream(config["pp_comm_stream"]) + def recv_prev(self, need_data=False): if not self.is_first_stage(): - res, h = recv_activations_list(self.topo.pipe_rank - 1, self.comm) - self.recv_handle["prev"].append(h) + res, handle = recv_activations_list(self.topo.pipe_rank - 1, self.comm) + self.recv_handle["prev"].append(handle) synchronize(config["pp_zero_comm"]) for idx,tensor in enumerate(res): if idx == 0: tensor.requires_grad_() - return res + data = self.get_data() + return res + data[1:] else: if need_data: return self.get_data() @@ -68,8 +88,8 @@ def recv_prev(self, need_data=False): def recv_next(self): if not self.is_last_stage(): - res, h = recv_activations_list(self.topo.pipe_rank + 1, self.comm) - self.recv_handle["next"].append(h) + res, handle = recv_activations_list(self.topo.pipe_rank + 1, self.comm) + self.recv_handle["next"].append(handle) return res else: return [None] @@ -77,12 +97,6 @@ def recv_next(self): def allocate_tensor(self, shape, dtype): return torch.empty(shape, dtype=dtype, device="cuda") - def is_first_stage(self): - return self.topo.pipe_rank == 0 - - def is_last_stage(self): - return self.topo.pipe_rank == self.topo.pipe_size - 1 - def is_even_rank(self): return self.topo.pipe_rank % 2 == 0 diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 97f73b3e..9054b780 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -58,7 +58,7 @@ def forward_func(model, inp, micro_idx, is_last_micro=False): config['logger'].info("inp shape: {}".format(inp[0].shape)) hidden_state = model(*inp) config['logger'].info("inp shape: {}".format(hidden_state[0].shape)) - if not isinstance(hidden_state, Iterable): + if torch.is_tensor(hidden_state) or (not isinstance(hidden_state, Iterable)): hidden_state = [hidden_state] return hidden_state diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index d416c941..8bb1ffb8 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -59,7 +59,7 @@ def forward(self, target = target.view(-1) return self.loss_func(logits, target) else: - return out, pos, mask, target + return out def preprocess_func(self, inp): if config['topology'].pipe_rank == 0: diff --git a/example/pipe_train.py b/example/pipe_train.py index 511e31e1..3353060f 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -77,11 +77,8 @@ def data_loader(): # load data st = time.time() rank = bmt.config["topology"].pipe_rank - if iteration == 4: - lookup_output(model, look_weight=False) if iteration >= 4: - with custom_redirection(f"pp_output_{rank}"): - global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) else: global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) # record time and loss From 5b7a18c4be51b97bd4da501428c9d2cb28e46525 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 20 Oct 2023 14:47:38 +0800 Subject: [PATCH 15/43] fix context for 1f1b --- bmtrain/hook_func.py | 56 +++++++++++++++++++++++++++++------------ bmtrain/pipe_layer.py | 2 +- bmtrain/zero_context.py | 6 ++++- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 2c6108b0..a69aaa6f 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -4,10 +4,13 @@ def zero_pre_forward(module, inputs): enter = True - pipe = False - if module._mode == "PIPE": - enter = module._micro_idx == 0 - pipe = True + if module._mode == "PIPE" or module._mode == "1F1B": + if not hasattr(module, "_micro_forward_idx") or module._micro_forward_idx == -1: + module._micro_forward_idx = 0 + enter = True + else: + enter = False + module._micro_forward_idx += 1 if enter: zero_level = module._zero_level forward_flag = 1 if zero_level == 2 else 0 @@ -15,40 +18,61 @@ def zero_pre_forward(module, inputs): forward_flag = 2 # repeating forward in same layer if module.all_param_no_grad: #only forward forward_flag = 0 - module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe) - module._forward_block_ctx.enter(forward_flag) + if module._mode == "1F1B": + module._block_ctx = ZeroContext(module, module._layer_dict) + module._block_ctx.enter(0, requires_grad=True) + else: + module._forward_block_ctx = ZeroContext(module, module._layer_dict) + module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): forward_flag = 1 if module._zero_level == 2 else 0 if module.all_param_no_grad: forward_flag = 0 exit = True - if module._mode == "PIPE": - exit = module._micro_idx == config['micros'] - 1 + if module._mode == "PIPE" or module._mode == "1F1B": + if module._micro_forward_idx == config["micros"] - 1: + module._micro_forward_idx = -1 + if module._mode == "1F1B": + exit = False + else: + exit = True + else: + exit = False if exit: module._forward_block_ctx.exit(forward_flag) def zero_pre_backward(module, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 - if module._mode != "PIPE": + if module._mode != "PIPE" and module._mode != "1F1B": module._backward_block_ctx = ZeroContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) module.release_next_module(backward_flag) else: - if module._micro_idx == config['micros'] - 1: - module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True) - module._backward_block_ctx.enter(backward_flag, True) + if not hasattr(module, "_micro_backward_idx") or module._micro_backward_idx == -1: + if module._mode == "1F1B": + module._micro_backward_idx = 0 + else: + module._micro_backward_idx = 0 + module._backward_block_ctx = ZeroContext(module, module._layer_dict) + module._backward_block_ctx.enter(backward_flag,requires_grad=True) + else: + module._micro_backward_idx += 1 def zero_post_backward(module, grad_inputs, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 - if module._mode != "PIPE": + if module._mode != "PIPE" and module._mode != "1F1B": if module._is_first_layer: module.release(backward_flag) else: - if module._micro_idx == 0: - module.release(backward_flag) - module._micro_idx -= 1 + if module._micro_backward_idx == config["micros"] - 1: + if module._mode == "1F1B": + module._block_ctx.exit(0, backward=True) + config['load_stream'].record_event(config['load_event']) + else: + module.release(backward_flag) + module._micro_backward_idx = -1 class OneStepNoGradFunc(torch.autograd.Function): """ diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 409e1eaa..00697134 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -245,7 +245,7 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa hidden_state = StagePreFunction.apply(hidden_state, self.pipe_rank) for idx,layer_id in enumerate(self.layer_ids): - self._modules[str(layer_id)]._micro_idx = micro_idx + # self._modules[str(layer_id)]._micro_idx = micro_idx if return_hidden_states: micro_hidden_states.append(hidden_state) hidden_state = self._modules[str(layer_id)](hidden_state, *arg) diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py index 653f40fa..4f287e3f 100644 --- a/bmtrain/zero_context.py +++ b/bmtrain/zero_context.py @@ -4,7 +4,7 @@ from .synchronize import wait_loader class ZeroContext: - def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> None: + def __init__(self, block : 'Block', ctx_dict : dict = None) -> None: self.block = block self.ctx_dict = ctx_dict self._param_buffer = {} @@ -16,6 +16,10 @@ def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> Non def enter(self, flag=0, requires_grad=False): """ gather parameters + flags = 0: normal mode + flags = 1: gather param and not release , then save in ctx_dict + flags = 2: not gather param and use the param in ctx_dict + """ if self.block._ready: return From 2a1fda1d3623e94d34e2eab68e7e870a497b2f6f Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 20 Oct 2023 15:56:38 +0800 Subject: [PATCH 16/43] better example validation --- example/convert.py | 85 +++++++++++++++++++++++++++++++++++++++++++ example/pipe_train.py | 9 +++-- example/train.py | 18 +++++---- 3 files changed, 100 insertions(+), 12 deletions(-) create mode 100644 example/convert.py diff --git a/example/convert.py b/example/convert.py new file mode 100644 index 00000000..c72aee90 --- /dev/null +++ b/example/convert.py @@ -0,0 +1,85 @@ +import bmtrain as bmt +import torch +from models import GPT, GPTPipe +import re +from collections import OrderedDict + +def partition(pipe_rank,pipe_size,len_modules): + part_lens = [0]+[(len_modules // pipe_size + (i < (len_modules % pipe_size))) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] + return start,end + +def key_process(key, pipe_size , rank, start, end): + res = re.search("\.(\d+)\.", key) + if res is not None: + layer_idx = int(res.group(1)) + else: + layer_idx = None + if layer_idx is None or (layer_idx >= start and layer_idx < end): + if rank == 0: + if key in ["word_emb.weight","pos_emb.weight"]: + return key + else: + if layer_idx is not None: + return re.sub(r"\d+", str(layer_idx), key) + elif rank == pipe_size - 1: + if key in ["word_emb.weight"] or key.startswith("layernorm"): + return key + else: + if layer_idx is not None: + return re.sub(r"\d+", str(layer_idx - start), key) + else: + if layer_idx is not None: + return re.sub(r"\d+", str(layer_idx - start), key) + else: + return None + + + +def init_model(): + model = GPT( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half + ) + return model + +def get_len_modules(state): + max_len = 0 + for key in state: + s = re.search("\.(\d+)\.", key) + if s is not None: + res = int(s.group(1)) + if res>max_len: + max_len = res + return max_len+1 + + +if __name__ == "__main__": + bmt.init_distributed() + model = init_model() + bmt.load(model, "ckpt-0.pt") + pipe_size = 4 + state = model.state_dict() + + for rank in range(pipe_size): + print(rank) + dic = OrderedDict() + len_modules = get_len_modules(state) + s,e = partition(rank, pipe_size, len_modules) + print(s," ",e) + for i in state.keys(): + k = key_process(i, pipe_size, rank, s, e) + if k is not None: + dic[k] = state[i] + print(dic.keys()) + torch.save(dic, f"pipe_{rank}.ckpt") + + \ No newline at end of file diff --git a/example/pipe_train.py b/example/pipe_train.py index 3353060f..d781ef65 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -27,8 +27,7 @@ def main(): bias=True, dtype=torch.float16 ) - - + inspect_iter = -1 bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() @@ -77,8 +76,10 @@ def data_loader(): # load data st = time.time() rank = bmt.config["topology"].pipe_rank - if iteration >= 4: - global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + if iteration == inspect_iter: + lookup_output(model) + with custom_redirection(f"outputs/pp_output_{pipe_rank}"): + global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) else: global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) # record time and loss diff --git a/example/train.py b/example/train.py index cc8b3f61..53dc5e26 100644 --- a/example/train.py +++ b/example/train.py @@ -23,6 +23,7 @@ def main(): bias=True, dtype=torch.half ) + inspect_iter = -1 bmt.load(model, "./ckpt-0.pt") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) @@ -66,8 +67,9 @@ def main(): for iteration in range(10): # load data st = time.time() - # if iteration == 1: - # lookup_output(model, look_weight=True) + if iteration == inspect_iter: + lookup_output(model) + sum_loss = 0 for micro in range(global_batch // batch_size): # for i in range(bmt.world_size()): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) @@ -88,7 +90,7 @@ def main(): pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) # if iteration == 4: # lookup_output(model) - if iteration >= 4: + if iteration == inspect_iter: with custom_redirection("dp_ref.output"): logits = model( enc_input, @@ -108,9 +110,8 @@ def main(): else: loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = loss.item() - optim_manager.zero_grad() optim_manager.backward(loss) - + sum_loss += global_loss # print inspected tensors in the forward & backward pass # print parameters of the model if iteration % 100 == 0: @@ -120,17 +121,18 @@ def main(): ) ) optim_manager.step() + optim_manager.zero_grad() # record time and loss iteration_time = time.time() - st avg_time_recorder.record(iteration_time) - avg_loss_recorder.record(global_loss) - + num_micro = global_batch // batch_size + avg_loss_recorder.record(sum_loss/num_micro) # print time and loss bmt.print_rank( "| Iter: {:6d} | loss: {:.10f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( iteration, - global_loss, + sum_loss / num_micro, avg_loss_recorder.value, lr_scheduler.current_lr, optim_manager.loss_scale, From db60ce37c728c06553d5eb649fa2c0bf11bcfeef Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 25 Oct 2023 13:00:44 +0800 Subject: [PATCH 17/43] Optimizer for 1f1b adaption --- bmtrain/lr_scheduler/warmup.py | 8 +++++--- bmtrain/nn/parallel_embedding.py | 2 +- bmtrain/optim/optim_manager.py | 5 ++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bmtrain/lr_scheduler/warmup.py b/bmtrain/lr_scheduler/warmup.py index 0f08a600..2eb23466 100644 --- a/bmtrain/lr_scheduler/warmup.py +++ b/bmtrain/lr_scheduler/warmup.py @@ -16,7 +16,7 @@ def __init__(self, optimizer : torch.optim.Optimizer, start_lr, warmup_iter, end self.warmup_iter = warmup_iter self.end_iter = end_iter self.optimizer = optimizer - self.num_iter = num_iter + self.num_iter = 0 self._current_lr = None self.step(self.num_iter) @@ -37,9 +37,11 @@ def get_lr(self): def current_lr(self): return self._current_lr - def step(self, num_iter = None) -> None: - if num_iter is None: + def step(self, num_step = None) -> None: + if num_step is None: num_iter = self.num_iter + 1 + else: + num_iter = self.num_iter + num_step self.num_iter = num_iter lr = self.get_lr() diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index cd567b4e..969bd923 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -54,7 +54,7 @@ def forward(self, ids: torch.Tensor, gather_input=True): embed_list = embeds.chunk(config['tp_size'], dim=0) embeds = embed_list[config['topology'].tp_id].flatten(0,1) - return embeds.clone() + return embeds def projection(self, x: torch.Tensor): """ diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index f69c98e8..c6131fe1 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -140,7 +140,10 @@ def step(self): optimizer.step() if lr_scheduler is not None: - lr_scheduler.step() + if config["pipe_size"] > 1: + lr_scheduler.step(config["micros"] // config["pipe_size"] // config["zero_size"]) + else: + lr_scheduler.step() if self.loss_scale_enabled: self.steps_since_last_scale += 1 From 5da4c7f092a7f73320acd2bddb0a26b7ba4edcb0 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 26 Oct 2023 20:53:48 +0800 Subject: [PATCH 18/43] fix logger --- bmtrain/init.py | 2 ++ example/convert.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 39cae9e0..7ba5f699 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -86,6 +86,8 @@ def init_distributed( config["pipe_rank"] = config['topology'].get_group_rank("pipe") if debug: config["logger"] = get_logger(rank, "DEBUG") + else: + config["logger"] = get_logger(rank, "ERROR") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) diff --git a/example/convert.py b/example/convert.py index c72aee90..2a20c281 100644 --- a/example/convert.py +++ b/example/convert.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict -def partition(pipe_rank,pipe_size,len_modules): +def partition(pipe_rank, pipe_size, len_modules): part_lens = [0]+[(len_modules // pipe_size + (i < (len_modules % pipe_size))) for i in range(pipe_rank+1)] start = sum(part_lens[:pipe_rank+1]) end = start + part_lens[pipe_rank+1] From ca2363fabe4fcbcd9f6f907579a4f4d97bc89f47 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 27 Oct 2023 11:31:45 +0800 Subject: [PATCH 19/43] fix logger level --- bmtrain/init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 7ba5f699..4d7dbcb1 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -87,7 +87,7 @@ def init_distributed( if debug: config["logger"] = get_logger(rank, "DEBUG") else: - config["logger"] = get_logger(rank, "ERROR") + config["logger"] = get_logger(rank, "WARNING") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) From 99eda097632c82632267913a8c36203689a262f8 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 3 Nov 2023 13:20:04 +0800 Subject: [PATCH 20/43] fix scale --- bmtrain/optim/optim_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index c6131fe1..3be03bb7 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -85,7 +85,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * (self.loss_scale / config['world_size'] * config['pipe_size']) # loss scale + return loss * (self.loss_scale / config['world_size'] * config['pipe_size'] * config['tp_size']) # loss scale def backward(self, loss : torch.Tensor): """ From 586d0b888bf3b8fad307f1cc7f9c290c899d5704 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 Nov 2023 11:02:33 +0800 Subject: [PATCH 21/43] fix recv async bug --- bmtrain/distributed/ops.py | 2 ++ bmtrain/store.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 248f452a..98269363 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -102,6 +102,8 @@ def recv_activations_list(prev_rank, comm, async_op = True): elif flag == 1: recv = recv_object(prev_rank, comm) hidden_state_list.append(recv) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["pp_comm_stream"]) for hidden_state in hidden_state_list: if torch.is_tensor(hidden_state): hidden_state.record_stream(torch.cuda.current_stream()) diff --git a/bmtrain/store.py b/bmtrain/store.py index 7279ac53..8f731190 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -87,7 +87,7 @@ def async_save_to_file(state_dict, file_path): config['finish_save'] = True print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. From 783adbf515da1247478c5831d7ffee550db0aed9 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 Nov 2023 12:25:31 +0800 Subject: [PATCH 22/43] fix lr_scheduler step and delete trash file --- bmtrain/nn/pipe_embedding.py | 101 --------------------------------- bmtrain/optim/optim_manager.py | 5 +- 2 files changed, 1 insertion(+), 105 deletions(-) delete mode 100644 bmtrain/nn/pipe_embedding.py diff --git a/bmtrain/nn/pipe_embedding.py b/bmtrain/nn/pipe_embedding.py deleted file mode 100644 index fc6d92b7..00000000 --- a/bmtrain/nn/pipe_embedding.py +++ /dev/null @@ -1,101 +0,0 @@ -import math -from typing import Optional -import torch -import torch.nn.functional as F -import bmtrain as bmt - - -class PipeEmbedding(bmt.DistributedModule): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[torch.Tensor] = None, - dtype=None): - super().__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.max_norm = max_norm - self.norm_type = norm_type - self.scale_grad_by_freq = scale_grad_by_freq - if _weight is None: - self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_) - else: - self.weight = bmt.DistributedParameter(_weight) - self.sparse = sparse - - @classmethod - def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, - sparse=False): - r"""Creates Embedding instance from given 2-dimensional FloatTensor. - - Args: - embeddings (Tensor): FloatTensor containing weights for the Embedding. - First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. - freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. - Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` - padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; - therefore, the embedding vector at :attr:`padding_idx` is not updated during training, - i.e. it remains as a fixed "pad". - max_norm (float, optional): See module initialization documentation. - norm_type (float, optional): See module initialization documentation. Default ``2``. - scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. - sparse (bool, optional): See module initialization documentation. - - Examples:: - - >>> # FloatTensor containing pretrained weights - >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) - >>> embedding = nn.Embedding.from_pretrained(weight) - >>> # Get embeddings for index 1 - >>> input = torch.LongTensor([1]) - >>> embedding(input) - tensor([[ 4.0000, 5.1000, 6.3000]]) - """ - assert embeddings.dim() == 2, \ - 'Embeddings parameter is expected to be 2-dimensional' - rows, cols = embeddings.shape - embedding = cls( - num_embeddings=rows, - embedding_dim=cols, - _weight=embeddings, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - embedding.weight.requires_grad = not freeze - return embedding - - def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: - if not projection: - out = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - return out - else: - out = F.linear(input, self.weight) - return out - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.max_norm is not None: - s += ', max_norm={max_norm}' - if self.norm_type != 2: - s += ', norm_type={norm_type}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) - - diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 3be03bb7..7f198cdd 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -140,10 +140,7 @@ def step(self): optimizer.step() if lr_scheduler is not None: - if config["pipe_size"] > 1: - lr_scheduler.step(config["micros"] // config["pipe_size"] // config["zero_size"]) - else: - lr_scheduler.step() + lr_scheduler.step() if self.loss_scale_enabled: self.steps_since_last_scale += 1 From 5c2222b575e644158ebb8fc38a5ef3ac5e2edf90 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 8 Nov 2023 13:20:21 +0800 Subject: [PATCH 23/43] avoid comm when no need --- bmtrain/block_layer.py | 21 +++++++++++++++------ bmtrain/nn/__init__.py | 1 - bmtrain/zero_context.py | 21 +++++++++++++++++++-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 985b5348..54828e60 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -189,7 +189,7 @@ def init_param_storage(self): # copy values to buffer for normal parameter storage_st = self._storage_info[kw_name]["begin"] storage_end = self._storage_info[kw_name]["end"] - + comm = self._storage_info[kw_name]["zero_comm"] # make parameter contiguous in storage with torch.no_grad(): contiguous_param = OpAllGather.apply(param) @@ -207,11 +207,16 @@ def init_param_storage(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) self._param_info[-1]["begin"] = to_offset_st self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + if nccl.commCount(comm) != 1: + param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + else: + param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, param_shape) + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, param_shape)[:] del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) @@ -455,8 +460,12 @@ def init_parameters(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] + if nccl.commCount(self._storage_info[kw_name]["zero_comm"]) == 1: + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, it["shape"])[:] + else: + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] del tmp_tensor diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 3b24b3ca..b5ceb80d 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -4,4 +4,3 @@ from .parallel_embedding import ParallelEmbedding from .parallel_cross_entropy_func import parallel_cross_entropy_func from .parallel_linear_func import OpParallelLinear -from .pipe_embedding import PipeEmbedding \ No newline at end of file diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py index 4f287e3f..030371b5 100644 --- a/bmtrain/zero_context.py +++ b/bmtrain/zero_context.py @@ -30,6 +30,8 @@ def enter(self, flag=0, requires_grad=False): with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda + if nccl.commCount(val['zero_comm']) == 1: + continue assert kw not in self._grad_buffer assert kw not in self._param_buffer local_param = self.block._storage_params[kw] @@ -45,6 +47,8 @@ def enter(self, flag=0, requires_grad=False): if flag != 2: nccl.groupStart() for kw, val in self.block._storage_info.items(): + if nccl.commCount(val['zero_comm']) == 1: + continue nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], @@ -57,6 +61,8 @@ def enter(self, flag=0, requires_grad=False): # set wait stream for each storage for kw in self.block._storage_info.keys(): + if nccl.commCount(self.block._storage_info[kw]['zero_comm']) == 1: + continue if flag != 2: self._param_tensor[kw].record_stream(current_stream) if requires_grad and kw in self._grad_tensor: @@ -68,6 +74,9 @@ def enter(self, flag=0, requires_grad=False): offset = param["offset"] shape = param["shape"] + if nccl.commCount(self.block._storage_info[kw_name]["zero_comm"]): + continue + if flag != 2: dtype = self._param_buffer[kw_name].dtype device = self._param_buffer[kw_name].device @@ -94,8 +103,11 @@ def exit(self, flag=0, backward=False): self.block._ready = False if backward: for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] + if nccl.commCount(val['zero_comm']) == 1: + continue + + local_param = self.block._storage_params[kw] # accumulate previous gradient if local_param.requires_grad: if local_param.grad is None: @@ -110,8 +122,11 @@ def exit(self, flag=0, backward=False): with torch.cuda.stream(config["load_stream"]): nccl.groupStart() for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] + if nccl.commCount(val["zero_comm"]): + continue + + local_param = self.block._storage_params[kw] # scatter gradient if local_param.requires_grad: nccl.reduceScatter( @@ -131,6 +146,8 @@ def exit(self, flag=0, backward=False): # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] + if nccl.commCount(self.block._storage_info[kw_name]["zero_comm"]): + continue dtype = self.block._storage_params[kw_name].dtype device = self.block._storage_params[kw_name].device if "begin" not in param: From 72bfc33948787e3423cf7315a9fe01b01f5b3f3b Mon Sep 17 00:00:00 2001 From: Maydomine <1583143678@qq.com> Date: Wed, 8 Nov 2023 16:36:18 +0800 Subject: [PATCH 24/43] fix comm bug --- bmtrain/block_layer.py | 6 +++--- bmtrain/zero_context.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 54828e60..2ad140e1 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -694,7 +694,7 @@ def _add_head(self, module): self.fisrt_module = (module,) def add_head(self, module): - module = _block_wrapper(module, self.module_dict, mode="1F1B") + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=False) module.init_param_storage() if config['topology'].pipe_rank != 0: return DummyForward @@ -708,7 +708,7 @@ def get_last_layer(self): return self._modules[str(len(self)-1)] def add_head_tail(self, module): - module = _block_wrapper(module, self.module_dict, mode="1F1B") + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=False) module.init_param_storage() if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): return DummyForward @@ -748,7 +748,7 @@ def _add_tail(self, module): self.last_module = (module,) def add_tail(self, module): - module = _block_wrapper(module, self.module_dict, mode="1F1B") + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=False) module.init_param_storage() if config['topology'].pipe_rank != config['topology'].pipe_size - 1: return DummyForward diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py index 030371b5..94c0112d 100644 --- a/bmtrain/zero_context.py +++ b/bmtrain/zero_context.py @@ -30,7 +30,7 @@ def enter(self, flag=0, requires_grad=False): with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda - if nccl.commCount(val['zero_comm']) == 1: + if val["world_size"] == 1: continue assert kw not in self._grad_buffer assert kw not in self._param_buffer @@ -47,7 +47,7 @@ def enter(self, flag=0, requires_grad=False): if flag != 2: nccl.groupStart() for kw, val in self.block._storage_info.items(): - if nccl.commCount(val['zero_comm']) == 1: + if val["world_size"] == 1: continue nccl.allGather( self.block._storage_params[kw].storage(), @@ -61,7 +61,7 @@ def enter(self, flag=0, requires_grad=False): # set wait stream for each storage for kw in self.block._storage_info.keys(): - if nccl.commCount(self.block._storage_info[kw]['zero_comm']) == 1: + if self.block._storage_info[kw]['world_size'] == 1: continue if flag != 2: self._param_tensor[kw].record_stream(current_stream) @@ -74,7 +74,7 @@ def enter(self, flag=0, requires_grad=False): offset = param["offset"] shape = param["shape"] - if nccl.commCount(self.block._storage_info[kw_name]["zero_comm"]): + if self.block._storage_info[kw_name]["world_size"] == 1: continue if flag != 2: @@ -104,7 +104,7 @@ def exit(self, flag=0, backward=False): if backward: for kw, val in self.block._storage_info.items(): - if nccl.commCount(val['zero_comm']) == 1: + if val['world_size'] == 1: continue local_param = self.block._storage_params[kw] @@ -123,7 +123,7 @@ def exit(self, flag=0, backward=False): nccl.groupStart() for kw, val in self.block._storage_info.items(): - if nccl.commCount(val["zero_comm"]): + if val["world_size"] == 1: continue local_param = self.block._storage_params[kw] @@ -146,7 +146,7 @@ def exit(self, flag=0, backward=False): # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] - if nccl.commCount(self.block._storage_info[kw_name]["zero_comm"]): + if self.block._storage_info[kw_name]["world_size"] == 1: continue dtype = self.block._storage_params[kw_name].dtype device = self.block._storage_params[kw_name].device From b06c59f93e0a38898e9e4b4ded7045d18636959c Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 8 Nov 2023 18:27:17 +0800 Subject: [PATCH 25/43] add ckpt args in pipe blocklist --- bmtrain/block_layer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 2ad140e1..856dd4f9 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -633,11 +633,14 @@ def DummyForward(*args, **kwargs): class PipeDreamBlockList(TransformerBlockList): - def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: + def __init__(self, modules: Iterable[Block], num_hidden=1, use_checkpoint=False) -> None: module_dict = {} mode = "1F1B" + if isinstance(use_checkpoint, bool): + use_checkpoint = [use_checkpoint for _ in range(len(modules))] + assert isinstance(use_checkpoint,Iterable) and len(use_checkpoint) == len(modules), "use_checkpoint should be a list of bool variable or a bool variable" for idx in range(len(modules)): - modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2, use_checkpoint=False) + modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2, use_checkpoint=use_checkpoint[idx]) s,e = self.partition(modules) self.head_idx = s self.tail_idx = e @@ -693,8 +696,8 @@ def _add_head(self, module): self.fisrt_module[0].set_pre_module(module) self.fisrt_module = (module,) - def add_head(self, module): - module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=False) + def add_head(self, module, use_checkpoint=False): + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) module.init_param_storage() if config['topology'].pipe_rank != 0: return DummyForward @@ -707,8 +710,8 @@ def get_first_layer(self): def get_last_layer(self): return self._modules[str(len(self)-1)] - def add_head_tail(self, module): - module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=False) + def add_head_tail(self, module, use_checkpoint=False): + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) module.init_param_storage() if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): return DummyForward @@ -747,8 +750,8 @@ def _add_tail(self, module): module.set_pre_module(self.last_module[0]) self.last_module = (module,) - def add_tail(self, module): - module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=False) + def add_tail(self, module, use_checkpoint=False): + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) module.init_param_storage() if config['topology'].pipe_rank != config['topology'].pipe_size - 1: return DummyForward From 1e687ea2b7994873e80726dfbeb5028b2de79d70 Mon Sep 17 00:00:00 2001 From: Maydomine <1583143678@qq.com> Date: Thu, 9 Nov 2023 16:48:17 +0800 Subject: [PATCH 26/43] scale loss in 1f1b --- bmtrain/pipe/schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 9054b780..0fe41326 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -32,7 +32,7 @@ def backward_step(inp, output, grad_output, optim_manager=None): output = optim_manager.scale_loss(output[0][0]) elif torch.is_tensor(output[0]): output = optim_manager.scale_loss(output[0]) - + output = output / config['micros'] else: output = output[0] torch.autograd.backward(output, grad_tensors=grad_output[0]) From 7ed0f89116e39a65dde740fff9c9ae1b3ffd26ab Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 10 Nov 2023 16:49:28 +0800 Subject: [PATCH 27/43] pipeline ckpt store and save --- bmtrain/pipe/__init__.py | 3 +- bmtrain/pipe/store.py | 74 ++++++++++++++++++++++++++++++++++++++++ bmtrain/store.py | 7 +++- 3 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 bmtrain/pipe/store.py diff --git a/bmtrain/pipe/__init__.py b/bmtrain/pipe/__init__.py index f697a75d..410e3437 100644 --- a/bmtrain/pipe/__init__.py +++ b/bmtrain/pipe/__init__.py @@ -1 +1,2 @@ -from .schedule import pipeline_forward_backward \ No newline at end of file +from .schedule import pipeline_forward_backward +from .store import load_model_pipe, save_model_pipe \ No newline at end of file diff --git a/bmtrain/pipe/store.py b/bmtrain/pipe/store.py new file mode 100644 index 00000000..5683cb66 --- /dev/null +++ b/bmtrain/pipe/store.py @@ -0,0 +1,74 @@ +import bmtrain as bmt +import torch +import re +from collections import OrderedDict + +def partition(pipe_rank, pipe_size, len_modules): + part_lens = [0]+[(len_modules // pipe_size + (i < (len_modules % pipe_size))) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] + return start,end + +def key_process(key, pipe_size , rank, start, end): + res = re.search("\.(\d+)\.", key) + if res is not None: + layer_idx = int(res.group(1)) + else: + layer_idx = None + if layer_idx is None or (layer_idx >= start and layer_idx < end): + if layer_idx is not None: + return re.sub("\.(\d+)\.", "."+str(layer_idx - start)+".", key) + else: + return key + +def get_len_modules(state): + max_len = 0 + for key in state: + s = re.search("\.(\d+)\.", key) + if s is not None: + res = int(s.group(1)) + if res>max_len: + max_len = res + return max_len+1 + +def get_state_dict_pipe(path): + pipe_size = bmt.config["pipe_size"] + pipe_rank = bmt.config["pipe_rank"] + + if bmt.rank() == 0: + ds_state_dict = bmt.store.DistributedStateDictWrapper(torch.load(path)) + else: + ds_state_dict = bmt.store.DistributedStateDictWrapper({}) + + len_modules = get_len_modules(ds_state_dict) + s,e = partition(pipe_rank, pipe_size, len_modules) + state_dict = OrderedDict() + + for key in ds_state_dict: + param = ds_state_dict[key].broadcast() + k_p = key_process(key, pipe_size, pipe_rank, s, e) + if k_p is not None: + state_dict[k_p] = param + else: + del param + return state_dict + +def load_model_pipe(model, path, load_whole=True): + """ + load_whole: Boolean, if True, load from the whole model file, else load model from the pipeline/tensor parallel model file + """ + if load_whole: + state_dict = get_state_dict_pipe(path) + model.load_state_dict(state_dict, strict=False) + else: + pipe_rank = bmt.config["pipe_rank"] + tp_rank = bmt.config["tp_rank"] + ckpt_path = f"{path}_pp_{pipe_rank}_tp_{tp_rank}.pt" + state_dict = torch.load(ckpt_path) + model.load_state_dict(state_dict) + +def save_model_pipe(model, path): + pipe_rank = bmt.config["pipe_rank"] + tp_rank = bmt.config["tp_rank"] + state_dict = model.state_dict() + torch.save(state_dict, f"{path}_pp_{pipe_rank}_tp_{tp_rank}.pt") diff --git a/bmtrain/store.py b/bmtrain/store.py index 8f731190..509e0717 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -1,7 +1,7 @@ from collections import OrderedDict from typing import Dict import torch - +from .pipe import save_model_pipe, load_model_pipe from .pipe_layer import PipelineTransformerBlockList from .block_layer import TransformerBlockList from .global_var import config @@ -102,7 +102,12 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): >>> bmtrain.save(model, "model.pt") """ torch.cuda.synchronize() + if config["pipe_size"] > 1: + save_model_pipe(model, file_name) + return + state_dict = _save_to_rank0(model) + if config["rank"] == 0: if non_blocking is False: torch.save(state_dict, file_name) From f5933dbef29bccb188e3fe0a2062fb38fc46ae76 Mon Sep 17 00:00:00 2001 From: Achazwl Date: Mon, 13 Nov 2023 19:35:51 +0800 Subject: [PATCH 28/43] clone logits --- bmtrain/loss/cross_entropy.py | 70 +++++++++++++++++++++++++++++++ tests/test_parallel_projection.py | 55 ++++++++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 tests/test_parallel_projection.py diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index a2e123ad..de8c7de2 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -36,6 +36,76 @@ def backward(ctx, grad_output : torch.Tensor): ) return (softmax, None, None) +class VPFusedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits : torch.Tensor, target : torch.Tensor): + comm = config['tp_comm'] + rank = config['tp_rank'] + world_size = config['tp_size'] + + max_logits = torch.max(logits, dim=-1)[0].float() + max_logits = all_reduce(max_logits, op="max", comm=comm) + + partition_vocab_size = logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].contiguous() # (-1,) + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 # if target=-100, it will also be 0 + + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits.float(), op="sum", comm=comm) + predicted_logits = predicted_logits - max_logits + # Sum of exponential of logits along vocab dimension across all GPUs. + + sum_exp_logits = torch.empty(logits.size(0), device=logits.device, dtype=torch.float) + sum_exp_logits = F.fused_sumexp(logits, max_logits) # float + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + 1e-10 # avoid nan + + softmax = logits.clone() + F.fused_softmax_inplace(softmax, max_logits, sum_exp_logits) # logits -> softmax + # logits = logits.float() - max_logits.unsqueeze(dim=-1).float() + # exp_logits = logits + # torch.exp(logits, out=exp_logits) + # sum_exp_logits = exp_logits.sum(dim=-1) + # exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize + ctx.save_for_backward(softmax, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + grad_2d[arange_1d, masked_target_1d] -= softmax_update + grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1)) + + return grad_input, None + class FusedCrossEntropy(torch.nn.Module): r"""This criterion computes the cross entropy loss between input and target. diff --git a/tests/test_parallel_projection.py b/tests/test_parallel_projection.py new file mode 100644 index 00000000..98de30a0 --- /dev/null +++ b/tests/test_parallel_projection.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np +import os + +def run_normal(x, t, ckp_path, dtype): + proj = bmt.nn.Projection(100, 64, dtype=dtype) + bmt.init_parameters(proj) + bmt.save(proj, ckp_path) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) + y = proj.projection(x) + y = y.detach().requires_grad_() + loss = loss_func(y, t) + loss.backward() + return y, loss, y.grad + +def run_vp(x, t, ckp_path, dtype): + proj = bmt.nn.VPProjection(100, 64, dtype=dtype) + bmt.load(proj, ckp_path) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + y = proj.projection(x) + y = y.detach().requires_grad_() + loss = loss_func(y, t) + loss.backward() + return y, loss, y.grad + +def run(dtype): + ckp_path = 'embedding.pt' + torch.cuda.manual_seed(100) + tp_size = config["tp_size"] + tp_rank = config['tp_rank'] + x = torch.randn(110, 64, device='cuda', dtype=dtype) + t = torch.cat([torch.arange(100).view(10, 10), torch.ones((10, 1))*-100], dim=-1).view(110).int().cuda() + y1, loss1, grad1 = run_normal(x, t, ckp_path, dtype) + y2, loss2, grad2 = run_vp(x, t, ckp_path, dtype) + y1 = y1.chunk(tp_size, dim=-1)[tp_rank] + grad1 = grad1.chunk(tp_size, dim=-1)[tp_rank] + for r in range(tp_size): + if bmt.rank() == r: + print((y1-y2).abs().max()) + print((loss1-loss2).abs().max()) + print((grad1-grad2).abs().max()) + assert (y1-y2).abs().max() < 1e-4 + assert (loss1-loss2).abs().max() < 1e-4 + assert (grad1-grad2).abs().max() < 1e-4 + bmt.synchronize() + if bmt.rank() == 0: + os.remove(f"embedding.pt") + +if __name__ == "__main__": + bmt.init_distributed(tp_size=4) + run(torch.half) + run(torch.bfloat16) + From 0bd5f2f40ef923ff6a8e823609f9c665d353e4c6 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 5 Mar 2024 23:35:14 +0800 Subject: [PATCH 29/43] fix init --- bmtrain/distributed/dtype.py | 12 +++ bmtrain/distributed/ops.py | 143 +----------------------------- bmtrain/distributed/p2p_ops.py | 157 +++++++++++++++++++++++++++++++++ bmtrain/init.py | 60 ++++++------- bmtrain/pipe/comm.py | 18 ++-- 5 files changed, 206 insertions(+), 184 deletions(-) create mode 100644 bmtrain/distributed/dtype.py create mode 100644 bmtrain/distributed/p2p_ops.py diff --git a/bmtrain/distributed/dtype.py b/bmtrain/distributed/dtype.py new file mode 100644 index 00000000..a2b18db1 --- /dev/null +++ b/bmtrain/distributed/dtype.py @@ -0,0 +1,12 @@ +import torch +DTYPE_LIST = [ + torch.float64, + torch.float32, + torch.float16, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.bfloat16, + torch.bool +] \ No newline at end of file diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 98269363..938ccc06 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -1,152 +1,13 @@ import torch import bmtrain as bmt from ..global_var import config, rank -from ..nccl import allGather as ncclAllGather, recv +from ..nccl import allGather as ncclAllGather from ..nccl import allReduce as ncclAllReduce from ..nccl import broadcast as ncclBroadcast from ..nccl import reduceScatter as ncclReduceScatter -from ..nccl import send as ncclSend -from ..nccl import recv as ncclRecv from ..nccl import commCount,commRank,NCCLCommunicator,groupStart,groupEnd -import contextlib -import pickle -DTYPE_LIST = [ - torch.float64, - torch.float32, - torch.float16, - torch.int64, - torch.int32, - torch.int16, - torch.int8, - torch.bfloat16, - torch.bool -] -@contextlib.contextmanager -def groupcall(): - groupStart() - yield - groupEnd() - -class handler: - def __init__(self, stream): - self.stream = stream - - def wait(self): - torch.cuda.current_stream().wait_stream(self.stream) - -def send_object(obj, next_rank, comm): - data_bytes: bytes = pickle.dumps(obj) - data_length: int = len(data_bytes) - - gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long) - ncclSend(gpu_data_length.storage(), next_rank, comm) - byte_storage = torch.ByteStorage.from_buffer(data_bytes).cuda() - ncclSend(byte_storage, next_rank, comm) - -def recv_object(prev_rank, comm): - data_length = torch.tensor([0], device="cuda", dtype=torch.long) - ncclRecv(data_length.storage(), prev_rank, comm) - data_bytes_stor = torch.cuda.ByteStorage(data_length.item()) - ncclRecv(data_bytes_stor, prev_rank, comm) - tensor = torch.ByteTensor(data_bytes_stor.cpu()) - data = pickle.loads(tensor.numpy().tobytes()) - return data +from .p2p_ops import * -def send_activations_list(hidden_state_list, next_rank, comm, async_op=False): - if async_op: - current_stream = torch.cuda.current_stream() - with torch.cuda.stream(config["pp_comm_stream"]): - config["pp_comm_stream"].wait_stream(current_stream) - length = torch.tensor(data=[len([h for h in hidden_state_list ])], device="cuda", dtype=torch.int) - ncclSend(length.storage(), next_rank, comm) - flags = torch.tensor(data=[0 for _ in range(len(hidden_state_list))], device="cuda",dtype=torch.int) - for i in range(len(hidden_state_list)): - if hidden_state_list[i] is None: - flag = -1 - elif torch.is_tensor(hidden_state_list[i]): - flag = 0 - else: - flag = 1 - flags[i] = flag - ncclSend(flags.contiguous().storage(), next_rank, comm) - for i in range(len(hidden_state_list)): - if flags[i] == 0: - hidden_state_list[i].record_stream(config["pp_comm_stream"]) - send_activations(hidden_state_list[i], next_rank, comm) - elif flags[i] == 1: - send_object(hidden_state_list[i], next_rank, comm) - return handler(config["pp_comm_stream"]) - else: - length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - length[0] = len(hidden_state_list) - ncclSend(length.storage(), next_rank, comm) - for i in range(length): - send_activations(hidden_state_list[i], next_rank, comm) - - -def recv_activations_list(prev_rank, comm, async_op = True): - if async_op: - with torch.cuda.stream(config["pp_comm_stream"]): - length = torch.tensor(data=[0], device="cuda", dtype=torch.int) - hidden_state_list = [] - ncclRecv(length.storage(), prev_rank, comm) - flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int) - ncclRecv(flags.storage(), prev_rank, comm) - for i in range(length[0].item()): - flag = flags[i].item() - if flag == -1: - hidden_state_list.append(None) - elif flag == 0: - recv = recv_activations(prev_rank, comm) - hidden_state_list.append(recv) - elif flag == 1: - recv = recv_object(prev_rank, comm) - hidden_state_list.append(recv) - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["pp_comm_stream"]) - for hidden_state in hidden_state_list: - if torch.is_tensor(hidden_state): - hidden_state.record_stream(torch.cuda.current_stream()) - return hidden_state_list, handler(config["pp_comm_stream"]) - - - -def send_activations(hidden_state, next_rank, comm): - hidden_state = hidden_state.contiguous() - send_meta(hidden_state, next_rank, comm) - ncclSend(hidden_state.storage(), next_rank, comm) - -def send_activations_inplace(hidden_state, next_rank, comm): - hidden_state = hidden_state.contiguous() - ncclSend(hidden_state.storage(), next_rank, comm) - -def recv_activations_inplace(hidden_state, prev_rank, comm): - hidden_state = hidden_state.contiguous() - ncclRecv(hidden_state.storage(), prev_rank, comm) - return hidden_state - -def recv_activations(prev_rank, comm): - dtype, shape = recv_meta(prev_rank, comm) - hidden_state = torch.empty(shape, dtype=dtype, device="cuda") - ncclRecv(hidden_state.storage(), prev_rank, comm) - return hidden_state - -def send_meta(x, next_rank, comm): - meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) - meta_data[0] = len(x.size()) - meta_data[1] = DTYPE_LIST.index(x.dtype) - meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) - meta_data = meta_data.contiguous() - ncclSend(meta_data.storage(), next_rank, comm) - -def recv_meta(prev_rank, comm): - meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) - ncclRecv(meta_data.storage(), prev_rank, comm) - n_dims = meta_data[0].item() - dtype = DTYPE_LIST[meta_data[1].item()] - shape = meta_data[2:n_dims+2].tolist() - - return dtype,shape class OpBroadcast(torch.autograd.Function): diff --git a/bmtrain/distributed/p2p_ops.py b/bmtrain/distributed/p2p_ops.py new file mode 100644 index 00000000..428940fd --- /dev/null +++ b/bmtrain/distributed/p2p_ops.py @@ -0,0 +1,157 @@ +import torch +from bmtrain import config +from ..nccl import reduceScatter as ncclReduceScatter +from ..nccl import send as ncclSend +from ..nccl import recv as ncclRecv +from ..nccl import groupStart,groupEnd +from .dtype import DTYPE_LIST +import pickle +import contextlib + +_p2p_stream = {} +_p2p_events = {} + +@contextlib.contextmanager +def groupcall(): + groupStart() + yield + groupEnd() +class handler: + def __init__(self, event): + self.event= event + + def wait(self): + torch.cuda.current_stream().wait_event(self.event) + +def send_object(obj, peer_rank, comm): + data_bytes: bytes = pickle.dumps(obj) + data_length: int = len(data_bytes) + + gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long) + ncclSend(gpu_data_length.storage(), peer_rank, comm) + byte_storage = torch.ByteStorage.from_buffer(data_bytes).cuda() + ncclSend(byte_storage, peer_rank, comm) + +def recv_object(peer_rank, comm): + data_length = torch.tensor([0], device="cuda", dtype=torch.long) + ncclRecv(data_length.storage(), peer_rank, comm) + data_bytes_stor = torch.cuda.ByteStorage(data_length.item()) + ncclRecv(data_bytes_stor, peer_rank, comm) + tensor = torch.ByteTensor(data_bytes_stor.cpu()) + data = pickle.loads(tensor.numpy().tobytes()) + return data + +def record_stream_helper(tensor_list, stream): + for t in tensor_list: + t.record_stream(stream) + +def send_tensor(tensor_list, peer_rank, comm): + handler = _send_tensors(tensor_list, peer_rank, comm) + handler.wait() + +def isend_tensor(tensor_list, peer_rank, comm): + return _send_tensors(tensor_list, peer_rank, comm) + +def _send_tensors(tensor_list, peer_rank, comm): + p2p_key = f"send {peer_rank}" + if p2p_key not in _p2p_stream: + _p2p_stream[p2p_key] = torch.cuda.Stream() + if p2p_key not in _p2p_events: + _p2p_events[p2p_key] = torch.cuda.Event() + stream = _p2p_stream[peer_rank] + end_event = _p2p_events[p2p_key] + with torch.cuda.stream(stream): + length = torch.tensor(data=[len([h for h in tensor_list ])], device="cuda", dtype=torch.int) + ncclSend(length.storage(), peer_rank, comm) + flags = torch.tensor(data=[0 for _ in range(len(tensor_list))], device="cuda",dtype=torch.int) + for i in range(len(tensor_list)): + if tensor_list[i] is None: + flag = -1 + elif torch.is_tensor(tensor_list[i]): + flag = 0 + else: + flag = 1 + flags[i] = flag + ncclSend(flags.contiguous().storage(), peer_rank, comm) + for i in range(len(tensor_list)): + if flags[i] == 0: + tensor_list[i].record_stream(stream) + send_tensor(tensor_list[i], peer_rank, comm) + elif flags[i] == 1: + send_object(tensor_list[i], peer_rank, comm) + end_event.record(stream) + return handler(end_event) + +def recv_tensors(peer_rank, comm): + tensors,handle = _recv_tensors(peer_rank, comm) + handle.wait() + return tensors + +def irecv_tensors(peer_rank, comm): + tensors, handle = _recv_tensors(peer_rank, comm) + return tensors, handle + +def _recv_tensors(peer_rank, comm): + p2p_key = f"recv {peer_rank}" + if p2p_key not in _p2p_stream: + _p2p_stream[peer_rank] = torch.cuda.Stream() + if p2p_key not in _p2p_events: + _p2p_events[p2p_key] = torch.cuda.Event() + stream = _p2p_stream[peer_rank] + end_event = _p2p_events[p2p_key] + with torch.cuda.stream(stream): + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + tensor_list = [] + ncclRecv(length.storage(), peer_rank, comm) + flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int) + ncclRecv(flags.storage(), peer_rank, comm) + for i in range(length[0].item()): + flag = flags[i].item() + if flag == -1: + tensor_list.append(None) + elif flag == 0: + recv = recv_tensor(peer_rank, comm) + tensor_list.append(recv) + elif flag == 1: + recv = recv_object(peer_rank, comm) + tensor_list.append(recv) + end_event.record(stream) + record_stream_helper([tensor_list[i] for i in range(length[0]).item() if flags[i].item() == 0], torch.cuda.current_stream()) + return tensor_list, handler(end_event) + +def send_tensor(hidden_state, peer_rank, comm): + hidden_state = hidden_state.contiguous() + send_meta(hidden_state, peer_rank, comm) + ncclSend(hidden_state.storage(), peer_rank, comm) + +def send_tensor_inplace(hidden_state, peer_rank, comm): + hidden_state = hidden_state.contiguous() + ncclSend(hidden_state.storage(), peer_rank, comm) + +def recv_tensor_inplace(hidden_state, peer_rank, comm): + hidden_state = hidden_state.contiguous() + ncclRecv(hidden_state.storage(), peer_rank, comm) + return hidden_state + +def recv_tensor(peer_rank, comm): + dtype, shape = recv_meta(peer_rank, comm) + hidden_state = torch.empty(shape, dtype=dtype, device="cuda") + ncclRecv(hidden_state.storage(), peer_rank, comm) + return hidden_state + +def send_meta(x, peer_rank, comm): + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + meta_data[0] = len(x.size()) + meta_data[1] = DTYPE_LIST.index(x.dtype) + meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) + meta_data = meta_data.contiguous() + ncclSend(meta_data.storage(), peer_rank, comm) + +def recv_meta(peer_rank, comm): + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + ncclRecv(meta_data.storage(), peer_rank, comm) + n_dims = meta_data[0].item() + dtype = DTYPE_LIST[meta_data[1].item()] + shape = meta_data[2:n_dims+2].tolist() + + return dtype,shape \ No newline at end of file diff --git a/bmtrain/init.py b/bmtrain/init.py index 0300e849..a65e6c27 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -13,7 +13,7 @@ def init_distributed( init_method : str = "env://", seed : int = 0, - pipe_size: int = -1, + pipe_size: int = 1, num_micro_batches: int = None, tp_size : int = 1, debug=False, @@ -67,7 +67,7 @@ def init_distributed( torch.cuda.set_device(local_rank) config["initialized"] = True config["pipe_size"] = pipe_size if pipe_size > 0 else 1 - config["pipe_enabled"] = pipe_size > 0 + config["pipe_enabled"] = pipe_size > 1 config["local_rank"] = local_rank config["local_size"] = local_size config["rank"] = rank @@ -119,13 +119,13 @@ def init_distributed( config['comm'] = nccl.commInitRank(unique_id, world_size, rank) topo = config['topology'] + config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] + if topo.pipe_rank == 0: + unique_id = nccl.getUniqueId() + store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) + config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.pipe_rank) if config['pipe_enabled']: - config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - if topo.pipe_rank == 0: - unique_id = nccl.getUniqueId() - store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) - config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.pipe_rank) if topo.pipe_rank == topo.pipe_size - 1 or topo.pipe_rank == 0: if topo.pipe_rank == 0: unique_tied_id = nccl.getUniqueId() @@ -134,32 +134,30 @@ def init_distributed( rank = 0 if topo.pipe_rank == 0 else 1 config['pipe_tied_comm'] = nccl.commInitRank(unique_tied_id, 2, rank) - if topo.pp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) - config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) + if topo.pp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) + config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) - if config['tp_size'] > 1: - if topo.tp_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) - config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) - if config['pipe_size'] > 1 and config['tp_size'] > 1: - if topo.pp_tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) - config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + if topo.pp_tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) + config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) config ['zero_comm'] = config['comm'] @@ -267,4 +265,4 @@ def get_logger(rank, level, print_to_screen=False): fh.setLevel(level) fh.setFormatter(formatter) logger.addHandler(fh) - return logger \ No newline at end of file + return logger diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index dcfa6c50..2659d8f7 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -1,5 +1,6 @@ import torch -from bmtrain.distributed.ops import send_activations_list, recv_activations_list, send_activations, recv_activations, groupcall,all_reduce +from bmtrain.distributed.ops import groupcall,all_reduce +from bmtrain.distributed.p2p_ops import send_tensors, recv_tensors from bmtrain.global_var import config from collections.abc import Iterable from bmtrain.synchronize import synchronize @@ -13,8 +14,6 @@ def __init__(self, topo, model, data_iter, num_micros, num_warmup, forward_only, self.forward_only = forward_only self.interleaving_size = interleaving_size self.model = model - self.send_handle = {"next":[], "prev":[]} - self.recv_handle = {"next":[], "prev":[]} def is_first_stage(self): if self.interleaving_size == 1: @@ -50,14 +49,12 @@ def get_data(self): return list(micro_batch) def send_next(self, tensors): - handle = [] if not self.is_last_stage(): if not isinstance(tensors, Iterable): tensors = [tensors] elif not isinstance(tensors, list): tensors = list(tensors) - handle.append(send_activations_list(tensors, self.topo.pipe_rank + 1, self.comm, async_op=True)) - self.send_handle["next"] = handle + send_tensors(tensors, self.topo.pipe_rank + 1, self.comm) def send_prev(self, tensors): if not self.is_first_stage(): @@ -65,16 +62,14 @@ def send_prev(self, tensors): tensors = [tensors] elif not isinstance(tensors, list): tensors = list(tensors) - self.send_handle["prev"].append(send_activations_list(tensors, self.topo.pipe_rank - 1, self.comm, async_op=True)) + send_tensors(tensors, self.topo.pipe_rank - 1, self.comm) def wait(self): torch.cuda.current_stream().wait_stream(config["pp_comm_stream"]) def recv_prev(self, need_data=False): if not self.is_first_stage(): - res, handle = recv_activations_list(self.topo.pipe_rank - 1, self.comm) - self.recv_handle["prev"].append(handle) - synchronize(config["pp_zero_comm"]) + res = recv_tensors(self.topo.pipe_rank - 1, self.comm) for idx,tensor in enumerate(res): if idx == 0: tensor.requires_grad_() @@ -88,8 +83,7 @@ def recv_prev(self, need_data=False): def recv_next(self): if not self.is_last_stage(): - res, handle = recv_activations_list(self.topo.pipe_rank + 1, self.comm) - self.recv_handle["next"].append(handle) + res = recv_tensors(self.topo.pipe_rank + 1, self.comm) return res else: return [None] From c8e184f34f90885a723eef581d0b18e610504b04 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 18 Mar 2024 15:30:35 +0800 Subject: [PATCH 30/43] refactor p2p ops --- bmtrain/block_layer.py | 12 ++++++------ bmtrain/distributed/__init__.py | 2 +- bmtrain/distributed/p2p_ops.py | 2 +- bmtrain/pipe_layer.py | 14 +++++++------- example/train.py | 1 - tests/test_send_recv.py | 4 ++-- 6 files changed, 17 insertions(+), 18 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d24f990c..a98fbd39 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -9,7 +9,7 @@ from . import hook_func import inspect from torch.utils.checkpoint import checkpoint -from .distributed.ops import send_activations_inplace, recv_activations_inplace +from .distributed.ops import send_tensor_inplace, recv_tensor_inplace def storage_type_cuda(storage_type): STORAGE_MAP = { @@ -739,14 +739,14 @@ def reduce_tied_module(self): if config['topology'].pipe_rank == 0 and param.grad is not None: with torch.no_grad(): grad = torch.empty_like(param) - param.grad += recv_activations_inplace(grad, 1, config["pipe_tied_comm"]) - send_activations_inplace(param.grad, 1, config["pipe_tied_comm"]) + param.grad += recv_tensor_inplace(grad, 1, config["pipe_tied_comm"]) + send_tensor_inplace(param.grad, 1, config["pipe_tied_comm"]) elif config['topology'].pipe_rank == 0 and param.grad is None: grad = torch.empty_like(param) - param.grad = recv_activations_inplace(grad, 1, config["pipe_tied_comm"]) + param.grad = recv_tensor_inplace(grad, 1, config["pipe_tied_comm"]) elif config['topology'].is_last_rank() and param.grad is not None: - send_activations_inplace(param.grad, 0, config["pipe_tied_comm"]) - param.grad = recv_activations_inplace(param.grad, 0, config["pipe_tied_comm"]) + send_tensor_inplace(param.grad, 0, config["pipe_tied_comm"]) + param.grad = recv_tensor_inplace(param.grad, 0, config["pipe_tied_comm"]) def _add_tail(self, module): self.last_module[0]._is_last_layer = False diff --git a/bmtrain/distributed/__init__.py b/bmtrain/distributed/__init__.py index 51dae1ce..8049b351 100644 --- a/bmtrain/distributed/__init__.py +++ b/bmtrain/distributed/__init__.py @@ -1 +1 @@ -from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, groupcall, send_object, recv_object, reduce_scatter +from .ops import all_gather, all_reduce, broadcast, recv_tensor, send_tensor, groupcall, send_object, recv_object, reduce_scatter diff --git a/bmtrain/distributed/p2p_ops.py b/bmtrain/distributed/p2p_ops.py index 428940fd..0e46c0ac 100644 --- a/bmtrain/distributed/p2p_ops.py +++ b/bmtrain/distributed/p2p_ops.py @@ -45,7 +45,7 @@ def record_stream_helper(tensor_list, stream): for t in tensor_list: t.record_stream(stream) -def send_tensor(tensor_list, peer_rank, comm): +def send_tensors(tensor_list, peer_rank, comm): handler = _send_tensors(tensor_list, peer_rank, comm) handler.wait() diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 00697134..5954f9b4 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -5,7 +5,7 @@ from typing import Dict, Iterable, Iterator, Tuple, Union, List import torch -from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from .distributed import all_gather, broadcast, all_reduce, send_tensor, recv_tensor from .global_var import config from . import nccl from .zero_context import ( @@ -130,7 +130,7 @@ def forward(ctx, input, pipe_rank): ctx.is_first_stage = pipe_rank == 0 ctx.is_last_stage = pipe_rank == config['pipe_size'] - 1 if not ctx.is_first_stage: - input = recv_activations(pipe_rank - 1, config['pipe_comm']) + input = recv_tensor(pipe_rank - 1, config['pipe_comm']) input.requires_grad_() return input return input @@ -143,7 +143,7 @@ def backward(ctx, grad_outputs): with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) send_data.record_stream(config['pp_comm_stream']) - send_activations(send_data, ctx.pipe_rank - 1, config['pipe_comm']) + send_tensor(send_data, ctx.pipe_rank - 1, config['pipe_comm']) return grad_outputs, None class StagePostFunction(torch.autograd.Function): @@ -158,13 +158,13 @@ def forward(ctx, outputs, pipe_rank): with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) send_data.record_stream(config['pp_comm_stream']) - send_activations(send_data.detach(), pipe_rank + 1, config['pipe_comm']) + send_tensor(send_data.detach(), pipe_rank + 1, config['pipe_comm']) return outputs @staticmethod def backward(ctx, grad_outputs): if not ctx.is_last_stage: - pre_grad_inputs = recv_activations(ctx.pipe_rank + 1, config['pipe_comm']) + pre_grad_inputs = recv_tensor(ctx.pipe_rank + 1, config['pipe_comm']) return pre_grad_inputs, None return grad_outputs, None @@ -307,8 +307,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): else: assert list(dst.keys()) == [name+n for n, parameter in module._module.named_parameters()] for key, tensor in dst.items(): - send_activations(tensor.cuda(), 0, config['pipe_comm']) + send_tensor(tensor.cuda(), 0, config['pipe_comm']) if config['rank'] == 0 and idx not in self.layer_ids: for n, parameter in module._module.named_parameters(): - destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu() + destination[name+n] = recv_tensor(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu() diff --git a/example/train.py b/example/train.py index f52aa7f0..65462593 100644 --- a/example/train.py +++ b/example/train.py @@ -24,7 +24,6 @@ def main(): dtype=torch.half ) inspect_iter = -1 - bmt.load(model, "./ckpt-0.pt") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() diff --git a/tests/test_send_recv.py b/tests/test_send_recv.py index 2ec406b4..009be56e 100644 --- a/tests/test_send_recv.py +++ b/tests/test_send_recv.py @@ -9,10 +9,10 @@ def test_send_recv(): a = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) a = a.cuda() print(f"send {a}") - bmt.distributed.send_activations(a, 1, config["pipe_comm"]) + bmt.distributed.send_tensor(a, 1, config["pipe_comm"]) else: ref = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) - a = bmt.distributed.recv_activations(0, config["pipe_comm"]) + a = bmt.distributed.recv_tensor(0, config["pipe_comm"]) print(f"recv {a}") assert_all_eq(a, ref.cuda()) From 851711c6b408e88674e79a10b34adf85c129161e Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 18 Mar 2024 18:09:58 +0800 Subject: [PATCH 31/43] formatting pipeline code --- bmtrain/__init__.py | 2 +- bmtrain/benchmark/all_reduce.py | 27 +++++++++++ bmtrain/distributed/ops.py | 2 +- bmtrain/distributed/p2p_ops.py | 26 +++++----- bmtrain/init.py | 27 +---------- bmtrain/lr_scheduler/warmup.py | 8 ++-- bmtrain/optim/optim_manager.py | 2 +- bmtrain/pipe/comm.py | 2 +- bmtrain/pipe/schedule.py | 26 +++++++++- example/convert.py | 85 --------------------------------- example/init_test.py | 2 + example/inspect_tools.py | 54 --------------------- example/layers/embedding.py | 2 +- example/models/gpt.py | 5 +- example/models/pipe_gpt.py | 8 ++-- example/pipe_train.py | 12 +---- example/train.py | 76 +++++++++-------------------- 17 files changed, 109 insertions(+), 257 deletions(-) create mode 100644 bmtrain/benchmark/all_reduce.py delete mode 100644 example/convert.py create mode 100644 example/init_test.py delete mode 100644 example/inspect_tools.py diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 459bfcc6..a4a895bd 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -10,7 +10,7 @@ from .layer import DistributedModule from .param_init import init_parameters, grouped_parameters from .synchronize import synchronize, sum_loss, wait_loader, gather_result -from .block_layer import Block, TransformerBlockList,PipeDreamBlockList +from .block_layer import Block, TransformerBlockList, PipeDreamBlockList from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug diff --git a/bmtrain/benchmark/all_reduce.py b/bmtrain/benchmark/all_reduce.py new file mode 100644 index 00000000..5a32db24 --- /dev/null +++ b/bmtrain/benchmark/all_reduce.py @@ -0,0 +1,27 @@ +from .. import nccl +from .shape import SHAPES +from ..global_var import config +from ..utils import round_up, print_rank +from .utils import format_size +import torch + +def all_reduce(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + global_size = round_up(shape, config['world_size'] * 2) + + partition_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" ) + global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" ) + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.allReduce(partition_tensor.storage(), global_tensor.storage(),"sum", config['comm']) + current_stream.record_event(end_evt) + current_stream.synchronize() + time_usage = start_evt.elapsed_time(end_evt) + + bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage * 2 + print_rank("All reduce:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw)) + diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 938ccc06..3690cc0e 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -5,7 +5,7 @@ from ..nccl import allReduce as ncclAllReduce from ..nccl import broadcast as ncclBroadcast from ..nccl import reduceScatter as ncclReduceScatter -from ..nccl import commCount,commRank,NCCLCommunicator,groupStart,groupEnd +from ..nccl import commCount, commRank, NCCLCommunicator, groupStart, groupEnd from .p2p_ops import * diff --git a/bmtrain/distributed/p2p_ops.py b/bmtrain/distributed/p2p_ops.py index 0e46c0ac..35f9c6ff 100644 --- a/bmtrain/distributed/p2p_ops.py +++ b/bmtrain/distributed/p2p_ops.py @@ -58,11 +58,12 @@ def _send_tensors(tensor_list, peer_rank, comm): _p2p_stream[p2p_key] = torch.cuda.Stream() if p2p_key not in _p2p_events: _p2p_events[p2p_key] = torch.cuda.Event() - stream = _p2p_stream[peer_rank] - end_event = _p2p_events[p2p_key] + stream = _p2p_stream[p2p_key] + event = _p2p_events[p2p_key] + event.record(torch.cuda.current_stream()) + stream.wait_event(event) with torch.cuda.stream(stream): length = torch.tensor(data=[len([h for h in tensor_list ])], device="cuda", dtype=torch.int) - ncclSend(length.storage(), peer_rank, comm) flags = torch.tensor(data=[0 for _ in range(len(tensor_list))], device="cuda",dtype=torch.int) for i in range(len(tensor_list)): if tensor_list[i] is None: @@ -72,6 +73,7 @@ def _send_tensors(tensor_list, peer_rank, comm): else: flag = 1 flags[i] = flag + ncclSend(length.storage(), peer_rank, comm) ncclSend(flags.contiguous().storage(), peer_rank, comm) for i in range(len(tensor_list)): if flags[i] == 0: @@ -79,11 +81,11 @@ def _send_tensors(tensor_list, peer_rank, comm): send_tensor(tensor_list[i], peer_rank, comm) elif flags[i] == 1: send_object(tensor_list[i], peer_rank, comm) - end_event.record(stream) - return handler(end_event) + event.record(stream) + return handler(event) def recv_tensors(peer_rank, comm): - tensors,handle = _recv_tensors(peer_rank, comm) + tensors, handle = _recv_tensors(peer_rank, comm) handle.wait() return tensors @@ -94,11 +96,11 @@ def irecv_tensors(peer_rank, comm): def _recv_tensors(peer_rank, comm): p2p_key = f"recv {peer_rank}" if p2p_key not in _p2p_stream: - _p2p_stream[peer_rank] = torch.cuda.Stream() + _p2p_stream[p2p_key] = torch.cuda.Stream() if p2p_key not in _p2p_events: _p2p_events[p2p_key] = torch.cuda.Event() - stream = _p2p_stream[peer_rank] - end_event = _p2p_events[p2p_key] + stream = _p2p_stream[p2p_key] + event = _p2p_events[p2p_key] with torch.cuda.stream(stream): length = torch.tensor(data=[0], device="cuda", dtype=torch.int) tensor_list = [] @@ -115,9 +117,9 @@ def _recv_tensors(peer_rank, comm): elif flag == 1: recv = recv_object(peer_rank, comm) tensor_list.append(recv) - end_event.record(stream) - record_stream_helper([tensor_list[i] for i in range(length[0]).item() if flags[i].item() == 0], torch.cuda.current_stream()) - return tensor_list, handler(end_event) + event.record(stream) + record_stream_helper([tensor_list[i] for i in range(length[0].item()) if flags[i].item() != -1], torch.cuda.current_stream()) + return tensor_list, handler(event) def send_tensor(hidden_state, peer_rank, comm): hidden_state = hidden_state.contiguous() diff --git a/bmtrain/init.py b/bmtrain/init.py index a65e6c27..bc990b69 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -3,7 +3,6 @@ import random import torch.distributed as dist import os -import logging from .utils import print_dict import ctypes from .global_var import config @@ -16,7 +15,6 @@ def init_distributed( pipe_size: int = 1, num_micro_batches: int = None, tp_size : int = 1, - debug=False, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -80,14 +78,10 @@ def init_distributed( config["load_event"] = torch.cuda.Event() config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) + config["pipe_rank"] = config['topology'].get_group_rank("pipe") config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") - config["pipe_rank"] = config['topology'].get_group_rank("pipe") - if debug: - config["logger"] = get_logger(rank, "DEBUG") - else: - config["logger"] = get_logger(rank, "WARNING") config["save_param_to_cpu"] = True cpus_this_worker = None @@ -208,10 +202,6 @@ def __init__(self,config): self.zero_idx = 0 self.zero_id = self.rank - def get_comm(self, group_name): - if group_name == "pipe": - return config["pipe_comm"] - def get_group_id(self,group_name): if group_name == "pipe": return self.pipe_idx @@ -251,18 +241,3 @@ def is_last_rank(self, group_name="pipe"): def is_initialized() -> bool: return config["initialized"] -def get_logger(rank, level, print_to_screen=False): - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - logger = logging.getLogger('pipeline') - logger.setLevel(level) - if print_to_screen: - if rank == 0: - ch = logging.StreamHandler() - ch.setLevel(level) - ch.setFormatter(formatter) - logger.addHandler(ch) - fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") - fh.setLevel(level) - fh.setFormatter(formatter) - logger.addHandler(fh) - return logger diff --git a/bmtrain/lr_scheduler/warmup.py b/bmtrain/lr_scheduler/warmup.py index 2eb23466..c690a154 100644 --- a/bmtrain/lr_scheduler/warmup.py +++ b/bmtrain/lr_scheduler/warmup.py @@ -16,7 +16,7 @@ def __init__(self, optimizer : torch.optim.Optimizer, start_lr, warmup_iter, end self.warmup_iter = warmup_iter self.end_iter = end_iter self.optimizer = optimizer - self.num_iter = 0 + self.num_iter = num_iter self._current_lr = None self.step(self.num_iter) @@ -37,11 +37,9 @@ def get_lr(self): def current_lr(self): return self._current_lr - def step(self, num_step = None) -> None: - if num_step is None: + def step(self, num_iter = None) -> None: + if num_iter is None: num_iter = self.num_iter + 1 - else: - num_iter = self.num_iter + num_step self.num_iter = num_iter lr = self.get_lr() diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 145a0f49..2d42b870 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -85,7 +85,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * (self.loss_scale / (config['world_size'] // (config['tp_size']*config['pipe_size']))) # loss scale + return loss * (self.loss_scale / (config['world_size'] // ( config['tp_size']*config['pipe_size'] ))) # loss scale def backward(self, loss : torch.Tensor): """ diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 2659d8f7..25dfe6d8 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -7,7 +7,7 @@ class PipeCommander: def __init__(self, topo, model, data_iter, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo - self.comm = self.topo.get_comm("pipe") + self.comm = config['pipe_comm'] self.input_generator = self.generator(data_iter) self.num_micros = num_micros self.num_warmup = num_warmup diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 0fe41326..84b156f1 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -3,6 +3,7 @@ import bmtrain as bmt from .comm import PipeCommander import torch +import logging from typing import Iterable @@ -62,7 +63,23 @@ def forward_func(model, inp, micro_idx, is_last_micro=False): hidden_state = [hidden_state] return hidden_state -def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros, optim_manager, clip_grad=1.0): +def get_logger(rank, level, print_to_screen=False): + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger = logging.getLogger('pipeline') + logger.setLevel(level) + if print_to_screen: + if rank == 0: + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") + fh.setLevel(level) + fh.setFormatter(formatter) + logger.addHandler(fh) + return logger + +def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros, optim_manager, clip_grad=1.0, debug_log=False): """Forward and backward the pipeline model. Args: @@ -76,6 +93,13 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros # forwrad unpack loss = None + if 'logger' not in config: + if debug_log: + config['logger'] = get_logger(bmt.config['pipe_rank'], level="INFO", print_to_screen=True) + else: + config['logger'] = logging.getLogger("dummy") + config['logger'].addHandler(logging.NullHandler()) + optim_manager.zero_grad() micro_batch_size = micro_batch_size num_micro_batches = num_micros diff --git a/example/convert.py b/example/convert.py deleted file mode 100644 index 2a20c281..00000000 --- a/example/convert.py +++ /dev/null @@ -1,85 +0,0 @@ -import bmtrain as bmt -import torch -from models import GPT, GPTPipe -import re -from collections import OrderedDict - -def partition(pipe_rank, pipe_size, len_modules): - part_lens = [0]+[(len_modules // pipe_size + (i < (len_modules % pipe_size))) for i in range(pipe_rank+1)] - start = sum(part_lens[:pipe_rank+1]) - end = start + part_lens[pipe_rank+1] - return start,end - -def key_process(key, pipe_size , rank, start, end): - res = re.search("\.(\d+)\.", key) - if res is not None: - layer_idx = int(res.group(1)) - else: - layer_idx = None - if layer_idx is None or (layer_idx >= start and layer_idx < end): - if rank == 0: - if key in ["word_emb.weight","pos_emb.weight"]: - return key - else: - if layer_idx is not None: - return re.sub(r"\d+", str(layer_idx), key) - elif rank == pipe_size - 1: - if key in ["word_emb.weight"] or key.startswith("layernorm"): - return key - else: - if layer_idx is not None: - return re.sub(r"\d+", str(layer_idx - start), key) - else: - if layer_idx is not None: - return re.sub(r"\d+", str(layer_idx - start), key) - else: - return None - - - -def init_model(): - model = GPT( - num_layers=8, - vocab_size=10240, - dim_model=2560, - dim_head=80, - num_heads=32, - dim_ff=8192, - max_distance=1024, - bias=True, - dtype=torch.half - ) - return model - -def get_len_modules(state): - max_len = 0 - for key in state: - s = re.search("\.(\d+)\.", key) - if s is not None: - res = int(s.group(1)) - if res>max_len: - max_len = res - return max_len+1 - - -if __name__ == "__main__": - bmt.init_distributed() - model = init_model() - bmt.load(model, "ckpt-0.pt") - pipe_size = 4 - state = model.state_dict() - - for rank in range(pipe_size): - print(rank) - dic = OrderedDict() - len_modules = get_len_modules(state) - s,e = partition(rank, pipe_size, len_modules) - print(s," ",e) - for i in state.keys(): - k = key_process(i, pipe_size, rank, s, e) - if k is not None: - dic[k] = state[i] - print(dic.keys()) - torch.save(dic, f"pipe_{rank}.ckpt") - - \ No newline at end of file diff --git a/example/init_test.py b/example/init_test.py new file mode 100644 index 00000000..67d62581 --- /dev/null +++ b/example/init_test.py @@ -0,0 +1,2 @@ +import bmtrain +bmtrain.init_distributed(sp_size=4,tp_size=2) diff --git a/example/inspect_tools.py b/example/inspect_tools.py deleted file mode 100644 index c170c74c..00000000 --- a/example/inspect_tools.py +++ /dev/null @@ -1,54 +0,0 @@ -from contextlib import contextmanager -from bmtrain import CheckpointBlock -import sys -log_file = set() -@contextmanager -def custom_redirection(fileobj): - if isinstance(fileobj, str): - if fileobj not in log_file: - ftmp = open(fileobj,"w") - ftmp.close() - log_file.add(fileobj) - file_handle = open(fileobj,"a") - else: - file_handle = fileobj - old = sys.stdout - sys.stdout = file_handle - try: - yield file_handle - finally: - sys.stdout = old - file_handle.close() - -def look_var(layer, _, output): - try: - print(f"{layer.__name__}: {output.min()}") - except: - print(f"{layer.__name__}: {output[0].min()}") - - -def look_inp_weight(look_inp,look_weight): - def look_inp_func(layer, inp): - if look_inp: - try: - print(f"{layer.__name__}: {inp.min()}") - except: - print(f"{layer.__name__}: {inp[0].min()}") - if look_weight: - print(f"{layer.__name__} weight: {layer._parameters}") - return look_inp_func - -def lookup_output(model,layers=set(), look_input=False, look_weight=False): - for key,layer in model.named_modules(): - layer.__name__ = key - if layer not in layers: - layers.add(layer) - else: - continue - if len(layer._modules) !=0: - layer.register_forward_hook(look_var) - lookup_output(layer,layers,look_input=look_input,look_weight=look_weight) - layer.register_forward_pre_hook(look_inp_weight(look_input,look_weight)) - else: - layer.register_forward_hook(look_var) - layer.register_forward_pre_hook(look_inp_weight(look_input,look_weight)) \ No newline at end of file diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 13c47384..9cbc6715 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -79,7 +79,7 @@ def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tenso if not projection: return F.embedding( input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + self.norm_type, self.scale_grad_by_freq, self.sparse) else: return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) diff --git a/example/models/gpt.py b/example/models/gpt.py index 791ac9d1..feb2bd59 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -20,13 +20,14 @@ def __init__(self, else: self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) + if config['pipe_size'] > 1: self.transformers = bmt.PipelineTransformerBlockList([ bmt.Block( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype ) - , mode="PIPE",use_checkpoint=False + , mode="PIPE" ) for _ in range(num_layers) ]) @@ -35,7 +36,7 @@ def __init__(self, bmt.Block( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ),use_checkpoint=False + ) ) for _ in range(num_layers) ]) diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index 8bb1ffb8..bf11ac20 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -14,10 +14,10 @@ def __init__(self, self.max_distance = max_distance - # if config['tp_size'] > 1: - # word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) - # else: - word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + if config['tp_size'] > 1: + word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype) + else: + word_emb = Embedding(vocab_size, dim_model, dtype=dtype) pos_emb = Embedding(max_distance, dim_model, dtype=dtype) blocklist = [] blocklist += [ diff --git a/example/pipe_train.py b/example/pipe_train.py index d781ef65..1e60f50b 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -6,14 +6,12 @@ from bmtrain.global_var import config from bmtrain import inspect from bmtrain.pipe import pipeline_forward_backward -from inspect_tools import custom_redirection, lookup_output def main(): bmt.init_distributed( seed=0, pipe_size=4, tp_size=1, - debug=True ) model = GPTPipe( @@ -27,7 +25,7 @@ def main(): bias=True, dtype=torch.float16 ) - inspect_iter = -1 + bmt.init_parameters(model) bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() @@ -71,17 +69,11 @@ def data_loader(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - # lookup_output(model) for iteration in range(10): # load data st = time.time() rank = bmt.config["topology"].pipe_rank - if iteration == inspect_iter: - lookup_output(model) - with custom_redirection(f"outputs/pp_output_{pipe_rank}"): - global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) - else: - global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) # record time and loss iteration_time = time.time() - st diff --git a/example/train.py b/example/train.py index 65462593..6e9694a6 100644 --- a/example/train.py +++ b/example/train.py @@ -5,7 +5,7 @@ from bmtrain import optim from bmtrain.global_var import config from bmtrain import inspect -from inspect_tools import lookup_output, custom_redirection + def main(): bmt.init_distributed( seed=0, @@ -23,7 +23,7 @@ def main(): bias=True, dtype=torch.half ) - inspect_iter = -1 + bmt.init_parameters(model) bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() @@ -32,22 +32,9 @@ def main(): torch.manual_seed(1234) batch_size = 2 seq_len = 512 - global_batch = 2 * 16 - - # for i in range(bmt.world_size()): - # sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) - # enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() - # enc_input = sent[:, :-1].long().cuda() - # targets = sent[:, 1:].long().cuda() - # mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] - # targets = torch.where( - # mask, - # targets, - # torch.full_like(targets, -100, dtype=torch.long) - # ) + batch = 2 + grad_accum = 1 - # if i == bmt.rank(): - # break if config['tp_size'] > 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: @@ -66,42 +53,26 @@ def main(): for iteration in range(10): # load data st = time.time() - if iteration == inspect_iter: - lookup_output(model) sum_loss = 0 - for micro in range(global_batch // batch_size): - # for i in range(bmt.world_size()): - sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) - enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() - enc_input = sent[:, :-1].long().cuda() - targets = sent[:, 1:].long().cuda() - mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] - targets = torch.where( - mask, - targets, - torch.full_like(targets, -100, dtype=torch.long) - ) + for micro in range(grad_accum): + for i in range(bmt.world_size()): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) - # if i == bmt.rank(): - # break - - # with inspect.inspect_tensor() as inspector: pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) - # if iteration == 4: - # lookup_output(model) - if iteration == inspect_iter: - with custom_redirection("dp_ref.output"): - logits = model( - enc_input, - pos, - pos < enc_length[:, None] - ) - else: - logits = model( - enc_input, - pos, - pos < enc_length[:, None] - ) + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) batch, seq_len, vocab_out_size = logits.size() if config['tp_size'] > 1: @@ -125,13 +96,12 @@ def main(): iteration_time = time.time() - st avg_time_recorder.record(iteration_time) - num_micro = global_batch // batch_size - avg_loss_recorder.record(sum_loss/num_micro) + avg_loss_recorder.record(sum_loss / grad_accum) # print time and loss bmt.print_rank( "| Iter: {:6d} | loss: {:.10f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( iteration, - sum_loss / num_micro, + sum_loss / grad_accum, avg_loss_recorder.value, lr_scheduler.current_lr, optim_manager.loss_scale, From d6397e7ee9da49c829306c5c64212e66bd9e8521 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 May 2024 10:49:37 +0800 Subject: [PATCH 32/43] WIP: Pipeline example code refactor --- bmtrain/pipe/comm.py | 33 +++++++++------------- bmtrain/pipe/schedule.py | 61 ++++++++++++---------------------------- 2 files changed, 32 insertions(+), 62 deletions(-) diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 25dfe6d8..30f4479a 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -8,7 +8,7 @@ class PipeCommander: def __init__(self, topo, model, data_iter, num_micros, num_warmup, forward_only, interleaving_size) -> None: self.topo = topo self.comm = config['pipe_comm'] - self.input_generator = self.generator(data_iter) + self.input_generator = data_iter self.num_micros = num_micros self.num_warmup = num_warmup self.forward_only = forward_only @@ -27,16 +27,6 @@ def is_last_stage(self): else: raise ValueError("Now only supoort interleaving_size == 1") - def generator(self, data_iterator): - while True: - try: - inp = next(data_iterator) - if self.is_first_stage(): - yield self.model.preprocess_func(inp) - else: - yield inp - except StopIteration: - break def param_reduce(self, module): for name, param in module.named_parameters(): @@ -46,7 +36,7 @@ def param_reduce(self, module): def get_data(self): micro_batch = next(self.input_generator) assert isinstance(micro_batch, Iterable) - return list(micro_batch) + return micro_batch def send_next(self, tensors): if not self.is_last_stage(): @@ -74,12 +64,15 @@ def recv_prev(self, need_data=False): if idx == 0: tensor.requires_grad_() data = self.get_data() - return res + data[1:] + # return hidden state and data + return res, data else: if need_data: - return self.get_data() + # for first stage , only data + return [None], self.get_data() else: - return [None] + # empty load for first stage + return [None], [None] def recv_next(self): if not self.is_last_stage(): @@ -105,15 +98,17 @@ def send_forward_recv_backward(self, forward_state): def send_backward_recv_forward(self, backward_grad, need_data=False): if not self.is_first_stage(): - forward_state = self.recv_prev() + forward_state, data = self.recv_prev() if backward_grad[0] is not None: self.send_prev(backward_grad) else: if need_data: - forward_state = self.get_data() + data = self.get_data() + forward_state = None else: forward_state = [None] - return forward_state + data = None + return forward_state, data - \ No newline at end of file + diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index 84b156f1..e60f4ec2 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -7,7 +7,7 @@ from typing import Iterable -def backward_step(inp, output, grad_output, optim_manager=None): +def backward_func(inp, backward_step, output, grad_output, optim_manager=None): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -25,20 +25,7 @@ def backward_step(inp, output, grad_output, optim_manager=None): output = [output] if not isinstance(grad_output, Iterable): grad_output = [grad_output] - #TODO scale the grad - # if output_tensor_grad[0] is None and config.grad_scale_func is not None: - # output_tensor[0] = config.grad_scale_func(output_tensor[0]) - if optim_manager is not None and config["topology"].is_last_rank(): - if not torch.is_tensor(output[0]) and isinstance(output[0], Iterable): - output = optim_manager.scale_loss(output[0][0]) - elif torch.is_tensor(output[0]): - output = optim_manager.scale_loss(output[0]) - output = output / config['micros'] - else: - output = output[0] - torch.autograd.backward(output, grad_tensors=grad_output[0]) - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config['load_stream']) + backward_step(output, grad_output) input_grad = [None] if inp is not None: input_grad = [] @@ -50,18 +37,11 @@ def backward_step(inp, output, grad_output, optim_manager=None): return input_grad -def forward_func(model, inp, micro_idx, is_last_micro=False): - if config["topology"].pipe_rank == config["topology"].pipe_size - 1: - loss = model(*inp) - - return [loss] - else: - config['logger'].info("inp shape: {}".format(inp[0].shape)) - hidden_state = model(*inp) - config['logger'].info("inp shape: {}".format(hidden_state[0].shape)) - if torch.is_tensor(hidden_state) or (not isinstance(hidden_state, Iterable)): - hidden_state = [hidden_state] - return hidden_state +def forward_func(model, forward_step, inp, data, micro_idx, is_last_micro=False): + output = forward_step(model, inp, data) + if not isinstance(output, list) and not isinstance(output, tuple): + output = [output] + return output def get_logger(rank, level, print_to_screen=False): formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -79,12 +59,13 @@ def get_logger(rank, level, print_to_screen=False): logger.addHandler(fh) return logger -def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros, optim_manager, clip_grad=1.0, debug_log=False): +def pipeline_forward_backward(model, data_iterator, forward_step, backward_step, micro_batch_size, num_micros, debug_log=False): """Forward and backward the pipeline model. Args: models (TransformerBlocklist): The list of models. data_iterator (iterator): The iterator of the dataset. + forward_step(function): Describe how to forward the model and how to get loss micro_batch_size (int): The micro batch size. Returns: @@ -100,7 +81,6 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros config['logger'] = logging.getLogger("dummy") config['logger'].addHandler(logging.NullHandler()) - optim_manager.zero_grad() micro_batch_size = micro_batch_size num_micro_batches = num_micros global_batch_size = micro_batch_size * num_micro_batches @@ -128,9 +108,9 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros outputs = [] logger.info("num_warmup: {}".format(num_warmup)) for micro in range(num_warmup): - inp = commander.recv_prev(need_data=True) + inp, data = commander.recv_prev(need_data=True) logger.info("{} recv micro {}th from prev neighbour".format(bmt.config["topology"].pipe_rank, micro)) - output = forward_func(model, inp, micro) + output = forward_func(model, forward_step, inp, data, micro) logger.info("{} micro forward".format(micro)) # send activations commander.send_next(output) @@ -141,15 +121,13 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros remain_batch = num_micro_batches - num_warmup logger.info("remain_batch: {}".format(remain_batch)) if remain_batch > 0: - inp = commander.recv_prev(need_data=True) + inp, data = commander.recv_prev(need_data=True) logger.info("recv micro from prev neighbour") - loss_items = [] for micro in range(num_micro_batches - num_warmup): is_last_micro = micro == num_micro_batches - num_warmup - 1 - output = forward_func(model, inp, micro + num_warmup, is_last_micro) + output = forward_func(model, forward_step, inp, data, micro + num_warmup, is_last_micro) if commander.is_last_stage(): loss = output[0] - loss_items.append(loss) logger.info("{} micro forward".format(micro+num_warmup)) grad_output = commander.send_forward_recv_backward(output) @@ -161,7 +139,7 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros inp = inps.pop(0) output = outputs.pop(0) - inp_grad = backward_step(inp, output, grad_output, optim_manager) + inp_grad = backward_func(inp, backward_step, output, grad_output) logger.info("{} micro backward".format(micro+num_warmup)) if micro == remain_batch - 1: inp = None @@ -169,7 +147,7 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros logger.info("{} send micro grad {}th to prev neighbour".format(bmt.config["topology"].pipe_rank, micro + num_warmup)) else: logger.info("send backward and recv forward") - inp = commander.send_backward_recv_forward(inp_grad, need_data=True) + inp, data = commander.send_backward_recv_forward(inp_grad, need_data=True) if not forward_only: logger.info("cooling stage") for i in range(num_warmup): @@ -178,18 +156,15 @@ def pipeline_forward_backward(model, data_iterator, micro_batch_size, num_micros output = outputs.pop(0) grad_output = commander.recv_next() logger.info("{} micro backward".format(num_micro_batches - num_warmup + i)) - input_grad = backward_step( - inp, output , grad_output, + input_grad = backward_func( + inp, backward_step, output , grad_output, ) logger.info("{} send micro grad {}th to prev neighbour".format(bmt.config["topology"].pipe_rank, i)) commander.send_prev(input_grad) blocklist = model.get_blocklist() # blocklist.reduce_tied_module() - grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, clip_grad, norm_type=2) - optim_manager.step() bmt.synchronize() - return loss_items, grad_norm - \ No newline at end of file + From 2de6becc7e749e2b90861593e4a62f31753f94a9 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 May 2024 10:58:02 +0800 Subject: [PATCH 33/43] WIP: Pipeline example code refactor --- example/pipe_train.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/example/pipe_train.py b/example/pipe_train.py index 1e60f50b..f20a90a7 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -68,12 +68,40 @@ def data_loader(): bmt.synchronize() avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() + global_loss = None + def forward_step(model, input, data): + enc_input, pos, mask, targets = data + input = model.preprocess_func((enc_input, pos)) if bmt.config["topology"].is_first_rank() else input + logits = model(input, pos, mask) + if bmt.config["topology"].is_last_rank(): + logits = logits.view(-1, logits.shape[-1]) + targets = targets.view(-1) + loss = loss_func(logits, targets) + global global_loss + global_loss = bmt.distributed.all_reduce(loss, comm=bmt.config["pp_tp_zero_comm"]).item() + return loss, logits + else: + return logits + + def backward_step(output, grad_output): + if not torch.is_tensor(output[0]) and isinstance(output[0], Iterable): + output = optim_manager.scale_loss(output[0][0]) + elif torch.is_tensor(output[0]): + output = optim_manager.scale_loss(output[0]) + output = output / bmt.config['micros'] + torch.autograd.backward(output, grad_tensors=grad_output[0]) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(bmt.config['load_stream']) + + + for iteration in range(10): # load data st = time.time() rank = bmt.config["topology"].pipe_rank - global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + # global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros) # record time and loss iteration_time = time.time() - st From 8b6f8db2abf020e4299dd9f715c967f2917af70a Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 May 2024 12:59:36 +0800 Subject: [PATCH 34/43] WIP: Pipeline example code refactor --- bmtrain/pipe/schedule.py | 2 +- example/models/pipe_gpt.py | 13 +++---------- example/pipe_train.py | 21 ++++++++++----------- example/run.sh | 4 +--- 4 files changed, 15 insertions(+), 25 deletions(-) diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index e60f4ec2..dd5fee45 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -25,7 +25,7 @@ def backward_func(inp, backward_step, output, grad_output, optim_manager=None): output = [output] if not isinstance(grad_output, Iterable): grad_output = [grad_output] - backward_step(output, grad_output) + backward_step(output[0], grad_output[0]) input_grad = [None] if inp is not None: input_grad = [] diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index bf11ac20..e57eae46 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -44,7 +44,6 @@ def forward(self, input : torch.LongTensor, # (batch, seq_len) pos : torch.LongTensor, # (batch, seq_len) mask : torch.BoolTensor, # (batch, seq_len) - target: torch.LongTensor, ) -> torch.Tensor: mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) @@ -53,20 +52,14 @@ def forward(self, # for layer in self.transformers: out = self.transformers(input, mask_2d, None) out = self.layernorm(out) - if config['topology'].pipe_rank == config['topology'].pipe_size - 1: - logits = self.word_emb(out, True) - logits = logits.float().view(-1, logits.shape[-1]) - target = target.view(-1) - return self.loss_func(logits, target) - else: - return out + return out def preprocess_func(self, inp): if config['topology'].pipe_rank == 0: inp_id = inp[0] pos = inp[1] - return self.pos_emb(pos) + self.word_emb(inp_id) , *inp[1:] + return self.pos_emb(pos) + self.word_emb(inp_id) else: return None - \ No newline at end of file + diff --git a/example/pipe_train.py b/example/pipe_train.py index f20a90a7..73d4efbe 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -56,7 +56,7 @@ def data_loader(): if config['tp_size'] > 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -64,33 +64,30 @@ def data_loader(): optim_manager = optim.OptimManager(loss_scale=2**20) optim_manager.add_optimizer(optimizer, lr_scheduler) pipe_rank = bmt.config["topology"].pipe_rank - model.load_state_dict(torch.load(f"pipe_{pipe_rank}.ckpt")) bmt.synchronize() avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - global_loss = None + global_loss_items = [] def forward_step(model, input, data): enc_input, pos, mask, targets = data - input = model.preprocess_func((enc_input, pos)) if bmt.config["topology"].is_first_rank() else input + input = model.preprocess_func((enc_input, pos)) if bmt.config["topology"].is_first_rank() else input[0] logits = model(input, pos, mask) if bmt.config["topology"].is_last_rank(): logits = logits.view(-1, logits.shape[-1]) targets = targets.view(-1) loss = loss_func(logits, targets) - global global_loss + nonlocal global_loss_items global_loss = bmt.distributed.all_reduce(loss, comm=bmt.config["pp_tp_zero_comm"]).item() + global_loss_items.append(global_loss) return loss, logits else: return logits def backward_step(output, grad_output): - if not torch.is_tensor(output[0]) and isinstance(output[0], Iterable): - output = optim_manager.scale_loss(output[0][0]) - elif torch.is_tensor(output[0]): - output = optim_manager.scale_loss(output[0]) + output = optim_manager.scale_loss(output) output = output / bmt.config['micros'] - torch.autograd.backward(output, grad_tensors=grad_output[0]) + torch.autograd.backward(output, grad_tensors=grad_output) current_stream = torch.cuda.current_stream() current_stream.wait_stream(bmt.config['load_stream']) @@ -102,11 +99,13 @@ def backward_step(output, grad_output): rank = bmt.config["topology"].pipe_rank # global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros) + optim_manager.step() + optim_manager.zero_grad() # record time and loss iteration_time = time.time() - st if bmt.config["topology"].is_last_rank(): - global_loss = sum(list(global_loss))/len(global_loss) + global_loss = sum(list(global_loss_items))/len(global_loss_items) avg_time_recorder.record(iteration_time) avg_loss_recorder.record(global_loss) print( diff --git a/example/run.sh b/example/run.sh index 542e5252..8a66db20 100644 --- a/example/run.sh +++ b/example/run.sh @@ -1,3 +1 @@ -export NCCL_P2P_DISABLE=1 -export CUDA_LAUNCH_BLOCKING=1 -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py +torchrun --nnodes=1 --nproc_per_node=1 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py From 07cc443b6d0a95e305212d8e9e67dcc61f28e1af Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 May 2024 13:34:42 +0800 Subject: [PATCH 35/43] WIP: Pipeline example code refactor --- bmtrain/pipe/comm.py | 2 +- bmtrain/pipe/schedule.py | 2 +- example/models/pipe_gpt.py | 11 +++++------ example/pipe_train.py | 4 ++-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 30f4479a..32099a0e 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -103,8 +103,8 @@ def send_backward_recv_forward(self, backward_grad, need_data=False): self.send_prev(backward_grad) else: if need_data: + forward_state = [None] data = self.get_data() - forward_state = None else: forward_state = [None] data = None diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index dd5fee45..abd898d2 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -38,7 +38,7 @@ def backward_func(inp, backward_step, output, grad_output, optim_manager=None): return input_grad def forward_func(model, forward_step, inp, data, micro_idx, is_last_micro=False): - output = forward_step(model, inp, data) + output = forward_step(model, inp[0], data) if not isinstance(output, list) and not isinstance(output, tuple): output = [output] return output diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py index e57eae46..acf7aa20 100644 --- a/example/models/pipe_gpt.py +++ b/example/models/pipe_gpt.py @@ -32,10 +32,6 @@ def __init__(self, self.pos_emb = self.transformers.add_head(pos_emb) self.layernorm = self.transformers.add_tail(layernorm) self.word_emb = self.transformers.add_head_tail(word_emb) - if config['tp_size'] > 1: - self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) - else: - self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) def get_blocklist(self): return self.transformers @@ -51,14 +47,17 @@ def forward(self, # for layer in self.transformers: out = self.transformers(input, mask_2d, None) - out = self.layernorm(out) + if bmt.config['topology'].is_last_rank(): + out = self.layernorm(out) + out = self.word_emb(out, True) return out def preprocess_func(self, inp): if config['topology'].pipe_rank == 0: inp_id = inp[0] pos = inp[1] - return self.pos_emb(pos) + self.word_emb(inp_id) + out = self.pos_emb(pos) + self.word_emb(inp_id) + return out else: return None diff --git a/example/pipe_train.py b/example/pipe_train.py index 73d4efbe..331addb0 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -71,7 +71,7 @@ def data_loader(): def forward_step(model, input, data): enc_input, pos, mask, targets = data - input = model.preprocess_func((enc_input, pos)) if bmt.config["topology"].is_first_rank() else input[0] + input = model.preprocess_func((enc_input, pos)) if bmt.config["topology"].is_first_rank() else input logits = model(input, pos, mask) if bmt.config["topology"].is_last_rank(): logits = logits.view(-1, logits.shape[-1]) @@ -86,7 +86,7 @@ def forward_step(model, input, data): def backward_step(output, grad_output): output = optim_manager.scale_loss(output) - output = output / bmt.config['micros'] + output = output torch.autograd.backward(output, grad_tensors=grad_output) current_stream = torch.cuda.current_stream() current_stream.wait_stream(bmt.config['load_stream']) From 290c1e364e89edf55f7bdfd9f1d7d73698649b3e Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 May 2024 14:10:08 +0800 Subject: [PATCH 36/43] Pipeline example code refactor --- example/pipe_train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/example/pipe_train.py b/example/pipe_train.py index 331addb0..c3c757fa 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -85,8 +85,9 @@ def forward_step(model, input, data): return logits def backward_step(output, grad_output): - output = optim_manager.scale_loss(output) - output = output + if bmt.config['topology'].is_last_rank(): + output = optim_manager.scale_loss(output) + output = output / bmt.config['micros'] torch.autograd.backward(output, grad_tensors=grad_output) current_stream = torch.cuda.current_stream() current_stream.wait_stream(bmt.config['load_stream']) @@ -99,6 +100,7 @@ def backward_step(output, grad_output): rank = bmt.config["topology"].pipe_rank # global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros) + grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, 1.0, norm_type=2) optim_manager.step() optim_manager.zero_grad() # record time and loss From a7420920cea16fe4a4cbc00a03d1c55e4f77b56e Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 6 May 2024 15:33:47 +0800 Subject: [PATCH 37/43] Pipeline example code refactor --- bmtrain/pipe/comm.py | 2 +- bmtrain/pipe/schedule.py | 4 ++-- example/pipe_train.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 32099a0e..d65f0f1d 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -107,7 +107,7 @@ def send_backward_recv_forward(self, backward_grad, need_data=False): data = self.get_data() else: forward_state = [None] - data = None + data = [None] return forward_state, data diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index abd898d2..29299fb9 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -26,6 +26,8 @@ def backward_func(inp, backward_step, output, grad_output, optim_manager=None): if not isinstance(grad_output, Iterable): grad_output = [grad_output] backward_step(output[0], grad_output[0]) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(bmt.config['load_stream']) input_grad = [None] if inp is not None: input_grad = [] @@ -164,7 +166,5 @@ def pipeline_forward_backward(model, data_iterator, forward_step, backward_step, commander.send_prev(input_grad) blocklist = model.get_blocklist() # blocklist.reduce_tied_module() - - bmt.synchronize() diff --git a/example/pipe_train.py b/example/pipe_train.py index c3c757fa..64204a00 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -6,6 +6,7 @@ from bmtrain.global_var import config from bmtrain import inspect from bmtrain.pipe import pipeline_forward_backward +from typing import Iterable def main(): bmt.init_distributed( @@ -51,7 +52,7 @@ def data_loader(): torch.full_like(targets, -100, dtype=torch.long) ) pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) - yield enc_input, pos, pos 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) @@ -67,7 +68,6 @@ def data_loader(): bmt.synchronize() avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - global_loss_items = [] def forward_step(model, input, data): enc_input, pos, mask, targets = data @@ -89,20 +89,20 @@ def backward_step(output, grad_output): output = optim_manager.scale_loss(output) output = output / bmt.config['micros'] torch.autograd.backward(output, grad_tensors=grad_output) - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(bmt.config['load_stream']) for iteration in range(10): # load data + global_loss_items = [] st = time.time() rank = bmt.config["topology"].pipe_rank # global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + optim_manager.zero_grad() pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros) grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, 1.0, norm_type=2) optim_manager.step() - optim_manager.zero_grad() + bmt.synchronize() # record time and loss iteration_time = time.time() - st From fd7ac11624217cbeeaf441f9667dae6c8425175b Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 8 May 2024 18:12:12 +0800 Subject: [PATCH 38/43] support bmt.save/load save model partition instead of whole model --- bmtrain/__init__.py | 2 +- bmtrain/block_layer.py | 14 +++++++-- bmtrain/init.py | 2 ++ bmtrain/layer.py | 16 +++++++--- bmtrain/store.py | 67 ++++++++++++++++++++++++++++++++++++----- tests/test_load_ckpt.py | 32 ++++++++++++-------- 6 files changed, 104 insertions(+), 29 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index f4ac3642..3e4fdf42 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -14,7 +14,7 @@ from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug -from .store import save, load +from .store import save, load, clean from . import loss from . import distributed diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 98200465..5d7170e9 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -314,8 +314,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def state_dict(self, destination=None, prefix='', keep_vars=False): # gather here with torch.no_grad(): - with ZeroContext(self): + if config['save_param_gather']: + with ZeroContext(self): + return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + else: return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -330,8 +334,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - - verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + if config['load_param_gather']: + verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + else: + verify_shape = param.shape if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' @@ -353,6 +359,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # copy to buffer verify_size = verify_shape.numel() assert input_param.numel() == verify_size + if not config['load_param_gather']: + continue contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() diff --git a/bmtrain/init.py b/bmtrain/init.py index 69273c09..44b7c88a 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -83,6 +83,8 @@ def init_distributed( config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") config["save_param_to_cpu"] = True + config["save_param_gather"] = True + config["load_param_gather"] = True cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) diff --git a/bmtrain/layer.py b/bmtrain/layer.py index e071e01b..8dee7167 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -33,10 +33,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for name, param in self._parameters.items(): if param is not None: if isinstance(param, DistributedParameter):#and not param._in_block: - if param._in_block: - destination[prefix + name] = param.tp_gather().detach() # sync operation + if config["save_param_gather"]: + if param._in_block: + destination[prefix + name] = param.tp_gather().detach() # sync operation + else: + destination[prefix + name] = param.gather_all().detach() # sync operation else: - destination[prefix + name] = param.gather_all().detach() # sync operation + destination[prefix + name] = param.clone().detach() # sync operation if config['save_param_to_cpu']: destination[prefix + name] = destination[prefix + name].cpu() else: @@ -110,14 +113,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if config['load_param_gather']: + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + else: + verify_shape = param.shape if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' .format(key, input_param.shape, verify_shape)) try: with torch.no_grad(): - if isinstance(param, DistributedParameter): + if isinstance(param, DistributedParameter) and config['load_param_gather']: tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: input_param = tp_split_tensor(input_param, tp_split_dim) diff --git a/bmtrain/store.py b/bmtrain/store.py index 2a3ee02c..f8cf7c19 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -11,6 +11,7 @@ from typing import Mapping import threading import bmtrain as bmt +import os def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): @@ -24,6 +25,21 @@ def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): destination._metadata = OrderedDict() model._save_to_state_dict(destination, prefix, False) +def _save_to_each_rank(model : torch.nn.Module, destination=None, prefix=''): + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, 0, destination, prefix) + for name, module in model._modules.items(): + if module is not None: + _save_to_each_rank(module, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): if destination is None: destination = OrderedDict() @@ -88,7 +104,7 @@ def async_save_to_file(state_dict, file_path): config['finish_save'] = True print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False, save_gather : bool=True): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. @@ -100,11 +116,18 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): Examples: - >>> bmtrain.save(model, "model.pt") + >>> bmtrain """ torch.cuda.synchronize() - state_dict = _save_to_rank0(model) - if config["rank"] == 0: + if save_gather: + save_method = _save_to_rank0 + else: + save_method = _save_to_each_rank + file_name = f"{file_name}_rank_{bmt.rank()}" + tmp = bmt.config['save_param_gather'] + bmt.config['save_param_gather'] = save_gather + state_dict = save_method(model) + if config["rank"] == 0 or not save_gather: if non_blocking is False: torch.save(state_dict, file_name) else: @@ -118,6 +141,9 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) config['save_thread'].start() bmt.synchronize() + bmt.config['save_param_gather'] = tmp + + DTYPE_LIST = [ torch.float64, @@ -299,7 +325,7 @@ def __iter__(self): # pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`. return iter(self.keys()) -def load(model : torch.nn.Module, file_name : str, strict : bool = True): +def load(model : torch.nn.Module, file_name : str, strict : bool = True, load_gather : bool = True): """Loads the model from the file. Similar to torch.load, but it uses less memory when loading large models. @@ -312,14 +338,39 @@ def load(model : torch.nn.Module, file_name : str, strict : bool = True): Example: >>> bmtrain.load(model, "model.pt", strict=True) """ - if config['rank'] == 0: - state_dict = DistributedStateDictWrapper(torch.load(file_name)) + tmp = config['load_param_gather'] + config['load_param_gather'] = load_gather + if load_gather: + if config['rank'] == 0: + state_dict = DistributedStateDictWrapper(torch.load(file_name)) + else: + state_dict = DistributedStateDictWrapper({}) else: - state_dict = DistributedStateDictWrapper({}) + if "rank" not in file_name: + file_name = f"{file_name}_rank_{bmt.rank()}" + state_dict = torch.load(file_name) ret = model.load_state_dict( state_dict, strict = strict ) + config['load_param_gather'] = tmp torch.cuda.synchronize() return ret + +def clean(file_name : str): + """Cleans the file. + + Args: + file_name (str): The file name of the checkpoint. + + Example: + >>> bmtrain.clean("model.pt") + """ + if bmt.rank() == 0: + parent = os.path.dirname(os.path.abspath(file_name)) + for f in os.listdir(parent): + if f.startswith(file_name): + os.remove(os.path.join(parent, f)) + + diff --git a/tests/test_load_ckpt.py b/tests/test_load_ckpt.py index 0eb4f95f..7de6b590 100644 --- a/tests/test_load_ckpt.py +++ b/tests/test_load_ckpt.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import bmtrain as bmt import os +from collections import OrderedDict class Linear_Normal(torch.nn.Module): def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: @@ -36,25 +37,30 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp def forward(self, input): return F.linear(input, self.weight, self.bias) +def test_save_load(m): + bmt.save(m, "test.pt", non_blocking=False) + bmt.load(m, "test.pt") + bmt.save(m, "test.pt", non_blocking=True) + bmt.load(m, "test.pt") + bmt.save(m, "test.pt", non_blocking=False, save_gather=True) + bmt.load(m, "test.pt", load_gather=True) + bmt.clean("test.pt") + def test_main(): - ckpt_path = "test_ckpt.pt" # Transformer BlockList m = Linear_Normal(256, 256).cuda() m2 = bmt.TransformerBlockList([bmt.Block(Linear_BMT(256, 256))]) - if bmt.rank() == 0: - torch.save(m.state_dict(), ckpt_path) - dic2 = m.state_dict() - dic2["0.weight"] = dic2.pop("weight") - dic2["0.bias"] = dic2.pop("bias") - m2.load_state_dict(dic2) + m2_state = m.state_dict().copy() + m2_state["0.weight"] = m2_state.pop("weight") + m2_state["0.bias"] = m2_state.pop("bias") + test_save_load(m2) + m2.load_state_dict(m2_state) for key in m.state_dict(): bmt_key = f"0.{key}" assert bmt_key in m2.state_dict(), "wrong key in bmtrain model" assert (m2.state_dict()[bmt_key].cuda() == m.state_dict()[key]).all() , "wrong param in bmtrain model" - if bmt.rank() == 0: - os.remove(ckpt_path) - print("Transformer Blocklist load_state_dict and state_dict test passed") + print("Transformer Blocklist load_state_dict ,state_dict, bmt.load/save test passed") # Block m3 = bmt.Block(Linear_BMT(256, 256)) @@ -62,7 +68,8 @@ def test_main(): for key in m.state_dict(): assert key in m3.state_dict(), "wrong key in bmtrain model" assert (m.state_dict()[key] == m3.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" - print("Block load_state_dict and state_dict test passed") + test_save_load(m2) + print("Block load_state_dict ,state_dict, bmt.load/save test passed") # normal Distributed module m4 = Linear_BMT(256, 256) @@ -70,7 +77,8 @@ def test_main(): for key in m.state_dict(): assert key in m4.state_dict(), "wrong key in bmtrain model" assert (m.state_dict()[key] == m4.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" - print("bmt.distributedmodule load_state_dict and state_dict test passed") + test_save_load(m2) + print("bmt.distributedmodule load_state_dict, state_dict, bmt.load/save test passed") if __name__ == "__main__": bmt.init_distributed() From af3458d23f91bfef6eb1589b3f7b23462f5109d8 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 8 May 2024 19:31:22 +0800 Subject: [PATCH 39/43] fix load model pipe --- bmtrain/pipe/store.py | 2 +- bmtrain/store.py | 1 - example/pipe_train.py | 4 ++++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/bmtrain/pipe/store.py b/bmtrain/pipe/store.py index 2f9aaacc..bbfc6ab9 100644 --- a/bmtrain/pipe/store.py +++ b/bmtrain/pipe/store.py @@ -53,7 +53,7 @@ def get_state_dict_pipe(path): del param return state_dict -def load_model_pipe(model, path, load_whole=True): +def load_model_pipe(model, path, load_whole=False): """ load_whole: Boolean, if True, load from the whole model file, else load model from the pipeline/tensor parallel model file """ diff --git a/bmtrain/store.py b/bmtrain/store.py index afe1cb66..488fb0b6 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -119,7 +119,6 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False, sa """ torch.cuda.synchronize() - if config["rank"] == 0: if save_gather: save_method = _save_to_rank0 else: diff --git a/example/pipe_train.py b/example/pipe_train.py index 64204a00..28c7ad1c 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -6,6 +6,7 @@ from bmtrain.global_var import config from bmtrain import inspect from bmtrain.pipe import pipeline_forward_backward +from bmtrain.pipe import load_model_pipe, save_model_pipe from typing import Iterable def main(): @@ -30,6 +31,9 @@ def main(): bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() + # test save/load + save_model_pipe(model, "pipe.pt") + load_model_pipe(model, "pipe.pt") # data # generate dummy data for each rank From 0305c595ade2e609d91dad499d1bc32846482afe Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 10 May 2024 21:17:07 +0800 Subject: [PATCH 40/43] fix Block load param logic --- bmtrain/block_layer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 3d4fb400..880efd25 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -363,26 +363,28 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # copy to buffer verify_size = verify_shape.numel() assert input_param.numel() == verify_size + contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + offset_st = max(storage_st - param_st, 0) + to_offset_st = offset_st + param_st - storage_st if not config['load_param_gather']: + partition_numel= len(contiguous_param) + torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (partition_numel,))[:] = \ + contiguous_param[:] continue - contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() - tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) - offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) + to_offset_end = offset_end + param_st - storage_st assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] del contiguous_param From 0e22e968e012a7a98e41a53639dce007eb8f7364 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Sat, 11 May 2024 15:07:50 +0800 Subject: [PATCH 41/43] fix load partition --- bmtrain/block_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 880efd25..d15d72bd 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -369,9 +369,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, offset_st = max(storage_st - param_st, 0) to_offset_st = offset_st + param_st - storage_st if not config['load_param_gather']: - partition_numel= len(contiguous_param) + partition_numel= contiguous_param.numel() torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (partition_numel,))[:] = \ - contiguous_param[:] + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), 0, (partition_numel,))[:] continue tp_split_dim = param._tp_split_dim @@ -771,4 +771,4 @@ def add_tail(self, module, use_checkpoint=False): return DummyForward else: self._add_tail(module) - return module \ No newline at end of file + return module From 88601be3fcebe7a22381a59b80366317d7a1977f Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 13 May 2024 20:21:34 +0800 Subject: [PATCH 42/43] fix OOM caused by PipeDreamBlockList.init_param_storage --- bmtrain/block_layer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d15d72bd..afbc6ef3 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -664,8 +664,10 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, use_checkpoint=False) m.init_param_storage() partition_modules.append(m) else: - m.init_param_storage() - del m + #m.init_param_storage() + for name, param in m.named_parameters(): + c = OpAllGather.apply(param) + del param super().__init__(partition_modules, num_hidden, mode=mode) self.fisrt_module = (self._modules['0'],) self.last_module = (self._modules[str(len(self._modules) - 1)],) @@ -712,9 +714,13 @@ def _add_head(self, module): def add_head(self, module, use_checkpoint=False): module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) - module.init_param_storage() if config['topology'].pipe_rank != 0: + for name, param in module.named_parameters(): + c = OpAllGather.apply(param) + del param return DummyForward + else: + module.init_param_storage() self._add_head(module) return module @@ -766,9 +772,12 @@ def _add_tail(self, module): def add_tail(self, module, use_checkpoint=False): module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) - module.init_param_storage() if config['topology'].pipe_rank != config['topology'].pipe_size - 1: + for name, param in module.named_parameters(): + c = OpAllGather.apply(param) + del param return DummyForward else: + module.init_param_storage() self._add_tail(module) return module From 053eee5a965a0fa3deb8146d2ad32ea0166f3a88 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 9 Jul 2024 00:50:04 +0800 Subject: [PATCH 43/43] add check_overflow even no loss scale enable --- bmtrain/optim/optim_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 1a98ed92..e45daa23 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -136,6 +136,12 @@ def step(self): self.zero_grad() return for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers): + try: + check_overflow(optimizer.param_groups) + except OverflowError: + has_overflow = True + print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor)) + break if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer: optimizer.step(scale=self.loss_scale) else: