当前位置: 移动技术网 > IT编程>脚本编程>Python > 文字检测模型EAST应用详解 ckpt pb的tf加载,opencv加载

文字检测模型EAST应用详解 ckpt pb的tf加载,opencv加载

2020年04月25日  | 移动技术网IT编程  | 我要评论

室外温度传感器,梁艺,南周事件是什么

参考链接:https://github.com/argman/east (项目来源)

                    (遇到的问题)

         (opencv加载)

 

文字检测有很多比较好的现成的模型比如yolov3,pesnet,pennet,east。不一一赘述,讲一下自己跑通east的过程。

https://github.com/argman/east链接中下载项目,windows下,各种包的版本要正确否则会出一些乱七八糟的错误。

运行east/eval.py。没有什么特别的问题要说,我在cpu下单张640*480的图能够达到每张0.4秒左右,还是非常优秀的。中英文数字都可。

 

但是源代码是ckpt,非常大,转成pb会稍微小点。添加:

##生成pb模型,但需要修改model.py
output_graph_def = tf.graph_util.convert_variables_to_constants(self.sess, # the session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # the graph_def is used to retrieve the nodes
["feature_fusion/conv_7/sigmoid", "feature_fusion/concat_3"]
)
output_graph='d:\\work\\video\\hand_tracking_no_op\\hand_tracking\\east\\east_icdar2015_resnet_v1_50_rbox\\out.pb'
with tf.gfile.gfile(output_graph, "wb") as f:
f.write(output_graph_def.serializetostring())
print("%d ops in the final graph." % len(output_graph_def.node))

位置在eval.py中的

saver.restore(self.sess, model_path)后面。注意如果你想要opencv加载pb还要修改model.py中的内容,这个在后面一篇文章中会讲到。
生成后用tf加载,方法跟加载ckpt相似:

import os
os.environ['cuda_visible_devices'] = flags.gpu_list

try:
os.makedirs(flags.output_dir)
except oserror as e:
if e.errno != 17:
raise

print("load_graph")
graph = load_graph(flags.checkpoint_path)

input_images = graph.get_tensor_by_name(
'import/input_images:0')

f_score = graph.get_tensor_by_name('import/feature_fusion/conv_7/sigmoid:0')
f_geometry = graph.get_tensor_by_name(
'import/feature_fusion/concat_3:0')

with tf.session(graph=graph) as sess:

im_fn_list = get_images()
for im_fn in im_fn_list:
im = cv2.imread(im_fn)[:, :, ::-1]
start_time = time.time()
im_resized, (ratio_h, ratio_w) = resize_image(im)

timer = {'net': 0, 'restore': 0, 'nms': 0}
start = time.time()

#file_writer = tf.summary.filewriter('tmp/log', sess.graph)

score, geometry = sess.run([f_score, f_geometry], feed_dict={
input_images: [im_resized]})
timer['net'] = time.time() - start

boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
print('{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format(
im_fn, timer['net']*1000, timer['restore']*1000, timer['nms']*1000))

if boxes is not none:
boxes = boxes[:, :8].reshape((-1, 4, 2))
boxes[:, :, 0] /= ratio_w
boxes[:, :, 1] /= ratio_h

duration = time.time() - start_time
print('[timing] {}'.format(duration))

# save to file
if boxes is not none:
res_file = os.path.join(
flags.output_dir,
'{}.txt'.format(
os.path.basename(im_fn).split('.')[0]))

with open(res_file, 'w') as f:
for box in boxes:
# to avoid submitting errors
box = sort_poly(box.astype(np.int32))
if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
continue
f.write('{},{},{},{},{},{},{},{}\r\n'.format(
box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1],
))
cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], true, color=(255, 255, 0), thickness=1)
if not flags.no_write_images:
img_path = os.path.join(flags.output_dir, os.path.basename(im_fn))
cv2.imwrite(img_path, im[:, :, ::-1])

以上就是east的ckpt转pb用tf加载啦。
下一篇讲opencv加载east的pb。



如对本文有疑问,请在下面进行留言讨论,广大热心网友会与你互动!! 点击进行留言回复

相关文章:

验证码:
移动技术网