본문 바로가기
개발 관련 지식/PyTorch

[PyTorch] torch.einsum() 사용하기

by rahites 2024. 2. 5.

 

HybrIK 코드를 보다 발견한 torch.einsum() ... 익숙하지 않은 표현이다보니 검색을 통해 사용법을 익혀보았다!! 구글 검색을 해보면 여러 블로그 들에 자세히 설명이 되어 있었으며, 공식 문서를 통해서도 어떤 방식으로 사용하는 지 확인할 수 있었다.

 

[공식문서] (numpy 혹은 tensorflow에도 einsum이 존재한다)

https://pytorch.org/docs/stable/generated/torch.einsum.html

 

torch.einsum — PyTorch 2.2 documentation

Shortcuts

pytorch.org


우선 EinsumEinstein Summation Convention의 준말로 특정 Index 집합에 대한 합연산을 간결하게 표시하는 방법을 의미한다고 한다. 

 

사실 이 말만으로는 무슨 말인지 이해가 어려웠는데 선형대수의 행렬 곱을 시그마 형태로 표현하는 예시를 보고 더 이해가 쉽게 되었던 것 같다.

출처 : 나무위키(아인슈타인 표기법)

그림에서 만약 내가 $y_2$의 값이 알고 싶다고 하자. 그렇다면 $y_2 = a_{21}x^1 + a_{22}x^2 + a_{23}x^3$의 식이 나올 것이다. 지금은 Matrix의 크기가 작기 때문에 이정도 길이의 결과가 나오지만, Matrix가 조금만 더 커진다면 계산해야 하는 수식의 길이는 훨씬 더 길어질 것이다.

 

따라서 이를 시그마의 형태로 간단하게 정리한 것이 아인슈타인 표기법이다. 그리고! 이를 코드 상에서 간편히 볼 수 있도록 정리한 것이 einsum 함수이다.

 

사용법은 torch.einsum(equation, *operands)로 여기서 operand는 피연산자를 뜻한다.

위의 그림처럼 equation string에는 operands의 순서에 맞게 영어 소문자를 입력하고 ,(comma)를 통해 다음 tensor와 구분하여 equation을 완성한다. -> 뒤에는 output인 tensor_4의 차원을 영어 소문자 형태로 적어준다. 이 때, 뒤에 값이 생락된다면 한번씩 나온 영어 소문자들을 순서대로 나열한 값으로 처리한다.

 

# Example from https://stackoverflow.com/questions/55894693/understanding-pytorch-einsum

# 임의의 값들 지정
vec
'''
tensor([0, 1, 2, 3])
'''
aten
'''
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])
'''
bten
'''
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])
'''
# 2개의 Matrix 내적
torch.einsum('ij, jk -> ik', aten, bten)
'''
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])
'''
# 대각행렬 표현
torch.einsum('ii -> i', aten)
'''
tensor([11, 22, 33, 44])
'''
# Hadamard 곱(Element-wise 곱 표현)
torch.einsum('ij, ij -> ij', aten, bten)
'''
tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])
'''

...

 

 

Clean Code를 짜는데 유용한 방법이니 다음에 꼭 시도해 볼 수 있도록 하자!!

 

https://theaisummer.com/einsum-attention/

 

Understanding einsum for Deep learning: implement a transformer with multi-head self-attention from scratch | AI Summer

Learn about the einsum notation and einops by coding a custom multi-head self-attention unit and a transformer block

theaisummer.com

 

댓글