Skip to content

Conversation

@zhengchenyu
Copy link
Contributor

Using replica groups offers the following advantages:

  • For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.

  • Checkpointing is performed only on replica_group_rank=0, guaranteeing constant checkpoint world size and avoiding the universal checkpoint transformations during scaling up or down.

We can achieve gradient all reduce within the replica group after backward and before optimizer.step, but we must wait for all buckets to complete, thus can not leverage concurrency advantages.

I know MICS has similar functionality, but currently only supports zero stage 3. Additionally, I want to use this feature for compatibility with architectures like TorchFT.

…ica groups.

Signed-off-by: zhengchenyu <zhengchenyu16@163.com>
@sfc-gh-truwase
Copy link
Collaborator

@zhengchenyu thanks for the PR. Can you provide some clarification for the motivation?

  • For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.

We already provide a form of this functionality in hpZ component of ZeRO++. Have you explored whether hpZ would meet your needs?

I know MICS has similar functionality, but currently only supports zero stage 3.

My understanding replica groups is only relevant for zero stage 3 since lower stages don't do parameter partitioning. Can you explain how replica groups exist in your workload?

@zhengchenyu
Copy link
Contributor Author

@sfc-gh-truwase Thanks for your review!
My main motivation is to support torchft to achieve fault tolerance. At the same time, I aim to solve the two following problems:

  • (1) During the forward and backward, parameter gather occurs on all machines.
  • (2) The zero checkpoint adjusts with the world size, leading to the universal checkpoint conversion.

Regarding zero++. It cannot solve problem (2). It can solve problem (1), but there is a cost involved, we must introduce extra ds_secondory_tensor. Moreover, in the first forward of each step, parameters still need to be collected on all machines.

Regarding MICS. For zero stage 3, these two problem do not exist. For stage 1/2, there are no problems (1), but if the optimizer parameters are considered when loading the checkpoint, there will be problem for issue (2).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants