Metrics是torchmetrics库里的度量类基类,本篇大体介绍一下它是如何工作的
它也是一个Model
由类的定义可以看到,它继承与两个类,一个是我们熟悉的Module,另外一个是ABC,所以它从行为上来说,跟Module一样
class Metric(Module, ABC):
第一步 __call__
它的行为同Model,所以通过__call__调用。
所以,第一步是Model的__call__
__call__ : Callable[..., Any] = _call_impl
__call__实际是直接调用_call_impl,这里在727行,直接调用self.forward
第二步 forward
同pytorch里的module子类一样,重载forward方法。
Metrics的forward函数,内部定义了update函数和compute函数,所以自定义的Metrics需要重载update和compute
这里有个参数compute_on_step,默认是True。默认情况下,update会在上面一行192行调用一次;然后在204行调用一次。 所以在默认情况下会调用两次。
compute方法仅仅在compute_on_step为True时调用,且在此时才有返回值
第三步 update 和 compute
每一个Metrics的子类都需要重载这两个函数 (默认compute_on_step=True的情况)
下面以一个自定义的Metrics子类为例
内部用于计算的变量通过add_state注册,然后在update里更新,最后在compute里运算出结果
class MyAccuracy(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
# update metric states
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
# compute final result
return self.correct.float() / self.total
metrics = MyAccuracy()
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
t = metrics(preds, target)
print(t)
结果为tensor(0.6667)