機器學習筆記-基于梯度下降的曲線擬合

背景

7月份的時候導師布置了個作業,他給了一條用程序生成的曲線,然后讓我們用代碼實現一個梯度下降算法來擬合曲線。具體要求:

data.csv文件中包含兩列用逗號分隔的數據。第一列是x,第二列是y。完成如下工作:
(1)在data.csv中隨機選擇80%的數據作為訓練集,剩余20%作為測試集。
(2)構造模型,采用梯度下降算法訓練模型。
(3)用測試集對訓練的模型進行評估,將測試集中的x作為輸入,用模型計算y,計算預測值與實際值的RMSE。
(4)繪制data.csv中的點,繪制x ∈ [0,1] 之間模型的對應曲線。

數據格式如下:

0.000000000000000000,0.000045401991009684
0.010010010010010010,0.000067487908347918
0.020020020020020020,0.000099516665248245
0.030030030030030030,0.000145574221405758
0.040040040040040040,0.000211247752152538
0.050050050050050046,0.000304101936049645
0.060060060060060060,0.000434277611628926
0.070070070070070073,0.000615236631426893
0.080080080080080079,0.000864687227990188
0.090090090090090086,0.001205760122738213
0.100100100100100092,0.001668621265042236

上面的csv文件一共有1000行數據,在xy平面上繪制出來的曲線如下:


思路

老師的意思是先猜這條曲線是什么函數的曲線(先確定函數的基本形式),一開始函數的具體參數是不知道的,需要猜幾個初始值,那么猜出來的曲線一定和實際曲線有較大差異,再用最優化的方法找到使差異最小化的函數參數,從而實現曲線的擬合。這里要求實現梯度下降算法來求解最小值。

從曲線的圖像來看原始數據應該是幾個均值方差不同的高斯函數疊加而成的,圖中有4個峰,因此可以假設曲線的模型為:f(x)=\alpha_1e^{-\frac{(x-\mu_1)^2}{2\sigma^2_1}}+\alpha_2e^{-\frac{(x-\mu_2)^2}{2\sigma^2_2}}+\alpha_3e^{-\frac{(x-\mu_3)^2}{2\sigma^2_3}}+\alpha_4e^{-\frac{(x-\mu_4)^2}{2\sigma^2_4}}。
令誤差函數為E=\sum\limits_{i=1}^{n} (f(x_i) - y_i)^2。則理想的模型參數:
(\alpha_1,\mu_1,\sigma_1,\alpha_2,...,\sigma_4)=\min\limits_{\alpha_1,...,\sigma_4}E

梯度下降算法每次求出函數(E)在某個點(當前參數)的梯度,因為梯度就是函數值增長最快的那個方向,所以讓參數沿著梯度的負方向乘以一定的步長進行更新,就一定能抵達一個局部極小點。所以只要給定了這里的誤差函數E(\alpha_1,\mu_1,\sigma_1,\alpha_2,\mu_2,\sigma_2,\alpha_3,\mu_3,\sigma_3,\alpha_4,\mu_4,\sigma_4),就可以通過梯度下降算法來找到使誤差函數達到局部極小的12個參數。

為了便于計算,可以把\sigma^2當成一個整體,此時需要求出E在某個點的梯度的一般表示:(\frac{\partial E}{\partial \alpha_1},\frac{\partial E}{\partial \mu_1},\frac{\partial E}{\partial \sigma_1^2},\frac{\partial E}{\partial \alpha_2},\frac{\partial E}{\partial \mu_2},\frac{\partial E}{\partial \sigma_2^2},\frac{\partial E}{\partial \alpha_3},\frac{\partial E}{\partial \mu_3},\frac{\partial E}{\partial \sigma_3^2},\frac{\partial E}{\partial \alpha_4},\frac{\partial E}{\partial \mu_4},\frac{\partial E}{\partial \sigma_4^2},)。其中\frac{\partial E}{\partial \alpha_1}=2\sum\limits_{i=1}^{n}((f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \mu_1}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)}{\sigma_1^2}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \sigma_1^2}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)^2}{2\sigma_1^4}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}),其余參數的偏導數以此類推。

設定一個迭代次數,每次求出誤差函數的梯度后,設定步長\eta,讓參數沿梯度的負方向更新,如:\alpha_1=\alpha_1-\eta\frac{\partial E}{\partial \alpha_1},\mu_1=\mu_1-\eta\frac{\partial E}{\partial \mu_1},然后重復這個步驟,直到達到一定迭代次數或者總誤差小于一定閾值停止迭代。

程序

程序使用Java實現。(C++寫起來麻煩而且沒有合適的圖表顯示庫,Python太慢,Java寫起來最順手)

