[Flux] mnist example with gpu, mini-batch, fix loss NaN

소스 URL : https://github.com/mrchaos/model-zoo/blob/master/vision/mnist/mlp_gpu_minibatch.jl

model zoo의 mnist예제 중 mlp.jl 소스에 몇가지 이슈가 있어 수정하고 미니배치를 적용 하였다.

Julia 1.3.1, Flux 0.10.1 을 사용 하였다.

가장 큰 이슈는 loss function에서 NaN이 발생하여 train이 제대로 되지 않는 문제를 수정 했다.

loss가 아래와 같이 정의 되는 경우 mini-batch를 적용하면 계산 되는 batch 데이터가 적기 때문에 자주 NaN이 발생한다.

loss(x,y) = crossentropy(m(x), y)

예를 들어 아래의 경우 계산 데이터가 적기 때문에 crossentropy값이 NaN이 발생하여 가중치 업데이트 즉 학습이 되지 않는다.

julia> crossentropy([1,1,0],[1,0,0])
NaN

위의 문제를 해결 하기 위해 아주 작은 값은 더 해 주면 문제를 피해 갈 수 있으며 학습이 정상적으로 잘 된다.

julia> const ϵ = 1.0f-10
julia> crossentropy([1,1,0] .+ ϵ,[1,0,0] .+ ϵ)
2.302585f-9

또한 onehotbatch 또는 onecold 가 gpu에서 재대로 동작 하지 않는 문제가 있으며 또한 gpu scalar 연산으로 인해 느려진다는 경고가 계속 발생한다.

gpu scalar 연산속도 저하에 대한 경고를 없애고 onehotbatch 또는 onecold 이 gpu에서 문제를 없애기 위해 아래와 같이 수정 하였다.

CuArrays.allowscalar(false) 로 gpu scalar을 방지함

onehotbatch(labels,0:9) |> gpu —–> float.(onehotbatch(labels,0:9)) |> gpu

accuracy 계산시에도 문제가 발생하여 아래와 같이 수정 하였다. onecode계산시 gpu메모리 데이터를 사용하는 경우 문제가 발생하는 듯함
추후에 개선될 것으로 기대 한다.

gpu메모리를 cpu 모드로 복사 후 onecold를 적용하면 문제 없이 수행된다.

에러가 발생하는 accuracy 버전

accuracy(x,y) = mean(onecold(m(x)) .== onecold(y)) # 에러 발생

수정된 accuracy 버전

accuracy(x,y) = mean(onecold(m(x)|>cpu) .== onecold(y|>cpu)) # 잘 동작함

mini-batch(미니배치) 를 적용하여 GPU의 메모리를 아끼고 학습결과를 향상 시킴

아래는 mini-batch 데이터를 만는 function

function make_minibatch(imgs,labels,batch_size)
  #=
   reshape.(MNIST.images(),:) : [(784,),(784,),...,(784,)]  60,000개의 데이터
   X : (784x60,000)
   Y : (10x60,000)
  =#
  X = hcat(float.(reshape.(imgs,:))...) |> gpu
  Y = float.(onehotbatch(labels,0:9)) |> gpu
  # Y = Float32.(onehotbatch(labels,0:9))
  
  data_set = [(X[:,i],Y[:,i]) for i in partition(1:length(labels),batch_size)]
  return data_set
end

참고로 여러개의 GPU가 있는 경우 특정 GPU를 사용 할 수 있다. GPU 사용전에 먼저 아래와 같이 설정 하면 지정된 GPU에서 연산이 발생된다.

# use 1nd GPU : default
CUDAnative.device!(0)
# use 2nd GPU
#CUDAnative.device!(1)

전체 소스 코드

#=
Julia version: 1.3.1
Flux version : 0.10.1
=#
__precompile__()
module MNIST_BATCH
using Flux
using Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy,throttle
using Base.Iterators: repeated,partition

using CUDAnative
using CuArrays
CuArrays.allowscalar(false)

#= 
Very important !!
ϵ is used to prevent loss NaN
=#
const ϵ = 1.0f-10

# Load training labels and images from Flux.Data.MNIST
@info("Loading data...")
#=
MNIST.images() : [(28x28),...,(28x28)] 60,000x28x28 training images
MNIST.labels() : 0 ~ 9 labels , 60,000x10 training labels
=#
train_imgs = MNIST.images()
train_labels = MNIST.labels()

# use 1nd GPU : default
#CUDAnative.device!(0)
# use 2nd GPU
#CUDAnative.device!(1)

# Bundle images together with labels and group into minibatch
function make_minibatch(imgs,labels,batch_size)
  #=
   reshape.(MNIST.images(),:) : [(784,),(784,),...,(784,)]  60,000개의 데이터
   X : (784x60,000)
   Y : (10x60,000)
  =#
  X = hcat(float.(reshape.(imgs,:))...) |> gpu
  Y = float.(onehotbatch(labels,0:9)) |> gpu
  # Y = Float32.(onehotbatch(labels,0:9))
  
  data_set = [(X[:,i],Y[:,i]) for i in partition(1:length(labels),batch_size)]
  return data_set
end

@info("Making model...")
# Model
m = Chain(
  Dense(28^2,32,relu), # y1 = relu(W1*x + b1), y1 : (32x?), W1 : (32x784), b1 : (32,)
  Dense(32,10), # y2 = W2*y1 + b2, y2 : (10,?), W2: (10x32), b2:(10,)
  softmax
) |> gpu
loss(x,y) = crossentropy(m(x) .+ ϵ, y .+ ϵ)
accuracy(x,y) = mean(onecold(m(x)|>cpu) .== onecold(y|>cpu))

batch_size = 500
train_dataset = make_minibatch(train_imgs,train_labels,batch_size)

opt = ADAM()


@info("Training model...")

epochs = 200
# used for plots
accs = Array{Float32}(undef,0)

dataset_len = length(train_dataset)
for i in 1:epochs
  for (idx,dataset) in enumerate(train_dataset)
    Flux.train!(loss,params(m),[dataset],opt)
    # Flux.train!(loss,params(m),[dataset],opt,cb = throttle(()->@show(loss(dataset...)),20))
    acc = accuracy(dataset...)
    if idx == dataset_len
      @info("Epoch# $(i)/$(epochs) - loss: $(loss(dataset...)), accuracy: $(acc)")
      push!(accs,acc)
    end
  end
end

# Test Accuracy
tX = hcat(float.(reshape.(MNIST.images(:test),:))...) |> gpu
tY = float.(onehotbatch(MNIST.labels(:test),0:9)) |> gpu

println("Test loss:", loss(tX,tY))
println("Test accuracy:", accuracy(tX,tY))

end

using Plots;gr()
plot(MNIST_BATCH.accs)

“[Flux] mnist example with gpu, mini-batch, fix loss NaN”의 1개의 댓글

  1. 핑백: [Flux] cifar10 example – gpu,minibatch, fix loss NaN – Power UP!

댓글 달기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다