55#include " gradients.h"
66
77#include " model.h"
8- #include " models/llama_weights.h"
98#include " utilities/allocator.h"
109#include " utilities/comm.h"
1110#include " utilities/lazy_allocator.h"
@@ -32,13 +31,13 @@ UnshardedGradientManager::UnshardedGradientManager(const TransformerConfig& cfg,
3231 mBlockGradients [i] = model.create_block_container (cfg, cfg.DType , cfg.DType );
3332 alloc_lazy.allocate (mBlockGradients [i]);
3433 alloc_lazy.commit (*alloc, EAllocationType::ON_DEVICE, " block_grad" );
35- mBlockShards [i] = shard_view (GenericTensorContainer ( mBlockGradients [i]) , rank, world);
34+ mBlockShards [i] = shard_view (mBlockGradients [i], rank, world);
3635 }
3736
3837 mNonBlockGradients = model.create_non_block_container (cfg, cfg.DType , cfg.DType );
3938 alloc_lazy.allocate (mNonBlockGradients );
4039 alloc_lazy.commit (*alloc, EAllocationType::ON_DEVICE, " nonblock_grad" );
41- mNonBlockShards = shard_view (GenericTensorContainer ( mNonBlockGradients ) , rank, world);
40+ mNonBlockShards = shard_view (mNonBlockGradients , rank, world);
4241
4342 mGradEvent = create_named_event (" grad_event" );
4443}
@@ -97,7 +96,7 @@ ShardedBlocksGradientManager::ShardedBlocksGradientManager(const TransformerConf
9796 mFullNonBlock = model.create_non_block_container (cfg, cfg.DType , cfg.DType );
9897 alloc_lazy.allocate (mFullNonBlock );
9998 alloc_lazy.commit (*alloc, EAllocationType::ON_DEVICE, " nonblock_grad" );
100- mNonBlockShards = shard_view (GenericTensorContainer ( mFullNonBlock ) , rank, world);
99+ mNonBlockShards = shard_view (mFullNonBlock , rank, world);
101100
102101 mGradBuffers [0 ] = model.create_block_container (cfg, cfg.DType , cfg.DType );
103102 mGradBuffers [1 ] = model.create_block_container (cfg, cfg.DType , cfg.DType );
0 commit comments