[Feature] Implement multi MTP Loss#2
Conversation
Change loss context from single object to dict-based API: - Update loss_cfg.build() to accept data parameter as dict - Change ModelItem.loss_ctx to dict with loss type keys (e.g. 'lm') - Update model forward pass to accept loss_ctx_dict parameter - Update all tests to use new dict-based loss context API ghstack-source-id: c938145 Pull-Request: InternLM#1569
ghstack-source-id: 8c98b3a Pull-Request: InternLM#1570
…s context base class - Rename CELossContext to LMHeadLossContext for better semantic clarity - Refactor BaseLossContext to be more abstract by removing LM-specific logic - Move eager_mode and chunk_mode implementations from base class to LMHeadLossContext - Make loss_ctx_cls and _loss_kwargs_cls abstract properties in BaseLossConfig - Remove sp_split() and to() implementations from BaseLossKwargs base class - Move sp_split() and to() to CELossKwargs subclass - Update BaseRLLossKwargs to properly inherit and extend sp_split() and to() methods - Add deprecation alias: CELossContext = LMHeadLossContext for backward compatibility - Export LMHeadLossContext in __init__.py ghstack-source-id: 1b3d648 Pull-Request: InternLM#1571
- Add raw_input_ids, raw_inputs_embeds, raw_position_ids, raw_rollout_routed_experts properties to SequenceContext for reconstructing full tensors from SP shards - Store raw_input_ids (full padded tensor), shard_start, shard_size in SequenceContext.split() for zero-communication input_ids rolling - raw_inputs_embeds triggers a single allgather on first access and caches the result, amortising communication across MTP layers - roll_sequence_context: remove SP assert; always operate on full tensors via raw_* properties, slice to local shard only when in SP ghstack-source-id: cc60a14 Pull-Request: InternLM#1629
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Multi-Token Prediction (MTP) loss mechanism by allowing the configuration and application of multiple distinct MTP losses simultaneously. This provides greater flexibility in training models by enabling different masking strategies and loss scaling factors for various MTP components. The changes involve updating the model's output structure to accommodate multiple MTP losses, modifying the loss calculation logic, and introducing new methods for applying specific masking types. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant changes to support multiple Multi-Token Prediction (MTP) configurations within the xtuner framework. This includes refactoring MTP loss handling, updating configuration structures to support lists of MTP configurations, and adapting model forward passes and sharding logic. Several critical issues were identified in the review: batch size limitations in CELoss's process_loss_weights_v1 and v2, a typo and TypeError risk in MoE's _should_recompute method, a ValueError risk in MTPBlock's forward method due to incorrect tuple unpacking, and a bug in roll_packed_tensor where fill_value was hardcoded to zero instead of using the provided parameter, potentially leading to incorrect loss calculations. Additionally, improvements for numerical stability and readability were suggested for CELoss and type hints for mtp_config parameters.
| for j in range(seq_len): | ||
| token = shifted_labels[0, j].item() |
There was a problem hiding this comment.
This implementation appears to only support a batch size of 1. The use of shifted_labels[0, j] hardcodes the batch index to 0. If shifted_labels has a batch size greater than 1, this will only process the first sample in the batch, leading to incorrect behavior. Please refactor this to support arbitrary batch sizes, for example by iterating over the batch dimension or, preferably, by vectorizing the operations across the batch dimension.
| else: | ||
| current_hidden_states = output | ||
| # Save output for this depth | ||
| mtp_outputs.append(current_hidden_states) | ||
| mtp_outputs.append(output) |
There was a problem hiding this comment.
In the else branch, output is a single tensor, but it's appended directly to mtp_outputs. The method's return type hint is list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], and the caller in moe.py unpacks a 3-element tuple. This will cause a ValueError at runtime if a decoder layer without router results (like a dense layer) is used.
To fix this, you should ensure that a tuple is always appended. In the else branch, you could create dummy tensors for router_results and router_weights.
| else: | |
| current_hidden_states = output | |
| # Save output for this depth | |
| mtp_outputs.append(current_hidden_states) | |
| mtp_outputs.append(output) | |
| else: | |
| current_hidden_states = output | |
| output = (current_hidden_states, torch.empty(0, device=current_hidden_states.device), torch.empty(0, device=current_hidden_states.device)) | |
| # Save output for this depth | |
| mtp_outputs.append(output) |
| # Fill the last |shifts| positions along dim to avoid information | ||
| # leakage across sequences. For shifts=-1 the last 1 position is | ||
| # filled; for shifts=-2 the last 2 positions are filled, etc. | ||
| fill_len = -shifts | ||
| fill_start = (end_idx - start_idx) - fill_len | ||
| fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type] | ||
| fill_slice.fill_(fill_value) | ||
| # zeroed; for shifts=-2 the last 2 positions are zeroed, etc. | ||
| zero_len = -shifts | ||
| zero_len = min(zero_len, (end_idx - start_idx)) | ||
| zero_start = (end_idx - start_idx) - zero_len | ||
| zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type] | ||
| zero_slice.zero_() |
There was a problem hiding this comment.
This change hardcodes the fill value to 0 by using zero_(). However, the function's docstring correctly states that for label tensors, a fill_value like the loss ignore index (e.g., -100) should be used. Hardcoding to zero will lead to incorrect loss calculations when this function is used on label tensors for MTP.
Please revert to using the fill_value parameter, like fill_slice.fill_(fill_value). The caller of this function should then be responsible for passing the correct fill_value (e.g., the ignore index for labels).
| # Fill the last |shifts| positions along dim to avoid information | |
| # leakage across sequences. For shifts=-1 the last 1 position is | |
| # filled; for shifts=-2 the last 2 positions are filled, etc. | |
| fill_len = -shifts | |
| fill_start = (end_idx - start_idx) - fill_len | |
| fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type] | |
| fill_slice.fill_(fill_value) | |
| # zeroed; for shifts=-2 the last 2 positions are zeroed, etc. | |
| zero_len = -shifts | |
| zero_len = min(zero_len, (end_idx - start_idx)) | |
| zero_start = (end_idx - start_idx) - zero_len | |
| zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type] | |
| zero_slice.zero_() | |
| # Fill the last |shifts| positions along dim to avoid information | |
| # leakage across sequences. For shifts=-1 the last 1 position is | |
| # filled; for shifts=-2 the last 2 positions are filled, etc. | |
| fill_len = -shifts | |
| fill_len = min(fill_len, (end_idx - start_idx)) | |
| fill_start = (end_idx - start_idx) - fill_len | |
| fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type] | |
| fill_slice.fill_(fill_value) |
| hidden_states: torch.Tensor, | ||
| head_weight: torch.Tensor, | ||
| head_bias: torch.Tensor | None = None, | ||
| mtp_config = None, |
There was a problem hiding this comment.
For better type safety and readability, please add a type hint for the mtp_config parameter. Based on its usage, it should be Optional[MTPConfig]. You may need to import MTPConfig from xtuner.v1.module.mtp.config and Optional from typing.
| mtp_config = None, | |
| mtp_config: Optional['MTPConfig'] = None, |
| if loss_weight.sum().item() != 0: | ||
| loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() |
There was a problem hiding this comment.
The division by loss_weight.sum() could lead to numerical instability if the sum is very close to zero. It's good practice to add a small epsilon to the denominator to prevent this, for example: loss_weight.sum() + 1e-12. This pattern is used elsewhere in the codebase (e.g., line 185).
This same issue exists in process_loss_weights_v2 (lines 360-361) and process_loss_weights_v3 (lines 381-382).
| if loss_weight.sum().item() != 0: | |
| loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() | |
| if loss_weight.sum().item() != 0: | |
| loss_weight = loss_weight * sum_loss_weight / (loss_weight.sum() + 1e-12) |
|
|
||
| easy_to_use = torch.cat([shifted_labels, torch.tensor([[0]], device=shifted_labels.device)], dim=-1) | ||
|
|
||
| is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0) |
There was a problem hiding this comment.
The expression torch.where(easy_to_use < 25, easy_to_use > 14, 0) is a bit obscure. For better readability and to make the intent clearer, consider using a more direct boolean expression like ((easy_to_use > 14) & (easy_to_use < 25)).to(easy_to_use.dtype).
This also applies to process_loss_weights_v3 on line 373.
| is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0) | |
| is_digit = ((easy_to_use > 14) & (easy_to_use < 25)).to(easy_to_use.dtype) |
| def forward( # type: ignore[override] | ||
| self, hidden_states: torch.Tensor, loss_ctx: LMHeadLossContext | None = None | ||
| self, hidden_states: torch.Tensor, loss_ctx: LMHeadLossContext | None = None, | ||
| mtp_config = None, layer_idx: int = 0, |
There was a problem hiding this comment.
For better type safety and readability, please add a type hint for the mtp_config parameter. Based on its usage, it appears to be an MTPConfig object. You could use Optional['MTPConfig'] to avoid potential circular import issues.
| mtp_config = None, layer_idx: int = 0, | |
| mtp_config: 'MTPConfig' | None = None, layer_idx: int = 0, |
ba182c8 to
f611c0d
Compare
* Fix MLLMPretrainHybridPackDataset * [Fix] Include MLLMPretrainHybridPackDataset in LengthGroupedSampler assertion --------- Co-authored-by: nil0x9 <nil.0x9@proton.me>
MTPConfig example: