Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,19 @@ add_executable(gpt2
example/gpt2/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/gpt2/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
example/gpt2/checkpoint_loader.cc
)
link_infini_train_exe(gpt2)

add_executable(llama3
example/llama3/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/llama3/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
example/llama3/checkpoint_loader.cc
)
link_infini_train_exe(llama3)

Expand Down
106 changes: 106 additions & 0 deletions example/common/checkpoint_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include "example/common/checkpoint_loader.h"
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件太重了,有一千多行,checkpoint相关的基建和 llama / gpt 的 save / load 都混在一起了。要不要拆分一个example/common/checkpoint_utils.h/.cc,然后保留 gpt2 和 llama3 各自的特化调用?这个可以再讨论一下

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是按模型拆分吧,通用的公共函数放这里,gpt2/llama3 的特化部分放 example 下模型各自文件夹里。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#include <cmath>
#include <cstdlib>
#include <filesystem>
#include <memory>
#include <string>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/tensor.h"

using namespace infini_train;
namespace nn = infini_train::nn;

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
ResumeFromCheckpointResult result;
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
Comment thread
chen2021673 marked this conversation as resolved.
int tp_world_size = nn::parallel::global::GetTensorParallelSize();
int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1;
int pp_world_size = nn::parallel::global::GetPipelineParallelSize();

if (args.resume_root.empty()) {
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
return result;
}

std::filesystem::path resume_dir = args.resume_root;
if (args.rank.IsParallel()) {
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
if (std::filesystem::exists(rank_dir)) {
resume_dir = rank_dir;
}
}

Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state);

result.global_step = static_cast<int>(args.state.global_step);
if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size)) {
LOG(FATAL) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
"Proceeding with recorded data_batch_idx {}.",
args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx);
}

CHECK_EQ(args.state.tp_size, tp_world_size)
<< "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size;
CHECK_EQ(args.state.sp_size, sp_world_size)
<< "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size;
CHECK_EQ(args.state.pp_size, pp_world_size)
<< "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size;

result.data_batch_idx = static_cast<size_t>(std::max<int64_t>(args.state.data_batch_idx, 0));
args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx);
if (args.rank.IsMainRank()) {
LOG(INFO) << std::format("Resume training from step {}, last_lr {:.3e}, data_batch_idx {}",
args.state.global_step, args.state.last_lr, args.state.data_batch_idx);
}

return result;
}

void SaveCheckpoint(const SaveCheckpointArgs &args) {
const auto ckpt_start = std::chrono::high_resolution_clock::now();

TrainerState state;
state.global_step = args.global_step;
state.data_batch_idx = static_cast<int64_t>(args.data_batch_idx);
state.data_batch_stride = args.ddp_size;
state.last_lr = args.last_lr;
state.optimizer_type = args.optimizer_type;
state.ddp_size = args.ddp_size;
state.tp_size = args.tp_size;
state.sp_size = args.sp_size;
state.pp_size = args.pp_size;

Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state);

const auto ckpt_end = std::chrono::high_resolution_clock::now();
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

if (!args.rank.IsMainRank()) {
return;
}

LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);

if (!args.prune_step_checkpoints) {
return;
}

std::vector<std::filesystem::path> ckpts;
if (std::filesystem::exists(args.checkpoint_root_dir)) {
for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) {
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
ckpts.push_back(entry.path());
}
}
std::sort(ckpts.begin(), ckpts.end());
while (ckpts.size() > args.max_checkpoint_keep) {
std::filesystem::remove_all(ckpts.front());
ckpts.erase(ckpts.begin());
}
}
}
57 changes: 57 additions & 0 deletions example/common/checkpoint_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#pragma once

#include "gflags/gflags.h"

#include <cstdint>
#include <cstring>
#include <filesystem>
#include <functional>
#include <limits>
#include <string>

#include "infini_train/include/checkpoint.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/optimizer.h"

using namespace infini_train;
namespace nn = infini_train::nn;

struct ResumeFromCheckpointArgs {
std::filesystem::path resume_root;
const nn::parallel::Rank &rank;
std::shared_ptr<nn::Module> model;
std::shared_ptr<Optimizer> optimizer;
DistributedDataLoader &train_loader;
TrainerState &state;
DataLoaderIterator &train_iter;
};

struct ResumeFromCheckpointResult {
int global_step = 0;
size_t data_batch_idx = 0;
};

struct SaveCheckpointArgs {
std::filesystem::path save_dir;
int64_t global_step = 0;
size_t data_batch_idx = 0;
double last_lr = 0.0;
std::string optimizer_type;
int ddp_size = 1;
int tp_size = 1;
int sp_size = 1;
int pp_size = 1;
bool no_save_optim = false;
bool prune_step_checkpoints = false;
std::filesystem::path checkpoint_root_dir;
size_t max_checkpoint_keep = 0;
const nn::parallel::Rank &rank;
const nn::Module &model;
const Optimizer &optimizer;
};

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

