Quantcast
Channel: LightningModule - Lightning AI
Viewing all articles
Browse latest Browse all 36

Custom model definition is not included in checkpoint hyper_parameters

$
0
0

Hi, i have the following dummy LightningModule

class MyLightningModule(LightningModule):
    def __init__(
        self,
        param_1: torch.nn.Module = torch.nn.Conv2d(1,1,1)
        param_2: torch.nn.Module = MyCustomModule(...)
    ):
    super().__init__()
    self.save_hyperparameters()
    print(self.hparams.param_1) # prints out correctly
    print(self.hparams.param_2) # prints out correctly

When I tried to load a checkpoint via MyLigntningModule.load_from_checkpoint(ckpt_path), I noticed that checkpoint[“hyper_parameters”] does NOT contain a key for param_2 while it DOES contain a key for param_1. I DO see the hparams.param_2 in my logger correctly printed, which i really weird.

For the param_2 is used a network from the escnn libarary which is derived from torch.nn.Module. I traced the problem back to using a any layer from that library:

import escnn.nn as enn
import escnn

param_1 = enn.R2Conv(enn.FieldType(escnn.gspaces.rot2dOnR2(8), [escnn.gspaces.rot2dOnR2(8).regular_repr]), enn.FieldType(escnn.gspaces.rot2dOnR2(8), [escnn.gspaces.rot2dOnR2(8).regular_repr]), 7),

What could be the reason that the custom model definition is not part of the checkpoint? Thanks in advance!

3 posts - 2 participants

Read full topic


Viewing all articles
Browse latest Browse all 36

Trending Articles