from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from numpy import *
import numpy as np
import time
import struct
#讀取圖片
def read_image(file_name):
#先用二進制方式把文件都讀進來
file_handle=open(file_name,"rb") #以二進制打開文檔
file_content=file_handle.read() #讀取到緩沖區中
offset=0
head = struct.unpack_from('>IIII', file_content, offset) # 取前4個整數,返回一個元組
offset += struct.calcsize('>IIII')
imgNum = head[1] #圖片數
rows = head[2] #寬度
cols = head[3] #高度
images=np.empty((imgNum , 784))#empty,是它所常見的數組內的所有元素均為空,沒有實際意義,它是創建數組最快的方法
image_size=rows*cols#單個圖片的大小
fmt='>' + str(image_size) + 'B'#單個圖片的format
for i in range(imgNum):
images[i] = np.array(struct.unpack_from(fmt, file_content, offset))
# images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))
offset += struct.calcsize(fmt)
return images
#讀取標簽
def read_label(file_name):
file_handle = open(file_name, "rb") # 以二進制打開文檔
file_content = file_handle.read() # 讀取到緩沖區中
head = struct.unpack_from('>II', file_content, 0) # 取前2個整數,返回一個元組
offset = struct.calcsize('>II')
labelNum = head[1] # label數
# print(labelNum)
bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'
label = struct.unpack_from(bitsString, file_content, offset) # 取data數據,返回一個元組
return np.array(label)
def loadDataSet():
train_x_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/train-images-idx3-ubyte"
train_y_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/train-labels-idx1-ubyte"
test_x_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/t10k-images-idx3-ubyte"
test_y_filename="/Users/didi/work/git_repo/demo_proj/src/logist_test/t10k-labels-idx1-ubyte"
train_x=read_image(train_x_filename)
train_y=read_label(train_y_filename)
test_x=read_image(test_x_filename)
test_y=read_label(test_y_filename)
return train_x, test_x, train_y, test_y
if __name__=='__main__':
print("Start reading data...")
time1=time.time()
train_x, test_x, train_y, test_y = loadDataSet()
clf = LogisticRegression()
clf.fit(train_x, train_y)
y_pred = clf.predict(test_x)
print('準確率:'+accuracy_score(test_y, y_pred))
Logistic 回歸(mnist數據集)
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
推薦閱讀更多精彩內容
- MNIST數據集是一個入門級的計算機視覺數據集,它包含各種手寫數字照片,它也包含每一張圖片對應的標簽,告訴我們這是...
- 本文作者:陳?鼎,中南財經政法大學統計與數學學院文字編輯:任?哲技術總編:張馨月 ??Logistic回歸分析是一...
- MNIST 數據集已經是一個被”嚼爛”了的數據集, 很多教程都會對它”下手”, 幾乎成為一個 “典范” 1.邏輯回...