Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -114,6 +114,7 @@ const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -252,8 +253,7 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
Expand All @@ -265,7 +265,7 @@ void Train(const nn::parallel::Rank &rank) {
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

Expand Down Expand Up @@ -294,7 +294,7 @@ void Train(const nn::parallel::Rank &rank) {
auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate);
std::shared_ptr<Optimizer> optimizer = nullptr;

if (FLAGS_use_distributed_optimizer) {
if (FLAGS_zero_stage >= 1) {
auto model_chunks = (pp_world_size > 1)
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
: std::vector<std::shared_ptr<nn::Module>>{model};
Expand Down
10 changes: 5 additions & 5 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -100,6 +100,7 @@ constexpr char kDtypeBF16[] = "bfloat16";
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -222,8 +223,7 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
Expand All @@ -236,7 +236,7 @@ void Train(const nn::parallel::Rank &rank) {
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.

auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

Expand Down Expand Up @@ -273,7 +273,7 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters";
}

if (FLAGS_use_distributed_optimizer) {
if (FLAGS_zero_stage >= 1) {
auto model_chunks = (pp_world_size > 1)
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
: std::vector<std::shared_ptr<nn::Module>>{model};
Expand Down
13 changes: 13 additions & 0 deletions infini_train/include/autograd/function_hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@ class ProcessGroup;

namespace infini_train::autograd {

class PreAccumulateGradHook {
public:
virtual void operator()(const std::shared_ptr<Tensor> &grad_output) = 0;

// Return true if this hook has handled the current gradient accumulation.
virtual bool TryBypassAccumulate(const std::shared_ptr<Tensor> &, const std::shared_ptr<Tensor> &, bool, float) {
Comment thread
kilinchange marked this conversation as resolved.
return false;
}

virtual ~PreAccumulateGradHook() = default;
};

class PostAccumulateGradHook {
public:
virtual void operator()(const std::shared_ptr<Tensor> &tensor) = 0;

virtual ~PostAccumulateGradHook() = default;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@ class DistributedDataParallelConfig {
// Ref:
// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel_config.py
// ======================================================
// Whether to enable DistributedOptimizer (ZeRO-1 equivalent).
// When set true:
// 1) Gradients/params are managed by ParamAndGradBuffer and reduced in groups.
// 2) The classic DDP reducer path is not used (i.e., disable reducer/bucketing in the DDP sense).
bool use_distributed_optimizer = false;

// Whether to overlap gradient reduce-scatter/all-reduce with backward compute.
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
bool overlap_grad_reduce = true;

// ZeRO-DP stage for memory optimization.
// ZeRO-0: Disabled; use the classic DDP reducer path.
// ZeRO-1: Optimizer states partitioning
// ZeRO-2: Gradients partitioning
// ZeRO-3: Parameters partitioning
int zero_stage = 0;

// Whether to overlap parameter all-gather with forward compute.
bool overlap_param_gather = true;

Expand All @@ -58,8 +59,7 @@ class DistributedDataParallelConfig {

// Maximum number of parameters in each ParamAndGradBucket.
// NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps.
// TODO(zbl): To unify the definition of bucket_size argument for users
size_t bucket_size_in_elements = 40000000;
size_t bucket_size_in_elements = 1000000;

// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
bool pad_buckets_for_high_nccl_busbw = false;
Expand Down
54 changes: 52 additions & 2 deletions infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,22 @@ class Work;
namespace infini_train::nn::parallel {
class ParamAndGradBucket {
public:
/**
* @brief Create bucket metadata and flat-buffer views.
*
* @param params Parameters in bucket-local order.
* @param param_data View of this bucket in the flat parameter buffer, or nullptr if unused.
* @param param_dtype Parameter storage dtype.
* @param grad_data View of this bucket in the flat gradient buffer; nullptr for ZeRO-2.
* @param grad_dtype Gradient storage dtype.
* @param offset Bucket start offset in the owning flat buffer.
* @param num_elements_unpadded Bucket element count before padding.
* @param gradient_scaling_factor Pre-collective gradient scale factor.
* @param bucket_id Bucket index in the owning ParamAndGradBuffer.
*/
ParamAndGradBucket(const std::vector<std::shared_ptr<Tensor>> &params, const std::shared_ptr<Tensor> &param_data,
Comment thread
kilinchange marked this conversation as resolved.
const std::shared_ptr<Tensor> &grad_data, size_t offset, size_t num_elements_unpadded,
float gradient_scaling_factor, size_t bucket_id);
DataType param_dtype, const std::shared_ptr<Tensor> &grad_data, DataType grad_dtype,
size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id);

size_t bucket_id() const { return bucket_id_; }

Expand All @@ -33,6 +46,10 @@ class ParamAndGradBucket {

const std::shared_ptr<Tensor> &grad_data() const { return grad_data_; }

DataType param_dtype() const { return param_dtype_; }

DataType grad_dtype() const { return grad_dtype_; }

size_t offset() const { return offset_; }

size_t num_elements_unpadded() const { return num_elements_unpadded_; }
Expand All @@ -49,6 +66,8 @@ class ParamAndGradBucket {
std::vector<std::shared_ptr<Tensor>> params_;
std::shared_ptr<Tensor> param_data_;
std::shared_ptr<Tensor> grad_data_;
DataType param_dtype_;
DataType grad_dtype_;

size_t offset_ = 0;
size_t num_elements_unpadded_ = 0;
Expand All @@ -59,6 +78,14 @@ class ParamAndGradBucket {

class ParamAndGradBucketGroup {
public:
/**
* @brief Group buckets that synchronize gradients and parameters together.
*
* @param buckets Buckets owned by this group.
* @param collective_pg Process group for gradient and parameter collectives.
* @param process_group_size Number of ranks in collective_pg.
* @param ddp_config DDP/DistributedOptimizer behavior config.
*/
ParamAndGradBucketGroup(const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets,
const ProcessGroup *collective_pg, size_t process_group_size,
DistributedDataParallelConfig ddp_config);
Expand All @@ -73,6 +100,10 @@ class ParamAndGradBucketGroup {
// Start grad reduce
void StartGradSync();

// Accumulate a parameter grad into bucket storage for the ZeRO-2 pre-accumulate hook.
void AccumulateParamGrad(const std::shared_ptr<Tensor> &parameter, const std::shared_ptr<Tensor> &grad,
bool overwrite, float learning_rate);

// Wait for gradient reduce to complete
void FinishGradSync();

Expand All @@ -87,6 +118,9 @@ class ParamAndGradBucketGroup {

const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets() const { return buckets_; }

// ZeRO-2: Get a bucket's local grad shard buffer
std::shared_ptr<Tensor> GetLocalGradShardBuffer(size_t bucket_idx) const;

const DistributedDataParallelConfig &config() const { return ddp_config_; }

private:
Expand All @@ -98,12 +132,19 @@ class ParamAndGradBucketGroup {

std::unordered_set<Tensor *> params_;
std::unordered_set<Tensor *> params_with_grad_;
// Tensor -> (Bucket, Bucket Index)
std::unordered_map<Tensor *, std::pair<std::shared_ptr<ParamAndGradBucket>, size_t>> param_to_bucket_;

// TODO(zbl): Implement CoalescedWork for aggregate works
// According to Megatron-LM's _coalescing_manager
std::vector<std::shared_ptr<Work>> grad_reduce_work_list_;
std::vector<size_t> grad_reduce_bucket_indices_;
std::vector<std::shared_ptr<Work>> param_gather_work_list_;

// ZeRO-2: persistent grad shard buffers and temporary full grad buffers
std::vector<std::shared_ptr<Tensor>> grad_shard_buffer_list_;
std::vector<std::shared_ptr<Tensor>> temp_full_grad_buffer_list_;

std::shared_ptr<ParamAndGradBucketGroup> next_param_gather_bucket_group_ = nullptr;

std::vector<std::vector<std::shared_ptr<Tensor>>> param_buffer_shard_list_;
Expand All @@ -117,6 +158,15 @@ class ParamAndGradBucketGroup {

class ParamAndGradBuffer {
public:
/**
* @brief Own flat buffers and bucket metadata for one dtype group.
*
* @param params Parameters with the same parameter/gradient dtype pair.
* @param param_dtype Flat parameter-buffer dtype.
* @param grad_dtype Gradient storage dtype.
* @param ddp_pg Data-parallel process group used by derived bucket groups.
* @param ddp_config DDP/DistributedOptimizer bucketing and padding config.
*/
ParamAndGradBuffer(const std::vector<std::shared_ptr<Tensor>> &params, DataType &param_dtype, DataType &grad_dtype,
const ProcessGroup *ddp_pg, DistributedDataParallelConfig ddp_config);

Expand Down
6 changes: 5 additions & 1 deletion infini_train/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace infini_train {
namespace autograd {
class Function;
class AccumulateGrad;
class PreAccumulateGradHook;
class PostAccumulateGradHook;
} // namespace autograd

Expand Down Expand Up @@ -230,8 +231,10 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator();
void ResetAccumulator();

void RegisterPostAccumulateGradHook(std::shared_ptr<autograd::PostAccumulateGradHook> hook);
void RegisterPreAccumulateGradHook(std::shared_ptr<autograd::PreAccumulateGradHook> hook);
autograd::PreAccumulateGradHook *pre_accumulate_grad_hook() const;

void RegisterPostAccumulateGradHook(std::shared_ptr<autograd::PostAccumulateGradHook> hook);
autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const;

private:
Expand All @@ -243,6 +246,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
// FIXME(dcj): This should be a weak_ptr. The autograd graph should hold
// a strong reference to the accumulator to manage its lifetime.
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator_ = nullptr;
std::shared_ptr<autograd::PreAccumulateGradHook> pre_accumulate_grad_hook_ = nullptr;
std::shared_ptr<autograd::PostAccumulateGradHook> post_accumulate_grad_hook_ = nullptr;

bool grad_overwrite_once_ = false;
Expand Down
20 changes: 15 additions & 5 deletions infini_train/src/autograd/accumulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
CHECK_EQ(grad_outputs.size(), 1);
auto grad_output = grad_outputs[0];

auto grad = tensor_->grad();
auto device = tensor_->GetDevice();
core::DeviceGuard guard(device);

Expand All @@ -33,8 +32,19 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
"running before autograd). The grad is not cast and will be used as-is.";
}

const bool overwrite = tensor_->ConsumeGradOverwriteFlag();
auto pre_hook = tensor_->pre_accumulate_grad_hook();
if (pre_hook) {
if (pre_hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) {
tensor_->ResetAccumulator();
return {};
}
(*pre_hook)(grad_output);
}

auto grad = tensor_->grad();
if (grad) {
if (tensor_->ConsumeGradOverwriteFlag()) {
if (overwrite) {
// If the tensor is marked to overrite its current grad on next grad update
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
// NOTE(zbl): must copy, cannot change grad buffer address
Expand All @@ -48,9 +58,9 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
auto new_grad = std::make_shared<Tensor>(*grad_output.get(), 0, grad_output->Dims());
tensor_->set_grad(new_grad);
}
auto hook = tensor_->post_accumulate_grad_hook();
if (hook != nullptr) {
(*hook)(tensor_->grad());
auto post_hook = tensor_->post_accumulate_grad_hook();
if (post_hook != nullptr) {
(*post_hook)(tensor_->grad());
}
tensor_->ResetAccumulator();
}
Expand Down
Loading
Loading