這個項目基于coursera上的ML課程,學過神經網絡之后,就利用octave做了一個神經網絡識別手寫數字的程序。 hidden layer有2層,theta都是給好的,所以完整的代碼就是
function p = predict(Theta1, Theta2, X)
m = size(X, 1);
num_labels = size(Theta2, 1);
p = zeros(size(X, 1), 1);
% Second layer
X = [ones(m, 1) X];
z_two = Theta1 * X';
a_two = sigmoid(z_two);
second_m = size(a_two, 2);
a_two = [ones(1, second_m); a_two];
% Third layer
z_three = a_two' * Theta2';
a_three = sigmoid(z_three);
[max_value max_index] = max(a_three');
p = max_index';
end
準確率可以達到97%,相比羅輯回歸跟高。 下面是識別的效果
1.png
識別手寫1的圖片
5.png
識別手寫5的圖片
6.png
識別手寫6的圖片