語音轉文字準確率計算 轉寫正確率的衡量項:ACC、Corr
H = 正確的字數
D = 刪減的錯誤,“我是中國人” “我是中國”
S = 替換的錯誤,“我是中國人” “我是中華人”
I = 插入的錯誤,“我是中國人” “我是中國男人”
N = 總字數
ACC = (H - I)/N
Corr= H/N
思路:
sample "我是中國人學大生,今天要測試,錄音結束"
test "中國人民生,今天有個大事情,我想吃飯"
語音轉寫文字,需要遵從文字的語義,所以不能文字出現就算正確,不考慮各種復雜的因素,
要從test中找到sample中對應的字符,并且順序要按sample中的字符順序(排除標點符號)
如:
sample中每個字找到的順序
image.png
那么如何找到正確的文字個數呢?應該就是在這組序號中剔除未找到的-1,然后從剩下的序號中找到最長升序序列(竟然是個算法問題,丟!)
image.png
如圖,剔除-1,剩下的就是12,0,1,2,9,4,5,6,最長升序那不就是0,1,2,4,5,6嗎,對應的文字就是“中國人生今天”,那么算法如何實現呢?
這里給出個笨辦法:
遍歷序列,將每個升序數組都保存在list里,如果不是升序,就新建一個list。
如當遇到12,則新建一個list:
image.png
當遇到0,則需要新建一個list
image.png
另外還需要注意,即使是升序,也不一定就能組成最長
如0 1 2,后面是9,如果加進去了,就只能組成0 1 2 9,長度為4,如果放棄加9,則有機會組成0 1 2 4 5 6,長度為6。所以每次添加一個符合升序規則的數字的時候,我們要提前將原list備份一個,留個機會看能否組成更長的序列。
剩下就是計算刪減、替換、添加的文字個數了,這個比較簡單,看看代碼邏輯就行了。
本文代碼沒有考慮時間復雜度和內存占用(用于測試),請自行優化。
上代碼:
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Set;
import java.util.function.IntPredicate;
import java.util.logging.Logger;
/**
* 語音轉文字準確率計算 轉寫正確率的衡量項:ACC、Corr H = 正確的字數 D = 刪減的錯誤,“我是中國人”?“我是中國” S =
* 替換的錯誤,“我是中國人”?“我是中華人” I = 插入的錯誤,“我是中國人”?“我是中國男人” N = 總字數 ACC = (H - I)/N
* Corr= H/N
*/
public class StringCompare {
private static final String ORIGINAL_STRING = "我是中國人學大生,今天要測試,錄音結束";
private static final String TEST_STRING = "中國人民生,今天有個大事情,我想吃飯";
private static float mAcc = 0;
private static float mCorr = 0;
private static int H = 0;// 正確的字數
private static int N = 0;// 總字數
private static int D = 0;// 刪減的字數
private static int S = 0;// 替換的字數
private static int I = 0;// 插入的字數
private static List<Character> arrayD = new ArrayList<Character>();
private static List<Character> arrayI = new ArrayList<Character>();
private static List<Character> arrayS = new ArrayList<Character>();
/**
* 移除標點符號
*/
private static String removePunctuation(String sentence) {
String string = sentence.replaceAll("\\pP", "");// 完全清除標點
System.out.println(string);
return string;
}
private static void print() {
mAcc = (float) (H - I) / (float) N;
mCorr = H / (float) N;
System.out.println("H:" + H + ", N:" + N + ", D:" + D + ", I:" + I + ", S:" + S);
System.out.println("mAcc" + mAcc + ",mCorr:" + mCorr);
}
/*
* 獲取A字符串每個字符在B字符串中的位置
*/
private static int[] getInIndex(String original, String test) {
char[] charOriginal = original.toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < charOriginal.length; i++) {
int tempIndex = test.indexOf(charOriginal[i]);
System.out.println("tempIndex:" + tempIndex);
while (list.contains(tempIndex) && tempIndex != -1) {
System.out.println("while");
tempIndex = test.indexOf(charOriginal[i], tempIndex + 1);
System.out.println("tempIndex:" + tempIndex);
}
list.add(tempIndex);
System.out.println("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
}
return list.stream().mapToInt(Integer::valueOf).toArray();
}
// 找出最長的有序list
private static Map<Integer, Integer> getMaxSortedIndexs(int[] index) {
List<Map<Integer, Integer>> list = new ArrayList<Map<Integer, Integer>>();
for (int i = 0; i < index.length; i++) {
if (index[i] >= 0) {
System.out.println("cur index:" + index[i]);
if (list.size() == 0) {
list.add(new LinkedHashMap<Integer, Integer>());
list.get(list.size() - 1).put(i, index[i]);
} else {
ListIterator<Map<Integer, Integer>> it = list.listIterator();
while (it.hasNext()) {
Map<Integer, Integer> everyList = it.next();
// 與最后一個元素比較
if (index[i] > (int) everyList.values().toArray()[everyList.size() - 1]) {
Map<Integer, Integer> tempList = new LinkedHashMap<Integer, Integer>();
tempList.putAll(everyList);
everyList.put(i, index[i]);
it.add(tempList);
} else {
Map<Integer, Integer> tempList = new LinkedHashMap<Integer, Integer>();
tempList.put(i, index[i]);
it.add(tempList);
}
}
}
}
}
System.out.println("list size:" + list.size());
int longestIndex = 0;
int maxlength = 0;
for (int i = 0; i < list.size(); i++) {
System.out.println("list[" + i + "]:" + list.get(i).toString());
if (list.get(i).size() > maxlength) {
maxlength = list.get(i).size();
longestIndex = i;
}
}
return list.get(longestIndex);
}
public static void main(String[] args) {
long time = System.currentTimeMillis();
// 拿到每個字符在測試字符串中的位置
String original = removePunctuation(ORIGINAL_STRING);
String test = removePunctuation(TEST_STRING);
int[] index = getInIndex(original, test);
for (int i = 0; i < index.length; i++) {
System.out.print(index[i] + "\t");
}
System.out.print("\n");
Map<Integer, Integer> sortedIndex = getMaxSortedIndexs(index);
System.out.println("cost:" + (System.currentTimeMillis() - time));
Set<Integer> keySet = sortedIndex.keySet();
Collection<Integer> valueSet = sortedIndex.values();
System.out.println("在original字符串中正確的字符");
System.out.println("[" + ORIGINAL_STRING + "]");
for (int key : keySet) {
System.out.println("index:" + key + "value:" + original.charAt(key));
}
System.out.println("--------------------------------------------------------------");
System.out.print("在test字符串中正確的字符\n");
System.out.println("[" + TEST_STRING + "]");
for (int key : valueSet) {
System.out.println("index:" + key + "value:" + test.charAt(key));
}
System.out.println("--------------------------------------------------------------");
System.out.print("計算準確率\n");
// 每一段的差值都需要計算
int tempkey = 0;
int tempValue = 0;
// 保存每一段的A字符序號的差
int diffKey = 0;
// 保存每一段的B字符序號的差
int diffValue = 0;
int indexFlag == 0;
for (Integer key : sortedIndex.keySet()) {
int value = sortedIndex.get(key);
if (tempkey == 0 && key != 0 && indexFlag == 0) {
diffKey = key - tempkey;
} else if (key == 0) {
diffKey = 0;
} else {
diffKey = key - tempkey - 1;
}
if (tempValue == 0 && value != 0 && indexFlag == 0) {
diffValue = value - tempValue;
} else if (value == 0) {
diffValue = 0;
} else {
diffValue = value - tempValue - 1;
}
System.err.println("diffKey:" + diffKey + ", diffValue:" + diffValue);
if (diffKey > diffValue) {
D += diffKey - diffValue;
S += diffValue;
} else if (diffKey == diffValue) {
S += diffValue;
} else {
I += diffValue - diffKey;
S += diffKey;
}
tempkey = key;
tempValue = value;
indexFlag ++;
}
System.out.println("tempkey:" + tempkey + ",tempValue:" + tempValue);
diffKey = original.length() - tempkey - 1;
diffValue = test.length() - tempValue - 1;
System.out.println("diffKey:" + diffKey + ",diffValue:" + diffValue);
if (diffKey > diffValue) {
D += diffKey - diffValue;
S += diffValue;
} else if (diffKey == diffValue) {
S += diffValue;
} else {
I += diffValue - diffKey;
S += diffKey;
}
H = sortedIndex.size();
N = test.length();
print();
System.out.println("--------------------------------------------------------------");
}
}
C#版本
using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using System.Linq;
namespace Calibration.utils
{
class EsrUtils
{
private EsrUtils() { }
/*private static readonly EsrUtils singleInstance = new EsrUtils();
public static EsrUtils GetInstance
{
get
{
return singleInstance;
}
}*/
//移除標點符號
public static string RemovePunctuation(string sentence)
{
return Regex.Replace(sentence, "[ \\[ \\] \\^ \\-_*×――(^)$%~!@#$…&%¥—+=<>《》!!???::?`·、。,;,.;\"‘’“”-]", "");
}
// 獲取A字符串每個字符在B字符串中的位置
public static int[] GetInIndex(String original, String test)
{
char[] charOriginal = original.ToCharArray();
var list = new List<int>();
for (int i = 0; i < charOriginal.Length; i++)
{
int tempIndex = test.IndexOf(charOriginal[i]);
Console.WriteLine("tempIndex:" + tempIndex);
while (list.Contains(tempIndex) && tempIndex != -1)
{
Console.WriteLine("while");
tempIndex = test.IndexOf(charOriginal[i], tempIndex + 1);
Console.WriteLine("tempIndex:" + tempIndex);
}
list.Add(tempIndex);
Console.WriteLine("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
}
return list.ToArray();
}
// 找出序號數組中最長的升序子序列
// 目的
public static Dictionary<int, int> GetMaxSortedIndexs(int[] index)
{
List<Dictionary<int, int>> list = new List<Dictionary<int, int>> { };
for (int i = 0; i < index.Length; i++)
{
if (index[i] >= 0)
{
Console.WriteLine("cur index:" + index[i]);
if (list.Count == 0)
{
list.Add(new Dictionary<int, int>());
list.Last().Add(i, index[i]);
}
else
{
List<Dictionary<int, int>> listBackup = new List<Dictionary<int, int>> { };
for (int j = 0; j < list.Count; j++)
{
Dictionary<int, int> everyList = list[j];
// 與最后一個元素比較
if (index[i] > everyList.Values.Last())
{
// 將當前Dictionary備份一個,因為當前的數據添加或者不添加會有兩種結果
// 如數組 12 0 1 2 9 4 5 6
// 如果 0 1 2 后加了9,那只有 0 1 2 9長度為4
// 如果 0 1 2 不加9,那就有0 1 2 4 5 6,長度為6
Dictionary<int, int> tempList = new Dictionary<int, int>(everyList);
listBackup.Add(tempList);
everyList.Add(i, index[i]);
}
else
{
Dictionary<int, int> tempList = new Dictionary<int, int>();
tempList.Add(i, index[i]);
listBackup.Add(tempList);
}
}
// list 合并
list = list.Union(listBackup).ToList<Dictionary<int, int>>();
}
}
}
Console.WriteLine("list size:" + list.Count);
int longestIndex = 0;
int maxlength = 0;
for (int i = 0; i < list.Count; i++)
{
Console.WriteLine("list[" + i + "]:" + list[i]);
if (list[i].Count > maxlength)
{
maxlength = list[i].Count;
longestIndex = i;
}
}
return list[longestIndex];
}
public static List<float[]> GetParameters(string originalString, List<string> testString)
{
List<float[]> resultList = new List<float[]> { };
try
{
string original = RemovePunctuation(originalString);
for (int i = 0; i < testString.Count; i++)
{
string test = testString[i];
test = RemovePunctuation(test);
int[] index = EsrUtils.GetInIndex(original, test);
Console.WriteLine("index:" + index);
Dictionary<int, int> dic = EsrUtils.GetMaxSortedIndexs(index);
int H = 0;
int N = 0;
int I = 0;
int S = 0;
int D = 0;
float corr = 0;
float acc = 0;
Dictionary<int, int> sortedIndex = GetMaxSortedIndexs(index);
Console.WriteLine("在original字符串中正確的字符");
Console.WriteLine("[" + originalString + "]");
foreach (int key in dic.Keys)
{
Console.WriteLine("index:" + key + "value:" + original.ToCharArray()[key]);
}
Console.WriteLine("--------------------------------------------------------------");
Console.WriteLine("在test字符串中正確的字符\n");
Console.WriteLine("[" + testString[i] + "]");
foreach (int key in dic.Values)
{
Console.WriteLine("index:" + key + "value:" + test.ToCharArray()[key]);
}
Console.WriteLine("--------------------------------------------------------------");
Console.WriteLine("計算準確率\n");
// 每一段的差值都需要計算
/*
int tempkey = 0;
int tempValue = 0;
int diffKey = 0;
int diffValue = 0;
foreach (int key in sortedIndex.Keys)
{
int value = sortedIndex[key];
diffKey = key - tempkey;
diffValue = value - tempValue;
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
tempkey = key;
tempValue = value;
}
Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
diffKey = original.Length - tempkey;
diffValue = test.Length - tempValue;
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
H = sortedIndex.Count;
N = test.Length;
*/
// 每一段的差值都需要計算
int tempkey = 0;
int tempValue = 0;
// 保存每一段的A字符序號的差
int diffKey = 0;
// 保存每一段的B字符序號的差
int diffValue = 0;
int indexFlag == 0;
foreach (int key in sortedIndex.Keys)
{
int value = sortedIndex[key];
if (tempkey == 0 && key != 0 && indexFlag == 0) //判斷第一位元素
{
diffKey = key - tempkey;
}
else if (key == 0)
{
diffKey = 0;
}
else
{
diffKey = key - tempkey - 1;
}
if (tempValue == 0 && value != 0 && indexFlag == 0) //判斷第一位元素
{
diffValue = value - tempValue;
}
else if (value == 0)
{
diffValue = 0;
}
else
{
diffValue = value - tempValue - 1;
}
Console.WriteLine("diffKey:" + diffKey + ", diffValue:" + diffValue);
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
tempkey = key;
tempValue = value;
indexFlag ++;
}
Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
diffKey = original.Length - tempkey - 1;
diffValue = test.Length - tempValue - 1;
Console.WriteLine("diffKey:" + diffKey + ",diffValue:" + diffValue);
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
H = sortedIndex.Count;
N = test.Length;
float[] result = new float[7];
acc = (float)(H - I) / (float)N;
corr = H / (float)N;
result[0] = acc;
result[1] = corr;
result[2] = H;
result[3] = N;
result[4] = D;
result[5] = S;
result[6] = I;
Console.WriteLine("acc" + acc + ",corr: " + corr);
resultList.Add(result);
}
}
catch (Exception e)
{
Console.WriteLine("error accur:" + e.ToString());
return null;
}
return resultList;
}
internal static object GetInstance()
{
throw new NotImplementedException();
}
}
}