Pytorch-TPU-WandB

Pytorch의 TPU 학습 방법, WandB의 로깅 방법

2022-06-04 03:43

3,819 views

안녕하세요

이번에 소개드릴 내용은 Pytorch의 TPU 학습을 하는 방법과 함께, 실험 관리 및 아티펙트 버전관리 기능을 제공하는 Weights and Biases (a.k.a wandb)의 로깅 방법에 대해 소개드리려 합니다.

Index

  1. TPU
  2. Weights and Biases
  3. Pytorch + XLA with Google Colaboratory

TPU

Google-TPU-v4-OK-datacenter-May-2021_534x

TPU는 구글에서 개발된 “Tensor Processing Units”을 의미합니다.

딥러닝 모델을 학습시키는 방법으로는, CPU를 사용한 방법과, 하드웨어 가속을 사용하는 GPU연산을 보편적으로 많이 사용합니다. 또한 거대한 모델을 학습시킬때에는 다수의 GPU를 사용하여 학습을 진행합니다. 그러나 프로토타입, POC, 또는 Kaggle과 같은 경진대회에서는 1장의 GPU를 가지고 있는 인스턴스의 제공이 일반적입니다. 반면에 TPU 인스턴스를 할당 받았을때에는 8-Cores의 TPU를 할당 받을수 있습니다.

tpu_vs_gpu


TPU를 사용하는 방법은 클라우드 TPU를 사용하는게 일반적이지만, 이번 블로그는 Google Colab의 TPU 인스턴스를 기준으로 소개드리려 합니다.

Google Colab

Weights and Biases

wandb

wandb

wandb는 머신러닝 실험관리, 데이터 및 아티펙트 버전관리, 그리고 리포팅을 도와주는 플랫폼입니다.

아래 예시는 저희 연구팀에서의 실험관리 대시보드입니다.

pz-nlp-wandb

다양한 실험관리 플랫폼이 있지만, 저희 연구팀에서 wandb를 사용하는 이유는 저장된 모델의 버전관리와 데이터 버전관리가 용이하여 사용하고 있습니다

Pytorch + XLA with Google Colaboratory

딥러닝 모델의 프로토타입을 개발 할 때, 고비용의 Cloud TPU는 비용적인 측면에서 부담이 될 수 있습니다. 가설 검증을 위한 POC단계에서 빠르게 프로토타이핑을 하고 가능 여부를 확인하는데에는 딥러닝 모델의 학습 시간 또한 중요합니다. 저희 연구팀은 간단한 프로토타이핑은 GCS와 Google Colaboratory (a.k.a colab)을 사용하여 딥러닝 모델의 학습을 진행합니다. 물론, GPU 인스턴스도 함께 사용합니다.

이번에 소개드릴 내용은 “ Google colab의 TPU인스턴스에서 모든 TPU코어를 사용하여 Pytorch 모델의 학습을 진행하면서 wandb에 로깅 ” 을 하는 방법을 공유드리려 합니다.

Colab의 런타임을 TPU로 설정해주셔야 학습까지 진행이 가능합니다.

colab-tpu

Tensorflow 프레임워크를 사용할 때에는 별도의 추가적인 작업 없이 TPU학습이 가능합니다. 그러나 Pytorch를 사용하실 때에는 추가적인 사전 작업이 필요합니다.

바로 Pytorch/xla 설치가 필요합니다.

torch-xla

설치가 완료되면 소스코드를 작성합니다.

중요한점은 8-Core를 모두 사용하는 학습은 분산학습 방법중, DDP Strategy를 사용함으로 로직이 실행되는 함수안에 데이터 로더의 생성과, 모델 생성하는 로직이 함께 포함 되어야 합니다.

import torch
import torchvision
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl


def _run(index, args):
   …

_run 이라는 이름의 함수를 정의하고, 학습 관련 코드를 추가하도록 하겠습니다.

함수의 아규먼트로 index와 args가 있는데 여기서 index는 TPU코어의 인덱스를 의미합니다. args는 추후 함수를 실행 할 시 주어질 파라미터입니다.

예시를 위해 데이터셋은 Torchvision에서 제공하는 데이터셋을 사용해보겠습니다.

