Skip to content

Commit 79c8c9f

Browse files
committed
fix tied/untied LM-head
1 parent c05d604 commit 79c8c9f

2 files changed

Lines changed: 15 additions & 6 deletions

File tree

src/models/llama_gradients.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
void LLamaGradientsUnsharded::on_first_micro_step(cudaStream_t stream) {
1212
using namespace LLamaWeightID;
1313
fill_zero(mNonBlockGradients.get_tensor(LNF_W), stream);
14-
if(mNonBlockGradients.get_tensor(LM_HEAD).Data != mNonBlockGradients.get_tensor(EMBEDDING).Data) {
15-
fill_zero(mNonBlockGradients.get_tensor(LM_HEAD), stream);// TODO superfluous?
14+
if(mNonBlockGradients.get_tensor(LM_HEAD).Data != nullptr) {
1615
fill_zero(mNonBlockGradients.get_tensor(EMBEDDING), stream);
1716
} else {
1817
// embedding backward comes after LMHead backward; and LMHead backward *sets* the gradient
@@ -32,8 +31,8 @@ void LLamaGradientsUnsharded::on_first_micro_step(cudaStream_t stream) {
3231
// shard the transformer blocks, but not the embeddings and lmhead.
3332

3433
void LLamaGradientsBlockShardedBase::on_first_micro_step(cudaStream_t stream) {
35-
// if we have untied embeddings, we need to zero them out
36-
if(mFullNonBlock.get_tensor(LLamaWeightID::EMBEDDING).Data != mFullNonBlock.get_tensor(LLamaWeightID::LM_HEAD).Data) {
34+
// if we have untied embeddings, we need to zero them out, same as above
35+
if(mFullNonBlock.get_tensor(LLamaWeightID::LM_HEAD).Data != nullptr) {
3736
fill_zero(mFullNonBlock.get_tensor(LLamaWeightID::EMBEDDING), stream);
3837
}
3938
fill_zero(mFullNonBlock.get_tensor(LLamaWeightID::LNF_W), stream);
@@ -102,7 +101,6 @@ void LLamaGradientsBlockSharded_AllToAll::on_notify_block(int layer_idx, SimpleT
102101
}
103102

104103
// make sure we've done the local accumulation before we allow communication to begin.
105-
106104
CUDA_CHECK(cudaEventRecord(signal, stream));
107105
NvtxRange range("all-to-all-gradients", layer_idx);
108106

src/models/llama_model.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,16 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step,
509509

510510
// handle the LM-head. We run the d_lmhead matmul first, so that the gradient reduction can overlap with the DLNF matmul.
511511
bool accumulate;
512-
auto& d_lmhead = Grads->get_non_block_full(LLamaWeightID::LM_HEAD, main_stream, comm, accumulate);
512+
// get the correct matrix depending on whether we have tied embeddings
513+
auto& d_lmhead = [&]() -> Tensor& {
514+
if (Config.TiedWordEmbeddings) {
515+
return Grads->get_non_block_full(LLamaWeightID::EMBEDDING, main_stream, comm, accumulate);
516+
} else {
517+
return Grads->get_non_block_full(LLamaWeightID::LM_HEAD, main_stream, comm, accumulate);
518+
}
519+
}();
520+
521+
// even if we overwrite for first micro-batch, we need to accumulate on non-first nano batch
513522
accumulate |= nano_step != 0;
514523
matmul(d_lmhead, lnf_slice, rs->Output, Tensor{}, nullptr, nullptr,
515524
rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream, rs->MatmulBackend);
@@ -757,6 +766,8 @@ void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const Tra
757766
create(target.get_tensor(LLamaWeightID::LNF_W), C, 0, other_dtype);
758767
if(!config.TiedWordEmbeddings) {
759768
create(target.get_tensor(LLamaWeightID::LM_HEAD), V, C, matrix_dtype);
769+
} else {
770+
create(target.get_tensor(LLamaWeightID::LM_HEAD), 0, 0, matrix_dtype);
760771
}
761772
}
762773

0 commit comments

Comments
 (0)