Trident: 딥러닝 모델 개발을 위한 도구

안녕하세요, 카카오브레인 ML Optimization팀의 리더 danny.jang(장대명)입니다. 카카오 기술블로그를 통해 ML Optimization팀이 개발하고 있는 Trident에 대해서 소개와 설명을 드리고자 합니다.

Trident란?

Trident는 딥러닝 모델의 훈련과 추론 속도를 향상할 수 있는 성능 라이브러리입니다. 기본적으로 Trident는 고도화된 커널과 다양한 함수 그리고 모듈들을 제공합니다. 또한, OpenAI Triton을 기반으로 작성되었기 때문에, PyTorch와 동일한 추상화 계층을 제공합니다.

이제 본격적으로 Trident의 사용법을 살펴보면서, 위에서 언급한 ‘추상화 계층이 동일’하다는 의미를 자세히 알아보겠습니다. PyTorch로 예시 모델을 작성하면 아래와 같이 구현할 수 있습니다.

				
					class Net(nn.Module):
   def __init__(self):
       super().__init__()
       self.rnn = nn.LSTM(input_size=28, hidden_size=64, batch_first=True)
       self.norm = nn.InstanceNorm1d(64)
       self.dropout1 = nn.Dropout(0.25)
       self.dropout2 = nn.Dropout(0.5)
       self.fc1 = nn.Linear(64, 32)
       self.fc2 = nn.Linear(32, 10)


				
			

이때, Trident를 사용하기 위해서는, 아래와 같이 nn 객체를 trident로 변경하기만 하면 됩니다.

				
					import trident

class Net(nn.Module):
   def __init__(self):
       super().__init__()
       self.rnn = nn.LSTM(input_size=28, hidden_size=64, batch_first=True)
       self.norm = trident.InstanceNorm1d(64)
       self.dropout1 = trident.Dropout(0.25)
       self.dropout2 = nn.Dropout(0.5)
       self.fc1 = trident.Linear(64, 32)
       self.fc2 = nn.Linear(32, 10)
				
			

이와 같이, Trident는 PyTorch와 동일한 추상화 계층을 제공하기 때문에, 모델에 적용하고자 한다면 관련 코드를 손쉽게 교체할 수 있습니다. 또한, 기존의 PyTorch와 혼합해서 사용할 수 있는 부분도 큰 장점입니다.

이제 Trident의 속도를 살펴보겠습니다. 아래의 그래프들에서 파란색 선은 PyTorch, 주황색 선은 Trident를 의미합니다. 그리고 X축은 텐서의 크기, Y축은 연산에 소요된 총시간을 의미합니다. 따라서, 그래프의 특정 지점에서 Y 값이 낮을수록, 속도가 더 좋다는 것을 의미합니다.

Linear 속도
RMSNorm 속도
ShiftGELU 성능

위 그래프들에서 확인할 수 있듯, PyTorch보다 속도가 빠른 Trident를 모델에 적용하면 동일한 하드웨어 기준 약 15% 수준으로 훈련 및 추론 속도가 향상될 수 있습니다.

마치며

Trident는 카카오브레인의 Foundation 모델을 효율적이고 빠르게 개발하기 위해서 시작되었습니다. Trident를 오픈소스로 공개한 이유는 카카오브레인이 추구하는 핵심 가치인 공유와 협력 때문입니다. Trident는 기술의 장벽을 무너뜨리고 AI의 미래를 더욱 밝게 만들기 위한 카카오브레인의 노력의 일환입니다. 모든 사람들이 AI의 무한한 가능성을 빠르게 탐색할 수 있도록, 자주 사용되는 연산들의 조합을 하나의 커널로 구현하여 Trident를 지속해서 발전시킬 예정입니다.

카카오톡 공유 보내기 버튼

Latest Posts