Hi everone,
Some of metrics in torchmetrics returns dictionary when I call its compute()
method.
For example, torchmetrics.SQuAD()
returns
{'exact_match': tensor(0., device='cuda:0'), 'f1': tensor(3.2220, device='cuda:0')}
when I call compute()
.
In [mode]_step
method in lightning module, I tried to log with SQuAD()
but I couldn’t.
self.log("score", METRIC_OBJECT, on_epoch=True, prog_bar=True, logger=True)
Error message: The .compute()
return of the metric logged as ‘score’ must be a tensor.
Is there a way to log these metrics without overloading on_[mode]_epoch_end
?
2 posts - 2 participants