較大規模圖片 使用phash去重

起因

先說下為什么要做這個事。做的圖片站的圖片來源為很多美女圖片站,自然地,會有很多重復的圖片,而我的目標就是要把重復的圖片找出來,剔除掉或者是做其他處理。

什么樣的圖片屬于相同圖片呢?因為會存在一些有水印的圖片(如下圖),或者是略微變形的圖片(如1024 * 720 與1020 * 720的圖片)

with_logo.jpeg

without_logo.jpeg

phash

phash全稱是感知哈希算法(Perceptual hash algorithm),使用這玩意兒可以對每個圖片生成一個值,如上面兩個圖分別是2582314446007581403 與 2582314446141799129 (轉為二進制再比較),然后計算他們的hamming distance,簡單的說就是數一數二進制之后有幾位不同。整個處理流程有點像對文章去重時先算simhash再算hamming distance,很多東西都可以直接套用過來。

phash具體的實現可以很多地方都有了,可以搜到很多差不多的內容,在這我也就簡單的記錄下,具體可以去谷歌或者百度搜下。

  • 縮小尺寸 為了后邊的步驟計算簡單些
  • 簡化色彩 將圖片轉化成灰度圖像,進一步簡化計算量
  • 計算DCT 計算圖片的DCT變換,得到32*32的DCT系數矩陣。
  • 縮小DCT 雖然DCT的結果是32*32大小的矩陣,但我們只要保留左上角的8*8的矩陣,這部分呈現了圖片中的最低頻率。
  • 計算平均值 如同均值哈希一樣,計算DCT的均值。
  • 計算hash值 根據8*8的DCT矩陣,設置0或1的64位的hash值,大于等于DCT均值的設為”1”,小于DCT均值的設為“0”。組合在一起,就構成了一個64位的整數,這就是這張圖片的指紋。
    python 版本的實現
# -*- coding: utf-8 -*-

from compiler.ast import flatten
import cv2
import numpy as np

def pHash(imgfile):
    # 加載并調整圖片為32x32灰度圖片
    img = cv2.imread(imgfile, 0)
    img = cv2.resize(img, (32, 32), interpolation=cv2.INTER_CUBIC)

    # 創建二維列表
    h, w = img.shape[:2]
    vis0 = np.zeros((h, w), np.float32)
    vis0[:h, :w] = img  # 填充數據

    # 二維Dct變換
    vis1 = cv2.dct(cv2.dct(vis0))
    # 拿到左上角的8 * 8
    vis1 = vis1[0:8, 0:8]

    # 把二維list變成一維list
    img_list = flatten(vis1.tolist())

    # 計算均值
    avg = sum(img_list) * 1. / len(img_list)
    avg_list = ['0' if i < avg else '1' for i in img_list]

    # 得到哈希值
    return ''.join(['%x' % int(''.join(avg_list[x:x + 4]), 2) for x in range(0, 8 * 8, 4)])

這段代碼是網上找來做測試用的,當時有個坑,他沒有vis1 = vis1[0:8,0:8]這一步,然后出來的結果就很奇葩,而且指紋長的可怕(32 * 32位),準確率和召回率都低的驚人。這段代碼也很簡單,幾乎和白話一樣,就是把上面phash的流程給翻譯了一遍。

然鵝,我并沒有使用上面的python版的,出于兩個原因,一是我上邊說的坑,當時并沒有發現,二是畢竟是python,雖說大部分計算的部分是用c寫的(opencv),但還是覺得會慢。找到的是一個純c的,來自 phash.org (沒錯,就是這么官方)。安裝啥的網站里邊都有,附上一個python調用的腳本。

class pHash(object):
    def __init__(self):
        self._lib = ctypes.CDLL('/opt/local/lib/libpHash.dylib', use_errno=True)

    def dct_imagehash(self, path):
        phash = ctypes.c_uint64()
        if self._lib.ph_dct_imagehash(path, ctypes.pointer(phash)):
            errno_ = ctypes.get_errno()
            err, err_msg = (errno.errorcode[errno_], os.strerror(errno_)) \
                if errno_ else ('none', 'errno was set to 0')
            print(('Failed to get image hash'
                   ' ({!r}): [{}] {}').format(path, err, err_msg), file=sys.stderr)
            return None
        return phash.value

    def hamming_distance(self, hash1, hash2):
        return self._lib.ph_hamming_distance(
            *map(ctypes.c_uint64, [hash1, hash2]))

非常貼心的還附贈了海明距的計算。
因為我的圖片都是存在云端,為了速度更快,我會直接用云端圖像處理把圖片先縮小,壓縮后再處理。我本機測試的結果是一千張圖生成phash耗時1.5s,相當快了。(有個很驚悚的發現,上頭那個python版本千張耗時0.7s...驚呆了...可能實現不太一樣吧...)

大量數據hamming distance 計算

如標題所述,較大規模圖片,我這邊的大概是百萬級別,但是即便是千萬級別應該還是差不多的方式,億級別的數據可能我的小破開發機就受不了了(沒錯...沒用服務器...)

先說說海明距,咱們上邊不是生成了一段64位的數呢?海明距就是數一數兩個hash值有多少位的差異,一般小于5的都算近似,就是這么簡單:)

假設有1000萬已經處理完的phash值吧,現在來了一個新的phash,如何找出所有可能和他重復的圖呢?
最簡單粗暴的,直接遍歷一次...即遍歷1000萬次....那么耗時大概...不用算了,肯定是個很夸張的值,不靠譜。

這邊我采用的是一種內存換速度的方式,64位的的hash值,分為八組,每組八位。建立八個dict,每個dict代表一組,以每組的值作為key,value是一個list,存放key相同的hash值。查找的時候,把hash值分成八個,分別在八個map里邊查找,如果有key相同的,取出key相同的所有hash值進行遍歷。

說的相當的亂,下邊是代碼。

split_count = 8  # 每個64位的phash值分為八段,每段8位

def split(key, split_count):
    pre_length = 64 / split_count
    return [key[i * pre_length: (i + 1) * pre_length] for i in range(split_count)]

class ImageManager(object):
    def __init__(self):
        self.phash = pHash() # 就是上面那個pHash類
        self.phash_cache = [defaultdict(list) for i in range(split_count)] #
        self.init_phash_map()

    def init_phash_map(self):
        #我是把所有的phash存在sqlite里邊,這邊取出所有的Image
        for image in Image.select():
            self.add_to_image_cache(image)

    def add_to_image_cache(self, image):
        # 將hash值分割為8段
        key_split = split(bin(int(image.phash))[2:].rjust(64, '0'), split_count)
        for index, k in enumerate(key_split):
            self.phash_cache[index][k].append(image)

    def has_same(self, ori_image):
        phash = ori_image.phash
        key_split = split(bin(int(phash))[2:].rjust(64, '0'), split_count)
        result = set()
        for index, k in enumerate(key_split):
            if k in self.phash_cache[index]:
                for image in self.phash_cache[index][k]:
                    distance = self.distance(int(phash), int(image.phash))
                    if distance < 5 and ori_image.key != image.key:
                        result.add(image)
        if result:
            return True,list(result)
        return False,[]
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容