当前位置: 移动技术网 > IT编程>开发语言>Java > TensorFlow 2.1.0 使用 TFRecord 转存与读取文本数据

TensorFlow 2.1.0 使用 TFRecord 转存与读取文本数据

2020年07月03日  | 移动技术网IT编程  | 我要评论

前言:

上次记录了一下如何使用 TFRecord 来转存图片与 label ,后续经手了一些 NLP 任务,尝试使用了 TF 2.1.0,所以记录一下如何使用 TFRecord 来保存和读取文本数据。


准备工作:

TFRecord 无法直接记录文本信息,所以需要首先对文本内容进行一些预处理的准备工作,分别是分词,去停用词,建立词典,以及将文本转化为词典 index。再将 index 值写入 TFRecord。

 

TFRecord

首先这里把训练集和验证集分构造为了两个 DataFrame ,然后一个 Text 文本对应两个 label。使用 Keras 中的 Tokenizer 进行词到字典的映射,同时把 label 转化为相应的 label index。

与保存图片不同的是,保存 Text index 时,需要用到 tf.train.FeatureLists()。转换 Text index 时,先将文本中的每一个 index 转换为一个 Int64List,再将整篇文章转换为一个 FeatureLists。

writer = tf.io.TFRecordWriter('./train_data_content_with_title')

for _, data in tqdm(train_pd.iterrows()):
    text = tokenizer.texts_to_sequences([data['content with title'].split(' ')])[0]
    text = list(map(lambda idx: tf.train.Feature(int64_list=tf.train.Int64List(value=[idx])), text))
    
    exam = tf.train.SequenceExample(
        context = tf.train.Features(
            feature = {
                'industry_label': tf.train.Feature(int64_list=tf.train.Int64List (value=[industry_dict[data["industry_label"]]])),
                'use_label': tf.train.Feature(int64_list=tf.train.Int64List (value=[use_dict[data["use_label"]]]))
                
            }
        ),
        feature_lists = tf.train.FeatureLists(
            feature_list={
                'text' : tf.train.FeatureList(feature=text)
            }
        )
    )
    writer.write (exam.SerializeToString())
writer.close()  

writer = tf.io.TFRecordWriter('./valid_data_content_with_title')
for _, data in tqdm(valid_pd.iterrows()):
    text = tokenizer.texts_to_sequences([data['content with title'].split(' ')])[0]
    text = list(map(lambda idx: tf.train.Feature(int64_list=tf.train.Int64List(value=[idx])), text))
    
    exam = tf.train.SequenceExample(
        context = tf.train.Features(
            feature = {
                'industry_label': tf.train.Feature(int64_list=tf.train.Int64List (value=[industry_dict[data["industry_label"]]])),
                'use_label': tf.train.Feature(int64_list=tf.train.Int64List (value=[use_dict[data["use_label"]]]))
                
            }
        ),
        feature_lists = tf.train.FeatureLists(
            feature_list={
                'text' : tf.train.FeatureList(feature=text)
            }
        )
    )
    writer.write (exam.SerializeToString())
writer.close()  

读取:

读取时,同样也是将 Label 与 Text index 两部分分开解析。

train_reader = tf.data.TFRecordDataset('./train_data_content_with_title')
valid_reader = tf.data.TFRecordDataset('./valid_data_content_with_title')

context_features = {
    "industry_label": tf.io.FixedLenFeature([], dtype=tf.int64),
    "use_label": tf.io.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
    "text": tf.io.FixedLenSequenceFeature([], dtype=tf.int64),
}

def parse_function(serialized_example):
    context_parsed, sequence_parsed = tf.io.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )
    industry_label = context_parsed['industry_label']
    use_label = context_parsed['use_label']
    text = sequence_parsed['text']
    return text, industry_label, use_label


train_dataset = train_reader.repeat(1).shuffle(1280, reshuffle_each_iteration=True).map(parse_function).padded_batch(256, padded_shapes=([110], [], []))
valid_dataset = valid_reader.repeat(1).shuffle(1280, reshuffle_each_iteration=True).map(parse_function).padded_batch(256, padded_shapes=([110], [], []))

 

本文地址:https://blog.csdn.net/ZJRN1027/article/details/107079071

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网