搭建模塊化的神經(jīng)網(wǎng)絡(luò)模板

CSDN文章地址:https://blog.csdn.net/kdongyi

1.前向傳播就是搭建網(wǎng)絡(luò),設(shè)計(jì)網(wǎng)絡(luò)結(jié)構(gòu)(forward.py)

前向傳播網(wǎng)絡(luò)結(jié)構(gòu):

#前向傳播網(wǎng)絡(luò)結(jié)構(gòu)
def forword(x, regularizer):
    w=
    b=
    y=
    return y

定義權(quán)重函數(shù):

#定義權(quán)重函數(shù)
def get_weight(shape, regularizer):
    w = tf.Variable( )
    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

定義偏置量:

#定義偏置量:
def get_bias(shape):
    b = tf.Variable( ) 
    return b

2.反向傳播就是訓(xùn)練網(wǎng)絡(luò),優(yōu)化網(wǎng)絡(luò)參數(shù)(backword.py)

def backword( ):
    x = tf.placeholder( )
    y_ = tf.placeholder( )
    y = forward.forword(x, REGULARIZER) #前向傳播網(wǎng)絡(luò)結(jié)構(gòu),計(jì)算求y     
    global_step = tf.Variable(0, trainable = False) #定義輪數(shù)計(jì)數(shù)器global_step     
    loss = #定義損失函數(shù),可以選用以下均方誤差、交叉熵、自定義,表示計(jì)算出來的y與y_的差距
    
    #loss可以是(如果使用均方誤差):
    y與y_的差距(loss_mse) = tf.reduce_mean(tf.square(y-y_))

    #也可以是(如果使用交叉熵,則用下面兩行代碼):
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_, 1))
    y與y_的差距(cem) = tf.reduce_mean(ce)      #加入正則化后(如果使用正則化)    
    loss = y與y_的差距 + tf.add_n(tf.get_collection('losses'))

    #指數(shù)衰減學(xué)習(xí)率(如果使用指數(shù)衰減學(xué)習(xí)率,動(dòng)態(tài)計(jì)算學(xué)習(xí)率)
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 數(shù)據(jù)集總樣本數(shù) / BATCH_SASE, LEARNING_RATE_DECAY,        staircase=True) 
    #訓(xùn)練過程
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step) 

    #如果用滑動(dòng)平均:
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)    
    ema_op = ema.apply(tf.trainable_variables())

    with tf.control_dependencies([train_step, ema_op]):
          train_op = tf.no_op(name='train')

    with tf.Session() as sess:
          init_op = tf.global_variables_initializer()
          sess.run(init_op)

    for i in range(STEPS): 
          sess.run(train_step, feed_dict={x: , y_: })
          if i % 輪數(shù) ==0: 
              print( ) 

#判斷運(yùn)行的是否為主文件

#判斷運(yùn)行的是否為主文件
if __name__=='__main__': 
     backward()

MNIST數(shù)據(jù)集下載地址:https://download.csdn.net/download/kdongyi/10666285

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲(chǔ)服務(wù)。