Pytorch - Dataloader collate_fn

Pytorch - Dataloader collate_fn

Pytorch for AI

ํ•ด๋‹น ์ธ๋„ค์ผ์€ Wonkook Lee ๋‹˜์ด ๋งŒ๋“œ์‹  Thumbnail-Maker ๋ฅผ ์ด์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค

Pytorch ๊ฐ•์ขŒ๋ฅผ ์ˆ˜๊ฐ•ํ•˜๋ฉด์„œ ์ดํ•ด ์•ˆ๋๋˜ collate_fn ์— ๋Œ€ํ•ด ํฌ์ŠคํŒ…ํ•˜๋ ค ํ•œ๋‹ค.

What is Dataloader?

Pytorch ๊ณต์‹๋ฌธ์„œ ๋ฅผ ํ™•์ธํ•ด๋ณด์ž

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

์œ„ ์ •์˜๋ฅผ ํ•ด์„ํ•˜์ž๋ฉด ๋ฐ์ดํ„ฐ์…‹๊ณผ ์ƒ˜ํ”Œ๋Ÿฌ๋ฅผ ํ•ฉ์น˜๊ณ  ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•œ ๋ฐ˜๋ณต ๊ฐ€๋Šฅํ•œ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•œ๋‹ค. ๋˜ํ•œ collation ๊ธฐ๋Šฅ๊ณผ memory pinnin ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•œ๋‹ค๊ณ  ๋‚˜์™€์žˆ๋‹ค. ๊ทธ๋Ÿผ collation ์€ ๋ฌด์—‡์ผ๊นŒ?

What is Collation?

๊ณต์‹๋ฌธ์„œ์˜ ์ •์˜์—์„œ๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด ๋‚˜์™€์žˆ๋‹ค.

Merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

map-style ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ƒ˜ํ”Œ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ๋ฐ”๊พธ๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•œ ๊ธฐ๋Šฅ์ด๋‹ค. ๊ทธ๋Ÿผ ์–ด๋–จ ๋•Œ ์“ด๋‹ค๋Š”๊ฑด๊ฐ€,,,?

๋ฐ์ดํ„ฐ ์‚ฌ์ด์ฆˆ๋ฅผ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋งŒํผ ๋งž์ถ”๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ

๊ฐ„๋‹จํ•œ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์ตํ˜€๋ณด์ž

Example using of collate_fn

์•„๋ž˜์™€ ๊ฐ™์ด ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ์…‹์ด ์žˆ๋‹ค๊ณ  ํ•˜์ž. ํ•ด๋‹น ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ 1 ๋กœ ์„ค์ •ํ•ด๋‘” Dataloader ๋กœ ๊ฐ€์ ธ์˜ค๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด ๋ฌธ์ œ์—†์ด ๊ฐ€์ ธ์˜จ๋‹ค.

