본문 바로가기
[AI]/인공지능 이론 및 실습

[인공지능] Transfer Learning : 전이학습이란?, Network Saving & Loading 실습

by seom-j 2024. 2. 8.

 

📌 Transfer Learning이란?

작업을 위해 어느정도 학습된 매개변수를 관련 작업의 시작점으로 재사용하는 학습 방법으로,

사전 학습된 네트워크의 매개변수를 기반으로 현재 네트워크 매개변수를 초기화하여 활용

 

데이터가 충분하지 않거나, GPU가 고사양이지 않아도 성능을 낼 수 있는 방식 중 하나

 

 

📌 Network Saving & Loading 실습

1. 모델 전체 저장

구조, 파라미터 등 전체가 저장되며, 모델로써 불러와 사용하면 됨

 

[saving]

torch.save(model, "model.pt")

 

[Loading]

with torch.no_grad():
  model = torch.load("model.pt")
  print(model)

⬇️ 출력 결과

 

2. 모델 state 저장

모델의 상태가 저장되며, 모델의 구조를 알 때 활용 가능

모델의 구조를 먼저 잡아준 후, 해당 구조에 state를 적용하는 방식

 

[saving]

torch.save(model.state_dict(), "model_state_dict.pt")

 

[Loading]

with torch.no_grad():
  model_state_dict = torch.load("model_state_dict.pt")
  model.load_state_dict(model_state_dict)
  print(model_state_dict.keys())

⬇️ 출력 결과