JAX 사용할 때, 계속해서 "AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)" 이런 오류가 발생.
고치면 또 까먹고 고치면 또 까먹어서 정리하는 글.
저 에러는 jaxlib랑 jax가 버전이 맞지 않아서 발생하는 오류라고 한다.
각 버전을 먼저 확인해주기 위해 아래를 실행:
pip list | grep jax
그리고 각 jaxlib랑 jax를 업데이트 진행하면 해결 가능.
pip install -U jax jaxlib
업데이트 후 각 버전 확인해보면 버전이 동일해진걸 확인할 수 있다.
그런데 가끔 colab 환경에서 진행할때 업데이트를 해도 에러가 계속해서 발생하는 경우도 있다..
이럴땐 그냥 런타임 다시 돌리는게 답 :)