Skip to content

[Feature] Implement multi MTP Loss#2

Open
x54-729 wants to merge 9 commits intoHAOCHENYE:mtpfrom
x54-729:mtp_sci
Open

[Feature] Implement multi MTP Loss#2
x54-729 wants to merge 9 commits intoHAOCHENYE:mtpfrom
x54-729:mtp_sci

Conversation

@x54-729
Copy link
Copy Markdown

@x54-729 x54-729 commented Mar 25, 2026

MTPConfig example:

model_cfg.text_config.mtp_config = [
    MTPConfig(name="normal", mask_type=None, num_layers=1, loss_scaling_factor=0.1),
    MTPConfig(name="sci", mask_type="v3", num_layers=1, loss_scaling_factor=0.1),
]

HAOCHENYE and others added 5 commits March 24, 2026 14:19
Change loss context from single object to dict-based API:
- Update loss_cfg.build() to accept data parameter as dict
- Change ModelItem.loss_ctx to dict with loss type keys (e.g. 'lm')
- Update model forward pass to accept loss_ctx_dict parameter
- Update all tests to use new dict-based loss context API

ghstack-source-id: c938145
Pull-Request: InternLM#1569
…s context base class

- Rename CELossContext to LMHeadLossContext for better semantic clarity
- Refactor BaseLossContext to be more abstract by removing LM-specific logic
- Move eager_mode and chunk_mode implementations from base class to LMHeadLossContext
- Make loss_ctx_cls and _loss_kwargs_cls abstract properties in BaseLossConfig
- Remove sp_split() and to() implementations from BaseLossKwargs base class
- Move sp_split() and to() to CELossKwargs subclass
- Update BaseRLLossKwargs to properly inherit and extend sp_split() and to() methods
- Add deprecation alias: CELossContext = LMHeadLossContext for backward compatibility
- Export LMHeadLossContext in __init__.py

ghstack-source-id: 1b3d648
Pull-Request: InternLM#1571
- Add raw_input_ids, raw_inputs_embeds, raw_position_ids,
  raw_rollout_routed_experts properties to SequenceContext for
  reconstructing full tensors from SP shards
- Store raw_input_ids (full padded tensor), shard_start, shard_size
  in SequenceContext.split() for zero-communication input_ids rolling
- raw_inputs_embeds triggers a single allgather on first access and
  caches the result, amortising communication across MTP layers
- roll_sequence_context: remove SP assert; always operate on full
  tensors via raw_* properties, slice to local shard only when in SP


ghstack-source-id: cc60a14
Pull-Request: InternLM#1629
@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Multi-Token Prediction (MTP) loss mechanism by allowing the configuration and application of multiple distinct MTP losses simultaneously. This provides greater flexibility in training models by enabling different masking strategies and loss scaling factors for various MTP components. The changes involve updating the model's output structure to accommodate multiple MTP losses, modifying the loss calculation logic, and introducing new methods for applying specific masking types.

Highlights

  • Multi-MTP Loss Support: Implemented the capability to define and utilize multiple Multi-Token Prediction (MTP) loss configurations, each with unique properties.
  • Flexible MTP Masking: Introduced three distinct masking types (v1, v2, v3) for MTP loss calculation, allowing for more granular control over which tokens contribute to the loss.
  • Refactored MTP Loss Handling: Updated the model's output structure and loss aggregation mechanisms to handle MTP losses as a dictionary, enabling named and independent MTP loss components.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant changes to support multiple Multi-Token Prediction (MTP) configurations within the xtuner framework. This includes refactoring MTP loss handling, updating configuration structures to support lists of MTP configurations, and adapting model forward passes and sharding logic. Several critical issues were identified in the review: batch size limitations in CELoss's process_loss_weights_v1 and v2, a typo and TypeError risk in MoE's _should_recompute method, a ValueError risk in MTPBlock's forward method due to incorrect tuple unpacking, and a bug in roll_packed_tensor where fill_value was hardcoded to zero instead of using the provided parameter, potentially leading to incorrect loss calculations. Additionally, improvements for numerical stability and readability were suggested for CELoss and type hints for mtp_config parameters.

Comment thread xtuner/v1/loss/ce_loss.py
Comment on lines +314 to +315
for j in range(seq_len):
token = shifted_labels[0, j].item()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation appears to only support a batch size of 1. The use of shifted_labels[0, j] hardcodes the batch index to 0. If shifted_labels has a batch size greater than 1, this will only process the first sample in the batch, leading to incorrect behavior. Please refactor this to support arbitrary batch sizes, for example by iterating over the batch dimension or, preferably, by vectorizing the operations across the batch dimension.

Comment thread xtuner/v1/loss/ce_loss.py Outdated
Comment thread xtuner/v1/model/moe/moe.py
Comment on lines +120 to +123
else:
current_hidden_states = output
# Save output for this depth
mtp_outputs.append(current_hidden_states)
mtp_outputs.append(output)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the else branch, output is a single tensor, but it's appended directly to mtp_outputs. The method's return type hint is list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], and the caller in moe.py unpacks a 3-element tuple. This will cause a ValueError at runtime if a decoder layer without router results (like a dense layer) is used.

To fix this, you should ensure that a tuple is always appended. In the else branch, you could create dummy tensors for router_results and router_weights.

