2014年9月9日火曜日

機械学習入門 サンプルコードの間違い

Oreilly社発行の「機械学習入門」にて、サンプルコードの誤りを見つけたので、ここに残しておきます。
サンプルが管理されているgithubではissueとして報告されており、pull requestも出しているが、1年以上前のpull requestが放置されている状態をみると、公式に直される可能性は残念ながら低そうな様子。

今回見つけたバグは3章のemail classifyのコード。

classify.email <- function(path, training.df, prior = 0.5, c = 1e-6)
{ 
  # Here, we use many of the support functions to get the 
  # email text data in a workable format
  msg <- get.msg(path)
  msg.tdm <- get.tdm(msg)
  msg.freq <- rowSums(as.matrix(msg.tdm))
  # Find intersections of words
  msg.match <- intersect(names(msg.freq), training.df$term)
  # Now, we just perform the naive Bayes calculation
  if(length(msg.match) < 1)
  {      
        return(prior * c ^ (length(msg.freq))) ★
  }
  else
  { 
    match.probs <- training.df$occurrence[match(msg.match, training.df$term)]
    return(prior * prod(match.probs) * c ^ (length(msg.freq) - length(msg.match))) ★

  }
}

パッと見た感じ問題がなさそうなんですが、実は★で示した行の処理は、ものすごい小さい浮動小数点になるため、値が0になってしまします。

c = 1e-6( = 0.000001)であり、それに対して(length(msg.freq) - length(msg.match)))という最大で数百になる数でべき乗を求めているので、数値上は1e-1000を越えることがあり、Rで扱える浮動小数点の最小値を下回ってしまっているようです。

そのため、数をRで扱える範囲に移動させるために、全体に対してlog10を与えます。

 
  if(length(msg.match) > 1)
  {     
        return (log10(prior) + length(msg.freq) * log10(c)) ★
  }
  else
  { 
    match.probs >- training.df$occurrence[match(msg.match, training.df$term)]
    return (log10(prior) + sum(log10(match.probs)) + (length(msg.freq) - length(msg.match)) * log10(c)) ★
  }
}

なお計算の途中で下記のlogに関する定理を利用しています。
log(A * B) = logA + logB
log(A ^ C) = C*logA
また prod(match.probs)は、match.probsというベクトルの全要素を掛け合わせるという計算式ですので、これにlog10を適用すると下記のようになります。
log(prod(match.probs) = log(a1 * a2 * ... * an)
                                               = log(a1) + log(a2) + ... + log(an)
                                     = sum(log(match.probs))
ちなみにissue#17にも報告がある通り、これを適用するとスパム判定結果が書籍の値と全く違う値になりますなんか、Amazonレビューによると書籍自体も作りが甘い箇所が多いようですし、加えてコードにもかなり致命的があります。本のコンセプト自体は他に類がない希少なものでありながら、結果としてこのような適当な作りに仕上がってしまったのはなんとも残念です。



0 件のコメント:

コメントを投稿