人工智能学习

用Qdrant玩转恐怖万圣节故事搜索工具

本文主要是介绍用Qdrant玩转恐怖万圣节故事搜索工具,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

万圣节终于到了! 🎃

那是南瓜雕刻、性感的服装和围着闪烁烛光低声细语讲述恐怖故事的时节。

但如果你和我一样,总是在最需要的时候记不起任何恐怖故事。所以我想到,为什么不做一个工具,可以从大量故事中挑选出真正能让我们感到毛骨悚然的那种故事呢。

这就是我们今天在建的东西。

这个计划很简单。

我们将使用Reddit上的恐怖故事数据集,对这些数据进行嵌入处理,并建立一个Qdrant集合,以便根据主题和氛围等进行搜索。简单来说,就是捕捉到这种“氛围”,比如“闹鬼的房子”或“诡异的森林”。

我会展示你需要的所有步骤来构建一个这样的应用:设置向量数据库的步骤,嵌入和索引数据,以及召唤最恐怖的万圣节故事。

咱们就开始吧。

1. 安装库文件

宇航员超级想吃南瓜王

首先,我们先来安装我们将要用到的工具:

在命令行中输入以下命令:

pip install qdrant-client sentence_transformers datasets

切换到全屏模式,退出全屏

2.: 下载数据集文件

我们将使用这个Reddit上的恐怖故事数据集。我们用datasets库来下载一下这个数据集吧。

    从datasets库导入load_dataset函数

    ds = load_dataset("intone/horror_stories_reddit")
    # ds变量被赋值为load_dataset函数加载的数据集,其来源为'intone/horror_stories_reddit'。

进入全屏模式,退出全屏模式

3. 加载嵌入式模型.

我们将使用sentence_transformers库来帮助我们将数据嵌入到all-MiniLM-L6-v2模型中。这里是如何进行设置的:

    从'sentence_transformers'导入SentenceTransformer

    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cpu')

进入全屏 退出全屏

如果你手头上有一个可用的GPU并且想要加快处理速度,只需改成 device='cuda:0'

4. 创建嵌入

宇航员与棉花糖

generate_embeddings_direct函数处理数据集的某部分(例如“train”)时,将其拆分成更小的组,称为批次(batch),根据指定的batch_size。这有助于更高效地管理内存。

对于每批,该函数提取一组句子(例如,每次32个),并使用加载的嵌入模型来处理这些句子。

    导入 tqdm 模块

    def generate_embeddings(split, batch_size=32):
        embeddings = []
        split_name = [name for name, data_split in ds.items() if data_split is split][0]

        with tqdm(total=len(split), desc=f"正在为 {split_name} 分割生成嵌入") as pbar:
            for i in range(0, len(split), batch_size):
                batch_sentences = split['text'][i:i+batch_size]
                batch_embeddings = model.encode(batch_sentences)
                embeddings.extend(batch_embeddings)
                pbar.update(len(batch_sentences))

        return embeddings

全屏 退出全屏

它会立即在数据集中添加一个新的列,把它们放进去。这样,这个功能能高效更新数据,同时节省内存。

# 生成训练数据集的嵌入表示
train_embeddings = generate_embeddings(ds['train'])
# 将嵌入表示添加到训练数据集中
ds["train"] = ds["train"].add_column("embeddings", train_embeddings)

点击全屏 点击退出全屏

5. 设置客户端

如图所示:宇航员和幽灵

现在我们可以开始使用Qdrant客户端了。如果你是在本地操作,只需连接到默认端点即可。

    from qdrant_client import QdrantClient

    # 连接到本地的Qdrant实例
    qdrant_client = QdrantClient(url="http://localhost:6333")

可以进入全屏模式,也可以退出全屏

很简单吧?但在实际工作中,你很可能是在云端工作。这意味着你需要设置并验证Qdrant Cloud实例。

云配置

要连接到您的云实例,您需要实例网址和一个API密钥(API key)。具体操作步骤如下。

    导入 QdrantClient 从 qdrant_client

    # 每处需用您的 Qdrant 云实例 URL 和 API Key 替换
    qdrant_client = QdrantClient(
        url="https://YOUR_CLOUD_INSTANCE_ID.aws.qdrant.tech",  # 用您的云实例 URL 替换这里的 URL
        api_key="YOUR_API_KEY"  # 用您的 API Key 替换这里的 api_key
    )

全屏模式 退出全屏

确保将 YOUR_CLOUD_INSTANCE_ID 替换成你的实际实例 ID,并将 YOUR_API_KEY 替换成你创建的那个 API 密钥。你可以在 Qdrant Cloud 控制台里找到这些信息。

6. 创建一个收藏

