diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index c12b5a28..67738e14 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -57,7 +57,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); -DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -114,6 +114,7 @@ const std::unordered_map kModelToConfigs = { DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -253,8 +254,7 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config - = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] @@ -266,7 +266,7 @@ void Train(const nn::parallel::Rank &rank) { // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank, ddp_config); } @@ -295,7 +295,7 @@ void Train(const nn::parallel::Rank &rank) { auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; - if (FLAGS_use_distributed_optimizer) { + if (FLAGS_zero_stage >= 1) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 117551d5..fadf205e 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -56,7 +56,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); -DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -100,6 +100,7 @@ constexpr char kDtypeBF16[] = "bfloat16"; DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -223,8 +224,7 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config - = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] @@ -237,7 +237,7 @@ void Train(const nn::parallel::Rank &rank) { // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank, ddp_config); } @@ -274,7 +274,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters"; } - if (FLAGS_use_distributed_optimizer) { + if (FLAGS_zero_stage >= 1) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; diff --git a/infini_train/include/autograd/function_hook.h b/infini_train/include/autograd/function_hook.h index 0cdd4170..734a0930 100644 --- a/infini_train/include/autograd/function_hook.h +++ b/infini_train/include/autograd/function_hook.h @@ -14,9 +14,22 @@ class ProcessGroup; namespace infini_train::autograd { +class PreAccumulateGradHook { +public: + virtual void operator()(const std::shared_ptr &grad_output) = 0; + + // Return true if this hook has handled the current gradient accumulation. + virtual bool TryBypassAccumulate(const std::shared_ptr &, const std::shared_ptr &, bool, float) { + return false; + } + + virtual ~PreAccumulateGradHook() = default; +}; + class PostAccumulateGradHook { public: virtual void operator()(const std::shared_ptr &tensor) = 0; + virtual ~PostAccumulateGradHook() = default; }; diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h index 99d30703..729456ce 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h @@ -30,16 +30,17 @@ class DistributedDataParallelConfig { // Ref: // https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel_config.py // ====================================================== - // Whether to enable DistributedOptimizer (ZeRO-1 equivalent). - // When set true: - // 1) Gradients/params are managed by ParamAndGradBuffer and reduced in groups. - // 2) The classic DDP reducer path is not used (i.e., disable reducer/bucketing in the DDP sense). - bool use_distributed_optimizer = false; - // Whether to overlap gradient reduce-scatter/all-reduce with backward compute. // In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready. bool overlap_grad_reduce = true; + // ZeRO-DP stage for memory optimization. + // ZeRO-0: Disabled; use the classic DDP reducer path. + // ZeRO-1: Optimizer states partitioning + // ZeRO-2: Gradients partitioning + // ZeRO-3: Parameters partitioning + int zero_stage = 0; + // Whether to overlap parameter all-gather with forward compute. bool overlap_param_gather = true; @@ -58,8 +59,7 @@ class DistributedDataParallelConfig { // Maximum number of parameters in each ParamAndGradBucket. // NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps. - // TODO(zbl): To unify the definition of bucket_size argument for users - size_t bucket_size_in_elements = 40000000; + size_t bucket_size_in_elements = 1000000; // Whether to pad bucket sizes to improve NCCL bus bandwidth utilization. bool pad_buckets_for_high_nccl_busbw = false; diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index c83fe9a5..4af99d81 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -21,9 +21,22 @@ class Work; namespace infini_train::nn::parallel { class ParamAndGradBucket { public: + /** + * @brief Create bucket metadata and flat-buffer views. + * + * @param params Parameters in bucket-local order. + * @param param_data View of this bucket in the flat parameter buffer, or nullptr if unused. + * @param param_dtype Parameter storage dtype. + * @param grad_data View of this bucket in the flat gradient buffer; nullptr for ZeRO-2. + * @param grad_dtype Gradient storage dtype. + * @param offset Bucket start offset in the owning flat buffer. + * @param num_elements_unpadded Bucket element count before padding. + * @param gradient_scaling_factor Pre-collective gradient scale factor. + * @param bucket_id Bucket index in the owning ParamAndGradBuffer. + */ ParamAndGradBucket(const std::vector> ¶ms, const std::shared_ptr ¶m_data, - const std::shared_ptr &grad_data, size_t offset, size_t num_elements_unpadded, - float gradient_scaling_factor, size_t bucket_id); + DataType param_dtype, const std::shared_ptr &grad_data, DataType grad_dtype, + size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id); size_t bucket_id() const { return bucket_id_; } @@ -33,6 +46,10 @@ class ParamAndGradBucket { const std::shared_ptr &grad_data() const { return grad_data_; } + DataType param_dtype() const { return param_dtype_; } + + DataType grad_dtype() const { return grad_dtype_; } + size_t offset() const { return offset_; } size_t num_elements_unpadded() const { return num_elements_unpadded_; } @@ -49,6 +66,8 @@ class ParamAndGradBucket { std::vector> params_; std::shared_ptr param_data_; std::shared_ptr grad_data_; + DataType param_dtype_; + DataType grad_dtype_; size_t offset_ = 0; size_t num_elements_unpadded_ = 0; @@ -59,6 +78,14 @@ class ParamAndGradBucket { class ParamAndGradBucketGroup { public: + /** + * @brief Group buckets that synchronize gradients and parameters together. + * + * @param buckets Buckets owned by this group. + * @param collective_pg Process group for gradient and parameter collectives. + * @param process_group_size Number of ranks in collective_pg. + * @param ddp_config DDP/DistributedOptimizer behavior config. + */ ParamAndGradBucketGroup(const std::vector> &buckets, const ProcessGroup *collective_pg, size_t process_group_size, DistributedDataParallelConfig ddp_config); @@ -73,6 +100,10 @@ class ParamAndGradBucketGroup { // Start grad reduce void StartGradSync(); + // Accumulate a parameter grad into bucket storage for the ZeRO-2 pre-accumulate hook. + void AccumulateParamGrad(const std::shared_ptr ¶meter, const std::shared_ptr &grad, + bool overwrite, float learning_rate); + // Wait for gradient reduce to complete void FinishGradSync(); @@ -87,6 +118,9 @@ class ParamAndGradBucketGroup { const std::vector> &buckets() const { return buckets_; } + // ZeRO-2: Get a bucket's local grad shard buffer + std::shared_ptr GetLocalGradShardBuffer(size_t bucket_idx) const; + const DistributedDataParallelConfig &config() const { return ddp_config_; } private: @@ -98,12 +132,19 @@ class ParamAndGradBucketGroup { std::unordered_set params_; std::unordered_set params_with_grad_; + // Tensor -> (Bucket, Bucket Index) + std::unordered_map, size_t>> param_to_bucket_; // TODO(zbl): Implement CoalescedWork for aggregate works // According to Megatron-LM's _coalescing_manager std::vector> grad_reduce_work_list_; + std::vector grad_reduce_bucket_indices_; std::vector> param_gather_work_list_; + // ZeRO-2: persistent grad shard buffers and temporary full grad buffers + std::vector> grad_shard_buffer_list_; + std::vector> temp_full_grad_buffer_list_; + std::shared_ptr next_param_gather_bucket_group_ = nullptr; std::vector>> param_buffer_shard_list_; @@ -117,6 +158,15 @@ class ParamAndGradBucketGroup { class ParamAndGradBuffer { public: + /** + * @brief Own flat buffers and bucket metadata for one dtype group. + * + * @param params Parameters with the same parameter/gradient dtype pair. + * @param param_dtype Flat parameter-buffer dtype. + * @param grad_dtype Gradient storage dtype. + * @param ddp_pg Data-parallel process group used by derived bucket groups. + * @param ddp_config DDP/DistributedOptimizer bucketing and padding config. + */ ParamAndGradBuffer(const std::vector> ¶ms, DataType ¶m_dtype, DataType &grad_dtype, const ProcessGroup *ddp_pg, DistributedDataParallelConfig ddp_config); diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 12f45f57..b6de3340 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -18,6 +18,7 @@ namespace infini_train { namespace autograd { class Function; class AccumulateGrad; +class PreAccumulateGradHook; class PostAccumulateGradHook; } // namespace autograd @@ -230,8 +231,10 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr grad_accumulator(); void ResetAccumulator(); - void RegisterPostAccumulateGradHook(std::shared_ptr hook); + void RegisterPreAccumulateGradHook(std::shared_ptr hook); + autograd::PreAccumulateGradHook *pre_accumulate_grad_hook() const; + void RegisterPostAccumulateGradHook(std::shared_ptr hook); autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const; private: @@ -243,6 +246,7 @@ class Tensor : public std::enable_shared_from_this { // FIXME(dcj): This should be a weak_ptr. The autograd graph should hold // a strong reference to the accumulator to manage its lifetime. std::shared_ptr grad_accumulator_ = nullptr; + std::shared_ptr pre_accumulate_grad_hook_ = nullptr; std::shared_ptr post_accumulate_grad_hook_ = nullptr; bool grad_overwrite_once_ = false; diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index d9b70bc1..0c34819f 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -21,7 +21,6 @@ AccumulateGrad::Backward(const std::vector> &grad_output CHECK_EQ(grad_outputs.size(), 1); auto grad_output = grad_outputs[0]; - auto grad = tensor_->grad(); auto device = tensor_->GetDevice(); core::DeviceGuard guard(device); @@ -33,8 +32,19 @@ AccumulateGrad::Backward(const std::vector> &grad_output "running before autograd). The grad is not cast and will be used as-is."; } + const bool overwrite = tensor_->ConsumeGradOverwriteFlag(); + auto pre_hook = tensor_->pre_accumulate_grad_hook(); + if (pre_hook) { + if (pre_hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) { + tensor_->ResetAccumulator(); + return {}; + } + (*pre_hook)(grad_output); + } + + auto grad = tensor_->grad(); if (grad) { - if (tensor_->ConsumeGradOverwriteFlag()) { + if (overwrite) { // If the tensor is marked to overrite its current grad on next grad update // See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()` // NOTE(zbl): must copy, cannot change grad buffer address @@ -48,9 +58,9 @@ AccumulateGrad::Backward(const std::vector> &grad_output auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); tensor_->set_grad(new_grad); } - auto hook = tensor_->post_accumulate_grad_hook(); - if (hook != nullptr) { - (*hook)(tensor_->grad()); + auto post_hook = tensor_->post_accumulate_grad_hook(); + if (post_hook != nullptr) { + (*post_hook)(tensor_->grad()); } tensor_->ResetAccumulator(); } diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 82f143a9..a3bfe008 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" +#include #include #include #include @@ -18,19 +19,25 @@ namespace infini_train::nn::parallel { namespace { constexpr char kModuleName[] = "module"; + } // namespace DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, const Rank &rank, const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { + CHECK(ddp_config_.zero_stage >= 0 && ddp_config_.zero_stage <= 3) + << "DistributedDataParallel: zero_stage must be in 0/1/2/3."; + if (ddp_config_.zero_stage == 3) { + LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet."; + } for (auto ¶m : module->Parameters()) { if (!param->requires_grad()) { continue; } auto device = param->GetDevice(); CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module"; - if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) { + if (!ddp_config.gradient_bucketing_enabled && ddp_config.zero_stage < 1) { auto hook = std::make_unique( function::ReduceOpType::kAvg, ddp_pg_); param->RegisterPostAccumulateGradHook(std::move(hook)); @@ -42,7 +49,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod } modules_[kModuleName] = std::move(module); - if (ddp_config.use_distributed_optimizer) { + if (ddp_config.zero_stage >= 1) { BuildParamAndGradBuffers(); RegisterBackwardHooks(); } else if (ddp_config.gradient_bucketing_enabled) { @@ -91,7 +98,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { // TODO(zbl): option for disable bucketing bucket_groups_ = PartitionBuckets(param_grad_buffers_, /*force_single_bucket_group=*/false); - if (ddp_config_.use_distributed_optimizer && ddp_config_.overlap_param_gather) { + if (ddp_config_.zero_stage >= 1 && ddp_config_.overlap_param_gather) { auto num_bucket_groups = bucket_groups_.size(); for (auto i = num_bucket_groups - 1; i > 0; --i) { bucket_groups_[i]->SetNextParamGatherBucketGroup(bucket_groups_[i - 1]); @@ -116,6 +123,47 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { } void DistributedDataParallel::RegisterBackwardHooks() { + if (ddp_config_.zero_stage >= 2) { + // NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's + // temporary full-grad buffer, then mark the bucket ready for reduce-scatter. + class Zero2PreAccumulateGradHook final : public autograd::PreAccumulateGradHook { + public: + explicit Zero2PreAccumulateGradHook(std::weak_ptr group) + : group_(std::move(group)) {} + + bool TryBypassAccumulate(const std::shared_ptr ¶m, const std::shared_ptr &grad_output, + bool overwrite, float learning_rate) override { + if (auto group = group_.lock(); group) { + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + } + + void operator()(const std::shared_ptr &) override {} + + private: + std::weak_ptr group_; + }; + + auto &module = modules_.at(kModuleName); + for (auto ¶m : module->Parameters()) { + if (!param->requires_grad()) { + continue; + } + auto it = param_to_bucket_group_.find(param.get()); + CHECK(it != param_to_bucket_group_.end()); + + std::weak_ptr weak_group = it->second; + auto hook = std::make_unique(weak_group); + param->RegisterPreAccumulateGradHook(std::move(hook)); + } + return; + } + class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { public: DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr param) @@ -147,7 +195,7 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr ¶m) auto it = param_to_bucket_group_.find(param.get()); if (it != param_to_bucket_group_.end()) { CHECK(param->requires_grad()); - if (ddp_config_.overlap_grad_reduce) { + if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) { CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True"; } @@ -163,7 +211,7 @@ DistributedDataParallel::Forward(const std::vector> &inp if (reducer_) { reducer_->PrepareForBackward(); } - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { for (auto buffer : param_grad_buffers_) { buffer->RebindGradViews(); } } return outputs; diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..3b86106b 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -35,10 +35,13 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { shard_params_.clear(); for (const auto &group : bucket_groups_) { - for (const auto &bucket : group->buckets()) { + const bool use_grad_shard = group->config().zero_stage >= 2; + const auto &buckets = group->buckets(); + for (size_t bucket_idx = 0; bucket_idx < buckets.size(); ++bucket_idx) { + const auto &bucket = buckets[bucket_idx]; auto bucket_param = bucket->param_data(); - auto bucket_grad = bucket->grad_data(); + auto bucket_grad = use_grad_shard ? group->GetLocalGradShardBuffer(bucket_idx) : bucket->grad_data(); CHECK(bucket_param) << "DistributedOptimizer requires param buffer."; CHECK(bucket_grad) << "DistributedOptimizer requires grad buffer."; @@ -65,7 +68,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { CHECK_GT(piece_numel, 0); const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype()); - const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype()); + // Adjust the offset since bucket_grad is already the shard of grad under ZeRO-2. + auto offset = use_grad_shard ? (local_start - bucket_shard_start) : local_start; + size_t grad_piece_offset_bytes = offset * kDataTypeToSize.at(bucket_grad->Dtype()); auto param_piece = std::make_shared(*bucket_param, param_piece_offset_bytes, std::vector{static_cast(piece_numel)}); @@ -74,6 +79,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { std::vector{static_cast(piece_numel)}); param_piece->set_grad(grad_piece); + // NOTE(zbl): Do not call `param->set_grad(grad_piece);` under ZeRO-2. + // The base optimizer updates param_piece views only; original param->grad() + // would be a partial flattened shard and does not represent the full parameter grad. shard_params_.push_back(param_piece); } } @@ -124,7 +132,7 @@ void DistributedOptimizer::Step() { // 3. Gather updated param shards back to full params StartParamSync(/*force_sync=*/false); - // FIXME(zbl): Call sync before param is actually used in next step + // TODO(zbl): Delay sync call until param is actually used in next step FinishParamSync(/*skip_next_bucket_dispatch=*/true); } diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 75a21f63..6771654f 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -6,6 +6,7 @@ #include "glog/logging.h" +#include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h" #include "infini_train/include/nn/parallel/global.h" @@ -53,12 +54,12 @@ std::vector> ShardBuffer(const std::shared_ptr b } // namespace ParamAndGradBucket::ParamAndGradBucket(const std::vector> ¶ms, - const std::shared_ptr ¶m_data, - const std::shared_ptr &grad_data, size_t offset, + const std::shared_ptr ¶m_data, DataType param_dtype, + const std::shared_ptr &grad_data, DataType grad_dtype, size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id) - : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), - grad_data_(std::move(grad_data)), offset_(offset), num_elements_unpadded_(num_elements_unpadded), - gradient_scaling_factor_(gradient_scaling_factor) { + : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), param_dtype_(param_dtype), + grad_data_(std::move(grad_data)), grad_dtype_(grad_dtype), offset_(offset), + num_elements_unpadded_(num_elements_unpadded), gradient_scaling_factor_(gradient_scaling_factor) { size_t current_offset = 0; for (const auto ¶m : params_) { auto numel = param->NumElements(); @@ -85,7 +86,7 @@ void ParamAndGradBucket::ScaleGradients(float scaling_factor) { // FIXME(zbl): should perform in-place multiply // grad_data_ *= scaling_factor; - LOG(FATAL) << "ParamAndGradBucket: Should not arrive here"; + LOG(FATAL) << "ParamAndGradBucket::ScaleGradients(): Inplace multiply not implemented yet."; } ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector> &buckets, @@ -97,26 +98,52 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vectorparams()) { params_.insert(param.get()); } + for (size_t bucket_idx = 0; bucket_idx < buckets_.size(); ++bucket_idx) { + const auto &bucket = buckets_[bucket_idx]; + for (const auto ¶m : bucket->params()) { + params_.insert(param.get()); + param_to_bucket_[param.get()] = {bucket, bucket_idx}; + } } if (rank_in_collective_pg_ == -1) { auto param = *params_.begin(); - // FIXME(zbl): get correct rank in multi-node settings - rank_in_collective_pg_ = collective_pg_->GetGroupRank(param->GetDevice().Rank().thread_rank()); + rank_in_collective_pg_ = collective_pg_->GetGroupRank(param->GetDevice().Rank().GlobalRank()); } param_buffer_shard_list_.resize(buckets_.size()); grad_buffer_shard_list_.resize(buckets_.size()); + + grad_shard_buffer_list_.resize(buckets_.size()); + temp_full_grad_buffer_list_.resize(buckets_.size()); + + if (ddp_config_.zero_stage >= 2) { + for (size_t i = 0; i < buckets_.size(); ++i) { + auto bucket = buckets_[i]; + CHECK(bucket->param_data()) << "ParamAndGradBucketGroup: param buffer required for ZeRO-2."; + const size_t bucket_numel = bucket->param_data()->NumElements(); + if (bucket_numel == 0) { + continue; + } + CHECK_EQ(bucket_numel % collective_pg_size_, 0); + const size_t shard_numel = bucket_numel / collective_pg_size_; + auto param = bucket->params().front(); + grad_shard_buffer_list_[i] = AllocateFlatBuffer(shard_numel, bucket->grad_dtype(), param->GetDevice()); + } + } } void ParamAndGradBucketGroup::Reset() { params_with_grad_.clear(); grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); param_gather_work_list_.clear(); is_last_microbatch_ = true; grad_reduce_dispatched_ = false; param_gather_dispatched_ = false; + + if (ddp_config_.zero_stage >= 2) { + std::fill(temp_full_grad_buffer_list_.begin(), temp_full_grad_buffer_list_.end(), nullptr); + } } void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr ¶meter) { @@ -127,18 +154,19 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p return; } - // Only register grads as ready when processing the last microbatch + // TODO(zbl): Only register grads as ready and trigger grad sync when processing the last microbatch + // For now, is_last_microbatch_ is always true if (is_last_microbatch_) { if (!parameter || params_.find(parameter.get()) == params_.end()) { return; } const bool inserted = params_with_grad_.insert(parameter.get()).second; - if (!inserted) { - LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a " - "bucket group."; - return; - } + // TODO(zbl): check this if sync is only done in last mircobatch + // if (!inserted) { + // LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a + // bucket group."; return; + // } if (params_with_grad_.size() == params_.size()) { // All param grads are ready in this group, trigger grad sync @@ -147,6 +175,65 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p } } +void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr ¶meter, + const std::shared_ptr &grad, bool overwrite, + float learning_rate) { + if (ddp_config_.zero_stage < 2) { + LOG(FATAL) << "ParamAndGradBucketGroup: AccumulateParamGrad called when ZeRO-2 is disabled."; + return; + } + if (!grad || !parameter) { + return; + } + + auto it = param_to_bucket_.find(parameter.get()); + if (it == param_to_bucket_.end()) { + return; + } + auto bucket = it->second.first; + const size_t bucket_idx = it->second.second; + + size_t param_start_in_bucket = 0, param_end_in_bucket = 0; + auto found = bucket->GetTensorLocInBucket(parameter, param_start_in_bucket, param_end_in_bucket); + if (!found) { + return; + } + + if (!temp_full_grad_buffer_list_[bucket_idx]) { + CHECK(bucket->param_data()) << "ParamAndGradBucketGroup: param buffer required for ZeRO-2."; + const size_t bucket_numel = bucket->param_data()->NumElements(); + if (bucket_numel == 0) { + return; + } + temp_full_grad_buffer_list_[bucket_idx] + = AllocateFlatBuffer(bucket_numel, bucket->grad_dtype(), parameter->GetDevice()); + temp_full_grad_buffer_list_[bucket_idx]->Fill(0.0f); + } + + const size_t offset_bytes = param_start_in_bucket * kDataTypeToSize.at(bucket->grad_dtype()); + auto bucket_grad_view + = std::make_shared(*temp_full_grad_buffer_list_[bucket_idx], offset_bytes, parameter->Dims()); + + if (overwrite) { + bucket_grad_view->CopyFrom(*grad); + } else { + auto device = parameter->GetDevice(); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); + kernel.Call(grad, learning_rate, bucket_grad_view); + } +} + +std::shared_ptr ParamAndGradBucketGroup::GetLocalGradShardBuffer(size_t bucket_idx) const { + if (ddp_config_.zero_stage < 2) { + LOG(WARNING) << "ParamAndGradBucketGroup: GetLocalGradShardBuffer called when ZeRO-2 is disabled."; + return nullptr; + } + if (bucket_idx >= grad_shard_buffer_list_.size()) { + return nullptr; + } + return grad_shard_buffer_list_[bucket_idx]; +} + void ParamAndGradBucketGroup::StartGradSync() { if (!collective_pg_) { LOG(FATAL) << "ParamAndGradBucketGroup: StartGradSync() called with null collective_pg_."; @@ -163,23 +250,40 @@ void ParamAndGradBucketGroup::StartGradSync() { // TODO(zbl): Check NaN/Inf/too large in grad (options in DistributedDataParallelConfig) - for (auto bucket : buckets_) { - if (bucket->gradient_scaling_factor() != 1.f) { - bucket->ScaleGradients(bucket->gradient_scaling_factor()); - } - } - auto reduce_op = ddp_config_.average_in_collective ? function::ReduceOpType::kAvg : function::ReduceOpType::kSum; auto async_op = ddp_config_.overlap_grad_reduce && (ddp_config_.num_distributed_optimizer_instances == 1); for (auto i = 0; i < buckets_.size(); ++i) { auto bucket = buckets_[i]; + + if (ddp_config_.zero_stage >= 2) { + auto full_grad_buffer = temp_full_grad_buffer_list_[i]; + if (!full_grad_buffer) { + continue; + } + if (bucket->gradient_scaling_factor() != 1.f) { + // FIXME(zbl): should perform in-place multiply + // full_grad_buffer *= bucket->gradient_scaling_factor(); + LOG(FATAL) << "ParamAndGradBucketGroup::StartGradSync(): Inplace multiply not implemented yet."; + } + CHECK(grad_shard_buffer_list_[i]) << "ParamAndGradBucketGroup: grad shard buffer missing."; + auto local_data_view = grad_shard_buffer_list_[i]; + grad_reduce_work_list_.push_back( + collective_pg_->ReduceScatter(local_data_view, full_grad_buffer, reduce_op, async_op)); + grad_reduce_bucket_indices_.push_back(i); + continue; + } + + if (bucket->gradient_scaling_factor() != 1.f) { + bucket->ScaleGradients(bucket->gradient_scaling_factor()); + } + std::shared_ptr grad_buffer = bucket->grad_data(); if (!grad_buffer) { continue; } - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { if (grad_buffer_shard_list_[i].empty()) { grad_buffer_shard_list_[i] = ShardBuffer(grad_buffer, collective_pg_size_); } @@ -193,6 +297,8 @@ void ParamAndGradBucketGroup::StartGradSync() { } grad_reduce_dispatched_ = true; + // TODO(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch + params_with_grad_.clear(); } void ParamAndGradBucketGroup::FinishGradSync() { @@ -203,21 +309,40 @@ void ParamAndGradBucketGroup::FinishGradSync() { if (!ddp_config_.overlap_grad_reduce) { // Assume reduce ops are synced and no work needs to be resolved grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); + grad_reduce_dispatched_ = false; + return; + } + + if (grad_reduce_work_list_.empty()) { + grad_reduce_bucket_indices_.clear(); grad_reduce_dispatched_ = false; return; } - CHECK(!grad_reduce_work_list_.empty()) - << "ParamAndGradBucketGroup: Communication call has not been issued for this bucket(" - << params_with_grad_.size() << "/" << params_.size() << " params have grad available)"; + if (ddp_config_.zero_stage >= 2) { + CHECK_EQ(grad_reduce_work_list_.size(), grad_reduce_bucket_indices_.size()) + << "ParamAndGradBucketGroup: grad reduce works and bucket indices are out of sync."; + for (size_t idx = 0; idx < grad_reduce_work_list_.size(); ++idx) { + auto &work = grad_reduce_work_list_[idx]; + work->WaitNonBlocking(); + const size_t bucket_idx = grad_reduce_bucket_indices_[idx]; + temp_full_grad_buffer_list_[bucket_idx].reset(); + } + grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); + grad_reduce_dispatched_ = false; + return; + } for (auto work : grad_reduce_work_list_) { work->WaitNonBlocking(); } grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); grad_reduce_dispatched_ = false; } void ParamAndGradBucketGroup::StartParamSync(bool force_sync) { - CHECK(ddp_config_.use_distributed_optimizer); + CHECK(ddp_config_.zero_stage >= 1); if (!collective_pg_) { LOG(ERROR) << "ParamAndGradBucketGroup: StartParamSync called with null collective_pg_."; @@ -253,7 +378,7 @@ void ParamAndGradBucketGroup::StartParamSync(bool force_sync) { } void ParamAndGradBucketGroup::FinishParamSync(bool skip_next_bucket_dispatch) { - if (!ddp_config_.use_distributed_optimizer || !ddp_config_.overlap_param_gather) { + if (ddp_config_.zero_stage < 1 || !ddp_config_.overlap_param_gather) { return; } @@ -302,7 +427,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // Param start must be multiple of 64 auto PadParamStartIfNeeded = [&](size_t start) -> size_t { - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { // According to Megatron-LM, make sure each param starts at 128B aligned address (by default align to 64 // elements for precision >=16-bit) return PadTo(start, kParamStartAlignElements); @@ -312,7 +437,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // Bucket size shoule be multiple of ddp size and 128 (sweet spot for NCCL) auto PadBucketEndIfNeeded = [&](size_t bucket_end_index) -> size_t { - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { // According to Megatron-LM, ensure that all buckets start at a memory address that is 256B // aligned(128 values since params and grads use >= 16-bit precision) size_t lcm_val = std::lcm(ddp_world_size_, kBucketEndAlignElements); @@ -384,7 +509,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) static_cast(0), std::plus()); CHECK(numel_unpadded_ <= numel_); - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { // numel must be multiple of ddp size (so that reduce-scatter could easily shard the buffer among ranks) CHECK_EQ(numel_ % ddp_world_size_, 0); } else { @@ -393,13 +518,17 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // 2. Allocate buffer auto device = params_.front()->GetDevice(); - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { param_buffer_ = AllocateFlatBuffer(numel_, param_dtype, device); } else { - // No param buffer needed if optimzer is not distributed + // No param buffer needed if optimizer is not distributed param_buffer_.reset(); } - grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + if (ddp_config_.zero_stage >= 2) { + grad_buffer_.reset(); + } else { + grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + } LOG(INFO) << "ParamAndGradBuffer: numel_unpadded=" << numel_unpadded_ << ", numel (padded)=" << numel_; @@ -412,7 +541,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) auto NewBucket = [&](const std::vector> &bucket_params, size_t start_index, size_t end_index, size_t num_elements_unpadded, size_t bucket_id) -> std::shared_ptr { - if (ddp_config_.use_distributed_optimizer) { + if (ddp_config_.zero_stage >= 1) { CHECK_EQ(start_index % ddp_world_size_, 0); CHECK_EQ(end_index % ddp_world_size_, 0); CHECK_EQ(bucket_indices_.at(bucket_id).first, start_index); @@ -424,14 +553,17 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) bucket_param_view = GetBufferView(param_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); } - std::shared_ptr bucket_grad_view = GetBufferView( - grad_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); + std::shared_ptr bucket_grad_view; + if (grad_buffer_) { + bucket_grad_view = GetBufferView(grad_buffer_, start_index, + std::vector{static_cast(end_index - start_index)}); + } // FIXME(zbl): Use default for now float gradient_scaling_factor = 1.0f; - auto bucket - = std::make_shared(bucket_params, bucket_param_view, bucket_grad_view, start_index, - num_elements_unpadded, gradient_scaling_factor, bucket_id); + auto bucket = std::make_shared(bucket_params, bucket_param_view, param_dtype, + bucket_grad_view, grad_dtype, start_index, + num_elements_unpadded, gradient_scaling_factor, bucket_id); for (auto param : bucket_params) { CHECK(param_bucket_map_.find(param.get()) == param_bucket_map_.end()) @@ -454,8 +586,11 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) param->SetData(*param_buffer_, param_start_index * kDataTypeToSize.at(param_buffer_->Dtype()), true); } - auto grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); - param->set_grad(grad_view); + std::shared_ptr grad_view; + if (grad_buffer_) { + grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); + param->set_grad(grad_view); + } // Save grad view for each params --i; grads_[i] = grad_view; @@ -496,7 +631,7 @@ void ParamAndGradBuffer::ScaleGradients(float scaling_factor) { // FIXME(zbl): should perform in-place multiply // grad_data_ *= scaling_factor; - LOG(FATAL) << "Should not arrive here"; + LOG(FATAL) << "ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet."; } void ParamAndGradBuffer::Reset(bool need_rebind) { @@ -506,7 +641,9 @@ void ParamAndGradBuffer::Reset(bool need_rebind) { if (!need_rebind) { grad_buffer_->Fill(0.f); } - need_rebind_grad_views_ = need_rebind; + // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad. + // It is constantly pointed to the shard of grad, so no need to rebind. + need_rebind_grad_views_ = need_rebind && (ddp_config_.zero_stage < 2); } void ParamAndGradBuffer::RebindGradViews() { @@ -514,10 +651,16 @@ void ParamAndGradBuffer::RebindGradViews() { return; } + if (!grad_buffer_) { + return; + } + CHECK_EQ(params_.size(), grads_.size()); for (size_t i = 0; i < params_.size(); ++i) { - params_[i]->set_grad(grads_[i]); - params_[i]->MarkGradOverwriteOnNextAccum(); + if (grads_[i]) { + params_[i]->set_grad(grads_[i]); + params_[i]->MarkGradOverwriteOnNextAccum(); + } } need_rebind_grad_views_ = false; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..3c2ae69b 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -559,6 +559,16 @@ void Tensor::ResetAccumulator() { } } +void Tensor::RegisterPreAccumulateGradHook(std::shared_ptr hook) { + CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; + + CHECK_EQ(grad_fn_, nullptr) << "pre accumulate grad hooks cannot be registered on non-leaf tensors"; + + pre_accumulate_grad_hook_ = hook; +} + +autograd::PreAccumulateGradHook *Tensor::pre_accumulate_grad_hook() const { return pre_accumulate_grad_hook_.get(); } + void Tensor::RegisterPostAccumulateGradHook(std::shared_ptr hook) { CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; diff --git a/scripts/test_config.json b/scripts/test_config.json index 2f061528..32c3e202 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -210,7 +210,18 @@ "num_iteration": 10, "batch_size": 10, "total_batch_size": 5120, - "use_distributed_optimizer": true + "zero_stage": 1 + } + }, + { + "id": "3_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "zero_stage": 2 } }, { @@ -221,7 +232,18 @@ "num_iteration": 10, "batch_size": 10, "total_batch_size": 5120, - "use_distributed_optimizer": true + "zero_stage": 1 + } + }, + { + "id": "3_bfloat16_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "zero_stage": 2 } }, { @@ -233,7 +255,19 @@ "batch_size": 40, "total_batch_size": 5120, "tensor_parallel": 4, - "use_distributed_optimizer": true + "zero_stage": 1 + } + }, + { + "id": "4_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "zero_stage": 2 } }, { @@ -245,7 +279,19 @@ "batch_size": 40, "total_batch_size": 5120, "tensor_parallel": 4, - "use_distributed_optimizer": true + "zero_stage": 1 + } + }, + { + "id": "4_bfloat16_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "zero_stage": 2 } }, { @@ -258,7 +304,20 @@ "total_batch_size": 5120, "tensor_parallel": 4, "sequence_parallel": true, - "use_distributed_optimizer": true + "zero_stage": 1 + } + }, + { + "id": "5_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "zero_stage": 2 } }, { @@ -271,7 +330,20 @@ "total_batch_size": 5120, "tensor_parallel": 4, "sequence_parallel": true, - "use_distributed_optimizer": true + "zero_stage": 1 + } + }, + { + "id": "5_bfloat16_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "zero_stage": 2 } }, { @@ -286,7 +358,7 @@ "sequence_parallel": true, "pipeline_parallel": 2, "virtual_pipeline_parallel": 2, - "use_distributed_optimizer": true + "zero_stage": 1 } }, { @@ -301,7 +373,7 @@ "sequence_parallel": true, "pipeline_parallel": 2, "virtual_pipeline_parallel": 2, - "use_distributed_optimizer": true + "zero_stage": 1 } } ]