Posts 우리가 PyTorch Lightning을 써야 하는 이유
Post
Cancel

우리가 PyTorch Lightning을 써야 하는 이유

다음 글은 PyTorch Lightning 라이브러리에 있는 여러 내용들을 참고하여 작성했습니다.


PyTorch Lightning

The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.



PyTorch Geometric에 이어 PyTorch Lightning라는 라이브러리를 소개하려고 합니다. “도대체 왜 PyTorch를 사용하면서, 거기에 또 다른 라이브러리들까지 사용해야 할까?”라는 의문이 들 수 있는데요. 이는 정말 다루고 싶은 문제들에 더 집중하고 싶기 때문이라고 생각합니다. 마찬가지로, PyTorch Lightning 또한 이런 목적에서 탄생했습니다. 몇 가지 예제를 통해 PyTorch Lightning 라이브러리에 대해 살펴보도록 하겠습니다.





PyTorch Lightning이란 무엇인가?

PyTorch Lightning은 PyTorch에 대한 High-level 인터페이스를 제공하는 오픈소스 Python 라이브러리입니다. PyTorch만으로도 충분히 다양한 AI 모델들을 쉽게 생성할 수 있지만 GPU나 TPU, 그리고 16-bit precision, 분산학습 등 더욱 복잡한 조건에서 실험하게 될 경우, 코드가 복잡해집니다. 따라서 코드의 추상화를 통해, 프레임워크를 넘어 하나의 코드 스타일로 자리 잡기 위해 탄생한 프로젝트가 바로 PyTorch Lightning입니다.

위 그림은 MNIST 예제에 대해 PyTorch와 PyTorch Lightning을 각각 사용해 작성한 코드입니다. 왼쪽보다 오른쪽이 훨씬 정돈된 느낌을 주는 데요. 특히, 기존 PyTorch를 사용하시는 분들은 파란 영역에 해당하는 모델 학습/평가 부분이 단 2줄로 줄어들었다는 사실을 확연히 보실 수 있습니다. 이에 대한 기능은 아래에서 조금 더 자세히 설명해 드리도록 하겠습니다.




우리가 Lightning을 써야하는 이유

라이브러리를 구현해주신 분들이 추천한 Lightning 사용 대상은 다음과 같습니다.

  • PyTorch를 사용하는 모든 사람
  • 딥러닝 연구자들
  • 추상화하여 딥러닝 모델을 훈련하고 싶은 사람들
  • PyTorch로 딥러닝 모델을 서비스화하는 엔지니어들

(저자들은 특히, dataset, dataloader, train/valid/test loop 등을 사용한 기존 PyTorch 예제를 몇 가지 살펴본 뒤, PyTorch Lightning을 적용해보는 것이 가장 효율적이다! 라고 말하고 있네요. 전적으로 공감합니다!)


먼저, PyTorch Lightning이 매력적인 이유는, (저처럼 추상화를 하지 않고) 코드를 작성하던 기존 PyTorch 사용자들이 더욱 정돈된 코드 스타일을 갖추게 된다는 점입니다. 제가 아직도 써먹는… 모델 학습/평가 과정 템플릿을 통해 구체적으로 설명해 드리도록 하겠습니다.


