Quantcast
Channel: LightningModule - Lightning AI
Viewing all articles
Browse latest Browse all 36

confusions about load_from_checkpoint() and save_hyperparameters()

$
0
0

according to

Saving and loading checkpoints (basic) — PyTorch Lightning 2.1.3 documentation,

There is a model like this:

class Encoder(L.LightningModule):
    ...

class Decoder(L.LightningModule):
    ...

class Autoencoder(L.LightningModule):
    def __init__(self, encoder, decoder, *args, **kwargs):
        self.save_hyperparameters(ignore=['encoder', 'decoder'])
        self.encoder=encoder
        self.encoder.freeze()
        self.decoder=decoder
        ...

# training code
encoder = Encoder.load_from_checkpoint("encoder.ckpt")
decoder = Decoder(some hyperparameters)
autoencoder = Autoencoder(encoder, decoder)
trainer.fit(autoencoder, datamodule)

We assume that the autoencoder has been stored in the autoencoder.ckpt file. There are three key points I am curious about:

  1. Does the autoencoder.ckpt file include both the encoder and decoder weights?
  2. If autoencoder.ckpt contains the encoder weights, how can I import the weights from encoder.ckpt into the autoencoder without them being overwritten?
  3. If autoencoder.ckpt does not include the decoder weights, what is the procedure to save the decoder weights separately?

2 posts - 2 participants

Read full topic


Viewing all articles
Browse latest Browse all 36

Trending Articles