Skip to content

Commit c05d604

Browse files
committed
cleanups
1 parent abe7400 commit c05d604

3 files changed

Lines changed: 9 additions & 10 deletions

File tree

src/training/gradients.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "gradients.h"
66

77
#include "model.h"
8-
#include "models/llama_weights.h"
98
#include "utilities/allocator.h"
109
#include "utilities/comm.h"
1110
#include "utilities/lazy_allocator.h"
@@ -32,13 +31,13 @@ UnshardedGradientManager::UnshardedGradientManager(const TransformerConfig& cfg,
3231
mBlockGradients[i] = model.create_block_container(cfg, cfg.DType, cfg.DType);
3332
alloc_lazy.allocate(mBlockGradients[i]);
3433
alloc_lazy.commit(*alloc, EAllocationType::ON_DEVICE, "block_grad");
35-
mBlockShards[i] = shard_view(GenericTensorContainer(mBlockGradients[i]), rank, world);
34+
mBlockShards[i] = shard_view(mBlockGradients[i], rank, world);
3635
}
3736

3837
mNonBlockGradients = model.create_non_block_container(cfg, cfg.DType, cfg.DType);
3938
alloc_lazy.allocate(mNonBlockGradients);
4039
alloc_lazy.commit(*alloc, EAllocationType::ON_DEVICE, "nonblock_grad");
41-
mNonBlockShards = shard_view(GenericTensorContainer(mNonBlockGradients), rank, world);
40+
mNonBlockShards = shard_view(mNonBlockGradients, rank, world);
4241

4342
mGradEvent = create_named_event("grad_event");
4443
}
@@ -97,7 +96,7 @@ ShardedBlocksGradientManager::ShardedBlocksGradientManager(const TransformerConf
9796
mFullNonBlock = model.create_non_block_container(cfg, cfg.DType, cfg.DType);
9897
alloc_lazy.allocate(mFullNonBlock);
9998
alloc_lazy.commit(*alloc, EAllocationType::ON_DEVICE, "nonblock_grad");
100-
mNonBlockShards = shard_view(GenericTensorContainer(mFullNonBlock), rank, world);
99+
mNonBlockShards = shard_view(mFullNonBlock, rank, world);
101100

102101
mGradBuffers[0] = model.create_block_container(cfg, cfg.DType, cfg.DType);
103102
mGradBuffers[1] = model.create_block_container(cfg, cfg.DType, cfg.DType);

src/utilities/tensor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,17 @@ const Tensor& GenericTensorContainer::get_tensor(std::size_t idx) const {
160160

161161
GenericTensorContainer shard_empty_container(GenericTensorContainer&& c, int world) {
162162
// can't use visit here, because we explicitly want to iterate over empty tensors
163-
for (int i = 0; i < c.num_tensors(); ++i) {
163+
for (std::size_t i = 0; i < c.num_tensors(); ++i) {
164164
auto& t = c.get_tensor(i);
165-
if (!t.empty()) { throw std::logic_error("shard_container called with non-empty tensor"); }
165+
if (!t.empty()) { throw std::logic_error("shard_empty_container called with non-empty tensor"); }
166166
t.Sizes[0] = div_exact(t.Sizes[0], static_cast<long>(world));
167167
}
168-
return c;
168+
return std::move(c);
169169
}
170170

171171
GenericTensorContainer shard_view(const GenericTensorContainer& c, int rank, int world) {
172172
std::vector<Tensor> shards(c.num_tensors());
173-
for (int i = 0; i < c.num_tensors(); ++i) {
173+
for (std::size_t i = 0; i < c.num_tensors(); ++i) {
174174
shards.at(i) = static_cast<Tensor>(shard_view(c.get_tensor(i), rank, world));
175175
}
176176
return GenericTensorContainer{shards};

src/utilities/tensor_container.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ class GenericTensorContainer final : public SimpleTensorContainer {
5757
std::vector<Tensor> mTensors;
5858
};
5959

60-
//! shards an _empty_ container, i.e., a container in which all data pointers
60+
//! Shards an _empty_ container, i.e., a container in which all data pointers
6161
//! are `nullptr`, but sizes have been set up.
6262
GenericTensorContainer shard_empty_container(GenericTensorContainer&& c, int world);
6363

64-
//! shard a non-empty tensor container. The returned container's tensors are _views_ into
64+
//! Shards a non-empty tensor container. The returned container's tensors are _views_ into
6565
//! the original container's tensors.
6666
GenericTensorContainer shard_view(const GenericTensorContainer& c, int rank, int world);
6767

0 commit comments

Comments
 (0)