在 Qdrant 中,一个集合就像是一个迷你数据库,优化存储和查询向量。当我们定义一个集合时,我们需要设定向量的尺寸和用于衡量相似性的度量。其设置可能如下所示:

    导入qdrant_client中的models

    collection_name="halloween"

    # 下面的代码创建了一个集合来存储产品特性的向量
    qdrant_client.create_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE) # 向量参数配置
    )

进入全屏,退出全屏

我们定义了一个名为 halloween 的集合,包含384维向量,这是 all-MiniLM-L6-v2 嵌入的大小。这里我们使用余弦距离作为相似度的衡量标准。根据您的数据和应用场景,您可能希望使用不同的距离度量方式,例如 Distance.EUCLIDDistance.DOT

7. 加载向量数据

看!宇航员在太空里吃苹果呢! :)

没有数据,集合就什么都不是。现在是时候把我们之前创建的嵌入插入其中了。这里有一个策略,可以分批加载嵌入:

def batched(iterable, n):
    iterator = iter(iterable)
    while 批次 := list(islice(iterator, n)):
        yield 批次

批大小 = 100
当前ID = 0  # 初始化计数器

切换到全屏模式 退出全屏

batched 函数将一个可迭代对象分割成大小为 n 的多个小块。它使用 islice 提取连续的元素组,并为每一块生成结果,直到数据集被完全处理完。

    from itertools 导入 islice

    for batch in batched(ds["train"], batch_size):
        # 生成一个ID列表,使用计数器
        ids = list(range(current_id, current_id + len(batch)))

        # 更新计数器,使其从批次后的下一个ID开始
        current_id += len(batch)

        vectors = [point.pop("embeddings") for point in batch]

        qdrant_client.upsert(
            collection_name=collection_name,
            points=models.Batch(
                ids=ids,
                vectors=vectors,
                payloads=batch,
            ),
        )

切换到全屏模式,切换回正常模式

每个批次都会通过upsert方法发送至Qdrant,该方法用来插入数据批次。upsert方法接收一组ID、向量以及其余项目数据(负载数据),以存储或更新在Qdrant集合里。

8. 讲述诅咒故事

宇航员在烛光下

终于,到了这个时候。

一切准备好了,现在来看看我们的恐怖故事搜索工具能否真的吓到人。我们试着搜一下“恐怖小丑”,看看结果。

    import json
    import textwrap

    # 用于包装并打印长文本的函数
    def print_wrapped(text, width=80):
        wrapped_text = textwrap.fill(text, width=width)
        print(wrapped_text)

    # 搜索结果查询
    search_result = qdrant_client.query_points(
        collection_name=collection_name,
        query=model.encode("creepy clown").tolist(),
        limit=1,
    )

    # 获取第一个结果
    if search_result.points:
        tale = search_result.points[0]

        # 打印负载信息
        print("ID:", tale.id)
        print("Score:", tale.score)
        print("Original:", tale.payload.get('isOriginal', 'N/A'))

        # 打印特定的负载字段
        print("Title:", tale.payload.get('title', 'N/A'))
        print("Author:", tale.payload.get('author', 'N/A'))
        print("Subreddit:", tale.payload.get('subreddit', 'N/A'))
        print("URL:", tale.payload.get('url', 'N/A'))

        # 单独打印故事文本,带有单词包装以提高可读性
        print("\nStory Text:\n")
        print_wrapped(tale.payload.get('text', 'No text available'), width=80)
    else:
        print("未找到结果.")

全屏 退出全屏

结果就这么出现了,标题叫做:“偷瞄小贼”。

说实话,这真的挺诡异的啊。

它到底是基于真实故事还是虚构的?说实话,我真的不清楚。它会让人感到一种若有若无的不安,仿佛有什么东西在注视着你。它挺长的,所以我就不贴在这里了,但如果你想亲自看看,尽管去运行程序试试吧。

你可以探索任何其他气氛:“闹鬼的屋子”,“阴森的森林”,“被诅咒的玩偶”,或者你想体验的任何其他气氛。谁知道呢,你可能会发现更恐怖的东西。

如果你发现了,不妨在评论里发出来。我很想看看它还能发现什么新东西。

下一步

感谢你陪我一起完成这次万圣节的实验!如果你一直跟着的话,你已经迈出了进入向量搜索世界的第一步,并学会了如何找到那些感觉恐怖但又不仅仅是包含恐怖词汇的。

如果你准备好了,就可以进入向量搜索的神秘领域,你可以探索许多更高级的话题,比如multitenancy、payload结构体和批量上传。

所以,放手去做,看看你能深入到什么地步。

祝你好运! 👻

这篇关于用Qdrant玩转恐怖万圣节故事搜索工具的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!