mobielnet提取特征图片比对

本文详细介绍使用TensorFlow和MobileNet V1模型进行图像特征提取的过程。从加载图像到预处理,再到模型加载和特征提取,最后计算两个图像特征之间的相似度。涉及的关键步骤包括图像裁剪、调整大小、模型加载、特征提取和相似度计算。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

平时pytorch用得多,tf的代码还是弄了半天,网上的code不靠谱太多。当然得先down模型,clone tensorflow models,然后执行下代码里的export。

# encoding: utf-8

import os
import sys
import cv2
import glob
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from nets import mobilenet_v1
import skimage
import skimage.io
import skimage.transform
from sklearn.preprocessing import Normalizer

# export PYTHONPATH="$PYTHONPATH:/ai/tensorflow/models/research/slim"

def load_image(path):
    # load image
    img = skimage.io.imread(path)
    img = img / 255.0
    assert (0 <= img).all() and (img <= 1.0).all()
    # print "Original Image Shape: ", img.shape
    # we crop image from center
    short_edge = min(img.shape[:2])
    yy = int((img.shape[0] - short_edge) / 2)
    xx = int((img.shape[1] - short_edge) / 2)
    crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
    # resize to 224, 224
    resized_img = skimage.transform.resize(crop_img, (224, 224))
    return resized_img

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    norm2 = Normalizer(norm='l2')

    ckpt_path = './mobilenet_v1/mobilenet_v1_1.0_224.ckpt'

    img1 = load_image(sys.argv[1])
    img2 = load_image(sys.argv[2])

    batch1 = img1.reshape((1, 224, 224, 3))
    batch2 = img2.reshape((1, 224, 224, 3))
    batch = np.concatenate((batch1, batch2), 0)

    images = tf.placeholder("float", [2, 224, 224, 3])

    with tf.contrib.slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
        logits, endpoints = mobilenet_v1.mobilenet_v1(images, num_classes=1001)

        with tf.Session() as sess:
            saver = tf.train.Saver()
            saver.restore(sess, ckpt_path)

            fc_map = endpoints['AvgPool_1a']
            fc_feat = tf.squeeze(fc_map, [1, 2])

            feed_dict = { images: batch }
            fc_feature = sess.run(fc_feat, feed_dict=feed_dict)

            norm_feas = norm2.fit_transform(fc_feature)

            print(np.matmul(norm_feas[0], norm_feas[1].T))

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值