faiss是Facebook开源的相似性搜索库,为稠密向量提供高效相似度搜索和聚类,支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库

faiss不直接提供余弦距离计算,而是提供了欧式距离和点积,利用余弦距离公式,经过L2正则后的向量点积结果即为余弦距离,所以利用faiss计算余弦距离需要先对输入进行L2正则化

  • 安装

    参照官方开源安装https://github.com/facebookresearch/faiss/blob/main/INSTALL.md

    # CPU-only version
    $ conda install -c pytorch faiss-cpu
    $ pip install faiss-cpu
    
    # GPU(+CPU) version
    $ conda install -c pytorch faiss-gpu
    $ pip install faiss-cpu
     
    
  • 常规计算余弦距离方式

    常规一般使用sklearn包的cosine_similarity计算余弦距离,因为该包自动对向量进行L2正则,所以不要求输入必须为正则结果,代码如下:

    ## 计算余弦距离
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn import preprocessing
    def get_cos_result(embeding_library, persons, embeding_search):
        simi = cosine_similarity(embeding_search, embeding_library)
        max_argmin = np.argmax(simi,axis=1)
        search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)]
        return search_speaker
    ## 对输入进行正则化,可以不用正则
    def l2_normal(embeding):
        return preprocessing.normalize(embeding)
    
  • faiss的精确搜索

    faiss并不提供计算与余弦距离,只提供了点积计算和欧式距离,所以在计算余弦距离时,需要对输入进行L2正则,代码如下:

    import faiss
    from faiss import normalize_L2
    def faiss_precise_search(embeding_library, persons, embeding_search,topk=1):
        ## 这里也可以使用上文的sklearn的包进行正则
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        # faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
        index = quantizer
        ## 要保证输入为np.float32格式
        index.add(embeding_library.astype(np.float32))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance,idx = library['index'].search(embeding_search,topk)
        print('precise search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
  • faiss快速搜索

    faiss提供了多种快速搜索的方式,这里介绍常用的一种加速搜索的方式:倒排索引,这种方式与ES快速搜索的方式类似,需要先使用k-means建立聚类中心,通过查询最近的聚类中心,然后比较聚类中所有向量得到相似向量,这里需要两个超参数,一个是聚类中心num_cells,一个是查找聚类中心的个数num_cells_in_search,具体代码如下

    def faiss_fast_search(embeding_library, persons, embeding_search,topk=1):
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        d = embeding_library.shape[1]
        num_cells = 50
        num_cells_in_search = 5
        # 声明量化器
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
         # faiss.METRIC_INNER_PRODUCT计算内积 faiss.METRIC_L2j计算欧式距离
        index = faiss.IndexIVFFlat(quantizer, d,min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT)
        assert not index.is_trained
        index.train(embeding_library.astype(np.float32))
        index.add(embeding_library.astype(np.float32))
        index.nprobe = min(num_cells_in_search,len(persons))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance, idx = library['index'].search(embeding_search, topk)
        print('fast search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
  • 整体代码

    # -*- coding: utf-8 -*-
    import faiss
    from faiss import normalize_L2
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn import preprocessing
    import numpy as np
    import time
    
    def l2_normal(embeding):
        return preprocessing.normalize(embeding)
    
    def get_cos_result(embeding_search, persons, embeding_library):
        simi = cosine_similarity(embeding_search, embeding_library)
        max_argmin = np.argmax(simi,axis=1)
        search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)]
        return search_speaker
    
    def faiss_precise_search(embeding_library, persons, embeding_search):
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        # faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
        index = quantizer
        index.add(embeding_library.astype(np.float32))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance,idx = library['index'].search(embeding_search,1)
        print('precise search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
    def faiss_fast_search(embeding_library, persons, embeding_search,topk=1):
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        num_cells = 500
        num_cells_in_search = 10
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
        index = faiss.IndexIVFFlat(quantizer, embeding_library.shape[1],min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) #faiss.METRIC_INNER_PRODUCT计算内积 faiss.METRIC_L2j计算欧式距离
        assert not index.is_trained
        index.train(embeding_library.astype(np.float32))
        index.add(embeding_library.astype(np.float32))
        index.nprobe = min(num_cells_in_search,len(persons))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance, idx = library['index'].search(embeding_search, topk)
        print('fast search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
    if __name__ == '__main__':
        d = 512
        n_library = 100000
        n_search = 1
        embeding_library = np.random.random((n_library, d)).astype(np.float32)
        persons = ['Speak' + "%0d" % (i + 1) for i in range(n_library)]
        embeding_search = np.random.random((n_search, d)).astype(np.float32)
        print(faiss_fast_search(embeding_library, persons, embeding_search))
        print(faiss_precise_search(embeding_library, persons, embeding_search))
        st = time.time()
        print(get_cos_result(embeding_search, persons, embeding_library))
        en1 = time.time()
        print(en1-st)