当前位置: 移动技术网 > IT编程>脚本编程>Python > 用bert训练模型并转换为pb格式

用bert训练模型并转换为pb格式

2020年07月03日  | 移动技术网IT编程  | 我要评论

具体代码在github:
https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.py

def serving_input_fn():
    # 保存模型为SaveModel格式
    # 采用最原始的feature方式,输入是feature Tensors。
    # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
    df = pd.read_csv(FLAGS.data_dir, delimiter="\t", names=['labels', 'text'], header=None)

    dense_units = len(df.labels.unique())
    label_ids = tf.placeholder(tf.int32, [None, dense_units], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, 128], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, 128], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, 128], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

本文地址:https://blog.csdn.net/qq236237606/article/details/107078973

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网