우선, 해당 포스트는 Stanford University School of Engineering의 CS231n 강의자료와 모두를 위한 딥러닝 시즌2의 자료를 기본으로 하여 정리한 내용임을 밝힙니다.
Reshape
numpy에서는 .reshape()함수를 사용하여 array의 shape를 다시 설정해주었지만, (Py)Torch에서는 .view()라는 함수를 사용한다.
import numpy as np
import torch
### view
t = np.array([[[0, 1, 2],
[3, 4, 5]],
[[6, 7, 8],
[9, 10, 11]]])
ft = torch.FloatTensor(t)
print(ft.shape)
view를 살펴보기 위하여 우선 위와 같이 size torch.Size([2, 2, 3]) 의 ft Tensor를 만들어 주었다.
print(ft.view([-1, 3]))
print(ft.view([-1, 3]).shape)
이 때 ft를 ft.view([-1, 3]) 을 통하여 4 by 3의 Tensor로 reshaping을 해줄 수 있다. shape을 다시 정의할 때 -1 을 사용하는 것은 몇으로 정의할지 모르겠다 는 뜻이다. 즉 O by O 형태의 2차원 Tensor를 만들 것인데 O by 3 인 것은 알겠으나, 앞의 숫자는 확실히 모르겠을 때 -1을 사용해주면 된다.
이렇게 간단한 예제에서는 [4, 3]을 명시해주면 되는데 왜 굳이 -1을 쓰는가에 대한 의문이 생길 수 있다. 하지만 computer vision이나, nlp 모델을 다루다 보면 상당히 복잡한 shape의 tensor를 다루게 되는데, 이 때 shape을 계산하는 것이 만만치 않다. 그럴 때 -1을 사용해주면 pytorch가 쉽게 계산을 해주기 때문에 매우 유용하다.
print(ft.view([-1, 1, 3]))
print(ft.view([-1, 1, 3]).shape)
이 경우, torch.Size([4, 1, 3]) 의 Tensor로 ft Tensor를 reshaping해준다.
Squeeze
ft = torch.FloatTensor([[0], [1], [2]])
print(ft)
print(ft.shape)
print(ft.squeeze())
print(ft.squeeze().shape)
print(ft.squeeze(dim=0).shape)
print(ft.squeeze(dim=1).shape)
우선 torch.Size([3, 1]) size의 ft Tensor를 만들었다.
ft.squeeze().shape는 torch.Size([3]) 를 return하는데, ft Tensor의 shape에서 1이 들어 있는 부분을 없애 차원을 축소시키라는 의미가 된다.
ft.squeeze(dim=1)과 같이 축소시킬 dimension을 정해줄 수도 있는데, 위의 경우 1번째 dimension이 1이므로 ft.squeeze(dim=1).shape의 결과는 변한 바가 없이 torch.Size([3]) 를 return한다.
ft.squeeze(dim=0)과 같이 축소시킬 dimension을 정해줄 수도 있는데, 위의 경우 0번째 dimension이 3이므로 ft.squeeze(dim=0).shape의 결과는 변한 바가 없이 torch.Size([3, 1]) 를 return하게 된다.
Unsqueeze
Unsqueeze는 shape에서 1을 없애는 Squeeze과는 반대로, 하나의 dimension을 1로 더 추가해주는 함수이다. 예시를 통해 살펴보는 것이 더 이해가 쉽다.
### Unsqueeze
ft = torch.Tensor([0, 1, 2])
print(ft.shape)
print(ft.unsqueeze(dim=0))
print(ft.unsqueeze(dim=0).shape)
우선, torch.Size([3]) 의 ft Tensor를 만든다. 만약 첫 번째 dimension(dim=0)에 1의 dimension을 끼워넣고 싶다면 ft.unsqueeze(dim=0)를 사용하면 된다.
결과적으로 ft.unsqueeze(dim=0).shape는 torch.Size([1, 3]) 를 return하게 될 것이다.
print(ft.view(1, -1))
print(ft.view(1, -1).shape)
print(ft.unsqueeze(dim=0))
print(ft.unsqueeze(dim=0).shape)
print(ft.unsqueeze(dim=-1))
print(ft.unsqueeze(dim=-1).shape)
또 다른 예시를 살펴보자. 먼저, 앞서 배운 view 함수를 통하여 torch.Size([3]) 의 tensor를 torch.Size([1, 3]) 으로 reshape해 줄 수 있었다.
만약 unsqueeze 함수로 똑같은 작업을 하고 싶다면, unsqueeze(dim=0)을 사용하여 첫 번째 dimension에 1을 추가함으로써 torch.Size([1, 3]) 의 tensor로 reshape해 줄 수 있다.
한편, unsqueeze(dim=-1)을 사용한다면 마지막 dimension(-1번째 dimension)에 1을 추가함으로써 torch.Size([3, 1]) 의 tensor로 reshape하여 줄 수 있다. 즉, 이 경우 unsqueeze(dim=1)과 같은 결과가 나타날 것이다.
Reference
CS231n, Stanford University School of Engineering