rueki

Pytorch로 FLOPs 계산하는 방법 본문

pytorch

Pytorch로 FLOPs 계산하는 방법

륵기 2022. 9. 27. 17:54
728x90
반응형

https://github.com/facebookresearch/fvcore

 

GitHub - facebookresearch/fvcore: Collection of common code that's shared among different research projects in FAIR computer vis

Collection of common code that's shared among different research projects in FAIR computer vision team. - GitHub - facebookresearch/fvcore: Collection of common code that's shared among dif...

github.com

 

from fvcore.nn import FlopCountAnalysis, flop_count_table


model = "선언할 모델"
input_img = torch.ones("shape 입력")

flops = FlopCountAnalysis(Teacher_model, input_img)

print(flops.total()) # kb단위로 모델전체 FLOPs 출력해줌
print(flop_count_table(flops)) # 테이블 형태로 각 연산하는 모듈마다 출력해주고, 전체도 출력해줌

 

- 예시 출력 결과 (Table 형태)

 

 

  • FLOPs = Floating point operations -> 모델 연산량 의미
  • FLOPS = Floating point operations per second -> 컴퓨팅 소스 관련 = 1초간 얼마나 많은 부동소수점 연산을 하는지
728x90
반응형
Comments