一開始我面臨的問題就是選擇一個圖表顯示庫,簡單地調研了一下選了XChart,但是去了該項目的Github主頁發現居然沒有打包好的 jar 包,于是需要 clone 下來然后用 mvn package 命令把 jar 包打出來。

然后我定義了一個模型類 Model,這個模型類的成員變量是 double數組,用來放待調的參數,比如上文中的f(x)對應的參數數組長度就為12。Model類有一些待實現的方法如函數的求值(val)、梯度的求值(grad)等,其派生類GaussianModel就是上文中的模型。另外,因為梯度下降會抵達最近的極小點而不是全局最小點,最終的收斂點極大依賴于參數的初始值,我每次隨機選取了一部分數據點來求梯度以跳出局部極小。

Java代碼如下:

package com.company;

import org.knowm.xchart.QuickChart;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.lang.Math.E;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;
import static java.lang.System.exit;


public class Solver {

    private List<Point> rawData = new ArrayList<>();
    private List<Point> trainData = new ArrayList<>();
    private List<Point> testData = new ArrayList<>();
    private Model model = null;
    private Function<Model, Double> loss = null;

    public Solver(String csvPath) throws FileNotFoundException {
        Scanner scanner = new Scanner(new File(csvPath));
        while (scanner.hasNextLine()) {
            String[] xy = scanner.nextLine().split(",");
            rawData.add(new Point(Double.valueOf(xy[0]), Double.valueOf(xy[1])));
        }
    }

    private Function<Model, Double> mse = (m) -> {
        double lossSum = 0.0;
        for (Point p : trainData) {
            double diff = m.val(p.x) - p.y;
            lossSum += (diff * diff);
        }
        return lossSum / 2.0;
    };

    private void divide(float ratio4Train) {
        trainData.clear();
        testData.clear();
        if (ratio4Train <= 0) throw new IllegalArgumentException("Ratio <= 0");
        int testCount = (int) (rawData.size() * (1 - ratio4Train));
        Random rand = new Random(System.currentTimeMillis());
        Set<Integer> exclusiveIndices4Test = new HashSet<>();
        while (exclusiveIndices4Test.size() < testCount) {
            int index = rand.nextInt(rawData.size());
            if (! exclusiveIndices4Test.contains(index)) {
                testData.add(rawData.get(index));
                exclusiveIndices4Test.add(index);
            }
        }
        for (int i = 0; i < rawData.size(); i ++) {
            if (! exclusiveIndices4Test.contains(i)) {
                trainData.add(rawData.get(i));
            }
        }
    }

    private void train() {
        System.out.println("Train data size: " + trainData.size());
        System.out.println("Test data size: " + testData.size());
//        model = new PolyModel(4);
        model = new GaussianModel(5);
        loss = mse;
        // ==========================================================
        for (int i = 0; i < 10000; i ++) {
            double lossVal = loss.apply(model);
            double[] gradVal = model.grad(trainData);
            System.out.println(String.format("Iter: %d, loss: %f ", i, lossVal));
            System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
            System.out.println(String.format("Grad: %f, %f, %f\n", gradVal[0], gradVal[1], gradVal[2]));
            if (Double.isNaN(lossVal)) {
                model.randomize(); i = 0;
                continue;
            }
            for (int j = 0; j < gradVal.length; j ++) {
                double delta = model.rate(j) * gradVal[j];
                model.theta[j] -= delta;
            }
//            if (lossVal < 1.06) break;
        }
        System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
    }

    private void validate() {
        double RMSE = 0.0;
        for (Point p : testData) {
            double diff = model.val(p.x) - p.y;
            RMSE += (diff * diff);
        }
        RMSE /= testData.size();
        RMSE = sqrt(RMSE);
        System.out.println("RMSE: " + RMSE);
    }

    private void plot() {
        XYChart chart = QuickChart.getChart(
                "Result", "X", "Y", "y(x)",
                trainData.stream().map(point -> point.x).collect(Collectors.toList()),
                trainData.stream().map(point -> point.y).collect(Collectors.toList()));

        double[] xPoints = new double[150];
        double[] yPoints = new double[150];
        for (int i = 0; i < 150; i ++) {
            xPoints[i] = i * 10.0 / 150;
            yPoints[i] = model.val(xPoints[i]);
        }
        chart.addSeries("model", xPoints, yPoints);

        new SwingWrapper<XYChart>(chart).displayChart();
    }

    public void solve() {
        divide(0.8f);
        train();
        validate();
        plot();
    }

    public static void main(String[] args) throws FileNotFoundException {
    // write your code here
        if (args.length < 1) {
            System.out.println("Usage: java -jar GradientDesent.jar data.csv");
            exit(0);
        }
        new Solver(args[0]).solve();
    }

    private static class Point {
        double x;
        double y;
        public Point(double x, double y) {this.x = x; this.y = y;}

    }

