[Flux] MNIST example update

수정부문:

기존
accuracy(x,y) = mean(onecold(m(x) .== onecold(y))
loss(x,y) = crossentropy(m(x),y)

변경
# ϵ : loss함수가 NaN이 되는것을 방지
ϵ = 1.0f-32
loss(x,y) = crossentropy(m(x) .+ ϵ,y)

# onecold가 GPU에서는 에러 발생 하기 때무에 아래로 대체
참조 : https://github.com/FluxML/Flux.jl/issues/556

compare(y::OneHotMatrix, y′) = maximum(y′, dims = 1) .== maximum(y .* y′, dims = 1)
accuracy(x, y::OneHotMatrix) = mean(compare(y, m(x)))

mnist.jl

using Flux
using Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy,OneHotMatrix,@epochs
using Base.Iterators: repeated
using CuArrays
CuArrays.allowscalar(false)

# loss함수의 NaN 방지
ϵ = 1.0f-32

# Classify MNIST digits with a simple multi-layer-perceptron
# imgs : [(28x28),(28x28),...,(28x28))]  60,000개의 데이터
# MNIST.images()는 MNIST.images(:train)과 동일하며 60,000개의 학습데이터를 가져온다.
imgs = MNIST.images()
# reshape.(imgs,:) : [(784,),(784,),...,(784,)]  60,000개의 데이터
# X : (784x60,000)
X = hcat(float.(reshape.(imgs,:))...)
# labels : (60,000,)
labels = MNIST.labels()
# label : 0 ~ 9
# Y : (10x60,000)
Y = onehotbatch(labels,0:9)
# Model
m = Chain(
  Dense(28^2,32,relu), # y1 = relu(W1*x + b1), y1 : (32x?), W1 : (32x784), b1 : (32,)
  Dense(32,10), # y2 = W2*y1 + b2, y2 : (10,?), W2: (10x32), b2:(10,)
  softmax
)

# GPU로 보내는 경우 ;를 붙여 에러가 나는 것을 방지
X = X |> gpu;

# ;가 없으면 아래 문장에서 에러 발생
# 내부적으로 getindex에서 에러 발생
Y = Y |> gpu;

m = m |> gpu;
loss(x,y) = crossentropy(m(x) .+ ϵ,y)

# 참조 : https://github.com/FluxML/Flux.jl/issues/556
compare(y::OneHotMatrix, y′) = maximum(y′, dims = 1) .== maximum(y .* y′, dims = 1)
accuracy(x, y::OneHotMatrix) = mean(compare(y, m(x)))

# 아래 accuracy실행시 에러 발생 : ERROR: scalar setindex! is disallowed

# onecold가 GPU기반에서 제대로 동작 하지 않음
# accuracy(x,y) = mean(onecold(m(x)) .== onecold(y))

dataset = repeated((X,Y),200);

opt = ADAM()

@epochs 50 begin
  Flux.train!(loss,params(m),dataset,opt)
  println("Train loss:",loss(X,Y))
  println("Train accuracy:", accuracy(X,Y))
end

# Test Accuracy
tX = hcat(float.(reshape.(MNIST.images(:test),:))...)
tY = onehotbatch(MNIST.labels(:test),0:9)

tX = tX |> gpu;
tY = tY |> gpu;

println("Test loss:",loss(tX,tY))
println("Test accuracy:", accuracy(tX,tY))

댓글 달기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다