train_dataset = torchvision.datasets.FashionMNIST( … )
test_dataset = torchvision.datasets.FashionMNIST( … )

데이터 다운로드가 완료되면 _run 함수에 pytorch학습을 위한 device설정과 함께 데이터 로더를 정의합니다.

중요한점은 데이터로더를 생성할때, Sampler를 정의하여 주어야 합니다.

이 Sampler는 TPU코어에 데이터를 샘플링하는 역활을 합니다.

def _run(index, args):

	device = xm.xla_device()

	train_sampler = torch.utils.data.distributed.DistributedSampler(
					train_dataset,
					num_replicas = xm.xrt_world_size(),
					rank = xm.get_ordinal(),
					shuffle = True
                                        )
	
	train_dataloader = torch.utils.data.DataLoader(
					dataset = train_dataset,
					sampler = train_sampler,
					batch_size = 128,
					drop_last = True
					)

	…

동일한 방법으로 validation데이터 로더도 함께 만들어 주시면 됩니다.

학습을 위한 모델과 loss object, optimizer를 정의합니다.

mx = torchvision.models.alexnet(num_classes=10)

def _run(index, args):

	device = xm.xla_device()

	train_sampler = torch.utils.data.distributed.DistributedSampler(
					train_dataset,
					num_replicas = xm.xrt_world_size(),
					rank = xm.get_ordinal(),
					shuffle = True
                                        )
	
	train_dataloader = torch.utils.data.DataLoader(
					dataset = train_dataset,
					sampler = train_sampler,
					batch_size = 128,
					drop_last = True
					)

	model = mx.to(device)
	loss_fn = torch.nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters())

	…

모델 학습을 위한 Train 로직을 정의합니다.

데이터로더를 직접 주는것이 아닌 xla의 데이터 로더를 같이 정의합니다.

Pytorch로 학습할때와 다르게 optimizer의 스텝은 xla를 통하여 진행합니다.

mx = torchvision.models.alexnet(num_classes=10)

def _run(index, args):

	device = xm.xla_device()

	train_sampler = torch.utils.data.distributed.DistributedSampler(
					train_dataset,
					num_replicas = xm.xrt_world_size(),
					rank = xm.get_ordinal(),
					shuffle = True
                                        )
	
	train_dataloader = torch.utils.data.DataLoader(
					dataset = train_dataset,
					sampler = train_sampler,
					batch_size = 128,
					drop_last = True
					)

	model = mx.to(device)
	loss_fn = torch.nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters())
	
	for epoch in range(epochs):
		para_train_loader = pl.ParallelLoader(
                               train_loader, [device]
                                ).per_device_loader(device)
		
		
		for batch in para_train_loader:
			model.train()
			data, label = batch
			data = data.to(device)
			label = label.to(device)
			
			optimizer.zero_grad()

			pred = model(data)
			loss = loss_fn(pred, label)
			
			loss.backward()

xm.optimizer_step(optimizer)	
							
…

비슷하게 Validation 스텝을 추가해 줍니다.

mx= torchvision.models.alexnet(num_classes=10)

def _run(index, args):

	device = xm.xla_device()

	train_sampler = torch.utils.data.distributed.DistributedSampler(
					train_dataset,
					num_replicas = xm.xrt_world_size(),
					rank = xm.get_ordinal(),
					shuffle = True
                                        )
	
	train_dataloader = torch.utils.data.DataLoader(
					dataset = train_dataset,
					sampler = train_sampler,
					batch_size = args[“batch_size”],
					drop_last = True
					)

	model =mx.to(device)
	loss_fn = torch.nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters())
	
	for epoch in range(epochs):
		para_train_loader = pl.ParallelLoader(
                                train_loader, [device]
                                ).per_device_loader(device)
		
		para_valid_loader =  …

		for batch in para_train_loader:
			model.train()
			data, label = batch
			data = data.to(device)
			label = label.to(device)
			
			optimizer.zero_grad()

			pred = model(data)
			loss = loss_fn(pred, label)
			
			loss.backward()

xm.optimizer_step(optimizer)			
			
		
for val_batch in para_valid_loader:
	model.eval()
        with torch.no_grad():
	    val_pred = …
	    val_loss = …
        …						
