語音轉寫正確率計算

語音轉文字準確率計算 轉寫正確率的衡量項:ACC、Corr
H = 正確的字數
D = 刪減的錯誤,“我是中國人” “我是中國”
S = 替換的錯誤,“我是中國人” “我是中華人”
I = 插入的錯誤,“我是中國人” “我是中國男人”
N = 總字數
ACC = (H - I)/N
Corr= H/N

思路:
sample "我是中國人學大生,今天要測試,錄音結束"
test "中國人民生,今天有個大事情,我想吃飯"

語音轉寫文字,需要遵從文字的語義,所以不能文字出現就算正確,不考慮各種復雜的因素,
要從test中找到sample中對應的字符,并且順序要按sample中的字符順序(排除標點符號)

如:
sample中每個字找到的順序


image.png

那么如何找到正確的文字個數呢?應該就是在這組序號中剔除未找到的-1,然后從剩下的序號中找到最長升序序列(竟然是個算法問題,丟!)


image.png

如圖,剔除-1,剩下的就是12,0,1,2,9,4,5,6,最長升序那不就是0,1,2,4,5,6嗎,對應的文字就是“中國人生今天”,那么算法如何實現呢?

這里給出個笨辦法:
遍歷序列,將每個升序數組都保存在list里,如果不是升序,就新建一個list。
如當遇到12,則新建一個list:


image.png

當遇到0,則需要新建一個list


image.png

另外還需要注意,即使是升序,也不一定就能組成最長
如0 1 2,后面是9,如果加進去了,就只能組成0 1 2 9,長度為4,如果放棄加9,則有機會組成0 1 2 4 5 6,長度為6。所以每次添加一個符合升序規則的數字的時候,我們要提前將原list備份一個,留個機會看能否組成更長的序列。

剩下就是計算刪減、替換、添加的文字個數了,這個比較簡單,看看代碼邏輯就行了。
本文代碼沒有考慮時間復雜度和內存占用(用于測試),請自行優化。
上代碼:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Set;
import java.util.function.IntPredicate;
import java.util.logging.Logger;

/**
 * 語音轉文字準確率計算 轉寫正確率的衡量項:ACC、Corr H = 正確的字數 D = 刪減的錯誤,“我是中國人”?“我是中國” S =
 * 替換的錯誤,“我是中國人”?“我是中華人” I = 插入的錯誤,“我是中國人”?“我是中國男人” N = 總字數 ACC = (H - I)/N
 * Corr= H/N
 */

public class StringCompare {
    private static final String ORIGINAL_STRING = "我是中國人學大生,今天要測試,錄音結束";

    private static final String TEST_STRING = "中國人民生,今天有個大事情,我想吃飯";

    private static float mAcc = 0;
    private static float mCorr = 0;

    private static int H = 0;// 正確的字數
    private static int N = 0;// 總字數
    private static int D = 0;// 刪減的字數
    private static int S = 0;// 替換的字數
    private static int I = 0;// 插入的字數

    private static List<Character> arrayD = new ArrayList<Character>();
    private static List<Character> arrayI = new ArrayList<Character>();
    private static List<Character> arrayS = new ArrayList<Character>();

    /**
     * 移除標點符號
     */
    private static String removePunctuation(String sentence) {
        String string = sentence.replaceAll("\\pP", "");// 完全清除標點
        System.out.println(string);
        return string;
    }

    private static void print() {
        mAcc = (float) (H - I) / (float) N;
        mCorr = H / (float) N;
        System.out.println("H:" + H + ", N:" + N + ", D:" + D + ", I:" + I + ", S:" + S);
        System.out.println("mAcc" + mAcc + ",mCorr:" + mCorr);
    }

