pytorch
Pytorch로 FLOPs 계산하는 방법
륵기
2022. 9. 27. 17:54
728x90
반응형
https://github.com/facebookresearch/fvcore
GitHub - facebookresearch/fvcore: Collection of common code that's shared among different research projects in FAIR computer vis
Collection of common code that's shared among different research projects in FAIR computer vision team. - GitHub - facebookresearch/fvcore: Collection of common code that's shared among dif...
github.com
from fvcore.nn import FlopCountAnalysis, flop_count_table
model = "선언할 모델"
input_img = torch.ones("shape 입력")
flops = FlopCountAnalysis(Teacher_model, input_img)
print(flops.total()) # kb단위로 모델전체 FLOPs 출력해줌
print(flop_count_table(flops)) # 테이블 형태로 각 연산하는 모듈마다 출력해주고, 전체도 출력해줌
- 예시 출력 결과 (Table 형태)
- FLOPs = Floating point operations -> 모델 연산량 의미
- FLOPS = Floating point operations per second -> 컴퓨팅 소스 관련 = 1초간 얼마나 많은 부동소수점 연산을 하는지
728x90
반응형