Suggested change
else:
current_hidden_states = output
# Save output for this depth
mtp_outputs.append(current_hidden_states)
mtp_outputs.append(output)
else:
current_hidden_states = output
output = (current_hidden_states, torch.empty(0, device=current_hidden_states.device), torch.empty(0, device=current_hidden_states.device))
# Save output for this depth
mtp_outputs.append(output)

Comment on lines 64 to +71
# Fill the last |shifts| positions along dim to avoid information
# leakage across sequences. For shifts=-1 the last 1 position is
# filled; for shifts=-2 the last 2 positions are filled, etc.
fill_len = -shifts
fill_start = (end_idx - start_idx) - fill_len
fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type]
fill_slice.fill_(fill_value)
# zeroed; for shifts=-2 the last 2 positions are zeroed, etc.
zero_len = -shifts
zero_len = min(zero_len, (end_idx - start_idx))
zero_start = (end_idx - start_idx) - zero_len
zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type]
zero_slice.zero_()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change hardcodes the fill value to 0 by using zero_(). However, the function's docstring correctly states that for label tensors, a fill_value like the loss ignore index (e.g., -100) should be used. Hardcoding to zero will lead to incorrect loss calculations when this function is used on label tensors for MTP.

Please revert to using the fill_value parameter, like fill_slice.fill_(fill_value). The caller of this function should then be responsible for passing the correct fill_value (e.g., the ignore index for labels).

Suggested change
# Fill the last |shifts| positions along dim to avoid information
# leakage across sequences. For shifts=-1 the last 1 position is
# filled; for shifts=-2 the last 2 positions are filled, etc.
fill_len = -shifts
fill_start = (end_idx - start_idx) - fill_len
fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type]
fill_slice.fill_(fill_value)
# zeroed; for shifts=-2 the last 2 positions are zeroed, etc.
zero_len = -shifts
zero_len = min(zero_len, (end_idx - start_idx))
zero_start = (end_idx - start_idx) - zero_len
zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type]
zero_slice.zero_()
# Fill the last |shifts| positions along dim to avoid information
# leakage across sequences. For shifts=-1 the last 1 position is
# filled; for shifts=-2 the last 2 positions are filled, etc.
fill_len = -shifts
fill_len = min(fill_len, (end_idx - start_idx))
fill_start = (end_idx - start_idx) - fill_len
fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type]
fill_slice.fill_(fill_value)

Comment thread xtuner/v1/loss/ce_loss.py
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
mtp_config = None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and readability, please add a type hint for the mtp_config parameter. Based on its usage, it should be Optional[MTPConfig]. You may need to import MTPConfig from xtuner.v1.module.mtp.config and Optional from typing.

Suggested change
mtp_config = None,
mtp_config: Optional['MTPConfig'] = None,

Comment thread xtuner/v1/loss/ce_loss.py
Comment on lines +333 to +334
if loss_weight.sum().item() != 0:
loss_weight = loss_weight * sum_loss_weight / loss_weight.sum()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The division by loss_weight.sum() could lead to numerical instability if the sum is very close to zero. It's good practice to add a small epsilon to the denominator to prevent this, for example: loss_weight.sum() + 1e-12. This pattern is used elsewhere in the codebase (e.g., line 185).

This same issue exists in process_loss_weights_v2 (lines 360-361) and process_loss_weights_v3 (lines 381-382).

Suggested change
if loss_weight.sum().item() != 0:
loss_weight = loss_weight * sum_loss_weight / loss_weight.sum()
if loss_weight.sum().item() != 0:
loss_weight = loss_weight * sum_loss_weight / (loss_weight.sum() + 1e-12)

Comment thread xtuner/v1/loss/ce_loss.py

easy_to_use = torch.cat([shifted_labels, torch.tensor([[0]], device=shifted_labels.device)], dim=-1)

is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression torch.where(easy_to_use < 25, easy_to_use > 14, 0) is a bit obscure. For better readability and to make the intent clearer, consider using a more direct boolean expression like ((easy_to_use > 14) & (easy_to_use < 25)).to(easy_to_use.dtype).

This also applies to process_loss_weights_v3 on line 373.

Suggested change
is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0)
is_digit = ((easy_to_use > 14) & (easy_to_use < 25)).to(easy_to_use.dtype)

def forward( # type: ignore[override]
self, hidden_states: torch.Tensor, loss_ctx: LMHeadLossContext | None = None
self, hidden_states: torch.Tensor, loss_ctx: LMHeadLossContext | None = None,
mtp_config = None, layer_idx: int = 0,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and readability, please add a type hint for the mtp_config parameter. Based on its usage, it appears to be an MTPConfig object. You could use Optional['MTPConfig'] to avoid potential circular import issues.

Suggested change
mtp_config = None, layer_idx: int = 0,
mtp_config: 'MTPConfig' | None = None, layer_idx: int = 0,

@HAOCHENYE HAOCHENYE force-pushed the mtp branch 5 times, most recently from ba182c8 to f611c0d Compare March 30, 2026 11:04
* Fix MLLMPretrainHybridPackDataset

* [Fix] Include MLLMPretrainHybridPackDataset in LengthGroupedSampler assertion

---------

Co-authored-by: nil0x9 <nil.0x9@proton.me>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants