JAX에서 모델을 사용할 때, 주로 Flax를 이용하여 모델을 다룬다. Flax official page: https://flax.readthedocs.io/en/latest/index.html Flax Functional API Flax’s functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex. flax.readthedocs.io Flax에도 Pytorch에서 nn.Module과 동일한 Flax.linen.Module이 ..