    /*
     * 獲取A字符串每個字符在B字符串中的位置
     */
    private static int[] getInIndex(String original, String test) {
        char[] charOriginal = original.toCharArray();
        List<Integer> list = new ArrayList<Integer>();
        for (int i = 0; i < charOriginal.length; i++) {
            int tempIndex = test.indexOf(charOriginal[i]);
            System.out.println("tempIndex:" + tempIndex);
            while (list.contains(tempIndex) && tempIndex != -1) {
                System.out.println("while");
                tempIndex = test.indexOf(charOriginal[i], tempIndex + 1);
                System.out.println("tempIndex:" + tempIndex);
            }
            list.add(tempIndex);
            System.out.println("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
        }
        return list.stream().mapToInt(Integer::valueOf).toArray();
    }

    // 找出最長的有序list
    private static Map<Integer, Integer> getMaxSortedIndexs(int[] index) {
        List<Map<Integer, Integer>> list = new ArrayList<Map<Integer, Integer>>();
        for (int i = 0; i < index.length; i++) {

            if (index[i] >= 0) {
                System.out.println("cur index:" + index[i]);
                if (list.size() == 0) {
                    list.add(new LinkedHashMap<Integer, Integer>());
                    list.get(list.size() - 1).put(i, index[i]);
                } else {
                    ListIterator<Map<Integer, Integer>> it = list.listIterator();
                    while (it.hasNext()) {
                        Map<Integer, Integer> everyList = it.next();
                        // 與最后一個元素比較
                        if (index[i] > (int) everyList.values().toArray()[everyList.size() - 1]) {
                            Map<Integer, Integer> tempList = new LinkedHashMap<Integer, Integer>();
                            tempList.putAll(everyList);
                            everyList.put(i, index[i]);
                            it.add(tempList);
                        } else {
                            Map<Integer, Integer> tempList = new LinkedHashMap<Integer, Integer>();
                            tempList.put(i, index[i]);
                            it.add(tempList);
                        }
                    }
                }
            }
        }
        System.out.println("list size:" + list.size());
        int longestIndex = 0;
        int maxlength = 0;
        for (int i = 0; i < list.size(); i++) {
            System.out.println("list[" + i + "]:" + list.get(i).toString());
            if (list.get(i).size() > maxlength) {
                maxlength = list.get(i).size();
                longestIndex = i;
            }
        }
        return list.get(longestIndex);
    }

    public static void main(String[] args) {
        long time = System.currentTimeMillis();
        // 拿到每個字符在測試字符串中的位置
        String original = removePunctuation(ORIGINAL_STRING);
        String test = removePunctuation(TEST_STRING);
        int[] index = getInIndex(original, test);
        for (int i = 0; i < index.length; i++) {
            System.out.print(index[i] + "\t");
        }
        System.out.print("\n");
        Map<Integer, Integer> sortedIndex = getMaxSortedIndexs(index);
        System.out.println("cost:" + (System.currentTimeMillis() - time));
        Set<Integer> keySet = sortedIndex.keySet();
        Collection<Integer> valueSet = sortedIndex.values();
        System.out.println("在original字符串中正確的字符");
        System.out.println("[" + ORIGINAL_STRING + "]");
        for (int key : keySet) {
            System.out.println("index:" + key + "value:" + original.charAt(key));
        }
        System.out.println("--------------------------------------------------------------");
        System.out.print("在test字符串中正確的字符\n");
        System.out.println("[" + TEST_STRING + "]");
        for (int key : valueSet) {
            System.out.println("index:" + key + "value:" + test.charAt(key));
        }
        System.out.println("--------------------------------------------------------------");
        System.out.print("計算準確率\n");

        // 每一段的差值都需要計算
        int tempkey = 0;
        int tempValue = 0;
        // 保存每一段的A字符序號的差
        int diffKey = 0;
        // 保存每一段的B字符序號的差
        int diffValue = 0;
        int indexFlag == 0;
        for (Integer key : sortedIndex.keySet()) {
            int value = sortedIndex.get(key);
            if (tempkey == 0 && key != 0 && indexFlag == 0) {
                diffKey = key - tempkey;
            } else if (key == 0) {
                diffKey = 0;
            } else {
                diffKey = key - tempkey - 1;
            }
            if (tempValue == 0 && value != 0 && indexFlag  == 0) {
                diffValue = value - tempValue;
            } else if (value == 0) {
                diffValue = 0;
            } else {
                diffValue = value - tempValue - 1;
            }
            System.err.println("diffKey:" + diffKey + ", diffValue:" + diffValue);
            if (diffKey > diffValue) {
                D += diffKey - diffValue;
                S += diffValue;
            } else if (diffKey == diffValue) {
                S += diffValue;
            } else {
                I += diffValue - diffKey;
                S += diffKey;
            }
            tempkey = key;
            tempValue = value;
            indexFlag ++;
        }
        System.out.println("tempkey:" + tempkey + ",tempValue:" + tempValue);
        diffKey = original.length() - tempkey - 1;
        diffValue = test.length() - tempValue - 1;
        System.out.println("diffKey:" + diffKey + ",diffValue:" + diffValue);
        if (diffKey > diffValue) {
            D += diffKey - diffValue;
            S += diffValue;
        } else if (diffKey == diffValue) {
            S += diffValue;
        } else {
            I += diffValue - diffKey;
            S += diffKey;
        }
        H = sortedIndex.size();
        N = test.length();
        print();
        System.out.println("--------------------------------------------------------------");
    }
}

C#版本

using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using System.Linq;

namespace Calibration.utils
{
    class EsrUtils
    {
        private EsrUtils() { }

        /*private static readonly EsrUtils singleInstance = new EsrUtils();

        public static EsrUtils GetInstance
        {
            get
            {
                return singleInstance;
            }
        }*/

