소스 URL : https://github.com/mrchaos/model-zoo/blob/master/vision/mnist/mlp_gpu_minibatch.jl
model zoo의 mnist예제 중 mlp.jl 소스에 몇가지 이슈가 있어 수정하고 미니배치를 적용 하였다.
Julia 1.3.1, Flux 0.10.1 을 사용 하였다.
가장 큰 이슈는 loss function에서 NaN이 발생하여 train이 제대로 되지 않는 문제를 수정 했다.
loss가 아래와 같이 정의 되는 경우 mini-batch를 적용하면 계산 되는 batch 데이터가 적기 때문에 자주 NaN이 발생한다.
loss(x,y) = crossentropy(m(x), y)
예를 들어 아래의 경우 계산 데이터가 적기 때문에 crossentropy값이 NaN이 발생하여 가중치 업데이트 즉 학습이 되지 않는다.
julia> crossentropy([1,1,0],[1,0,0])
NaN
위의 문제를 해결 하기 위해 아주 작은 값은 더 해 주면 문제를 피해 갈 수 있으며 학습이 정상적으로 잘 된다.
julia> const ϵ = 1.0f-10
julia> crossentropy([1,1,0] .+ ϵ,[1,0,0] .+ ϵ)
2.302585f-9
또한 onehotbatch 또는 onecold 가 gpu에서 재대로 동작 하지 않는 문제가 있으며 또한 gpu scalar 연산으로 인해 느려진다는 경고가 계속 발생한다.
gpu scalar 연산속도 저하에 대한 경고를 없애고 onehotbatch 또는 onecold 이 gpu에서 문제를 없애기 위해 아래와 같이 수정 하였다.
CuArrays.allowscalar(false) 로 gpu scalar을 방지함
onehotbatch(labels,0:9) |> gpu —–> float.(onehotbatch(labels,0:9)) |> gpu
accuracy 계산시에도 문제가 발생하여 아래와 같이 수정 하였다. onecode계산시 gpu메모리 데이터를 사용하는 경우 문제가 발생하는 듯함
추후에 개선될 것으로 기대 한다.
gpu메모리를 cpu 모드로 복사 후 onecold를 적용하면 문제 없이 수행된다.
에러가 발생하는 accuracy 버전
accuracy(x,y) = mean(onecold(m(x)) .== onecold(y)) # 에러 발생
수정된 accuracy 버전
accuracy(x,y) = mean(onecold(m(x)|>cpu) .== onecold(y|>cpu)) # 잘 동작함
mini-batch(미니배치) 를 적용하여 GPU의 메모리를 아끼고 학습결과를 향상 시킴
아래는 mini-batch 데이터를 만는 function
function make_minibatch(imgs,labels,batch_size)
#=
reshape.(MNIST.images(),:) : [(784,),(784,),...,(784,)] 60,000개의 데이터
X : (784x60,000)
Y : (10x60,000)
=#
X = hcat(float.(reshape.(imgs,:))...) |> gpu
Y = float.(onehotbatch(labels,0:9)) |> gpu
# Y = Float32.(onehotbatch(labels,0:9))
data_set = [(X[:,i],Y[:,i]) for i in partition(1:length(labels),batch_size)]
return data_set
end
참고로 여러개의 GPU가 있는 경우 특정 GPU를 사용 할 수 있다. GPU 사용전에 먼저 아래와 같이 설정 하면 지정된 GPU에서 연산이 발생된다.
# use 1nd GPU : default
CUDAnative.device!(0)
# use 2nd GPU
#CUDAnative.device!(1)
전체 소스 코드
#=
Julia version: 1.3.1
Flux version : 0.10.1
=#
__precompile__()
module MNIST_BATCH
using Flux
using Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy,throttle
using Base.Iterators: repeated,partition
using CUDAnative
using CuArrays
CuArrays.allowscalar(false)
#=
Very important !!
ϵ is used to prevent loss NaN
=#
const ϵ = 1.0f-10
# Load training labels and images from Flux.Data.MNIST
@info("Loading data...")
#=
MNIST.images() : [(28x28),...,(28x28)] 60,000x28x28 training images
MNIST.labels() : 0 ~ 9 labels , 60,000x10 training labels
=#
train_imgs = MNIST.images()
train_labels = MNIST.labels()
# use 1nd GPU : default
#CUDAnative.device!(0)
# use 2nd GPU
#CUDAnative.device!(1)
# Bundle images together with labels and group into minibatch
function make_minibatch(imgs,labels,batch_size)
#=
reshape.(MNIST.images(),:) : [(784,),(784,),...,(784,)] 60,000개의 데이터
X : (784x60,000)
Y : (10x60,000)
=#
X = hcat(float.(reshape.(imgs,:))...) |> gpu
Y = float.(onehotbatch(labels,0:9)) |> gpu
# Y = Float32.(onehotbatch(labels,0:9))
data_set = [(X[:,i],Y[:,i]) for i in partition(1:length(labels),batch_size)]
return data_set
end
@info("Making model...")
# Model
m = Chain(
Dense(28^2,32,relu), # y1 = relu(W1*x + b1), y1 : (32x?), W1 : (32x784), b1 : (32,)
Dense(32,10), # y2 = W2*y1 + b2, y2 : (10,?), W2: (10x32), b2:(10,)
softmax
) |> gpu
loss(x,y) = crossentropy(m(x) .+ ϵ, y .+ ϵ)
accuracy(x,y) = mean(onecold(m(x)|>cpu) .== onecold(y|>cpu))
batch_size = 500
train_dataset = make_minibatch(train_imgs,train_labels,batch_size)
opt = ADAM()
@info("Training model...")
epochs = 200
# used for plots
accs = Array{Float32}(undef,0)
dataset_len = length(train_dataset)
for i in 1:epochs
for (idx,dataset) in enumerate(train_dataset)
Flux.train!(loss,params(m),[dataset],opt)
# Flux.train!(loss,params(m),[dataset],opt,cb = throttle(()->@show(loss(dataset...)),20))
acc = accuracy(dataset...)
if idx == dataset_len
@info("Epoch# $(i)/$(epochs) - loss: $(loss(dataset...)), accuracy: $(acc)")
push!(accs,acc)
end
end
end
# Test Accuracy
tX = hcat(float.(reshape.(MNIST.images(:test),:))...) |> gpu
tY = float.(onehotbatch(MNIST.labels(:test),0:9)) |> gpu
println("Test loss:", loss(tX,tY))
println("Test accuracy:", accuracy(tX,tY))
end
using Plots;gr()
plot(MNIST_BATCH.accs)
핑백: [Flux] cifar10 example – gpu,minibatch, fix loss NaN – Power UP!