@@ -509,7 +509,16 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step,
509509
510510 // handle the LM-head. We run the d_lmhead matmul first, so that the gradient reduction can overlap with the DLNF matmul.
511511 bool accumulate;
512- auto & d_lmhead = Grads->get_non_block_full (LLamaWeightID::LM_HEAD, main_stream, comm, accumulate);
512+ // get the correct matrix depending on whether we have tied embeddings
513+ auto & d_lmhead = [&]() -> Tensor& {
514+ if (Config.TiedWordEmbeddings ) {
515+ return Grads->get_non_block_full (LLamaWeightID::EMBEDDING, main_stream, comm, accumulate);
516+ } else {
517+ return Grads->get_non_block_full (LLamaWeightID::LM_HEAD, main_stream, comm, accumulate);
518+ }
519+ }();
520+
521+ // even if we overwrite for first micro-batch, we need to accumulate on non-first nano batch
513522 accumulate |= nano_step != 0 ;
514523 matmul (d_lmhead, lnf_slice, rs->Output , Tensor{}, nullptr , nullptr ,
515524 rs->CublasLtHandle , rs->CuBlasWorkspace , C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream, rs->MatmulBackend );
@@ -757,6 +766,8 @@ void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const Tra
757766 create (target.get_tensor (LLamaWeightID::LNF_W), C, 0 , other_dtype);
758767 if (!config.TiedWordEmbeddings ) {
759768 create (target.get_tensor (LLamaWeightID::LM_HEAD), V, C, matrix_dtype);
769+ } else {
770+ create (target.get_tensor (LLamaWeightID::LM_HEAD), 0 , 0 , matrix_dtype);
760771 }
761772}
762773
0 commit comments