        //移除標點符號
        public static string RemovePunctuation(string sentence)
        {
            return Regex.Replace(sentence, "[ \\[ \\] \\^ \\-_*×――(^)$%~!@#$…&%¥—+=<>《》!!???::?`·、。,;,.;\"‘’“”-]", "");
        }


        // 獲取A字符串每個字符在B字符串中的位置
        public static int[] GetInIndex(String original, String test)
        {
            char[] charOriginal = original.ToCharArray();
            var list = new List<int>();
            for (int i = 0; i < charOriginal.Length; i++)
            {
                int tempIndex = test.IndexOf(charOriginal[i]);
                Console.WriteLine("tempIndex:" + tempIndex);
                while (list.Contains(tempIndex) && tempIndex != -1)
                {
                    Console.WriteLine("while");
                    tempIndex = test.IndexOf(charOriginal[i], tempIndex + 1);
                    Console.WriteLine("tempIndex:" + tempIndex);
                }
                list.Add(tempIndex);
                Console.WriteLine("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
            }
            return list.ToArray();
        }

        // 找出序號數組中最長的升序子序列
        // 目的
        public static Dictionary<int, int> GetMaxSortedIndexs(int[] index)
        {
            List<Dictionary<int, int>> list = new List<Dictionary<int, int>> { };
            for (int i = 0; i < index.Length; i++)
            {

                if (index[i] >= 0)
                {
                    Console.WriteLine("cur index:" + index[i]);
                    if (list.Count == 0)
                    {
                        list.Add(new Dictionary<int, int>());
                        list.Last().Add(i, index[i]);
                    }
                    else
                    {
                        List<Dictionary<int, int>> listBackup = new List<Dictionary<int, int>> { };
                        for (int j = 0; j < list.Count; j++)
                        {
                            Dictionary<int, int> everyList = list[j];
                            // 與最后一個元素比較
                            if (index[i] > everyList.Values.Last())
                            {
                                // 將當前Dictionary備份一個,因為當前的數據添加或者不添加會有兩種結果
                                // 如數組 12 0 1 2 9 4 5 6
                                // 如果 0 1 2 后加了9,那只有 0 1 2 9長度為4
                                // 如果 0 1 2 不加9,那就有0 1 2 4 5 6,長度為6
                                Dictionary<int, int> tempList = new Dictionary<int, int>(everyList);
                                listBackup.Add(tempList);
                                everyList.Add(i, index[i]);
                            }
                            else
                            {
                                Dictionary<int, int> tempList = new Dictionary<int, int>();
                                tempList.Add(i, index[i]);
                                listBackup.Add(tempList);
                            }
                        }
                        // list 合并
                        list = list.Union(listBackup).ToList<Dictionary<int, int>>();
                    }
                }
            }
            Console.WriteLine("list size:" + list.Count);
            int longestIndex = 0;
            int maxlength = 0;
            for (int i = 0; i < list.Count; i++)
            {
                Console.WriteLine("list[" + i + "]:" + list[i]);
                if (list[i].Count > maxlength)
                {
                    maxlength = list[i].Count;
                    longestIndex = i;
                }
            }
            return list[longestIndex];
        }

        public static List<float[]> GetParameters(string originalString, List<string> testString)
        {
            List<float[]> resultList = new List<float[]> { };
            try
            {
                string original = RemovePunctuation(originalString);
                for (int i = 0; i < testString.Count; i++)
                {
                    string test = testString[i];
                    test = RemovePunctuation(test);
                    int[] index = EsrUtils.GetInIndex(original, test);
                    Console.WriteLine("index:" + index);
                    Dictionary<int, int> dic = EsrUtils.GetMaxSortedIndexs(index);
                    int H = 0;
                    int N = 0;
                    int I = 0;
                    int S = 0;
                    int D = 0;
                    float corr = 0;
                    float acc = 0;
                    Dictionary<int, int> sortedIndex = GetMaxSortedIndexs(index);
                    Console.WriteLine("在original字符串中正確的字符");
                    Console.WriteLine("[" + originalString + "]");
                    foreach (int key in dic.Keys)
                    {
                        Console.WriteLine("index:" + key + "value:" + original.ToCharArray()[key]);
                    }
                    Console.WriteLine("--------------------------------------------------------------");
                    Console.WriteLine("在test字符串中正確的字符\n");
                    Console.WriteLine("[" + testString[i] + "]");
                    foreach (int key in dic.Values)
                    {
                        Console.WriteLine("index:" + key + "value:" + test.ToCharArray()[key]);
                    }
                    Console.WriteLine("--------------------------------------------------------------");
                    Console.WriteLine("計算準確率\n");

                    // 每一段的差值都需要計算
                    /*
                    int tempkey = 0;
                    int tempValue = 0;
                    int diffKey = 0;
                    int diffValue = 0;
                    foreach (int key in sortedIndex.Keys)
                    {
                        int value = sortedIndex[key];
                        diffKey = key - tempkey;
                        diffValue = value - tempValue;
                        if (diffKey > diffValue)
                        {
                            D += diffKey - diffValue;
                            S += diffValue;
                        }
                        else if (diffKey == diffValue)
                        {
                            S += diffValue;
                        }
                        else
                        {
                            I += diffValue - diffKey;
                            S += diffKey;
                        }
                        tempkey = key;
                        tempValue = value;
                    }
                    Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
                    diffKey = original.Length - tempkey;
                    diffValue = test.Length - tempValue;
                    if (diffKey > diffValue)
                    {
                        D += diffKey - diffValue;
                        S += diffValue;
                    }
                    else if (diffKey == diffValue)
                    {
                        S += diffValue;
                    }
                    else
                    {
                        I += diffValue - diffKey;
                        S += diffKey;
                    }
                    H = sortedIndex.Count;
                    N = test.Length;
                    */
                    // 每一段的差值都需要計算
                    int tempkey = 0;
                    int tempValue = 0;
                    // 保存每一段的A字符序號的差
                    int diffKey = 0;
                    // 保存每一段的B字符序號的差
                    int diffValue = 0;
                    int indexFlag == 0;
                    foreach (int key in sortedIndex.Keys)
                    {
                        int value = sortedIndex[key];
                        if (tempkey == 0 && key != 0 && indexFlag  == 0) //判斷第一位元素
                        {
                            diffKey = key - tempkey;
                        }
                        else if (key == 0)
                        {
                            diffKey = 0;
                        }
                        else
                        {
                            diffKey = key - tempkey - 1;
                        }
                        if (tempValue == 0 && value != 0 && indexFlag  == 0) //判斷第一位元素
                        {
                            diffValue = value - tempValue;
                        }
                        else if (value == 0)
                        {
                            diffValue = 0;
                        }
                        else
                        {
                            diffValue = value - tempValue - 1;
                        }
                        Console.WriteLine("diffKey:" + diffKey + ", diffValue:" + diffValue);
                        if (diffKey > diffValue)
                        {
                            D += diffKey - diffValue;
                            S += diffValue;
                        }
                        else if (diffKey == diffValue)
                        {
                            S += diffValue;
                        }
                        else
                        {
                            I += diffValue - diffKey;
                            S += diffKey;
                        }
                        tempkey = key;
                        tempValue = value;
                        indexFlag ++;
                    }
                    Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
                    diffKey = original.Length - tempkey - 1;
                    diffValue = test.Length - tempValue - 1;
                    Console.WriteLine("diffKey:" + diffKey + ",diffValue:" + diffValue);
                    if (diffKey > diffValue)
                    {
                        D += diffKey - diffValue;
                        S += diffValue;
                    }
                    else if (diffKey == diffValue)
                    {
                        S += diffValue;
                    }
                    else
                    {
                        I += diffValue - diffKey;
                        S += diffKey;
                    }
                    H = sortedIndex.Count;
                    N = test.Length;
                    float[] result = new float[7];
                    acc = (float)(H - I) / (float)N;
                    corr = H / (float)N;
                    result[0] = acc;
                    result[1] = corr;
                    result[2] = H;
                    result[3] = N;
                    result[4] = D;
                    result[5] = S;
                    result[6] = I;
                    Console.WriteLine("acc" + acc + ",corr: " + corr);
                    resultList.Add(result);
                }
            }
            catch (Exception e)
            {
                Console.WriteLine("error accur:" + e.ToString());
                return null;
            }
            return resultList;
        }

        internal static object GetInstance()
        {
            throw new NotImplementedException();
        }
    }
}

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容

  • 用到的組件 1、通過CocoaPods安裝 2、第三方類庫安裝 3、第三方服務 友盟社會化分享組件 友盟用戶反饋 ...
    SunnyLeong閱讀 14,659評論 1 180
  • 用兩張圖告訴你,為什么你的 App 會卡頓? - Android - 掘金 Cover 有什么料? 從這篇文章中你...
    hw1212閱讀 12,861評論 2 59
  • Java工程師成神之路數據一、基礎篇1.1 JVM1.1.1. Java內存模型,Java內存管理,Java堆和棧...
    漂泊的靈魂閱讀 483評論 0 4
  • 16宿命:用概率思維提高你的勝算 以前的我是風險厭惡者,不喜歡去冒險,但是人生放棄了冒險,也就放棄了無數的可能。 ...
    yichen大刀閱讀 6,098評論 0 4
  • 公元:2019年11月28日19時42分農歷:二零一九年 十一月 初三日 戌時干支:己亥乙亥己巳甲戌當月節氣:立冬...
    石放閱讀 6,909評論 0 2