    private static abstract class Model {
        double theta[] = null;
        abstract double val(double x);
        abstract double[] grad(List<Point> trainData);
        abstract void randomize();
        abstract double rate(int i);
    }

    private static class PolyModel extends Model{

        public PolyModel(int n) {
            if (n < 2) throw new IllegalArgumentException("n MUST be larger than 2.");
            theta = new double[n];
            randomize();
        }

        double val(double x) {
            double result = 0.0;
            for (int i = 0; i < theta.length; i ++) {
                result += theta[i] * pow(x, i);
            }
            return result;
        }

        @Override
        double[] grad(List<Point> trainData) {
            double []gradVec = new double[theta.length];
            for (int i = 0; i < gradVec.length; i ++) {
                gradVec[i] = 0.0;
                Random r = new Random();
                List<Point> data = new ArrayList<>();
                for (int k = 0; k < 50; k ++)
                    data.add(trainData.get(r.nextInt(trainData.size())));
                for (Point p : data) {
                    double diff = val(p.x) - p.y;
                    gradVec[i] += (diff * pow(p.x, i));
                }
            }
            return gradVec;
        }

        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());
            for (int i = 0; i < theta.length; i ++) {
                theta[i] = rand.nextDouble() ;
            }
        }

        @Override
        double rate(int i) {
            return 0.00000002;
        }
    }

    private static class GaussianModel extends Model{

        /**
         * f(x) = a * e ^ (- (x - μ)^2 / σ^2)
         * (a, μ, σ2) <<----
         * @param n number of gaussian function
         */
        public GaussianModel(int n) {
            if (n < 1) throw new IllegalArgumentException("n MUST be larger than 1.");
            theta = new double[n * 3];
            randomize();
        }

        @Override
        double val(double x) {
            double result = 0.0;
            for (int i = 0; i < theta.length / 3; i ++) {
                double alpha = theta[i * 3 + 0];
                double miu = theta[i * 3 + 1];
                double sigma2 = theta[i * 3 + 2];
                result += (alpha * pow(E, - pow((x - miu), 2) / sigma2 / 2));
            }
            return result;
        }

        @Override
        double[] grad(List<Point> trainData) {
            double[] gradVec = new double[theta.length];
            for (int i = 0; i < theta.length / 3; i ++) {
                gradVec[i * 3 + 0] = 0;
                gradVec[i * 3 + 1] = 0;
                gradVec[i * 3 + 2] = 0;
                double alpha = theta[i * 3 + 0];
                double miu = theta[i * 3 + 1];
                double sigma2 = theta[i * 3 + 2];
                Random r = new Random();
                List<Point> stochasticData = new ArrayList<>();
                for (int k = 0; k < 30; k ++)
                    stochasticData.add(trainData.get(r.nextInt(trainData.size())));
                for (Point p : stochasticData) {
                    double val = val(p.x);
                    gradVec[i * 3 + 0] += 2
                            * (val - p.y)
                            * (pow(E, - pow((p.x - miu), 2) / sigma2 / 2));
                    gradVec[i * 3 + 1] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * ((p.x - miu) / sigma2));
                    gradVec[i * 3 + 2] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * (pow((p.x - miu), 2) / pow(sigma2, 2) / 2)); //把sigma平方當成了一個整體
                }
            }
            return gradVec;
        }

        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());
            for (int i = 0; i < theta.length / 3; i ++) {
                theta[i * 3 + 0] = rand.nextDouble();
                theta[i * 3 + 1] = rand.nextDouble() * 5;
                theta[i * 3 + 2] = rand.nextDouble();
            }
        }

        @Override
        double rate(int i) {
            if (i % 3 == 0) {
                return 0.0005;
            } else if (i % 3 == 1) { // miu
                return 0.0005;
            } else {
                return 0.00005;
            }
        }

        public String toString() {
            StringBuilder builder = new StringBuilder("Theta: ");
            for (double t : theta) {
                builder.append(t);
                builder.append(", ");
            }
            builder.append("\nGrad: ");
            return builder.toString();
        }
    }
}

最后的結果還是比較看人品的,并不是每次都能擬合地比較好,貼一個結果的圖:


結果

數據和代碼我放到了我的Github:https://github.com/Jimmie00x0000/gradient_desent_demo。

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 230,247評論 6 543
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 99,520評論 3 429
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事?!?“怎么了?”我有些...
    開封第一講書人閱讀 178,362評論 0 383
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,805評論 1 317
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 72,541評論 6 412
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,896評論 1 328
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,887評論 3 447
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 43,062評論 0 290
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,608評論 1 336
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 41,356評論 3 358
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,555評論 1 374
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 39,077評論 5 364
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,769評論 3 349
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 35,175評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,489評論 1 295
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,289評論 3 400
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,516評論 2 379

推薦閱讀更多精彩內容