2

I'm searching a way to write the equivalent of the following Pytorch module in Flax but I haven't found a way to do it. The important thing is that the constant should be loadable and saveable upon checkpoint.

class SillyModule(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.register_buffer('constant', torch.randn(1, 128))

    def forward(self, x):
        return torch.matmul(x, self.B)

Does anybody know how to do this? What is the equivalent of register_buffer in flax?

user1635327
  • 1,469
  • 3
  • 11
ysig
  • 447
  • 4
  • 18

0 Answers0