神經網絡實現手寫識別

這個項目基于coursera上的ML課程,學過神經網絡之后,就利用octave做了一個神經網絡識別手寫數字的程序。 hidden layer有2層,theta都是給好的,所以完整的代碼就是

function p = predict(Theta1, Theta2, X)
m = size(X, 1);
num_labels = size(Theta2, 1);
p = zeros(size(X, 1), 1);


% Second layer
X = [ones(m, 1) X];
z_two = Theta1 * X';
a_two = sigmoid(z_two);
second_m = size(a_two, 2);
a_two = [ones(1, second_m); a_two];

% Third layer
z_three = a_two' * Theta2';
a_three = sigmoid(z_three);

[max_value max_index] = max(a_three');
p = max_index';
end

準確率可以達到97%,相比羅輯回歸跟高。 下面是識別的效果

1.png

識別手寫1的圖片

5.png

識別手寫5的圖片

6.png

識別手寫6的圖片

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容