…

이렇게 학습을 위한 _run 함수를 정의가 되었습니다.

pytorch/xla를 통해 학습을 하기 위해서는 spawn을 통해 학습을 진행함으로 코드를 py 파일화와 함께 main구문을 작성해 줍니다.

colab에서는 cell의 내용을 py파일로 작성하기위해 magic command를 사용해서 쉽게 작성할수 있습니다

%%writefile pytorch_tpu_train.py

import …

def _run(index, args):
	…



if __name __ == “__main__”:
	FLAGS = {“batch_size” : 128}
	xmp.spawn(_run, args =(FLAGS, ), nprocs=8, start_method = ‘fork’)

이렇게 main 구문으로 학습을 실행하는 코드를 정의합니다.

pytorch/xla를 사용하여 TPU의 8 core를 사용하는 로직이 완성되었습니다.

이제 wandb를 통한 실험관리를 위해 정의된 _run함수에 wandb와 관련된 로깅 코드를 추가해 보도록 하겠습니다.

Before Wandb Logging Code Snippet

pytorch_tpu_train.py

import torch
import torchvision
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl


train_transform = torchvision.transforms.Compose(
				[ … ]
				)

valid_transform = torchvision.transforms.Compose(
				[ … ]
				)

train_dataset = torchvision.datasets.FashionMNIST( … )
test_dataset = torchvision.datasets.FashionMNIST( … )

mx = torchvision.models.alexnet(num_classes=10)

def _run(index, args):

	device = xm.xla_device()

	train_sampler = torch.utils.data.distributed.DistributedSampler(
					train_dataset,
					num_replicas = xm.xrt_world_size(),
					rank = xm.get_ordinal(),
					shuffle = True
                                        )
	
	train_dataloader = torch.utils.data.DataLoader(
					dataset = train_dataset,
					sampler = train_sampler,
					batch_size = args[“batch_size”],
					drop_last = True
					)

	model = mx.to(device)
	loss_fn = torch.nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters())
	
	for epoch in range(epochs):
		para_train_loader = pl.ParallelLoader(
                              train_loader, [device]
                              ).per_device_loader(device)
		
		para_valid_loader =  …

		for batch in para_train_loader:
			model.train()
			data, label = batch
			data = data.to(device)
			label = label.to(device)
			
			optimizer.zero_grad()

			pred = model(data)
			loss = loss_fn(pred, label)
			
			loss.backward()

xm.optimizer_step(optimizer)			
			
		
for val_batch in para_valid_loader:
	model.eval()
        with torch.no_grad():
	    val_pred = …
	    val_loss = …
            …	

					
if __name __ == “__main__”:
	FLAGS = {“batch_size” : 128}
	xmp.spawn(_run, args =(FLAGS, ), nprocs=8, start_method = ‘fork’)

Wandb Logging을 추가해 보도록 하겠습니다.

After Wandb Logging Code Snippet

pytorch_tpu_train.py

import torch
import torchvision
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.serialization as xser

import wandb

train_transform = torchvision.transforms.Compose(
				[ … ]
				)

valid_transform = torchvision.transforms.Compose(
				[ … ]
				)

train_dataset = torchvision.datasets.FashionMNIST( … )
test_dataset = torchvision.datasets.FashionMNIST( … )

mx = torchvision.models.alexnet(num_classes=10)

