当前位置: 移动技术网 > IT编程>脚本编程>Python > Tensorflow导出pb模型,并在python和matlab下分别进行预测

Tensorflow导出pb模型,并在python和matlab下分别进行预测

2020年07月18日  | 移动技术网IT编程  | 我要评论
tensorflow下训练完模型测试程序比较杂乱,特此整理一下。1、我都是在linux下训练,windows下测试分析,训练保存模型如下所示。2、然后调用frozen_model.py将模型进行固化

tensorflow下训练完模型测试程序比较杂乱,特此整理一下。

1、我都是在linux下训练,windows下调用测试,训练保存模型如下所示。

2、然后调用frozen_model.py将模型进行固化,这里需要注意一点就是网络输出结点的名称,可以在tensorboard查看GRAPHS中网络输出结点名或训练时进行命名。

frozen_model.py

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
    # 原模型中输出节点名称
    output_node_names = "generator/decoder_1/output_node"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() 
    input_graph_def = graph.as_graph_def() 
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) 
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))
 
        with tf.gfile.GFile(output_graph, "wb") as f: 
            f.write(output_graph_def.SerializeToString()) 
        print("%d ops in the final graph." % len(output_graph_def.node)) 
        

tf.reset_default_graph()
input_checkpoint = "tensorflow_model/spot_train/model-500"
output_graph = "frozen_model/frozen_model.pb"
freeze_graph(input_checkpoint,output_graph)

3、得到pb模型后,调用test.py进行测试。需要注意输入输出tensor名字一定要写对,一般结点名字后面加":0"就是对应tensor名,可以在这个网站打开pb模型查看tensor名

https://lutzroeder.github.io/netron/

test.py

#-*- coding:utf-8 -*-
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import scipy.io 
from tensorflow.python.platform import gfile
tf.reset_default_graph()

pb_file_path = 'model_package/frozen_model/'
result_file_path = 'test_results/'

def preprocess(x):   
    Max = np.max(x)
    Min = np.min(x)
    x = (x-Min)/(Max-Min)
    return x*2-1

def deprocess(x):
    return (x+1)/2

data = scipy.io.loadmat('1.mat')['data']
flatten_img = preprocess(np.reshape(data, [1, 256,256,1]))   

sess = tf.Session()
with gfile.FastGFile(pb_file_path + 'frozen_model.pb', 'rb') as f: #加载模型
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')  # 导入计算图

# 初始化
#sess.run(tf.global_variables_initializer())

x = sess.graph.get_tensor_by_name('batch:1')  
y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0')  
                                                       
y_out=sess.run(y,feed_dict={x:flatten_img})
scipy.io.savemat(result_file_path+'test.mat', {'output':y_out})

后面为方便matlab调用,又整理成类了

python_test.py

#-*- coding:utf-8 -*-
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import scipy.io 
from tensorflow.python.platform import gfile
from glob import glob

#pb_file_path = 'model_package/frozen_model/'

class Predict(object):
    
    def __init__(self):
        tf.reset_default_graph()    
        self.sess = tf.Session()
        with gfile.FastGFile('frozen_model/frozen_model.pb', 'rb') as f: #加载模型
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')  #导入计算图
        print('load model success')
        #sess.run(tf.global_variables_initializer()) #初始化
        self.x = sess.graph.get_tensor_by_name('batch:1')  
        self.y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0') 
        
    def preprocess(x):   
        Max = np.max(x)
        Min = np.min(x)
        x = (x-Min)/(Max-Min)
        return x*2-1

    def deprocess(x):
        return (x+1)/2   
    
    def predict(self,input_path):
        #加载测试输出数据
        img= scipy.io.loadmat(input_path)['data'] 
        flatten_img = preprocess(np.reshape(img, [1, 256,256,1]))  
        
        y_out = sess.run(self.y,feed_dict={self.x:flatten_img})
        y_out = np.squeeze(deprocess(y_out))
        scipy.io.savemat('test_results/'+input_path[17:], {'output':y_out})

if __name__ == '__main__':
    model = Predict()
    model.predict("1.mat")

4、Matlab中测试采用的是调用python测试脚本还实现的。

clear;close all;clc
clear classes

tf = py.importlib.import_module('tensorflow');
np = py.importlib.import_module('numpy');
%plt = py.importlib.import_module('matplotlib.pyplot');
sio = py.importlib.import_module('scipy.io');

obj = py.importlib.import_module('python_test'); %python测试脚本路径
py.importlib.reload(obj);
a = py.pix2pix_test.Predict();
a.predict('test_data/2.bmp')
a.predict('test_data/3.bmp')

写完回头看了一眼,咋这么混乱都没说清也就自己能看懂了,先这样吧后面要提高一下写作能力了,再接再厉,加油!

本文地址:https://blog.csdn.net/megaoliyuanzhende/article/details/107365281

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

相关文章:

验证码:
移动技术网