diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 74bf80c6..9d2d251e 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -52,6 +52,15 @@ class ProcessGroup { function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const; + // root_rank_in_group is ProcessGroup-local rank. Broadcast updates tensors in place. + virtual std::shared_ptr Broadcast(const std::vector> &tensors, int root_rank_in_group, + bool async_op = false) const; + + // Root provides rank-major input_tensors: rank * output_tensors.size() + tensor_index. + virtual std::shared_ptr Scatter(const std::vector> &output_tensors, + const std::vector> &input_tensors, + int root_rank_in_group, bool async_op = false) const; + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op = false) const; @@ -60,13 +69,13 @@ class ProcessGroup { // Legacy communication APIs (Single-stream) virtual std::vector> - BroadCast(const std::vector> &input_tensors) const; + BroadCast_(const std::vector> &input_tensors) const; virtual std::vector> ReduceAddCoalesced(const std::vector>> &grads, Device destination) const; - virtual std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const; + virtual std::vector> Scatter_(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const; virtual std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) const; diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index d524088a..df40f916 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -19,7 +19,7 @@ std::vector> Scatter::Forward(const std::vector> output_tensors; auto device = input->GetDevice().type(); - output_tensors = pg_->Scatter(input, target_gpus_, dim_); + output_tensors = pg_->Scatter_(input, target_gpus_, dim_); return output_tensors; } @@ -83,7 +83,7 @@ std::vector> Broadcast::Forward(const std::vectorBroadCast(input_tensors); + return pg_->BroadCast_(input_tensors); } void Broadcast::SetupContext(const std::vector> &input_tensors, diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc index 760ed3d8..595ad2ca 100644 --- a/infini_train/src/nn/lora/lora_parallel_linear.cc +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -11,6 +11,7 @@ #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" @@ -89,22 +90,38 @@ LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr(std::vector{config_.rank, in_features_}, DataType::kFLOAT32, device_) ->RequiresGrad(); - if (config_.use_kaiming_a) { - init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + + if (parallel::global::GetTensorParallelSize() > 1) { + const auto global_rank = device_.Rank().GlobalRank(); + auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type()) + ->Get(parallel::GetTensorParallelProcessGroupName(global_rank)); + const int tp_rank = tp_group->GetGroupRank(global_rank); + + // Only TP rank 0 generates random values; others zero-init. + // AllReduce(sum) then broadcasts rank-0's values to all TP ranks. + if (tp_rank == 0) { + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } + } else { + init::Zeros(parameters_[kParamLoraAName]); + } + tp_group->AllReduce(parameters_[kParamLoraAName]); } else { - init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } } - // lora_B: [out_per_partition, rank] - sharded like base weight parameters_[kParamLoraBName] = std::make_shared(std::vector{out_features_per_partition_, config_.rank}, DataType::kFLOAT32, device_) @@ -126,39 +143,35 @@ LoRAColumnParallelLinear::Forward(const std::vector> &in << "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training."; if (!merged_) { - // 1. Compute base output via parent class - auto base_result = ColumnParallelLinear::Forward(input_tensors); - auto base_output = base_result[0]; - - // 2. Compute LoRA output using the SAME input that base module uses - // Match base input path exactly: use direct input if input_is_parallel_ or sequence_parallel_, - // otherwise copy to TP region - auto lora_input = (input_is_parallel_ || sequence_parallel_) - ? input_tensors[0] - : parallel::CopyToTPRegionFunc(input_tensors[0])[0]; + // Inline base + LoRA matmuls, add locally, then single collective op. + // This avoids 2 separate AllGather ops which cause floating-point divergence. + auto input = (input_is_parallel_ || sequence_parallel_) ? input_tensors[0] + : parallel::CopyToTPRegionFunc(input_tensors[0])[0]; if (sequence_parallel_) { - // Base uses GatherFromSPRegionFunc to gather sequence dimension - lora_input = parallel::GatherFromSPRegionFunc(lora_input)[0]; + input = parallel::GatherFromSPRegionFunc(input)[0]; } - // Compute LoRA: lora_A: [rank, in_features], lora_B: [out_per_partition, rank] - auto lora_proj = std::make_shared()->Apply({lora_input, parameters_[kParamLoraAName]})[0]; + // Base matmul (bias folded in when applicable, matching ColumnParallelLinear::Forward) + auto base_shard = std::make_shared()->Apply( + (bias_ && !skip_bias_add_) + ? std::vector>{input, parameters_.at(kParamWeightName), + parameters_[kParamBiasName]} + : std::vector>{input, parameters_.at(kParamWeightName)})[0]; + + // LoRA matmul (local) + // Wrap replicated lora_A through CopyToTPRegion so its gradient gets AllReduced in backward + auto lora_A = parallel::CopyToTPRegionFunc(parameters_[kParamLoraAName])[0]; + auto lora_proj = std::make_shared()->Apply({input, lora_A})[0]; auto lora_output = std::make_shared()->Apply({lora_proj, parameters_[kParamLoraBName]})[0]; - // Match base output layout (gather if base gathers) - if (gather_output_) { - lora_output = parallel::GatherFromTPRegionFunc(lora_output)[0]; - } + // Local add before collective + auto combined = base_shard->Add(lora_output->Mul(config_.Scaling())); - auto scaled_lora = lora_output->Mul(config_.Scaling()); + // Single collective op + auto output = gather_output_ ? parallel::GatherFromTPRegionFunc(combined)[0] : combined; - // 3. Add LoRA contribution to base output - // Both should now have the same sequence dimension - auto output = base_output->Add(scaled_lora); - - // Return in same format as base module return skip_bias_add_ - ? std::vector>{output, bias_ ? parameters_[kParamBiasName] : nullptr} + ? std::vector>{output, bias_ ? parameters_.at(kParamBiasName) : nullptr} : std::vector>{output}; } @@ -321,42 +334,32 @@ LoRARowParallelLinear::Forward(const std::vector> &input << "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training."; if (!merged_) { - // Get effective input - match what base module uses - auto effective_input = input_tensors[0]; - const int64_t in_dim = effective_input->Dims().back(); - - if (!input_is_parallel_) { - // base would scatter; lora must match - effective_input = parallel::ScatterToTPRegionFunc(effective_input)[0]; - CHECK_EQ(effective_input->Dims().back(), in_features_per_partition_); - } else { - // input_is_parallel_=true means caller promised shard input - CHECK_EQ(in_dim, in_features_per_partition_) - << "RowParallel expects sharded input when input_is_parallel_=true. " - << "Got full in_dim=" << in_dim << " (likely upstream gathered TP output)."; + // Inline base + LoRA matmuls, add locally, then single collective op. + // This avoids 2 separate AllReduce ops which cause floating-point divergence. + auto input = input_is_parallel_ ? input_tensors[0] : parallel::ScatterToTPRegionFunc(input_tensors[0])[0]; + + // Base matmul (no bias — RowParallel adds bias AFTER collective) + auto base_shard = std::make_shared()->Apply({input, parameters_.at(kParamWeightName)})[0]; + + // LoRA matmul (local) + // Wrap replicated lora_B through CopyToTPRegion so its gradient gets AllReduced in backward + auto lora_proj = std::make_shared()->Apply({input, parameters_[kParamLoraAName]})[0]; + auto lora_B = parallel::CopyToTPRegionFunc(parameters_[kParamLoraBName])[0]; + auto lora_output = std::make_shared()->Apply({lora_proj, lora_B})[0]; + + // Local add before collective + auto combined = base_shard->Add(lora_output->Mul(config_.Scaling())); + + // Single collective op + auto output = reduce_output_ ? (sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(combined)[0] + : parallel::ReduceFromTPRegionFunc(combined)[0]) + : combined; + + // Bias after collective (matching RowParallelLinear::Forward) + if (bias_ && !skip_bias_add_) { + output = output->Add(parameters_[kParamBiasName]); } - // 1) base output - use effective_input - auto base_result = RowParallelLinear::Forward({effective_input}); - auto base_output = base_result[0]; - - // 2) lora branch uses the SAME effective_input - auto lora_proj - = std::make_shared()->Apply({effective_input, parameters_[kParamLoraAName]})[0]; - auto lora_output = std::make_shared()->Apply({lora_proj, parameters_[kParamLoraBName]})[0]; - - // 3) apply same reduction as base - auto lora_out = lora_output; - if (reduce_output_) { - lora_out = sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(lora_out)[0] - : parallel::ReduceFromTPRegionFunc(lora_out)[0]; - } - - auto scaled_lora = lora_out->Mul(config_.Scaling()); - CHECK_EQ(base_output->NumElements(), scaled_lora->NumElements()); - auto output = base_output->Add(scaled_lora); - - // Return in same format as base module return skip_bias_add_ ? std::vector>{output, bias_ ? parameters_[kParamBiasName] : nullptr} : std::vector>{output}; diff --git a/infini_train/src/nn/lora/lora_utils.cc b/infini_train/src/nn/lora/lora_utils.cc index 7b8f3668..56f5f012 100644 --- a/infini_train/src/nn/lora/lora_utils.cc +++ b/infini_train/src/nn/lora/lora_utils.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/lora/lora_parallel_linear.h" #include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -392,10 +393,30 @@ void LoadLoRAWeights(std::shared_ptr model, const std::string &filepath) auto cpu_tensor = std::make_shared(dims, DataType::kFLOAT32, Device(Device::DeviceType::kCPU, 0)); file.read(reinterpret_cast(cpu_tensor->DataPtr()), num_elements * sizeof(float)); - // Load into model + // Load into model, slicing sharded tensors by tp_rank if shapes differ auto it = model_state_dict.find(name); if (it != model_state_dict.end()) { - it->second->CopyFrom(cpu_tensor); + auto &dst = it->second; + const auto &dst_dims = dst->Dims(); + if (dst_dims == dims) { + dst->CopyFrom(cpu_tensor); + } else { + // Determine which dim is sharded: find first dim where sizes differ + int shard_dim = -1; + for (int d = 0; d < static_cast(dims.size()); ++d) { + if (d < static_cast(dst_dims.size()) && dst_dims[d] != dims[d]) { + shard_dim = d; + break; + } + } + CHECK(shard_dim >= 0) << "LoadLoRAWeights: shape mismatch for " << name + << " but no differing dim found"; + int tp_size = parallel::global::GetTensorParallelSize(); + int64_t shard_size = dims[shard_dim] / tp_size; + int64_t start = parallel::tp_rank * shard_size; + auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size); + dst->CopyFrom(sliced); + } } else { LOG(WARNING) << "LoRA parameter not found in model: " << name; } diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..174aa645 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -194,6 +194,116 @@ std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr } } +std::shared_ptr ProcessGroup::Broadcast(const std::vector> &tensors, + int root_rank_in_group, bool async_op) const { + CHECK_GE(root_rank_in_group, 0); + CHECK_LT(root_rank_in_group, world_size_); + CHECK_GT(tensors.size(), 0); + CHECK_NOTNULL(tensors[0]); + + auto device = tensors[0]->GetDevice(); + auto group_rank = GetGroupRank(device.Rank().GlobalRank()); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + for (const auto &tensor : tensors) { + CHECK_NOTNULL(tensor); + CHECK_EQ(device, tensor->GetDevice()); + const void *send_buffer = (group_rank == root_rank_in_group) ? tensor->DataPtr() : nullptr; + ccl_impl_->Broadcast(send_buffer, tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), root_rank_in_group, + comm, comm_stream); + } + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + +std::shared_ptr ProcessGroup::Scatter(const std::vector> &output_tensors, + const std::vector> &input_tensors, + int root_rank_in_group, bool async_op) const { + CHECK_GE(root_rank_in_group, 0); + CHECK_LT(root_rank_in_group, world_size_); + CHECK_GT(output_tensors.size(), 0); + CHECK_NOTNULL(output_tensors[0]); + + auto device = output_tensors[0]->GetDevice(); + auto group_rank = GetGroupRank(device.Rank().GlobalRank()); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + for (const auto &output_tensor : output_tensors) { + CHECK_NOTNULL(output_tensor); + CHECK_EQ(device, output_tensor->GetDevice()); + } + + const bool is_root = group_rank == root_rank_in_group; + const size_t num_outputs = output_tensors.size(); + if (is_root) { + CHECK_EQ(input_tensors.size(), static_cast(world_size_) * num_outputs) + << "Root rank must provide rank-major input tensors for every rank."; + for (const auto &input_tensor : input_tensors) { + CHECK_NOTNULL(input_tensor); + CHECK_EQ(device, input_tensor->GetDevice()); + } + for (size_t tensor_idx = 0; tensor_idx < num_outputs; ++tensor_idx) { + const auto &local_input = input_tensors[static_cast(group_rank) * num_outputs + tensor_idx]; + CHECK(local_input->Dtype() == output_tensors[tensor_idx]->Dtype()); + CHECK(local_input->Dims() == output_tensors[tensor_idx]->Dims()); + } + } else { + CHECK(input_tensors.empty()) << "Only root rank should provide scatter input tensors."; + } + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + if (is_root) { + for (size_t tensor_idx = 0; tensor_idx < num_outputs; ++tensor_idx) { + const auto &input_tensor = input_tensors[static_cast(group_rank) * num_outputs + tensor_idx]; + runtime_impl_->MemcpyAsync(output_tensors[tensor_idx]->DataPtr(), input_tensor->DataPtr(), + input_tensor->SizeInBytes(), core::MemcpyKind::kD2D, comm_stream); + } + + core::CclGroupGuard ccl_group_guard(device.type()); + for (int rank = 0; rank < world_size_; ++rank) { + if (rank == group_rank) { + continue; + } + for (size_t tensor_idx = 0; tensor_idx < num_outputs; ++tensor_idx) { + const auto &input_tensor = input_tensors[static_cast(rank) * num_outputs + tensor_idx]; + ccl_impl_->Send(input_tensor->DataPtr(), input_tensor->NumElements(), input_tensor->Dtype(), rank, comm, + comm_stream); + } + } + } else { + core::CclGroupGuard ccl_group_guard(device.type()); + for (const auto &output_tensor : output_tensors) { + ccl_impl_->Recv(output_tensor->DataPtr(), output_tensor->NumElements(), output_tensor->Dtype(), + root_rank_in_group, comm, comm_stream); + } + } + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); @@ -249,7 +359,7 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } std::vector> -ProcessGroup::BroadCast(const std::vector> &input_tensors) const { +ProcessGroup::BroadCast_(const std::vector> &input_tensors) const { std::vector> outputs; std::vector streams; std::vector comms; @@ -329,8 +439,8 @@ ProcessGroup::ReduceAddCoalesced(const std::vector> ProcessGroup::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { +std::vector> ProcessGroup::Scatter_(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const { std::vector> outputs; auto split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); std::vector streams; diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 06589904..15d32770 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -154,8 +154,9 @@ run_and_log() { > "$log_path" fi - # Write the current run command to the log - echo "[COMMAND] $cmd" >> "$log_path" + # Write the current run command to the log (expand $LORA_WEIGHTS_DIR) + local expanded_cmd="${cmd//\$LORA_WEIGHTS_DIR/$LORA_WEIGHTS_DIR}" + echo "[COMMAND] $expanded_cmd" >> "$log_path" # Run the command and append both stdout and stderr to the log file if ! eval "$cmd" >> "$log_path" 2>&1; then @@ -267,10 +268,12 @@ for ((id=0; id