def _run(index, args):
	wandb.login(key=args["wandb_key"])
	run = wandb.init(
		project = "pytorch-xla-fashionmnist",
		group = "xla-colab-tpu")
	device = xm.xla_device()

	train_sampler = torch.utils.data.distributed.DistributedSampler(
					train_dataset,
					num_replicas = xm.xrt_world_size(),
					rank = xm.get_ordinal(),
					shuffle = True
                    )
	
	train_dataloader = torch.utils.data.DataLoader(
					dataset = train_dataset,
					sampler = train_sampler,
					batch_size = args["batch_size"],
					drop_last = True
					)

	model = mx.to(device)
	loss_fn = torch.nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters())
	
	for epoch in range(epochs):
		para_train_loader = pl.ParallelLoader(
            train_loader, [device]
            ).per_device_loader(device)
		
		para_valid_loader =  …
		
		train_cum_loss = 0
		valid_cum_los s= 0
		for n, batch in enumerate(para_train_loader):
			model.train()
			data, label = batch
			data = data.to(device)
			label = label.to(device)
			
			optimizer.zero_grad()

			pred = model(data)
			loss = loss_fn(pred, label)
			
			loss.backward()

            xm.optimizer_step(optimizer)	

            train_cum_loss += loss.item()
            if n % 100 == 0:
                        run.log(
                            {"epoch": epoch,
                            "train_loss": loss}
                    
        for val_batch in para_valid_loader:
                model.eval()
                with torch.no_grad():
                    val_pred = …
                    val_loss = …
                    valid_cum_loss += val_loss
                    …
		run.log(
        {"epoch": epoch,
        "valid_loss": valid_cum_loss / len(valid_dataloader)
        }
		
		xser.save(model.state_dict(), f"{index}_model.pt")
		
		run.log_artifact(f"{index}_model.pt", 
				name = f"{epoch}_model",
				type="model")

					
if __name __ == "__main__":
	FLAGS = {"batch_size" : 128,
                       "wandb_key": <PERSONAL-KEY>}

	xmp.spawn(_run, args =(FLAGS, ), nprocs=8, start_method = ‘fork’)
	
	wandb.finish()

로깅을 할 수치에 대해서 wandb.log를 이용하여 로깅을 진행합니다.

TPU의 8 core를 사용하기때문에 wandb또한 8개의 run이 실행됩니다. 여러개의 wandb run을 관리하기 위해 동일한 group으로 묶어서 표시하도록 wandb.init에 group을 지정해줍니다

Full Source Code

Pytorch / XLA와 wandb를 설치합니다.

!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
 
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev
 
!pip install -qqq wandb

학습 로직이 정의된 파이썬 스크립트를 생성합니다.

%%writefile pytorch_tpu_train.py

import torch
import torchvision
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torchmetrics

import wandb

train_transform = torchvision.transforms.Compose([
   torchvision.transforms.Resize((128, 128)),
   torchvision.transforms.Lambda(lambda x: x.convert('RGB')),
   torchvision.transforms.ToTensor(),
   torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225]),
   ])
 
valid_transform = torchvision.transforms.Compose([                           
   torchvision.transforms.Resize((128, 128)),
   torchvision.transforms.Lambda(lambda x: x.convert('RGB')),
   torchvision.transforms.ToTensor(),
   torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225]),
   ])
 
 
train_dataset = torchvision.datasets.FashionMNIST('.',
   train=True,
   download=True,
   transform=train_transform)
 
valid_dataset = torchvision.datasets.FashionMNIST('.',
    train=False,
    download=True,
    transform=valid_transform)
 
