当前位置: 移动技术网 > IT编程>脚本编程>Python > 编辑距离WER/CER计算的一种python实现

编辑距离WER/CER计算的一种python实现

2020年07月18日  | 移动技术网IT编程  | 我要评论

WER(word error rate)经常作为语音识别任务的性能评测指标,WER的计算公式,直接从网上粘贴过来了。
在这里插入图片描述
一些语音识别框架(如:Kaldi、ESPNet等)中,都会包含wer的计算方法,其中ESPNet的结果展示如下:
在这里插入图片描述
我们希望用python实现上面的效果,首先来看看wer是怎么计算的。
首先,随便写个例子,ref(reference)表示标注文本序列,hyp(hypothesis)表示预测文本序列,则可以计算 cer/wer = 3,其中一次替换错误(S),一次删除错误(D),一次插入错误(I)。
在这里插入图片描述
参考资料:https://martin-thoma.com/word-error-rate-calculation/
我们列出WER的计算公式如下,看似很绕,我们用图来画一下:
在这里插入图片描述
首先,横轴为ref(标注序列),我们列出来,最后加一个<b>作为占位,纵轴为hyp(预测序列),也列出来,同样最后加上<b>占位。
在这里插入图片描述
然后,中间的每个方格代表一个cost,什么是cost呢,就是到目前为止两个序列错了多少。如ref=b,hyp=b时,错误为1,因为此时的子序列 sub_ref=ab,sub_hyp=bb,所以只有一个替换错误(S)。
而前面公式中列出来的,就是计算递归计算cost的方法。也就是说当前位置ij的cost只与相邻的前面三个位置有关(上图中紫色部分),而且是三个紫色方块的最小值+1,翻译一下:
(1)对于位置 i、j,如果 hyp(i-1) == ref(j-i),则 cost(i,j)=cost(i-1,j-1);
(2)如果 hyp(i-1) != ref(j-i) ,也就是图中ij的位置,那么 cost(i,j) 就是三个紫色方块的最小值+1
而三个紫色方块代表的物理意义如下:(左上表示替换错误S,右上表示插入错误I,左下表示删除错误D)
在这里插入图片描述
一直迭代下去,直到所有cost都被计算出来之后,整个cost矩阵右下角的为的值就是你要的wer了(也就是<b>和<b>的位置),上上图中wer=3。

得到了wer,我们还想直到 I、D、S 到底各占多少,对齐的文本到底是什么样子的,这是我们要从右下角回溯。对于右下角的“3”,它的前三个值最小为3,说明没有发生错误。接下来(f,f)位置,前三个值为2、2、3,最小值为2,说明发生了错误,这是按照 insert > delete > substitution 的优先级,选择上方的方格,并记录一次插入错误。以此类推,直到遍历到左上角为止。如下图所示,我们就得到了所有的错误类型。
在这里插入图片描述

下面看一个特殊情况,即句子开头有插入或者删除错误。
在这里插入图片描述
这时如果我们回溯整个矩阵发现,hyp先结束了,而ref还没有结束,为了得到所有的操作,我们必须要遍历的左上角才行,所以,强行从遍历结束的位置移动到左上角。那么,如果是hyp先结束,则所有移动都是删除错误(D),如果是ref先结束,那么所有错误都是插入错误(I)。
在这里插入图片描述
python的实现如下,供参考:

import numpy as np

