카테고리 없음

[JAX] 모델 구조 확인하기

juice_moon 2023. 8. 1. 23:05

JAX에서 모델을 사용할 때, 주로 Flax를 이용하여 모델을 다룬다. 

Flax official page: https://flax.readthedocs.io/en/latest/index.html

 

Flax

Functional API Flax’s functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex.

flax.readthedocs.io

Flax에도 Pytorch에서 nn.Module과 동일한 Flax.linen.Module이 존재하고, 이를 이용해서 모델 구현을 하게 되는다. (모델은 pytree 구조로 이루어져있음!)

Pytorch는 torchsummary 라이브러리를 사용하거나 그냥 print를 하면 간단하게 모델 구조를 확인할 수 있는데, Flax도 이와 비슷하게 clu에서 제동하는 parameter_overview를 이용해서  모델 구조 확인 가능하다. 

from clu import parameter_overview

params = model.params
print(parameter_overview.get_parameter_overview(params))