@@ -1761,7 +1761,11 @@ def training_step(self, batch, batch_idx):
17611761 """Lightning calls this inside the training loop."""
17621762 # Separate training batch from validation batch (with the latter being used for visualizations)
17631763 train_batch , val_batch = batch ['train_batch' ], batch ['val_batch' ]
1764- graph1 , graph2 , examples_list , filepaths = train_batch [0 ], train_batch [1 ], train_batch [2 ], train_batch [3 ]
1764+
1765+ # Unpack list of input training complexes - Assume batch_size=1
1766+ train_batch = train_batch [0 ]
1767+ graph1 , graph2 = train_batch ['graph1' ], train_batch ['graph2' ]
1768+ examples_list , filepaths = [train_batch ['examples' ]], [train_batch ['filepath' ]]
17651769
17661770 # Forward propagate with network layers
17671771 logits_list , _ , _ , _ , _ = self .shared_step (graph1 , graph2 ) # The forward method must be named something new
@@ -1832,10 +1836,13 @@ def training_step(self, batch, batch_idx):
18321836 # Val Sample
18331837 # ------------
18341838 # Make a forward pass through the network for a held-out validation complex for visualization
1835- val_graph1 , val_graph2 , val_examples_list = val_batch [0 ], val_batch [1 ], val_batch [2 ][0 ]
1839+ val_batch = val_batch [0 ] # Unpack list of input validation complexes - Assume batch_size=1
1840+ val_graph1 , val_graph2 = val_batch ['graph1' ], val_batch ['graph2' ]
1841+ val_examples_list = val_batch ['examples' ]
18361842
18371843 # Forward propagate with network layers without accumulating any gradients
1838- val_logits , _ , _ , _ , _ = self .shared_step (val_graph1 , val_graph2 )[0 ].squeeze ()
1844+ val_logits_list , _ , _ , _ , _ = self .shared_step (val_graph1 , val_graph2 )
1845+ val_logits = val_logits_list [0 ].squeeze ()
18391846 val_len_1 , val_len_2 = val_logits .shape [1 :]
18401847
18411848 # Construct the predicted M x N interaction tensor and its corresponding labels
@@ -1916,7 +1923,9 @@ def training_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
19161923 def validation_step (self , batch , batch_idx ):
19171924 """Lightning calls this inside the validation loop."""
19181925 # Make a forward pass through the network for a batch of protein complexes
1919- graph1 , graph2 , examples_list , filepaths = batch [0 ], batch [1 ], batch [2 ], batch [3 ]
1926+ batch = batch [0 ] # Unpack list of input complexes - Assume batch_size=1
1927+ graph1 , graph2 = batch ['graph1' ], batch ['graph2' ]
1928+ examples_list , filepaths = [batch ['examples' ]], [batch ['filepath' ]]
19201929
19211930 # Forward propagate with network layers
19221931 logits_list , _ , _ , _ , _ = self .shared_step (graph1 , graph2 )
@@ -2018,8 +2027,10 @@ def validation_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None
20182027
20192028 def test_step (self , batch , batch_idx ):
20202029 """Lightning calls this inside the testing loop."""
2021- # Make a forward pass through the network for a batch of protein complexes (batch_size=1)
2022- graph1 , graph2 , examples_list , filepaths = batch [0 ], batch [1 ], batch [2 ], batch [3 ]
2030+ # Make a forward pass through the network for a batch of protein complexes
2031+ batch = batch [0 ] # Unpack list of input complexes - Assume batch_size=1
2032+ graph1 , graph2 = batch ['graph1' ], batch ['graph2' ]
2033+ examples_list , filepaths = [batch ['examples' ]], [batch ['filepath' ]]
20232034
20242035 # Forward propagate with network layers
20252036 logits_list , _ , _ , _ , _ = self .shared_step (graph1 , graph2 )
0 commit comments