優化算法中梯度下降算法的編程實現

優化算法中梯度下降算法的編程實現

簡介

梯度下降算法是運籌學的基礎數學方法,用來求解運籌學所構造的數學問題。

本文在Linux平臺下,采用C++語言編寫梯度下降算法的實現程序。

程序的基本思路是以虛基類構建計算流程,以繼承類定義代價函數(Cost Function)的形式,代價函數由用戶自定義,繼承關系由C++的多態特性輔助。

優化算法通過三個部分實現,殘差塊(Residual Block)構造、代價函數(Cost Function)構造和優化計算(Optimization)。由于優化計算的框架是固定的,只有具體到殘差塊函數、代價函數等是不同的,因此將優化計算整體拆分為計算框架和具體內容兩個部分,其中的計算框架部分由程序構造,具體內容由用戶自定義。為了達成這一思路,程序以虛基類構建計算流程,以繼承類定義殘差塊、代價函數等的形式,代價函數由用戶自定義,繼承關系由C++的多態特性輔助。

程序包含三個虛基類,殘差塊類(ResidualBlockFunction)、代價函數類(CostFunction)和優化算法管理類(OptimizationManager),殘差塊類負責定義殘差塊的數學結構,為最基礎的類;代價函數類負責構造代價函數的數學結構,該結構通過殘差塊構造,并且負責代價函數的導數求解,導數使用數值微分的方法進行求解;優化算法管理類統籌整個優化計算的實現過程。

虛基類名稱以"User"開頭,繼承類名稱以描述自身目標的詞匯而非"User"開頭。

程序結構

殘差塊類

殘差塊類包含兩個部分,虛基類和繼承類。虛基類名為"UserResidualBlockFunction",繼承類名為"PolyResidualBlockFunction"。

代價函數類

代價函數類包含兩個部分,虛基類"UserCostFunction"和繼承類"SteepestCostFunction"。

代價函數類由優化算法管理類調用,負責為優化算法管理類提供代價函數的函數值、一階導數值(也稱為梯度、Jacobi)、二階導數(即黑森矩陣,Hessian Matrix)等,因此其核心功能是計算代價函數值、計算代價函數的導數值。同時,代價函數類負責管理殘差塊,由此,代價函數類的核心功能還包括殘差塊的添加。

一個典型的代價函數類的虛基類如下,核心功能所指的函數包括了添加計算代價函數值bool CostFunction,計算代價函數的一階導數值bool DerivativesFunction,殘差塊函數void AddResidualBlock。而其它函數是用于輔助核心功能的,包括計算代價函數對某一參數的一階偏導數值的函數bool GetOneDerivative,設定迭代步長void SetStepLength

在虛基類的這些函數中,有一部分是純虛函數(Pure Virtual Function),它們的定義交由繼承類完成,并且繼承類必須定義,否則程序無法完成編譯。這些函數都是與用戶的選擇掛鉤的,包括添加殘差塊函數和計算代價函數值的函數。殘差塊和代價函數必須由用戶自行定義,雖然在大量研究文獻中,殘差塊都定義為觀測數值與理論數值之差,代價函數都定義為殘差塊的平方和,但這不代表殘差塊和代價函數沒有其它定義方式,因此,這兩個函數的定義交由繼承類完成。

class UserCostFunction
{
    public:
        UserCostFunction(string name, int SizeObservations, int SizeVariables, int SizeResiduals);
        ~UserCostFunction();

    public:
        // pure virtual
        virtual void AddResidualBlock(vector<double> observations) = 0;
        virtual bool CostFunction(vector<double> variables, vector<double> &CostFunctionValues)=0;

    public:
        // virtual
        virtual bool GetOneDerivative(int VarialbleID, vector<double> variables, double &theDerivativeValue);

    public:                                                                                                                                                                                                      
        bool DerivativesFunction(vector<double> variables, vector<double> &theDerivatives);
        void SetStepLength(double delta);

    public:
        virtual void Show() = 0;

    protected:
        vector<UserResidualBlockFunction*> ResidualBlockFunctions_;
        int SizeObservations_;
        int SizeVariables_;
        int SizeResiduals_;

        // for derivative calculation
        double delta_;

    private:
        string name_;
};

<center>虛基類"UserCostFunction"的頭文件部分</center>

class SteepestCostFunction : virtual public UserCostFunction
{
    public:
        SteepestCostFunction(string name, int SizeObservations, int SizeVariables, int SizeResiduals);                                                                                                           
        ~SteepestCostFunction();
            
    public:
        void Show();

    public:
        virtual void AddResidualBlock(vector<double> observations);
        virtual bool CostFunction(vector<double> variables, vector<double> &CostFunctionValues);

    public:
        virtual bool GetOneDerivative(int VarialbleID, vector<double> variables, double &theDerivativeValue);

    private:
        string name_;
};

<center>繼承類"SteepestCostFunction"的頭文件部分</center>

代價函數值的計算框架

本文設定代價函數為F(A),殘差塊為f_i(A),參量以矩陣形式表達,這里假設參量數量為3,則參量為A[a_0,a_1,a_2]^T,觀測數據為x_iy_i,假設其理論數學關系符合三階多項式,y_i=a_0+a_1x_i+a_2x_i^2。雖然代價函數F(A)的形式可以根據用戶實際使用而不同,但殘差塊的平方和形式仍然在大量文獻中被采用,如下。
F(A) = \sum_{i=1}^{m} {(f_i(A))^2}
殘差塊f_i(A)也面臨相同的情況,雖然可以根據用戶使用情況而不同,但觀測值與理論值之差的形式依然是廣泛使用的形式,如下。
f_i(A) = y_i-(a_0+a_1x_i+a_2x_i^2)
采用上述形式,則代價函數可以表示如下。
F(A)= (y_0-(a_0+a_1x_0+a_2x_0^2))^2 \\ +(y_1-(a_0+a_1x_1+a_2x_1^2))^2 \\+\cdots \\+ (y_m-(a_0+a_1x_m+a_2x_m^2))^2
代價函數值的計算由函數bool CostFunction完成,由于該函數是核心功能,其名稱必須固定,因此該函數作為虛基類"UserCostFunction"的純虛函數進行聲明,在繼承類"PolyResidualBlockFunction"中進行定義。

一階導數的計算框架

優化算法管理類

優化算法管理類包含兩個部分,虛基類"UserOptimizationManager"和繼承類"SteepestOptimizationManager"。÷

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容

  • Swift1> Swift和OC的區別1.1> Swift沒有地址/指針的概念1.2> 泛型1.3> 類型嚴謹 對...
    cosWriter閱讀 11,136評論 1 32
  • 前言 把《C++ Primer》[https://book.douban.com/subject/25708312...
    尤汐Yogy閱讀 9,533評論 1 51
  • 今天晚上值班,每次值班大半夜都是電腦上看電影,連續劇。今天看了幾集連續劇,又看看其他的,此時此刻突然發現,今天的日...
    黃灰紅閱讀 162評論 0 0
  • 你好嗎?我親愛的!我知道你一直在陪伴著我! 你好嗎?當你出現糾結難受的時候,我知道,你在掉入到事情里面了,我知道你...
    COCO之聲閱讀 375評論 2 1