tensor([[0.]])
tensor([[1., 1.]])
tensor([[2., 2., 2.]])
tensor([[3., 3., 3., 3.]])
tensor([[4., 4., 4., 4., 4.]])
tensor([[5., 5., 5., 5., 5., 5.]])
tensor([[6., 6., 6., 6., 6., 6., 6.]])
tensor([[7., 7., 7., 7., 7., 7., 7., 7.]])
tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8.]])
tensor([[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])

ํ•˜์ง€๋งŒ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ 2 ์ด์ƒ์œผ๋กœ ํ•ด๋ฒ„๋ฆฌ๋ฉด ์•„๋ž˜์™€ ๊ฐ™์€ ์˜ค๋ฅ˜๊ฐ€ ๋œฌ๋‹ค.

RuntimeError: stack expects each tensor to be equal size ~~

๊ฐ„๋‹จํ•˜๊ฒŒ ๋ณด๋ฉด ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ 2 ๋กœ ํ–ˆ์„ ๋•Œ [0, ] ๊ณผ [1., 1.] ์˜ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅด๋ฏ€๋กœ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋กœ ๋ฌถ์„ ์ˆ˜ ์—†๋‹ค๋Š” ์—๋Ÿฌ์ด๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ์ž„์‹œ์ ์ธ ํ•ด๊ฒฐ๋ฒ•์œผ๋กœ ๊ธธ์ด์— ๋งž์ถฐ 0 ์„ ์ฑ„์›Œ๋„ฃ๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•˜์ž. ์ด ํ•ด๊ฒฐ๋ฒ•์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด collate_fn ์„ ์‚ฌ์šฉํ•œ๋‹ค.

Solution

ํ•ด๋‹น ๋ฌธ์ œ๋ฅผ ํ’€๊ธฐ ์œ„ํ•ด ํ‚ค์›Œ๋“œ๊ฐ€ ์žˆ๋‹ค๋ฉด batch size ์— ๋งž๊ฒŒ 0์œผ๋กœ ์ฑ„์›Œ๋„ฃ๋Š”๋‹ค ์ด๋‹ค. ๊ทธ๋ž˜์„œ ์ƒ๊ฐํ•ด๋ณด๋ฉด ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋งŒํผ ๋ฌถ์ธ ๋ฐ์ดํ„ฐ์…‹๋“ค ์ค‘์—์„œ ๊ฐ€์žฅ ๊ธธ์ด๊ฐ€ ๊ธด ํ…์„œ์˜ ๊ธธ์ด๋ฅผ ์•Œ์•„๋‚ธ ๋‹ค์Œ max_len - ๋‚˜๋จธ์ง€ ๊ฐ ํ…์„œ๋“ค์˜ ๊ธธ์ด ๋งŒํผ 0 ์œผ๋กœ ์ฑ„์›Œ์ฃผ๋ฉด ๋  ๊ฒƒ ๊ฐ™๋‹ค

๊ทธ๋ž˜์„œ 0 ์œผ๋กœ ์–ด๋–ป๊ฒŒ ์ฑ„์›Œ์ฃผ๋ฉด ์ข‹์„๊นŒ ์ฐพ์•„๋ณด๋‹ค๊ฐ€ TORCH.NN.FUNCTIONAL ์˜ pad ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•˜๋ฉด ์ข‹์„ ๊ฑฐ ๊ฐ™์•˜๋‹ค. ๊ณต์‹๋ฌธ์„œ

๊ณต์‹๋ฌธ์„œ๋ฅผ ๋ณด๋ฉด ์–ด๋–ป๊ฒŒ ์“ฐ๋ฉด ๋˜์–ด์žˆ๋Š”์ง€ ์ž˜ ๋‚˜์™€์žˆ๋‹ค. ์ฝ์œผ๋ฉด

to pad only the last dimension of the input tensor, then pad has the form (padding_left,padding_right)

์ด๋ผ๋Š” ๊ฒƒ์„ ์ฝ์„ ์ˆ˜ ์žˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ๋งˆ์ง€๋ง‰, ์ฆ‰ ์˜ค๋ฅธ์ชฝ์—๋งŒ 0์„ ์ฑ„์šฐ๋ฉด ๋œ๋‹ค. ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ๋ณด์ž.

code

def my_collate_fn(samples):
    collate_X = []

    max_len = max([len(sample['X']) for sample in samples]) 
    # ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋งŒํผ ๋ฌถ์ธ ๋ฐ์ดํ„ฐ๋“ค ์ค‘ ๊ฐ€์žฅ ๊ธด ๋ฐ์ดํ„ฐ์…‹ ๊ธธ์ด

    for _x in [sample['X'] for sample in samples]:
      tensor_len = _x.size(dim=0) # ํ…์„œ ๊ธธ์ด
      p2d = (0, max_len - tensor_len)
      # right ์—๋งŒ max_len - tensor_len ๋งŒํผ 0์œผ๋กœ ์ฑ„์›Œ์ค„ ๊ฒƒ์ด๋‹ค.

      _x = F.pad(_x, p2d)

      collate_X.append(_x)

    return {'X': torch.stack(collate_X)}

pad ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ๊ฐ„๊ฒฐํ•˜๊ฒŒ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์—ˆ๋‹ค. 0์œผ๋กœ ์ฑ„์›Œ์•ผ ๋  ๊ฐœ์ˆ˜๋ฅผ max_len - tensor_len ๋กœ ์•Œ์•„๋‚ด ์ž๋™์œผ๋กœ ์ฑ„์šธ ์ˆ˜ ์žˆ์—ˆ๋‹ค.

Result

์œ„์—์„œ ์ง  my_collate_fn ๋ฅผ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๋„˜๊ธฐ๊ณ  ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ 3 ์œผ๋กœ ์„ค์ •ํ–ˆ์„ ๋•Œ ์•„๋ž˜์™€ ๊ฐ™์€ ๊ฒฐ๊ณผ๊ฐ’์ด ๋‚˜์˜จ๋‹ค.

tensor([[0., 0., 0.],
        [1., 1., 0.],
        [2., 2., 2.]])
tensor([[3., 3., 3., 3., 0., 0.],
        [4., 4., 4., 4., 4., 0.],
        [5., 5., 5., 5., 5., 5.]])
tensor([[6., 6., 6., 6., 6., 6., 6., 0., 0.],
        [7., 7., 7., 7., 7., 7., 7., 7., 0.],
        [8., 8., 8., 8., 8., 8., 8., 8., 8.]])
tensor([[9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])

์›ํ•˜๋Š”๋Œ€๋กœ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ 3 ๋งŒํผ ๋ฐ์ดํ„ฐ๋“ค์„ ๋ฌถ์€ ๋ชจ์Šต์„ ๋ณด์ด๊ณ  ๋งˆ์ง€๋ง‰ ํ…์„œ๋งŒ ๋‚จ์•„์žˆ๋Š” ๊ฒฐ๊ณผ๊ฐ’์ด ๋ณด์ธ๋‹ค. ์ด๋ ‡๊ฒŒ ๋งˆ์ง€๋ง‰ ๋ฐฐ์น˜๋งŒ ๋ฉ๊ทธ๋Ÿฌ๋‹ˆ ๋‚จ์„ ๊ฒฝ์šฐ๊ฐ€ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ด๋Ÿด ๋•Œ๋Š” Dataloader ์˜ drop_last ์˜ต์…˜์„ True ๋กœ ์„ค์ •ํ•ด๋‘๋ฉด ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค.

What I learned?

์‚ฌ์‹ค ๊ฐ•์˜๋ฅผ ๋“ค์„ ๋•Œ ๋งŽ์€ ๋ถ€๋ถ„๋“ค์„ ์ดํ•ด๋ฅผ ๋ชปํ–ˆ๋‹ค. ์ธ๊ณต์ง€๋Šฅ ๋ถ„์•ผ ๊ณต๋ถ€๊ฐ€ ์ฒ˜์Œ์ด๋ผ ์กฐ๊ธˆ ๋ฒ…์ฐผ๋Š”๋ฐ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹ ์ฑ…์„ ์ฝ์–ด๊ฐ€๋ฉด์„œ ์ดํ•ดํ•˜๊ณ  ๊ฐ•์˜๋ฅผ ๋งŽ์ด ๋Œ๋ ค๋ณด๋‹ˆ ์ดํ•ด๋˜๋Š” ๋ถ€๋ถ„๋“ค์ด ๋งŽ์•˜๋‹ค. ๊ธฐ๋ณธ๊ธฐ๋ฅผ ์†Œํ™€ํžˆ ํ•˜์ง€ ๋ง์ž!