Quantcast
Channel: LightningModule - Lightning AI
Viewing all articles
Browse latest Browse all 36

Skip instances during training

$
0
0

Hi, I am using the LightningModule to train a neural network across many instances/GPUs, however the data is imbalanced ( I cannot change this ), so I want to skip over some instances during training to balance it.

Here is the logic of my code inside inside the training_step()

# call this from training_step(4 batches...) elsewhere...

    def forward(... ):
        batch_size = len(inputs). # passing 4 batches at a time...
        total_loss = 0.

        for batch_idx in range(batch_size):

            train_model = True
            filtered_tags = tag[batch_idx][self.most_frequent_feature_indexes]
            if filtered_tags.sum() == 0:
                # Skip 3 times more data points where all scenario tags are 0
                if self.count_zero % 4 != 0:
                    self.count_zero += 1
                    train_model = False
                else:
                    self.count_zero += 1
                    self.total_zero_data += 1
            if (filtered_tags.sum() == 1) and (torch.argmax(filtered_tags) == 5):
                # Skip 9 out of 10 data points where there is only one tag and it's 5
                if self.count_stationary % 10 < 9:
                    self.count_stationary += 1
                    train_model = False
                else:
                    self.count_stationary += 1

            if train_model:
                logits = self.model(...)
            else:
                with torch.no_grad():
                    logits = self.model(...)

But I get this error

Latest log:
[E ProcessGroupNCCL.cpp:414] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:737] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=27, OpType=BROADCAST, Timeout(ms)=7200000) ran for 7208733 milliseconds before timing out.

There must be a way to do this right? Thank you!

3 posts - 2 participants

Read full topic


Viewing all articles
Browse latest Browse all 36

Trending Articles