아래 예제는 크게 <TRAINING LOOP>와 <VALIDATION LOOP>로 이루어져 있습니다. mnist_trainmnist_val이라는 dataloader를 통해 배치학습을 진행 중입니다. (제가 처음 이 코드를 볼 때는 정말 힘들었습니다 :cold_sweat: 만약 여기에 early stooping 등의 기법을 추가한다면 더욱 복잡한 코드가 되겠죠…?)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# ----------------
# TRAINING LOOP
# ----------------
num_epochs = 1
for epoch in range(num_epochs):

  # TRAINING LOOP
  for train_batch in mnist_train:
    x, y = train_batch

    logits = pytorch_model(x)
    loss = cross_entropy_loss(logits, y)
    print('train loss: ', loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

  # VALIDATION LOOP
  with torch.no_grad():
    val_loss = []
    for val_batch in mnist_val:
      x, y = val_batch
      logits = pytorch_model(x)
      val_loss.append(cross_entropy_loss(logits, y).item())

    val_loss = torch.mean(torch.tensor(val_loss))
    print('val_loss: ', val_loss.item())


이를 PyTorch Lightning을 바꿔본다면 어떻게 될까요?

1
2
3
4
5
# train
model = LightningMNISTClassifier()
trainer = pl.Trainer()

trainer.fit(model)

참 쉽죠…? 뭔가 sklearn.fit 메소드를 보는 듯합니다.


이 앞에서 살펴본 건 단순히 “복잡한 양의 코드 작업이 간단하게 줄었다”라는 부분입니다. 그렇다면, 그 복잡한 양의 코드 작업은 어떻게 변형(추상화)되었을까요? 이 부분에 바로 PyTorch Lightning의 핵심이 담겨 있습니다.

크게 파란색 영역인 Trainer와 빨간색 영역인 Lightning Module로 나누어 살펴볼 수 있습니다.

먼저, 파란색 영역Trainer는 모델의 학습에 관여되는 engineering을 담당하는 클래스라고 볼 수 있습니다. 모델의 학습 epoch이나 batch 등의 상태뿐만 아니라, 모델을 저장해 로그를 생성하는 부분까지 전반적으로 담당합니다. 실제로 코드에서는 pl.Trainer()라고 정의하면 끝입니다.

두 번째로, 빨간색 영역Lightning Module은 모델 내부의 구조를 설계하는 research & science 클래스라고 생각할 수 있습니다. 모델의 구조나 데이터 전처리, 손실함수 등의 설정 등을 통해 모델을 초기화하게 됩니다. 실제로 코드에서는 pl.LightningModule클래스를 상속받아 새로운 lightningmodule클래스를 생성합니다. 기존 PyTorch의 nn.Module과 같은 방식이라고 보시면 됩니다.

결국 두 가지의 큰 클래스를 통해, 복잡한 양의 작업들을 2가지 영역으로 추상화할 수 있게 됩니다.



이밖에 PyTorch Lightning 장점은 다음과 같습니다.

  • 굉장히 유연함
    • PyTorch Code에 아주 적합한 구조
    • Trainer를 다양한 방식으로 Override 가능
    • Callback System을 통해 추가적인 작업을 할 수 있음
  • 딥러닝 학습 시 다뤄야할 부분을 잘 구조화 하였음
    • Lightning module
    • Trainer
    • 기타(Callbacks, Loggers…)
  • 다양한 학습 방법에 적용가능함
    • GPU, TPU learning
    • 16-bit precision
  • PyTorch Ecosystem에 속해 있음
    • 엄격한 Testing 과정
    • PyTorch 친화적
  • 다양한 예제와 풍부한 Documentation
  • 많은 contributor들이 존재함
  • integration with logging/visualization frameworks
    • Tensorboard, MLFLow, Neptune.ai, Comet.ml…
  • 기타 등등…




Lightning의 핵심 요소 2가지

PyTorch Lightning은 크게 2가지 영역으로 추상화하여, 코드 스타일의 혁신을 추구하고 있는데요.
이 2가지 영역의 핵심 요소, LightningModuleTrainer에 대해 더 자세히 살펴보도록 하겠습니다.


LightningModule 클래스

먼저 LightningModule에 대해 살펴보도록 하겠습니다.

  • 1) 모델의 기본적인 구조정의 (기존 코드와 동일)

    • 기존 모델을 초기화하듯이 그대로 사용 가능

      1
      2
      3
      4
      5
      6
      7
      
        def __init__(self):
          super(LightningMNISTClassifier, self).__init__()
          
          # mnist images are (1, 28, 28) (channels, width, height) 
          self.layer_1 = torch.nn.Linear(28 * 28, 128)
          self.layer_2 = torch.nn.Linear(128, 256)
          self.layer_3 = torch.nn.Linear(256, 10)
      
    • 기존 모델의 forward를 정의하는 부분을 그대로 사용 가능

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      
      def forward(self, x):
            batch_size, channels, width, height = x.size()
          
            # (b, 1, 28, 28) -> (b, 1*28*28)
            x = x.view(batch_size, -1)
          
            # layer 1 (b, 1*28*28) -> (b, 128)
            x = self.layer_1(x)
            x = torch.relu(x)
          
            # layer 2 (b, 128) -> (b, 256)
            x = self.layer_2(x)
            x = torch.relu(x)
          
            # layer 3 (b, 256) -> (b, 10)
            x = self.layer_3(x)
          
            # probability distribution over labels
            x = torch.log_softmax(x, dim=1)
          
            return x
      
    • 손실함수 또한 클래스 내부에 정의하여 사용하는 것이 구조화되어 좋음

      1
      2
      
      def cross_entropy_loss(self, logits, labels):
          return F.nll_loss(logits, labels)
      


  • 2) 모델 학습 루프 (Training Loop Sturcutre)

    • 복잡하게 작성하던 내용을 추상화한 부분
    • 일반적인 루프 패턴은 루프마다 3가지 메소드를 가지고 있음
    • (Training, validation, test loop) * (___step(스텝마다), ___step_end(스텝 종료), ___epoch_end(1 epoch 종료))
    • 해당되는 이름에 루프 패턴을 붙여서 정의
    • 일반적으로 쓰는 구조는 다음과 같음
    • training_step - validation_step - validation_epoch_end

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      
      def training_step(self, train_batch, batch_idx):
            x, y = train_batch
            logits = self.forward(x)
            loss = self.cross_entropy_loss(logits, y)
          
            logs = {'train_loss': loss}
            return {'loss': loss, 'log': logs}
          
      def validation_step(self, val_batch, batch_idx):
            x, y = val_batch
            logits = self.forward(x)
            loss = self.cross_entropy_loss(logits, y)
            return {'val_loss': loss}
          
      def validation_epoch_end(self, outputs):
            # called at the end of the validation epoch
            # outputs is an array with what you returned in validation_step for each batch
            # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] 
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            tensorboard_logs = {'val_loss': avg_loss}
            return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
      


  • 3) 데이터 준비
    • PyTorch의 데이터 준비하는 과정을 크게 5가지 형태로 구조화
      • 1) 다운로드
      • 2) 데이터 정리 혹은 메모리 저장
      • 3) 데이터셋 로드
      • 4) 데이터 전처리 (특히, transforms를 의미)
      • 5) dataloader 형태로 wrapping
    • 이에 맞게 코드를 추상화
      • prepare_data()
      • train_dataloader, val_dataloader, test_dataloader
    • 코드 참조

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      
      def prepare_data(self):
          # transforms for images
          transform=transforms.Compose([transforms.ToTensor(), 
                                        transforms.Normalize((0.1307,), (0.3081,))])
                
          # prepare transforms standard to MNIST
          mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
          mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
              
          self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
          
        def train_dataloader(self):
          return DataLoader(self.mnist_train, batch_size=64)
          
        def val_dataloader(self):
          return DataLoader(self.mnist_val, batch_size=64)
          
        def test_dataloader(self):
          return DataLoader(self,mnist_test, batch_size=64)
      


  • 4) optimizer 설정

    1
    2
    
    def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=0.02)
    


