카테고리 없음

[JAX] Version Error 해결 방법

juice_moon 2023. 8. 1. 22:31

JAX 사용할 때, 계속해서 "AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)" 이런 오류가 발생.

 

고치면 또 까먹고 고치면 또 까먹어서 정리하는 글.

 

저 에러는 jaxlib랑 jax가 버전이 맞지 않아서 발생하는 오류라고 한다.

 

각 버전을 먼저 확인해주기 위해 아래를 실행:

pip list | grep jax

 

그리고 각 jaxlib랑 jax를 업데이트 진행하면 해결 가능.

pip install -U jax jaxlib

업데이트 후 각 버전 확인해보면 버전이 동일해진걸 확인할 수 있다.

그런데 가끔 colab 환경에서 진행할때 업데이트를 해도 에러가 계속해서 발생하는 경우도 있다..

이럴땐 그냥 런타임 다시 돌리는게 답 :)