void SaveCheckpoint(const SaveCheckpointArgs &args);
163 changes: 138 additions & 25 deletions example/gpt2/checkpoint_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
#include <filesystem>
#include <fstream>
#include <memory>
#include <random>
#include <string>
#include <tuple>
#include <vector>

#include "glog/logging.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/sparse.h"
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
Expand All @@ -24,39 +21,32 @@
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"

using namespace infini_train;
namespace nn = infini_train::nn;

namespace {
constexpr int kRandomSeed = 42;

// TODO(dcj): make this rng generator compatible with torch later
static std::mt19937 gen{kRandomSeed};
} // namespace

namespace {
constexpr int32_t kHeaderMagic = 20240326;
constexpr int32_t kHeaderFP32Version = 3;
constexpr int32_t kHeaderBF16Version = 5;
constexpr int32_t kGPT2Magic = 20240326;
constexpr int32_t kGPT2FP32Version = 3;
constexpr int32_t kGPT2BF16Version = 5;

std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::vector<uint8_t> &header,
size_t offset) {
std::tuple<int32_t, DataType> DetermineAndCheckVersion(const std::vector<uint8_t> &header, size_t offset) {
const auto version = BytesToType<uint32_t>(header, offset);
switch (version) {
case kHeaderBF16Version:
return {version, infini_train::DataType::kBFLOAT16};
case kHeaderFP32Version:
return {version, infini_train::DataType::kFLOAT32};
case kGPT2BF16Version:
return {version, DataType::kBFLOAT16};
case kGPT2FP32Version:
return {version, DataType::kFLOAT32};
default:
LOG(FATAL) << "Unsupported version: " << version << " at " << __FILE__ << ":" << __LINE__;
return {}; // Unreachable, but keeps compiler happy
}
}
} // namespace

namespace gpt2 {

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &filepath) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand All @@ -65,9 +55,9 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs);

const auto magic = BytesToType<uint32_t>(header, 0);
CHECK_EQ(magic, kHeaderMagic);
CHECK_EQ(magic, kGPT2Magic);
auto [version, dtype] = DetermineAndCheckVersion(header, 4);
CHECK_EQ(version, kHeaderFP32Version);
CHECK_EQ(version, kGPT2FP32Version);

auto tp_size = nn::parallel::global::GetTensorParallelSize();

Expand Down Expand Up @@ -428,4 +418,127 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)

return local_gpt2;
}
} // namespace gpt2

void gpt2::SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath) {
CHECK_EQ(nn::parallel::global::GetTensorParallelSize(), 1) << "SaveAsLLMC currently supports TP=1 only.";
CHECK_EQ(nn::parallel::global::GetPipelineParallelSize(), 1) << "SaveAsLLMC currently supports PP=1 only.";

std::ofstream ofs(filepath, std::ios::binary);
CHECK(ofs.is_open()) << "Failed to open model file for write: " << filepath;

auto config = model->Config();
std::vector<int32_t> header(256, 0);
header[0] = kGPT2Magic;
header[1] = kGPT2FP32Version;
header[2] = static_cast<int32_t>(config.block_size);
header[3] = static_cast<int32_t>(config.original_vocab_size);
header[4] = static_cast<int32_t>(config.n_layer);
header[5] = static_cast<int32_t>(config.n_head);
header[6] = static_cast<int32_t>(config.n_embd);
header[7] = static_cast<int32_t>(config.vocab_size);
ofs.write(reinterpret_cast<const char *>(header.data()),
static_cast<std::streamsize>(header.size() * sizeof(int32_t)));

const auto state_dict = model->StateDict();
auto get_tensor = [&](const std::string &name) -> std::shared_ptr<Tensor> {
CHECK(state_dict.contains(name)) << "Missing tensor in GPT2 state_dict: " << name;
return state_dict.at(name);
};

auto write_tensor_fp32 = [&](const std::shared_ptr<Tensor> &tensor) {
Tensor cpu = tensor->To(Device());
if (cpu.Dtype() != DataType::kFLOAT32) {
cpu = cpu.To(DataType::kFLOAT32);
}
const auto bytes = static_cast<std::streamsize>(cpu.SizeInBytes());
ofs.write(reinterpret_cast<const char *>(cpu.DataPtr()), bytes);
};

// transformer.wte.weight
write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerFirstStage::kWTELayerName,
nn::parallel::VocabParallelEmbedding::kParamWeightName)));

// transformer.wpe.weight
write_tensor_fp32(
get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerFirstStage::kWPELayerName, nn::Embedding::kParamWeightName)));

for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(std::format(
"{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx,
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx,
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(std::format(
"{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx,
nn::TransformerLayer::kAttnLayerName, nn::CausalSelfAttention::kCAttnLayerName,
nn::parallel::ColumnParallelLinear::kParamWeightName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(
std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName,
nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(
std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName,
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(
std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName,
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(std::format(
"{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx,
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx,
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)));
}
for (int idx = 0; idx < config.n_layer; ++idx) {
write_tensor_fp32(
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)));
}

write_tensor_fp32(
get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)));
write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)));

ofs.flush();
CHECK(ofs.good()) << "Failed to flush model file: " << filepath;
}
Loading
Loading