mx = torchvision.models.resnet50(num_classes=10)
 
 
def _run(index, args):
   wandb.login(key=args['wandb_key'])
  
   run = wandb.init(project='pytorch-xla-fashionmnist',
                    group='xla-colab-tpu')
  
   device = xm.xla_device()
 
   train_sampler = torch.utils.data.distributed.DistributedSampler(
                       train_dataset,
                       num_replicas = xm.xrt_world_size(),
                       rank = xm.get_ordinal(),
                       shuffle = True
   )
      
   train_dataloader = torch.utils.data.DataLoader(
                   dataset = train_dataset,
                   sampler = train_sampler,
                   batch_size = 128,
                   drop_last = True
                   )
  
   valid_sampler = torch.utils.data.distributed.DistributedSampler(
                       valid_dataset,
                       num_replicas = xm.xrt_world_size(),
                       rank = xm.get_ordinal(),
                       shuffle = True
   )
      
   valid_dataloader = torch.utils.data.DataLoader(
                   dataset = valid_dataset,
                   sampler = valid_sampler,
                   batch_size = 128,
                   drop_last = True
                   )
  
   model = mx.to(device)
   learning_rate = 0.001 * xm.xrt_world_size()
   loss_fn = torch.nn.CrossEntropyLoss()
   optimizer = torch.optim.Adam(model.parameters(), 
                          lr=learning_rate)
 
 
   for epoch in range(100):
       para_train_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(device)
       para_valid_loader = pl.ParallelLoader(valid_dataloader, [device]).per_device_loader(device)
 
       train_cum_loss = 0
       valid_cum_loss = 0     
       train_cum_acc = 0.
       valid_cum_acc = 0.
 
       for n, batch in enumerate(para_train_loader):
           model.train()
           data, label = batch
           data = data.to(device)
           label = label.to(device)
 
           optimizer.zero_grad()
           pred = model(data)
           loss = loss_fn(pred, label)
           loss.backward()
 
           xm.optimizer_step(optimizer)
 
           train_cum_loss += loss.item()
 
           pred_argmax = torch.argmax(pred, dim=-1)
           batch_acc = torch.sum(pred_argmax == label).float() / float(label.size(0))
           train_cum_acc += batch_acc
 
           if n % 10 == 0:
               print(f"XLA: [{index}]: Train Loss: {loss.item()} Train Accuracy: {batch_acc.cpu()}")
               run.log({'epoch': epoch,
                        'train_loss': loss.item(),
                        'train_accuracy': batch_acc.cpu()})
                      
       xm.master_print(f"{epoch} |Training Done")
 
       for n, val_batch in enumerate(para_valid_loader):
           model.eval()
           val_data, val_label = val_batch
           val_data = val_data.to(device)
           val_label = val_label.to(device)
 
           with torch.no_grad():
               val_pred = model(val_data)
               val_loss = loss_fn(val_pred, val_label)
               valid_cum_loss += val_loss.item()
              
               val_pred_argmax = torch.argmax(val_pred, dim=-1)
               val_batch_acc = torch.sum(val_pred_argmax == val_label).float() / float(val_label.size(0))
               valid_cum_acc += val_batch_acc
 
      
       run.log({"epoch":epoch,
                "valid_loss": valid_cum_loss / len(valid_dataloader),
                'valid_accuracy': valid_cum_acc / n })
    print(f"XLA: [{index}]: Valid Loss: {val_loss.item()} Valid Accuracy: {valid_cum_acc.cpu()}")
 
 
      
xm.master_print(f"{epoch} |Validation Done")
 
 
if __name__ == "__main__":
   FLAGS = {'wandb_key': "<YOUR-WANDB-KEY>"}
   xmp.spawn(_run, args=(FLAGS, ), nprocs=8, start_method='fork')
   wandb.finish()

학습 시작을 위해 파이썬 코드를 실행합니다

!python pytorch_tpu_train.py

그룹으로 묶여있는 wandb런을 보여줍니다.

wandb-fashion

왼쪽의 보라색으로 표시된 group보기를 제거 하면 각 wandb run에 대한 그래프를 확인 할 수 있습니다.

wandb_no_group

학습이 진행됨에 따라 colab의 표시 화면 입니다.

XLA: index는 각 core에 할당된 모델의 학습 과정입니다.

소스코드에서의 xm.master_print를 통해 각 core에서 출력되는것을 방지 할수 있습니다.

colab_output

Conclude

TPU를 통한 학습은 GPU를 사용할때보다 빠른 속도로 학습을 가능하도록 도와줍니다. 실험 관리를 위한 wandb를 사용함으로 학습 과정과 데이터, 모델을 아티펙트화 하고 저장하면서 이전 MLOPs 시리즈에서 소개드렸던 아티팩트 버저닝을 손쉽게 할수 있는 장점이 있습니다. Google Colab 을 사용하면서 GPU와 TPU를 함께 사용하면 다양한 실험을 시도해 볼 수 있습니다. 추가적으로 현재 Colab은 TPU-V2를 제공하고 있으며, Kaggle 노트북은 TPU-V3를 제공합니다.

오늘 소개드린 내용과 관련된 추가자료는 아래 링크에서 확인하실수 있습니다.

TPU

Weights and Biases

Pytorch/XLA - TPU Training

wandb Distributed Training



다음글 - Airflow와 Git, DVC를 활용하여 배치 서비스 운영 및 데이터 버전 컨트롤
이전글 - 플러스제로 MLOPs Series 3/3