기존 PyTorch 학습 과정에 관여하는 여러 코드가 위와 같은 추상화된 함수형태로 LightningModule안에 포함되어 있다는 것을 확인할 수 있습니다. 특히, 상속받은 LightningModule 클래스는 위와 같은 함수들을 순서에 따라 실행하는데, 이를 바로 Lifecycle이라고 부릅니다. (즉, 해당하는 순서에 따라 함수를 작성하는 것이 중요합니다.)

  1. __init__
  2. prepare_data
  3. configure_optimizers
  4. train_dataloader
  5. val_dataloader
  6. test_dataloader (.test()가 호출될 때 호출)

또한, 각 배치와 에폭마다 루프 메소드는 함수 이름에 맞게 정해진 순서대로 호출됩니다.

  1. validation_step : 배치마다 실행
  2. validation_epoch_end : 에폭마다 실행




Trainer 클래스

다음은 Trainer 입니다.

  • 기본 사용

    1
    2
    3
    4
    5
    6
    
    from pytorch_lightning import Trainer
      
    model = MyLightningModule()
      
    trainer = Trainer()
    trainer.fit(model)
    
  • main.py로 작성 시

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    
    from argparse import ArgumentParser
      
    def main(hparams):
        model = LightningModule()
        trainer = Trainer(gpus=hparams.gpus)
        trainer.fit(model)
      
    if __name__ == '__main__':
        parser = ArgumentParser()
        parser.add_argument('--gpus', default=None)
        args = parser.parse_args()
      
        main(args)
    
    1
    2
    
    # 실행
    $ python main.py --gpus 2
    
  • Testing

    1
    
    trainer.test()
    
  • Deployment / prediction

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    
    # load model
    pretrained_model = LightningModule.load_from_checkpoint(PATH)
    pretrained_model.freeze()
      
    # use it for finetuning
    def forward(self, x):
        features = pretrained_model(x)
        classes = classifier(features)
      
    # or for prediction
    out = pretrained_model(x)
    api_write({'response': out}
    


이밖에도 다양한 Trainer flags 옵션이 존재합니다.

  • callbacks
  • logger
  • max_epochs
  • min_epochs
  • distributed_backend
  • overfit_pct




PyTorch Lightning에 contribute하기

현재 PyTorch Lightning은 꾸준한 IssuesPull requests가 올라오고 있습니다. 해당 라이브러리를 사용하면서 불편하거나, 개선되어야 할 점을 제시했을 떄, contributor들이 적극적으로 디스커션하는 것을 볼 수 있습니다. 저도 아주 단순한 내용이긴 하지만… 태어나자마자 처음으로 Issues부터 시작해 PR까지 도전했던 라이브러리가 바로 PyTorch Lightning입니다. (해당 라이브러리 예제에 argparser 옵션에 관련된 내용이었습니다.)

Desktop View

알고보니 이전에 저의 내용보다 더 포괄적으로, default argparser를 만드는 것이 좋겠다는 이슈를 제시하신 분이 있었습니다. 이를 여러 개발자 분이 해결해, 해당 클래스에 get_init_arguments_and_types()라는 classmethod로 구현된 것을 확인할 수 있었습니다. (argparser에 대한 기초 개념을 공부해야 하겠다는 생각이 들었습니다ㅠㅠ)

Desktop View





PyTorch Lightning 소개를 마치며

PyTorch Lightning은 2018년도부터 시작해 현재까지 활성화된 PyTorch Project입니다. 현재 lightning과 비슷한 역할을 하는 High-level api로는 keras, Ignite, fast.ai 등이 있습니다. 각각의 장단점을 살펴보는 것도 좋을 것 같네요!

이러한 라이브러리들을 사용하면서, 장단점을 비교할 때 어떤 점을 보면 좋을까 곰곰이 생각해봤는데,
다음과 같은 질문들이 있을 수 있을 것 같아요.

  • 단순히 MNIST를 넘어 빠르게 다양한 예제들을 구성할 수 있는가?
  • 실제로 프로젝트, 논문 등에 구현된 사례들이 많은가?
  • 독자적인 생태계를 구축하는 것이 아니라 flexible한 연동이 가능한가?
  • 추가적인 작업을 하고 싶을 때, 구조적 변경이 자주 일어나는가?

그런 점에서는 만드신 모든 분이 대단하신 것 같아요 :smiley:

Updated Apr 14, 2020 2020-04-14T14:19:20+09:00
This post is licensed under CC BY 4.0