def levenshtein_distance(hypothesis: list, reference: list):
    """编辑距离
    计算两个序列的levenshtein distance,可用于计算 WER/CER
    参考资料:https://martin-thoma.com/word-error-rate-calculation/

    C: correct
    W: wrong
    I: insert
    D: delete
    S: substitution

    :param hypothesis: 预测序列
    :param reference: 标注序列
    :return: 1: 错误操作,所需要的 S,D,I 操作的次数;
             2: ref 与 hyp 的所有对齐下标
             3: 返回 C、W、S、D、I 各自的数量
    """
    len_hyp = len(hypothesis)
    len_ref = len(reference)
    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)

    # 记录所有的操作,0-equal;1-insertion;2-deletion;3-substitution
    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)

    for i in range(len_hyp + 1):
        cost_matrix[i][0] = i
    for j in range(len_ref + 1):
        cost_matrix[0][j] = j

    # 生成 cost 矩阵和 operation矩阵,i:外层hyp,j:内层ref
    for i in range(1, len_hyp + 1):
        for j in range(1, len_ref + 1):
            if hypothesis[i-1] == reference[j-1]:
                cost_matrix[i][j] = cost_matrix[i-1][j-1]
            else:
                substitution = cost_matrix[i-1][j-1] + 1
                insertion = cost_matrix[i-1][j] + 1
                deletion = cost_matrix[i][j-1] + 1

                compare_val = [insertion, deletion, substitution]   # 优先级
                min_val = min(compare_val)
                operation_idx = compare_val.index(min_val) + 1
                cost_matrix[i][j] = min_val
                ops_matrix[i][j] = operation_idx

    match_idx = []  # 保存 hyp与ref 中所有对齐的元素下标
    i = len_hyp
    j = len_ref
    nb_map = {"N": len_hyp, "C": 0, "W": 0, "I": 0, "D": 0, "S": 0}
    while i >= 0 or j >= 0:
        i_idx = max(0, i)
        j_idx = max(0, j)

        if ops_matrix[i_idx][j_idx] == 0:     # correct
            if i-1 >= 0 and j-1 >= 0:
                match_idx.append((j-1, i-1))
                nb_map['C'] += 1

            # 出边界后,这里仍然使用,应为第一行与第一列必然是全零的
            i -= 1
            j -= 1
        elif ops_matrix[i_idx][j_idx] == 1:   # insert
            i -= 1
            nb_map['I'] += 1
        elif ops_matrix[i_idx][j_idx] == 2:   # delete
            j -= 1
            nb_map['D'] += 1
        elif ops_matrix[i_idx][j_idx] == 3:   # substitute
            i -= 1
            j -= 1
            nb_map['S'] += 1

        # 出边界处理
        if i < 0 and j >= 0:
            nb_map['D'] += 1
        elif j < 0 and i >= 0:
            nb_map['I'] += 1

    match_idx.reverse()
    wrong_cnt = cost_matrix[len_hyp][len_ref]
    nb_map["W"] = wrong_cnt

    print("ref: %s" % " ".join(reference))
    print("hyp: %s" % " ".join(hypothesis))
    print(nb_map)
    print("match_idx: %s" % str(match_idx))
    print()
    return wrong_cnt, match_idx, nb_map


def test():
    """
    id: (301225575230191207_spkb_f-301225575230191207_spkb_f_slice19)
    Scores: (#C #S #D #I) 27 4 1 2
    REF:  然 后 而 且 这 个 账 号 , 你 这 边 *** 做 车 商 续 费 的 话 就 发 真 车 应 该 *** 稍 微 再 便 宜 点 。
    HYP:  然 后 而 且 这 个 账 号 *** 你 这 边 要 做 车 商 续 费 的 话 就 发 真 车 应 该 还 有 一 个 便 宜 的 。
    Eval:
    :return:
    """
    wrong_cnt, match_idx, nb_map = levenshtein_distance(
        reference=list('abcdef'),
        hypothesis=list('cdefg')
    )

    wrong_cnt, match_idx, nb_map = levenshtein_distance(
        reference=list('cdefg'),
        hypothesis=list('abcdef')
    )

    wrong_cnt, match_idx, nb_map = levenshtein_distance(
        reference=list('cdefg'),
        hypothesis=list('')
    )

    wrong_cnt, match_idx, nb_map = levenshtein_distance(
        reference=list(''),
        hypothesis=list('')
    )

    wrong_cnt, match_idx, nb_map = levenshtein_distance(
        reference=list('abcdf'),
        hypothesis=list('bbdef')
    )

    wrong_cnt, match_idx, nb_map = levenshtein_distance(
        hypothesis=list('然后而且这个账号,你这边做车商续费的话就发真车应该稍微再便宜点。'),
        reference=list('然后而且这个账号你这边要做车商续费的话就发真车应该还有一个便宜的。')
    )

if __name__ == '__main__':
    test()

本文地址:https://blog.csdn.net/baobao3456810/article/details/107381052

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网