Quantcast
Channel: LightningModule - Lightning AI
Viewing all articles
Browse latest Browse all 36

LightningModule.train_dataloader()

$
0
0

How do the hooks for the LightningModule interact with the hooks for the LightningDataModule?
Does one overwrite the other? Previously, I was able to call the LightningDataModule.train_dataloader() from within the LightningModule.train_dataloader() but it seems that the latter is not being called at all anymore when using trainer.fit(model, dm=datamodule)

My use case is that I’d like to modify the dataloader using a function that needs access to the optimizer, therefore I’d like to do it from the model:

class Classifier(LightningModule):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__()
        # model initalized here

    def train_dataloader(self) -> Any:
        dl = self.trainer.datamodule.train_dataloader()
        if not hasattr(self.trainer.datamodule, "batch_size_physical"):
            return dl # just use the LightningDataModule as is
        # wrap using this function otherwise
        return wrap_data_loader(
            data_loader=dl,
            max_batch_size=self.trainer.datamodule.batch_size_physical,
            optimizer=self.optimizer,
        )

1 post - 1 participant

Read full topic


Viewing all articles
Browse latest Browse all 36

Trending Articles