Skip to content

Commit b817f2f

Browse files
committed
Update training, validation, and testing steps to match the shared_step() design pattern
1 parent 70c3e49 commit b817f2f

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

project/utils/deepinteract_modules.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,11 @@ def training_step(self, batch, batch_idx):
17611761
"""Lightning calls this inside the training loop."""
17621762
# Separate training batch from validation batch (with the latter being used for visualizations)
17631763
train_batch, val_batch = batch['train_batch'], batch['val_batch']
1764-
graph1, graph2, examples_list, filepaths = train_batch[0], train_batch[1], train_batch[2], train_batch[3]
1764+
1765+
# Unpack list of input training complexes - Assume batch_size=1
1766+
train_batch = train_batch[0]
1767+
graph1, graph2 = train_batch['graph1'], train_batch['graph2']
1768+
examples_list, filepaths = [train_batch['examples']], [train_batch['filepath']]
17651769

17661770
# Forward propagate with network layers
17671771
logits_list, _, _, _, _ = self.shared_step(graph1, graph2) # The forward method must be named something new
@@ -1832,10 +1836,13 @@ def training_step(self, batch, batch_idx):
18321836
# Val Sample
18331837
# ------------
18341838
# Make a forward pass through the network for a held-out validation complex for visualization
1835-
val_graph1, val_graph2, val_examples_list = val_batch[0], val_batch[1], val_batch[2][0]
1839+
val_batch = val_batch[0] # Unpack list of input validation complexes - Assume batch_size=1
1840+
val_graph1, val_graph2 = val_batch['graph1'], val_batch['graph2']
1841+
val_examples_list = val_batch['examples']
18361842

18371843
# Forward propagate with network layers without accumulating any gradients
1838-
val_logits, _, _, _, _ = self.shared_step(val_graph1, val_graph2)[0].squeeze()
1844+
val_logits_list, _, _, _, _ = self.shared_step(val_graph1, val_graph2)
1845+
val_logits = val_logits_list[0].squeeze()
18391846
val_len_1, val_len_2 = val_logits.shape[1:]
18401847

18411848
# Construct the predicted M x N interaction tensor and its corresponding labels
@@ -1916,7 +1923,9 @@ def training_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
19161923
def validation_step(self, batch, batch_idx):
19171924
"""Lightning calls this inside the validation loop."""
19181925
# Make a forward pass through the network for a batch of protein complexes
1919-
graph1, graph2, examples_list, filepaths = batch[0], batch[1], batch[2], batch[3]
1926+
batch = batch[0] # Unpack list of input complexes - Assume batch_size=1
1927+
graph1, graph2 = batch['graph1'], batch['graph2']
1928+
examples_list, filepaths = [batch['examples']], [batch['filepath']]
19201929

19211930
# Forward propagate with network layers
19221931
logits_list, _, _, _, _ = self.shared_step(graph1, graph2)
@@ -2018,8 +2027,10 @@ def validation_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None
20182027

20192028
def test_step(self, batch, batch_idx):
20202029
"""Lightning calls this inside the testing loop."""
2021-
# Make a forward pass through the network for a batch of protein complexes (batch_size=1)
2022-
graph1, graph2, examples_list, filepaths = batch[0], batch[1], batch[2], batch[3]
2030+
# Make a forward pass through the network for a batch of protein complexes
2031+
batch = batch[0] # Unpack list of input complexes - Assume batch_size=1
2032+
graph1, graph2 = batch['graph1'], batch['graph2']
2033+
examples_list, filepaths = [batch['examples']], [batch['filepath']]
20232034

20242035
# Forward propagate with network layers
20252036
logits_list, _, _, _, _ = self.shared_step(graph1, graph2)

0 commit comments

Comments
 (0)