HybrIK 코드를 보다 발견한 torch.einsum() ... 익숙하지 않은 표현이다보니 검색을 통해 사용법을 익혀보았다!! 구글 검색을 해보면 여러 블로그 들에 자세히 설명이 되어 있었으며, 공식 문서를 통해서도 어떤 방식으로 사용하는 지 확인할 수 있었다.
[공식문서] (numpy 혹은 tensorflow에도 einsum이 존재한다)
https://pytorch.org/docs/stable/generated/torch.einsum.html
우선 Einsum은 Einstein Summation Convention의 준말로 특정 Index 집합에 대한 합연산을 간결하게 표시하는 방법을 의미한다고 한다.
사실 이 말만으로는 무슨 말인지 이해가 어려웠는데 선형대수의 행렬 곱을 시그마 형태로 표현하는 예시를 보고 더 이해가 쉽게 되었던 것 같다.
그림에서 만약 내가 $y_2$의 값이 알고 싶다고 하자. 그렇다면 $y_2 = a_{21}x^1 + a_{22}x^2 + a_{23}x^3$의 식이 나올 것이다. 지금은 Matrix의 크기가 작기 때문에 이정도 길이의 결과가 나오지만, Matrix가 조금만 더 커진다면 계산해야 하는 수식의 길이는 훨씬 더 길어질 것이다.
따라서 이를 시그마의 형태로 간단하게 정리한 것이 아인슈타인 표기법이다. 그리고! 이를 코드 상에서 간편히 볼 수 있도록 정리한 것이 einsum 함수이다.
사용법은 torch.einsum(equation, *operands)로 여기서 operand는 피연산자를 뜻한다.
위의 그림처럼 equation string에는 operands의 순서에 맞게 영어 소문자를 입력하고 ,(comma)를 통해 다음 tensor와 구분하여 equation을 완성한다. -> 뒤에는 output인 tensor_4의 차원을 영어 소문자 형태로 적어준다. 이 때, 뒤에 값이 생락된다면 한번씩 나온 영어 소문자들을 순서대로 나열한 값으로 처리한다.
# Example from https://stackoverflow.com/questions/55894693/understanding-pytorch-einsum
# 임의의 값들 지정
vec
'''
tensor([0, 1, 2, 3])
'''
aten
'''
tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
'''
bten
'''
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
'''
# 2개의 Matrix 내적
torch.einsum('ij, jk -> ik', aten, bten)
'''
tensor([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
'''
# 대각행렬 표현
torch.einsum('ii -> i', aten)
'''
tensor([11, 22, 33, 44])
'''
# Hadamard 곱(Element-wise 곱 표현)
torch.einsum('ij, ij -> ij', aten, bten)
'''
tensor([[ 11, 12, 13, 14],
[ 42, 44, 46, 48],
[ 93, 96, 99, 102],
[164, 168, 172, 176]])
'''
...
Clean Code를 짜는데 유용한 방법이니 다음에 꼭 시도해 볼 수 있도록 하자!!
https://theaisummer.com/einsum-attention/
댓글