📌 Transfer Learning이란?
작업을 위해 어느정도 학습된 매개변수를 관련 작업의 시작점으로 재사용하는 학습 방법으로,
사전 학습된 네트워크의 매개변수를 기반으로 현재 네트워크 매개변수를 초기화하여 활용
데이터가 충분하지 않거나, GPU가 고사양이지 않아도 성능을 낼 수 있는 방식 중 하나
📌 Network Saving & Loading 실습
1. 모델 전체 저장
구조, 파라미터 등 전체가 저장되며, 모델로써 불러와 사용하면 됨
[saving]
torch.save(model, "model.pt")
[Loading]
with torch.no_grad():
model = torch.load("model.pt")
print(model)
⬇️ 출력 결과
2. 모델 state 저장
모델의 상태가 저장되며, 모델의 구조를 알 때 활용 가능
모델의 구조를 먼저 잡아준 후, 해당 구조에 state를 적용하는 방식
[saving]
torch.save(model.state_dict(), "model_state_dict.pt")
[Loading]
with torch.no_grad():
model_state_dict = torch.load("model_state_dict.pt")
model.load_state_dict(model_state_dict)
print(model_state_dict.keys())
⬇️ 출력 결과