본문 바로가기

Upstage AI Lab 2기

Pytorch Lightning

PyTorch에 대한 high-level 인터페이스

코드 템플릿으로써 기능 -> 더 간결, 정돈

 

  1. 코드 추상화 및 하드웨어 호출 자동화
    • 기존의 PyTorch에서는 model, optimizer, training loop 등을 전부 따로 구현해야 했지만 lightning module 안에 한꺼번에 구현되어 있음.
    • .to(device) 안 해도 됨 (자동으로 하드웨어 호출)
  2. 다양한 콜백 함수와 로깅
    • early stopping 등의 콜백 함수
    • 로깅 -> tensorboard, wandb 등과 연동 가능
  3. 16-bit precision - 계산속도 향상, 메모리 사용량 감소
    • quantization - 너무 큰 모델의 경우 모두 로드해서 쓰기 어렵기 때문에 모델 크기를 줄여서 GPU에 올려줌

 

Lightning Module을 상속받아 모델의 구조, 손실함수, 학습 및 평가 방법과 최적화 알고리즘을 클래스에 선언

Lightning Module의 구성

  • __init__ : 모델의 레이어 초기화, loss function, evaluation metric 등 선언
  • forward
  • configure_optimizers : 반환은 optimizer, scheduler( -> learning rate scheduler, 생략 가능) 순서로!
  • training_step : 미니 배치에 대한 loss 반환, -> optimizer.zero_grad(), loss.backward(), optimizer.step() 작성 안해도 됨

[[PyTorch] (6-1) 파이토치 라이트닝 소개] 강의 자료 중

 

구현해야하는 method 3개

validation_step, test_step, predict_step - 평가 및 추론에 쓰이는 method

  • validation_step, test_step -> 미니 배치에 대한 성능 평가 (loss 또는 metric) 반환
  • predict_step -> 미니 배치에 대한 예측결과 반환

 

Trainer : Lightning Module의 메서드 이용해 모델 학습 실행

trainer = Trainer(     )

trainer.fit(model, train_dataloader, valid_dataloader) - 학습과 평가를 반복

장점 : 분산학습환경을 자동으로 관리, epoch나 반복물을 명시하지 않아도 됨.

[[PyTorch] (6-1) 파이토치 라이트닝 소개] 강의 자료 중

trainer.validate() -> validation_step 호출

trainer.test() -> test_step 호출

trainer.predict() -> predict_step 호출