Metrics是torchmetrics库里的度量类基类,本篇大体介绍一下它是如何工作的

它也是一个Model

由类的定义可以看到,它继承与两个类,一个是我们熟悉的Module,另外一个是ABC,所以它从行为上来说,跟Module一样

class Metric(Module, ABC):

第一步 __call__

它的行为同Model,所以通过__call__调用。

所以,第一步是Model的__call__

__call__ : Callable[..., Any] = _call_impl

__call__实际是直接调用_call_impl,这里在727行,直接调用self.forward
【pytorch】Metrics的工作原理_yacc

第二步 forward

同pytorch里的module子类一样,重载forward方法。
Metrics的forward函数,内部定义了update函数和compute函数,所以自定义的Metrics需要重载update和compute

【pytorch】Metrics的工作原理_子类_02

这里有个参数compute_on_step,默认是True。默认情况下,update会在上面一行192行调用一次;然后在204行调用一次。 所以在默认情况下会调用两次

compute方法仅仅在compute_on_step为True时调用,且在此时才有返回值

第三步 update 和 compute

每一个Metrics的子类都需要重载这两个函数 (默认compute_on_step=True的情况)
下面以一个自定义的Metrics子类为例

内部用于计算的变量通过add_state注册,然后在update里更新,最后在compute里运算出结果

class MyAccuracy(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        # update metric states
        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        # compute final result
        return self.correct.float() / self.total


metrics = MyAccuracy()

preds  = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

t = metrics(preds, target)
print(t)

结果为tensor(0.6667)