How do the hooks for the LightningModule
interact with the hooks for the LightningDataModule
?
Does one overwrite the other? Previously, I was able to call the LightningDataModule.train_dataloader()
from within the LightningModule.train_dataloader()
but it seems that the latter is not being called at all anymore when using trainer.fit(model, dm=datamodule)
My use case is that I’d like to modify the dataloader using a function that needs access to the optimizer, therefore I’d like to do it from the model:
class Classifier(LightningModule):
def __init__(
self,
*args,
**kwargs,
):
super().__init__()
# model initalized here
def train_dataloader(self) -> Any:
dl = self.trainer.datamodule.train_dataloader()
if not hasattr(self.trainer.datamodule, "batch_size_physical"):
return dl # just use the LightningDataModule as is
# wrap using this function otherwise
return wrap_data_loader(
data_loader=dl,
max_batch_size=self.trainer.datamodule.batch_size_physical,
optimizer=self.optimizer,
)
1 post - 1 participant