Pytorch - Dataloader collate_fn
Pytorch for AI
Pytorch ๊ฐ์ข๋ฅผ ์๊ฐํ๋ฉด์ ์ดํด ์๋๋ collate_fn
์ ๋ํด ํฌ์คํ
ํ๋ ค ํ๋ค.
What is Dataloader?
Pytorch ๊ณต์๋ฌธ์
๋ฅผ ํ์ธํด๋ณด์
Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.
์ ์ ์๋ฅผ ํด์ํ์๋ฉด ๋ฐ์ดํฐ์
๊ณผ ์ํ๋ฌ๋ฅผ ํฉ์น๊ณ ์ฃผ์ด์ง ๋ฐ์ดํฐ์
์ ๋ํ ๋ฐ๋ณต ๊ฐ๋ฅํ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค. ๋ํ collation
๊ธฐ๋ฅ๊ณผ memory pinnin
๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค๊ณ ๋์์๋ค. ๊ทธ๋ผ collation
์ ๋ฌด์์ผ๊น?
What is Collation?
๊ณต์๋ฌธ์์ ์ ์์์๋ ์๋์ ๊ฐ์ด ๋์์๋ค.
Merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
map-style
๋ฐ์ดํฐ์
์์ ์ํ ๋ฆฌ์คํธ๋ฅผ ๋ฏธ๋๋ฐฐ์น ๋จ์๋ก ๋ฐ๊พธ๊ธฐ ์ํด ํ์ํ ๊ธฐ๋ฅ์ด๋ค. ๊ทธ๋ผ ์ด๋จ ๋ ์ด๋ค๋๊ฑด๊ฐ,,,?
๋ฐ์ดํฐ ์ฌ์ด์ฆ๋ฅผ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๋งํผ ๋ง์ถ๊ธฐ ์ํด ์ฌ์ฉ
๊ฐ๋จํ ์์๋ฅผ ํตํด ์ตํ๋ณด์
Example using of collate_fn
์๋์ ๊ฐ์ด ์ฃผ์ด์ง ๋ฐ์ดํฐ์
์ด ์๋ค๊ณ ํ์. ํด๋น ๋ฐ์ดํฐ์
์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ 1 ๋ก ์ค์ ํด๋ Dataloader
๋ก ๊ฐ์ ธ์ค๋ฉด ์๋์ ๊ฐ์ด ๋ฌธ์ ์์ด ๊ฐ์ ธ์จ๋ค.
tensor([[0.]])
tensor([[1., 1.]])
tensor([[2., 2., 2.]])
tensor([[3., 3., 3., 3.]])
tensor([[4., 4., 4., 4., 4.]])
tensor([[5., 5., 5., 5., 5., 5.]])
tensor([[6., 6., 6., 6., 6., 6., 6.]])
tensor([[7., 7., 7., 7., 7., 7., 7., 7.]])
tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8.]])
tensor([[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])
ํ์ง๋ง ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ 2 ์ด์์ผ๋ก ํด๋ฒ๋ฆฌ๋ฉด ์๋์ ๊ฐ์ ์ค๋ฅ๊ฐ ๋ฌ๋ค.
RuntimeError: stack expects each tensor to be equal size ~~
๊ฐ๋จํ๊ฒ ๋ณด๋ฉด ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ 2 ๋ก ํ์ ๋ [0, ]
๊ณผ [1., 1.]
์ ๊ธธ์ด๊ฐ ๋ค๋ฅด๋ฏ๋ก ๋ฐฐ์น ์ฌ์ด์ฆ๋ก ๋ฌถ์ ์ ์๋ค๋ ์๋ฌ์ด๋ค. ๊ทธ๋ ๋ค๋ฉด ์์์ ์ธ ํด๊ฒฐ๋ฒ์ผ๋ก ๊ธธ์ด์ ๋ง์ถฐ 0 ์ ์ฑ์๋ฃ๋๋ค๊ณ ๊ฐ์ ํ์. ์ด ํด๊ฒฐ๋ฒ์ ์ฌ์ฉํ๊ธฐ ์ํด collate_fn
์ ์ฌ์ฉํ๋ค.
Solution
ํด๋น ๋ฌธ์ ๋ฅผ ํ๊ธฐ ์ํด ํค์๋๊ฐ ์๋ค๋ฉด batch size
์ ๋ง๊ฒ 0์ผ๋ก ์ฑ์๋ฃ๋๋ค
์ด๋ค. ๊ทธ๋์ ์๊ฐํด๋ณด๋ฉด ๋ฐฐ์น ์ฌ์ด์ฆ๋งํผ ๋ฌถ์ธ ๋ฐ์ดํฐ์
๋ค ์ค์์ ๊ฐ์ฅ ๊ธธ์ด๊ฐ ๊ธด ํ
์์ ๊ธธ์ด๋ฅผ ์์๋ธ ๋ค์ max_len - ๋๋จธ์ง ๊ฐ ํ
์๋ค์ ๊ธธ์ด
๋งํผ 0 ์ผ๋ก ์ฑ์์ฃผ๋ฉด ๋ ๊ฒ ๊ฐ๋ค
๊ทธ๋์ 0 ์ผ๋ก ์ด๋ป๊ฒ ์ฑ์์ฃผ๋ฉด ์ข์๊น ์ฐพ์๋ณด๋ค๊ฐ TORCH.NN.FUNCTIONAL
์ pad
ํจ์๋ฅผ ์ด์ฉํ๋ฉด ์ข์ ๊ฑฐ ๊ฐ์๋ค. ๊ณต์๋ฌธ์
๊ณต์๋ฌธ์๋ฅผ ๋ณด๋ฉด ์ด๋ป๊ฒ ์ฐ๋ฉด ๋์ด์๋์ง ์ ๋์์๋ค. ์ฝ์ผ๋ฉด
to pad only the last dimension of the input tensor, then pad has the form (padding_left,padding_right)
์ด๋ผ๋ ๊ฒ์ ์ฝ์ ์ ์๋ค. ์ฐ๋ฆฌ๋ ๋ง์ง๋ง, ์ฆ ์ค๋ฅธ์ชฝ์๋ง 0์ ์ฑ์ฐ๋ฉด ๋๋ค. ์๋ ์ฝ๋๋ฅผ ๋ณด์.
code
def my_collate_fn(samples):
collate_X = []
max_len = max([len(sample['X']) for sample in samples])
# ๋ฐฐ์น ์ฌ์ด์ฆ๋งํผ ๋ฌถ์ธ ๋ฐ์ดํฐ๋ค ์ค ๊ฐ์ฅ ๊ธด ๋ฐ์ดํฐ์
๊ธธ์ด
for _x in [sample['X'] for sample in samples]:
tensor_len = _x.size(dim=0) # ํ
์ ๊ธธ์ด
p2d = (0, max_len - tensor_len)
# right ์๋ง max_len - tensor_len ๋งํผ 0์ผ๋ก ์ฑ์์ค ๊ฒ์ด๋ค.
_x = F.pad(_x, p2d)
collate_X.append(_x)
return {'X': torch.stack(collate_X)}
pad
ํจ์๋ฅผ ์ฌ์ฉํด ๊ฐ๊ฒฐํ๊ฒ ํด๊ฒฐํ ์ ์์๋ค. 0์ผ๋ก ์ฑ์์ผ ๋ ๊ฐ์๋ฅผ max_len - tensor_len
๋ก ์์๋ด ์๋์ผ๋ก ์ฑ์ธ ์ ์์๋ค.
Result
์์์ ์ง my_collate_fn
๋ฅผ ํ๋ผ๋ฏธํฐ๋ก ๋๊ธฐ๊ณ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ 3 ์ผ๋ก ์ค์ ํ์ ๋ ์๋์ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ์ด ๋์จ๋ค.
tensor([[0., 0., 0.],
[1., 1., 0.],
[2., 2., 2.]])
tensor([[3., 3., 3., 3., 0., 0.],
[4., 4., 4., 4., 4., 0.],
[5., 5., 5., 5., 5., 5.]])
tensor([[6., 6., 6., 6., 6., 6., 6., 0., 0.],
[7., 7., 7., 7., 7., 7., 7., 7., 0.],
[8., 8., 8., 8., 8., 8., 8., 8., 8.]])
tensor([[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])
์ํ๋๋๋ก ๋ฐฐ์น ์ฌ์ด์ฆ 3 ๋งํผ ๋ฐ์ดํฐ๋ค์ ๋ฌถ์ ๋ชจ์ต์ ๋ณด์ด๊ณ ๋ง์ง๋ง ํ
์๋ง ๋จ์์๋ ๊ฒฐ๊ณผ๊ฐ์ด ๋ณด์ธ๋ค. ์ด๋ ๊ฒ ๋ง์ง๋ง ๋ฐฐ์น๋ง ๋ฉ๊ทธ๋ฌ๋ ๋จ์ ๊ฒฝ์ฐ๊ฐ ์๋ค. ๋ฐ๋ผ์ ์ด๋ด ๋๋ Dataloader
์ drop_last
์ต์
์ True
๋ก ์ค์ ํด๋๋ฉด ํด๊ฒฐํ ์ ์๋ค.
What I learned?
์ฌ์ค ๊ฐ์๋ฅผ ๋ค์ ๋ ๋ง์ ๋ถ๋ถ๋ค์ ์ดํด๋ฅผ ๋ชปํ๋ค. ์ธ๊ณต์ง๋ฅ ๋ถ์ผ ๊ณต๋ถ๊ฐ ์ฒ์์ด๋ผ ์กฐ๊ธ ๋ฒ
์ฐผ๋๋ฐ ๋ฐ๋ฐ๋ฅ๋ถํฐ ์์ํ๋ ๋ฅ๋ฌ๋
์ฑ
์ ์ฝ์ด๊ฐ๋ฉด์ ์ดํดํ๊ณ ๊ฐ์๋ฅผ ๋ง์ด ๋๋ ค๋ณด๋ ์ดํด๋๋ ๋ถ๋ถ๋ค์ด ๋ง์๋ค. ๊ธฐ๋ณธ๊ธฐ๋ฅผ ์ํํ ํ์ง ๋ง์!