{"id":2537,"date":"2020-02-18T03:37:42","date_gmt":"2020-02-17T18:37:42","guid":{"rendered":"https:\/\/julialang.kr\/?p=2537"},"modified":"2020-02-20T17:57:08","modified_gmt":"2020-02-20T08:57:08","slug":"flux-rnn-example-char_rnn_gpu_minibatch","status":"publish","type":"post","link":"https:\/\/julialang.kr\/?p=2537","title":{"rendered":"[Flux] RNN example &#8211; char_rnn_gpu_minibatch"},"content":{"rendered":"\n<p><a href=\"https:\/\/github.com\/FluxML\/model-zoo\/blob\/master\/text\/char-rnn\/char-rnn.jl\">https:\/\/github.com\/FluxML\/model-zoo\/blob\/master\/text\/char-rnn\/char-rnn.jl<\/a> \uc758 error fix \ubc0f \uc218\uc815 \ud655\uc7a5\ubcf8<\/p>\n\n\n\n<p>\ud658\uacbd : Julia v1.3.1, Flux v0.10.1<\/p>\n\n\n\n<p>char-rnn\uc758 \uc774\ud574\ub97c \uc704\ud55c jupyer notebook (html) &#8211;&gt; <a rel=\"noreferrer noopener\" aria-label=\"click here (\uc0c8\ud0ed\uc73c\ub85c \uc5f4\uae30)\" href=\"https:\/\/julialang.kr\/wp-content\/uploads\/2020\/02\/char_rnn_gpu_minibatch_\uc758_\uc774\ud574.html\" target=\"_blank\">click here<\/a><\/p>\n\n\n\n<p>github :  <a href=\"https:\/\/github.com\/mrchaos\/model-zoo\/blob\/master\/text\/char-rnn\/char_rnn_gpu_minibatch.jl\">https:\/\/github.com\/mrchaos\/model-zoo\/blob\/master\/text\/char-rnn\/char_rnn_gpu_minibatch.jl<\/a> <\/p>\n\n\n\n<p><a rel=\"noreferrer noopener\" aria-label=\"char_rnn_gpu_minibatch.jl (\uc0c8\ud0ed\uc73c\ub85c \uc5f4\uae30)\" href=\"https:\/\/julialang.kr\/wp-content\/uploads\/2020\/02\/char_rnn_gpu_minibatch.zip\" target=\"_blank\">char_rnn_gpu_minibatch.jl<\/a><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>using BSON\nusing BSON: @save,@load\nusing Flux\nusing Flux: onehot, chunk, batchseq, throttle, crossentropy\nusing StatsBase: wsample\nusing Base.Iterators: partition\nusing CuArrays\nusing CUDAnative: device!\nusing Random\nusing Dates\nusing Logging\n\n\u03f5 = 1.0f-32\nworking_path = dirname(@__FILE__)\nfile_path(file_name) = joinpath(working_path,file_name)\ninclude(file_path(\"cmd_parser.jl\"))\n\nmodel_file = file_path(\"char_rnn_gpu_minibatch.bson\")\n\n# # Get arguments\nparsed_args = CmdParser.parse_commandline()\nepochs = parsed_args&#91;\"epochs\"]\nbatch_size = parsed_args&#91;\"batch\"]\nuse_saved_model = parsed_args&#91;\"model\"]\ngpu_device = parsed_args&#91;\"gpu\"]\ncreate_log_file = parsed_args&#91;\"log\"]\nsequence = parsed_args&#91;\"seq\"]\n\nif create_log_file\n  log_file =\".\/char_rnn_gpu_minibatch_$(Dates.format(now(),\"yyyymmdd-HHMMSS\")).log\"\n  log = open(log_file, \"w+\")\nelse\n  log = stdout\nend\nglobal_logger(ConsoleLogger(log))\n\nstart_time = now()\n@info \"Start - $(start_time)\";flush(log)\n@info \"=============== Arguments ===============\"\n@info \"epochs=$(epochs)\"\n@info \"batch_size=$(batch_size)\"\n@info \"use_saved_model=$(use_saved_model)\"\n@info \"gpu_device=$(gpu_device)\"\n@info \"sequence=$(sequence)\"\n@info \"log_file=$(create_log_file)\"\n@info \"=========================================\";flush(log)\n\n\ndevice!(gpu_device)\nCuArrays.allowscalar(false)\n\ninput_file = file_path(\"input.txt\")\nisfile(input_file) ||\n    download(\"https:\/\/cs.stanford.edu\/people\/karpathy\/char-rnn\/shakespeare_input.txt\",\n             input_file)\n\n# read(input_file) : \ud30c\uc77c\uc5d0\uc11c \ud14d\uc2a4\ud2b8 \uc77d\uc624\uc634 - \ubc14\uc774\ub108\ub9ac\n# String(read(input_file)) : \ubc14\uc774\ub108\ub9ac\ub97c \uc2a4\ud2b8\ub9c1\uc73c\ub85c \ubcc0\ud658\n# collect(String(read(input_file)) : \uc2a4\ud2b8\ub9c1\uc744 \uac1c\ubcc4 char array\ub85c \ubcc0\ud658 - Array{Char,1}\ntext = collect(String(read(input_file)))\n\n# unique(text) : text\uc5d0\uc11c unique\ud55c char array\ub97c \ub9cc\ub4e0\ub2e4 - \uc911\ubcf5\uc81c\uac70 - \ud558\uace0\n# \ub9e8\ub4a4\uc5d0 '_' \ub97c \ucd94\uac00 \ud55c\ub2e4.\n# unique\ud55c char -\uc54c\ud30c\ubcb3 array\ub97c \ub9cc\ub4e0\ub2e4.\nalphabet = &#91;unique(text)...,'_']\n# ch onehot\uc744 \ub9cc\ub4e0\ub2e4. onhot\uc758 \uae38\uc774\ub294 length(alphabet)\uc774\uace0 onehot\uc5d0\uc11c 1\uc774 \uc788\ub294 \uc704\uce58\ub294\n# alphabet\uc5d0\uc11c ch\uac00 \uc788\ub294 \uc704\uce58\uc640 \ub3d9\uc77c\ntext = map(ch -> Float32.(onehot(ch,alphabet)),text)\nstop = Float32.(onehot('_',alphabet))\n\nN = length(alphabet)\nseqlen = sequence\nnbatch = batch_size\n\nXs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))\ntxt = circshift(text,-1)\ntxt&#91;end] = stop\nYs = collect(partition(batchseq(chunk(txt, nbatch), stop), seqlen))\n\nvloss=Inf; epoch = 0; t_sec = Second(0);\nif use_saved_model &amp;&amp; isfile(model_file) &amp;&amp; filesize(model_file) > 0\n  # flush : \ubc84\ud37c\ub9c1 \uc5c6\uc774 \uc989\uac01 log\ub97c \ud30c\uc77c \ub610\ub294 console\uc5d0 write\ud558\ub3c4\ub85d \ud568\n  @info \"Load saved model $(model_file) ...\";flush(log)\n  # model : @save\uc2dc \uc0ac\uc6a9\ud55c object\uba85\n  @load model_file model vloss epoch sec\n  t_sec = sec \n  m = model |> gpu\n  run_min = round(Second(t_sec), Minute)\n  @info \" -> loss : $(vloss), epochs : $(epoch), run time : $(run_min)\";flush(log)\nelse\n  @info \"Create new model ...\";flush(log)  \n  model = Chain(\n    LSTM(N, 128),\n    LSTM(128, 256),\n    LSTM(256, 128),\n    Dense(128, N),\n    softmax)\n    m = model |>gpu\nend\n\nopt = ADAM(0.01)\ntx, ty = (Xs&#91;1]|>gpu, Ys&#91;1]|>gpu)\n\nfunction loss(xx, yy)\n  out = 0.0f0\n  for (idx, x) in enumerate(xx)\n    out += crossentropy(m(x) .+ \u03f5, yy&#91;idx])\n  end  \n  Flux.reset!(m)\n  out\nend\n\n@info \"Training model...\";flush(log)\n\nidxs = length(Xs)\nbest_loss = vloss\nlast_improvement = epoch\nepoch_start_time = now() \nepochs += epoch\nepoch += 1\nfor epoch_idx in epoch:epochs\n  global best_loss,last_improvement,t_sec,epoch_start_time\n  mean_loss = 0.0f0\n  for (idx,(xs,ys)) in enumerate(zip(Xs, Ys))\n    Flux.train!(loss, params(m), &#91;(xs|>gpu,ys|>gpu)], opt)\n    lss = loss(tx,ty)\n    mean_loss += lss\n    if idx % 10 == 0\n      @info \"epoch# $(epoch_idx)\/$(epochs)-$(idx)\/$(idxs) loss = $(lss)\";flush(log)\n    end\n  end\n  mean_loss \/= idxs\n\n  run_sec = round(Millisecond(now()-epoch_start_time), Second)\n  run_min = round(Second(run_sec), Minute)\n  t_run_min = round(Second(t_sec+run_sec), Minute)\n  @info \"epoch# $(epoch_idx)\/$(epochs)-> mean loss : $(mean_loss), running time : $(run_min)\/$(t_run_min)\";flush(log)\n  \n  # If this is the best accuracy we've seen so far, save the model out\n  if mean_loss &lt;= best_loss\n    @info \" -> New best loss! saving model out to $(model_file)\"; flush(log)\n    model = m |> cpu\n    vloss = mean_loss;epoch = epoch_idx; sec = t_sec + run_sec\n    # @save,@load \uc2dc \uac19\uc740 \uc774\ub984\uc744 \uc0ac\uc6a9\ud574\uc57c \ud568, \uc5ec\uae30\uc11c\ub294 \"model\"\uc744 \uc0ac\uc6a9\ud568\n    @save model_file model vloss epoch sec\n    best_loss = mean_loss\n    last_improvement = epoch_idx    \n  end\n\n  # If we haven't seen improvement in 5 epochs, drop out learning rate:\n  if epoch_idx - last_improvement >= 5 &amp;&amp; opt.eta > 1e-6\n    opt.eta \/= 10.0\n    @info \" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!\";flush(log)\n    # After dropping learning rate, give it a  few epochs to improve\n    last_improvement = epoch_idx\n  end  \n\n  if epoch_idx - last_improvement >= 10  \n    @info \" -> We're calling this converged.\"; flush(log)\n    break\n  end  \nend\nend_time = now()\n@info \"End - $(end_time)\";flush(log)\nrun_min = round(round(Millisecond(end_time - start_time), Second),Minute)\n@info \"Running time : $(run_min)\";flush(log)\n\n# Sampling\n\nfunction sample(m, alphabet, len)\n  m = cpu(m)\n  Flux.reset!(m)\n  buf = IOBuffer()\n  c = rand(alphabet)\n  for i = 1:len\n    write(buf, c)\n    c = wsample(alphabet, m(onehot(c, alphabet)))\n  end\n  return String(take!(buf))\nend\n\n@info sample(m, alphabet, 1000);flush(log)\n\nif create_log_file\n  close(log)\nend<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>https:\/\/github.com\/FluxML\/model-zoo\/blob\/master\/text\/char-rnn\/char-rnn.jl \uc758 error fix \ubc0f \uc218\uc815 \ud655\uc7a5\ubcf8 \ud658\uacbd : Julia v1.3.1, Flux v0.10.1 char-rnn\uc758 \uc774\ud574\ub97c \uc704\ud55c jupyer notebook (html) &#8211;&gt; click here github : https:\/\/github.com\/mrchaos\/model-zoo\/blob\/master\/text\/char-rnn\/char_rnn_gpu_minibatch.jl char_rnn_gpu_minibatch.jl<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"default","ast-site-content-layout":"","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"default","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"default","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"footnotes":""},"categories":[18,21],"tags":[],"_links":{"self":[{"href":"https:\/\/julialang.kr\/index.php?rest_route=\/wp\/v2\/posts\/2537"}],"collection":[{"href":"https:\/\/julialang.kr\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/julialang.kr\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/julialang.kr\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/julialang.kr\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2537"}],"version-history":[{"count":3,"href":"https:\/\/julialang.kr\/index.php?rest_route=\/wp\/v2\/posts\/2537\/revisions"}],"predecessor-version":[{"id":2544,"href":"https:\/\/julialang.kr\/index.php?rest_route=\/wp\/v2\/posts\/2537\/revisions\/2544"}],"wp:attachment":[{"href":"https:\/\/julialang.kr\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2537"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/julialang.kr\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2537"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/julialang.kr\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2537"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}