本ラボでは、Amazon SageMaker 上にデプロイした密ベクトル埋め込みモデルを活用したベクトル検索を実装していきます。
ベクトル検索とは、与えられたクエリアイテムに類似または関連するアイテムを効率的かつ効果的に検索する手法です。ベクトル間の距離や角度の近さといった数値に基づき、類似のアイテムやエンティティを探査します。従来の検索エンジンが苦手とする類似表現や関連語を含むクエリによる問い合わせでも類似性の高い結果を返すことができるため、レコメンドや類似検索、検索検索拡張生成(RAG) に代表される文書検索・ナレッジ検索で幅広く活用されています。
全文検索はクエリと検索対象のデータ間で厳密なマッチングが要求される一方、ベクトル検索は"意味的に近い" 文書を取得する際に有用であるため、うまく使い分けることで幅広い検索要件を達成できます。
一般的にベクトル検索とは、N 次元の数値配列からなる密ベクトルを使った検索のことを指します。密ベクトル検索では、クエリと検索対象のデータは N 次元の数値配列として扱われ、それらの距離や角度の差異が類似度として表されます。距離や角度が近いほど類似度が高いとみなすことができます。
ベクトル検索を行う上では、検索対象のテキストやクエリ文字列をベクトルデータに変換し、格納する必要があります。
データをベクトルに変換する処理を "埋め込み (Embedding)" と呼びます。埋め込み処理は、一般的に機械学習モデルの一種である埋め込みモデル (Embedding model) によって生成します。
本ラボでは、MIT ライセンスで公開されている BAAI/bge-m3 を Amazon SageMaker 上にデプロイし、テキストからベクトルデータを生成します。モデルの詳細については Hugging Face 上の解説を参照してください。
OpenSearch においては、ベクトル検索を実行するために k-NN search(k-nearest neighbors search) と呼ばれる機能を提供しています。k-NN search は、ベクトル空間内で最も近い k 個の近傍点を探す機能です。
k-NN search では、データセットの規模や要件に応じた複数の方式を提供しています。大規模データには Approximate k-NN、フィルタリングが必要な小規模データには Script Score k-NN、複雑なスコアリングが必要な場合は Painless extensions が推奨されてます。
本ラボでは、実ユースケースでも多く採用されている Approximate kNN を使用して k-NN search を実行していきます。
Approximate k-NN (近似最近傍探索、ANN)は、大規模なデータセットで効率的な類似検索を実現するための手法です。Script score による厳密な k-NN search はクエリと全てのデータポイント間の距離を総当たりで計算するため、高次元の大規模データセットでは処理効率が低下します。
ANN は、グラフやバケットなど独自のデータ構造にベクトルデータを格納することで、検索速度を大幅に向上させるアプローチです。精度が若干低下するものの、大規模なベクトルデータに対して効率的な検索が可能になります。
OpenSearch では、ANN を実行するためのエンジンを以下 3 つ用意しています。ただし nmslib は将来廃止予定であるため、実質的には Faiss と Lucene のどちらかから選択する形となります。Faiss は高機能かつ高速であり、大規模データセットに適しています。Lucene は省リソースが特徴で、数百万ベクトルまでの小規模なデータセットで良好な性能を発揮します。
本ラボでは、実ユースケースでも多く採用されている Faiss エンジンを使用します。
もう一つ ANN を使用するうえで必要になるのがアルゴリズムの選定です。OpenSearch では、以下 2 つのアルゴリズムを提供しています。HNSW は多くのメモリを必要としますが、高速な検索が可能です。IVF はメモリ効率が良好ですが、検索にあたっては事前トレーニングが必要です。各アルゴリズムの詳細については、AWS Bigdata blog の OpenSearch における 10 億規模のユースケースに適した k-NN アルゴリズムの選定 に詳しい解説が掲載されています。
本ラボでは、トレーニングが不要な HNSW を使用します。
OpenSearch におけるベクトル検索の流れは以下の通りです。
ベクトルデータの登録
検索
本ラボでは、ノートブック環境(JupyterLab) および Amazon OpenSearch Service、Amazon SageMaker を使用します。
!pip install opensearch-py requests-aws4auth "awswrangler[opensearch]" --quiet
from IPython.core.magic import register_cell_magic
from IPython import get_ipython
import ipywidgets as widgets
import boto3
import json
import time
import logging
from tqdm import tqdm
from datetime import datetime, timedelta
from functools import lru_cache
import awswrangler as wr
import pandas as pd
import numpy as np
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
import sagemaker
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
/opt/conda/lib/python3.12/site-packages/pydantic/_internal/_fields.py:192: UserWarning: Field name "json" in "MonitoringDatasetFormat" shadows an attribute in parent "Base" warnings.warn(
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml
以降の処理を実行する際に必要なヘルパー関数を定義しておきます。
def search_cloudformation_output(stackname, key):
cloudformation_client = boto3.client("cloudformation", region_name=default_region)
for output in cloudformation_client.describe_stacks(StackName=stackname)["Stacks"][0]["Outputs"]:
if output["OutputKey"] == key:
return output["OutputValue"]
raise ValueError(f"{key} is not found in outputs of {stackname}.")
def get_huggingface_tei_image_uri(instance_type, region):
key = "huggingface-tei" if instance_type.startswith("ml.g") or instance_type.startswith("ml.p") else "huggingface-tei-cpu"
return get_huggingface_llm_image_uri(key, version="1.2.3", region=region)
@lru_cache(maxsize=None)
def list_instance_quotas_for_realtime_inference(instance_family, region):
service_quotas_client = boto3.client("service-quotas", region_name=region)
quotas = []
paginator = service_quotas_client.get_paginator("list_service_quotas")
page_iterator = paginator.paginate(ServiceCode="sagemaker",
PaginationConfig={"MaxResults": 100})
for page in page_iterator:
for quota in page["Quotas"]:
if quota["QuotaName"].endswith("endpoint usage") and quota["Value"] > 0 and quota["QuotaName"].startswith("ml."+instance_family):
quotas.append(quota)
return quotas
def list_instance_usages(region):
sagemaker_client = boto3.client("sagemaker", region_name=region)
response = sagemaker_client.list_endpoints(
)
instances = []
for endpoint in response["Endpoints"]:
response = sagemaker_client.describe_endpoint(EndpointName=endpoint["EndpointName"])
response = sagemaker_client.describe_endpoint_config(EndpointConfigName=response["EndpointConfigName"])
if "InstanceType" in response["ProductionVariants"][0]:
instances.append(response["ProductionVariants"][0]["InstanceType"])
values, counts = np.unique(instances, return_counts=True)
return values,counts
def list_instance_attributes_realtime_inference(instance_family, region):
pricing = boto3.client("pricing", region_name="us-east-1")
instance_types = []
paginator = pricing.get_paginator("get_products")
page_iterator = paginator.paginate(
ServiceCode="AmazonSageMaker",
Filters=[
{
"Type": "TERM_MATCH",
"Field": "productFamily",
"Value": "ML Instance"
},
{
"Type": "TERM_MATCH",
"Field": "regionCode",
"Value": region
},
{
"Type": "TERM_MATCH",
"Field": "platoinstancetype",
"Value": "Hosting"
},
{
"Type": "TERM_MATCH",
"Field": "platoinstancename",
"Value": instance_family
},
],
)
products = []
for page in page_iterator:
for product in page["PriceList"]:
products.append(json.loads(product)["product"]["attributes"])
return products
def list_available_instance_types_for_realtime_inference(instance_family, region):
quotas = list_instance_quotas_for_realtime_inference(instance_family=instance_family, region=region)
quotas_df = pd.json_normalize(quotas).loc[:,["QuotaName","Value"]]
quotas_df["InstanceType"] = quotas_df["QuotaName"].str.removesuffix(" for endpoint usage")
quotas_df = quotas_df.drop(columns=["QuotaName"]).rename(columns={"Value":"Limit"})
quotas_df["Limit"] = quotas_df["Limit"].astype(int)
usage_values,usage_counts = list_instance_usages(region=region)
usages_df = pd.DataFrame({"InstanceType": usage_values, "Usage": usage_counts})
attributes = list_instance_attributes_realtime_inference(instance_family=instance_family, region=region)
attributes_df = pd.json_normalize(attributes)
attributes_df = attributes_df.loc[:,["instanceName","vCpu"]].rename(columns={"instanceName": "InstanceType"})
merged_df = pd.merge(pd.merge(quotas_df, usages_df, how="left", on="InstanceType"), attributes_df, on='InstanceType').fillna(value=0)
filtered_df = merged_df.query("Usage<Limit")
filtered_df["vCpu"] = filtered_df["vCpu"].astype(int)
instance_types = filtered_df.sort_values("vCpu", ascending=True, ignore_index=True).InstanceType
return instance_types.array
def search_sagemaker_inference_endpoint(endpoint_name, region):
sagemaker_client = boto3.client("sagemaker", region_name=region)
try:
response = sagemaker_client.search(
Resource="Endpoint",
SearchExpression={
"Filters": [
{
"Name": "EndpointName",
"Operator": "Contains",
"Value": endpoint_name
},
],
},
SortBy="LastModifiedTime",
SortOrder="Descending"
)
return response["Results"]
except Exception as e:
print(e)
default_region = boto3.Session().region_name
logging.getLogger().setLevel(logging.ERROR)
Amazon SageMaker に埋め込みモデル BAAI/bge-m3 を実行するリアルタイム推論エンドポイントを作成します。
重複してエンドポイントを作成しないように、推論エンドポイントが既に作成されている場合はモデルの作成をスキップし、既存のエンドポイント情報を返します。
try:
sagemaker_region = default_region
embedding_model_id_on_hf = "BAAI/bge-m3"
embedding_model_name = embedding_model_id_on_hf.lower().replace("/", "-")
realtime_inference_endpoint_type = "instance"
embedding_endpoint_name_prefix = f"{embedding_model_name}-{realtime_inference_endpoint_type}"
response = search_sagemaker_inference_endpoint(embedding_endpoint_name_prefix, sagemaker_region)
if len(response) == 0:
print(f"Model {embedding_model_id_on_hf} is not deployed on a realtime {realtime_inference_endpoint_type} endpoint." )
hub = {
"HF_MODEL_ID": embedding_model_id_on_hf,
"HF_TASK":"feature-extraction", # NLP task you want to use for predictions
}
role = sagemaker.get_execution_role()
available_instance_types = list_available_instance_types_for_realtime_inference(instance_family="g5", region=sagemaker_region)
if (len(available_instance_types)) > 0:
instance_type = available_instance_types[0]
else:
print("There is no available gpu instance type. Trying to use cpu instance type.")
available_instance_types = list_available_instance_types_for_realtime_inference(instance_family="c5", region=sagemaker_region)
if (len(available_instance_types)) > 0:
instance_type = available_instance_types[0]
else:
raise ValueException("There is no available instance types. Please change deployment option to on serverless endpoint, and try again.")
print(f"Found an eligable instance type. {instance_type} will be used for a realtime {realtime_inference_endpoint_type} endpoint.")
embedding_model = HuggingFaceModel(
name=sagemaker.utils.name_from_base(embedding_model_name),
env=hub, # configuration for loading model from Hub
role=role, # iam role with permissions to create an Endpoint
image_uri=get_huggingface_tei_image_uri(instance_type=instance_type, region=sagemaker_region)
)
print(f"start deploy {embedding_model_id_on_hf} on a realtime {realtime_inference_endpoint_type} endpoint.")
embedding_model.deploy(
endpoint_name=sagemaker.utils.name_from_base(embedding_endpoint_name_prefix),
initial_instance_count=1,
instance_type=instance_type
)
embedding_endpoint_name = embedding_model.endpoint_name
else:
print(f"Model {embedding_model_id_on_hf} is already deployed on a realtime {realtime_inference_endpoint_type} endpoint.")
embedding_endpoint = response[0]["Endpoint"]
embedding_endpoint_name = embedding_endpoint["EndpointName"]
embedding_endpoint_url = f"https://runtime.sagemaker.{default_region}.amazonaws.com/endpoints/{embedding_endpoint_name}/invocations"
print(f"\nendpoint name: {embedding_endpoint_name}")
print(f"endpoint url: {embedding_endpoint_url}")
except Exception as e:
print(e)
Model BAAI/bge-m3 is not deployed on a realtime instance endpoint. Found an eligable instance type. ml.g5.xlarge will be used for a realtime instance endpoint. start deploy BAAI/bge-m3 on a realtime instance endpoint. -------------! endpoint name: baai-bge-m3-instance-2025-04-14-05-40-48-158 endpoint url: https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/baai-bge-m3-instance-2025-04-14-05-40-48-158/invocations
%%time
payload = {
"inputs": ["hello world!"]
}
body = bytes(json.dumps(payload), 'utf-8')
sagemaker_runtime_client = boto3.client("sagemaker-runtime", region_name=sagemaker_region)
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=embedding_endpoint_name,
ContentType="application/json",
Accept="application/json",
Body=body
)
embeddings = eval(response['Body'].read().decode('utf-8'))
print("embedding:")
print(np.array(embeddings[0]))
print("dimension:")
print(np.shape(embeddings[0]))
embedding: [-0.03411743 0.02971338 -0.04121593 ... 0.02504918 -0.03434042 0.00949565] dimension: (1024,) CPU times: user 19.3 ms, sys: 0 ns, total: 19.3 ms Wall time: 128 ms
サンプルデータをダウンロードし、Pandas の DataFrame 形式に変換します
%%time
dataset_dir = "./dataset/jsquad/"
%mkdir -p $dataset_dir
!curl -L -s -o $dataset_dir/valid.json https://github.com/yahoojapan/JGLUE/raw/main/datasets/jsquad-v1.3/valid-v1.3.json
#!curl -L -s -o $dataset_dir/train.json https://github.com/yahoojapan/JGLUE/raw/main/datasets/jsquad-v1.3/train-v1.3.json
CPU times: user 11 ms, sys: 15.8 ms, total: 26.8 ms Wall time: 1.1 s
%%time
import pandas as pd
import json
def squad_json_to_dataframe(input_file_path, record_path=["data", "paragraphs", "qas", "answers"]):
file = json.loads(open(input_file_path).read())
m = pd.json_normalize(file, record_path[:-1])
r = pd.json_normalize(file, record_path[:-2])
idx = np.repeat(r["context"].values, r.qas.str.len())
m["context"] = idx
m["answers"] = m["answers"]
m["answers"] = m["answers"].apply(lambda x: np.unique(pd.json_normalize(x)["text"].to_list()))
return m[["id", "question", "context", "answers"]]
valid_filename = f"{dataset_dir}/valid.json"
valid_df = squad_json_to_dataframe(valid_filename)
#train_filename = f"{dataset_dir}/train.json"
#train_df = squad_json_to_dataframe(train_filename)
CPU times: user 1.4 s, sys: 15.4 ms, total: 1.41 s Wall time: 1.4 s
サンプルデータは質問文フィールドの question、回答の answers、説明文の context フィールド、問題 ID である id フィールドから構成されています。
サンプルデータの一部を見ていきましょう。
valid_df
id | question | context | answers | |
---|---|---|---|---|
0 | a10336p0q0 | 日本で梅雨がないのは北海道とどこか。 | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [小笠原諸島, 小笠原諸島を除く日本] |
1 | a10336p0q1 | 梅雨とは何季の一種か? | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [雨季] |
2 | a10336p0q2 | 梅雨は、世界的にどのあたりで見られる気象ですか? | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [東アジア, 東アジアの広範囲] |
3 | a10336p0q3 | 梅雨がみられるのはどの期間? | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [5月から7月, 5月から7月にかけて] |
4 | a10336p1q0 | 入梅は何の目安の時期か? | 梅雨 [SEP] 梅雨の時期が始まることを梅雨入りや入梅(にゅうばい)といい、社会通念上・気... | [春の終わりであるとともに夏の始まり(初夏), 田植えの時期, 田植えの時期の目安] |
... | ... | ... | ... | ... |
4437 | a95156p5q3 | 国際銀行間通信協会ならびに国際決済機関の何と何も企業体である | 多国籍企業 [SEP] 国際銀行間通信協会ならびに国際決済機関のクリアストリームとユーロクリ... | [クリアストリームとユーロクリア] |
4438 | a95156p6q0 | ゼネコンはどの国特有の形態か | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [日本] |
4439 | a95156p6q1 | 多国籍企業においてゼネコンはどこの国特有の形態であるか? | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [日本] |
4440 | a95156p6q2 | 多国籍企業を一つ挙げよ | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [イタルチェメンティ, ラファージュホルシム] |
4441 | a95156p6q3 | ゼネコンはどの国の特有の形態か? | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [日本] |
4442 rows × 4 columns
ドメイン(クラスター)に接続するためのエンドポイント情報を CloudFormation スタックの出力から取得し、OpenSearch クライアントを作成します。
cloudformation_stack_name = "search-lab-jp"
opensearch_cluster_endpoint = search_cloudformation_output(cloudformation_stack_name, "OpenSearchDomainEndpoint")
credentials = boto3.Session().get_credentials()
service_code = "es"
auth = AWSV4SignerAuth(credentials=credentials, region=default_region, service=service_code)
opensearch_client = OpenSearch(
hosts=[{"host": opensearch_cluster_endpoint, "port": 443}],
http_compress=True,
http_auth=auth,
use_ssl=True,
verify_certs=True,
connection_class = RequestsHttpConnection
)
opensearch_client
<OpenSearch([{'host': 'vpc-opensearchservi-cyiiwlmtgk2r-jkahtzltl6zkpgjx5czbxq5vvy.us-west-2.es.amazonaws.com', 'port': 443}])>
OpenSearch クラスターへのネットワーク接続性が確保されており、OpenSearch の Security 機能により API リクエストが許可されているかを確認します。 レスポンスに cluster_name や cluster_uuid が含まれていれば、接続確認が無事完了したと判断できます
opensearch_client.info()
{'name': 'cf756e86f83b28e0bd2ffe2ff501ccf4', 'cluster_name': '123456789012:opensearchservi-cyiiwlmtgk2r', 'cluster_uuid': 'UoIf1GJCTauJlwbKQrxTUA', 'version': {'distribution': 'opensearch', 'number': '2.17.0', 'build_type': 'tar', 'build_hash': 'unknown', 'build_date': '2025-02-14T09:38:50.023788640Z', 'build_snapshot': False, 'lucene_version': '9.11.1', 'minimum_wire_compatibility_version': '7.10.0', 'minimum_index_compatibility_version': '7.0.0'}, 'tagline': 'The OpenSearch Project: https://opensearch.org/'}
id、question、context、answers フィールドを格納するための文字列型フィールドに加えて、question、context フィールドから生成したベクトルデータを格納するための context_embedding、question_embedding フィールドを持つインデックスを作成します。
question、context、answers フィールドについては、テキスト検索でもある程度の検索精度を出せるように、id フィールドを除いて sudachi のカスタムアナライザーをセットしています。
OpenSearch では、ベクトルデータを格納するためのフィールドタイプとして knn_vector タイプを提供しています。
payload = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"question": {"type": "text", "analyzer": "custom_sudachi_analyzer"},
"context": {"type": "text", "analyzer": "custom_sudachi_analyzer"},
"answers": {"type": "text", "analyzer": "custom_sudachi_analyzer"},
"question_embedding": {
"type": "knn_vector",
"dimension": 1024,
"space_type": "l2",
"method": {
"name": "hnsw",
"engine": "faiss",
}
},
"context_embedding": {
"type": "knn_vector",
"dimension": 1024,
"space_type": "l2",
"method": {
"name": "hnsw",
"engine": "faiss",
},
}
}
},
"settings": {
"index.knn": True,
"index.number_of_shards": 1,
"index.number_of_replicas": 0,
"analysis": {
"analyzer": {
"custom_sudachi_analyzer": {
"char_filter": ["icu_normalizer"],
"filter": [
"sudachi_normalizedform",
"custom_sudachi_part_of_speech"
],
"tokenizer": "sudachi_tokenizer",
"type": "custom"
}
},
"filter": {
"custom_sudachi_part_of_speech": {
"type": "sudachi_part_of_speech",
"stoptags": ["感動詞,フィラー","接頭辞","代名詞","副詞","助詞","助動詞","動詞,一般,*,*,*,終止形-一般","名詞,普通名詞,副詞可能"]
}
}
}
}
}
# インデックス名を指定
index_name = "jsquad-knn"
try:
# 既に同名のインデックスが存在する場合、いったん削除を行う
print("# delete index")
response = opensearch_client.indices.delete(index=index_name)
print(json.dumps(response, indent=2))
except Exception as e:
print(e)
# インデックスを作成
response = opensearch_client.indices.create(index_name, body=payload)
response
# delete index NotFoundError(404, 'index_not_found_exception', 'no such index [jsquad-knn]', jsquad-knn, index_or_alias)
{'acknowledged': True, 'shards_acknowledged': True, 'index': 'jsquad-knn'}
サンプルデータにベクトルデータを追加し、OpenSearch に格納します。
DataFrame 形式に加工したサンプルデータの question フィールドと context フィールドを対象に、SageMaker 上で稼働している埋め込みモデルの推論エンドポイントを呼び出してベクトルデータを生成、結合する処理を実行します。
%%time
def get_df_with_embeddings(input_df, field_mappings, embedding_endpoint_name, sagemaker_region, batch_size):
output_df = pd.DataFrame([]) #create empty dataframe
df_list = np.array_split(input_df, input_df.shape[0]/batch_size)
for df in tqdm(df_list):
index = df.index #backup index number
df_with_embeddings = df
for field_mapping in field_mappings:
input_field_name = field_mapping["InputFieldName"]
embedding_field_name = field_mapping["EmbeddingFieldName"]
#index = df.index #backup index number
payload = {
"inputs": df_with_embeddings[input_field_name].values.tolist()
}
body = bytes(json.dumps(payload), 'utf-8')
sagemaker_runtime_client = boto3.client("sagemaker-runtime", region_name=sagemaker_region)
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=embedding_endpoint_name,
ContentType="application/json",
Accept="application/json",
Body=body
)
embeddings = eval(response['Body'].read().decode('utf-8'))
df_with_embeddings = pd.concat([df_with_embeddings.reset_index(drop=True), pd.Series(embeddings,name=embedding_field_name).reset_index(drop=True)],axis=1) #join embedding results to source dataframe
df_with_embeddings = df_with_embeddings.set_index(index) #restore index number
output_df = pd.concat([output_df, df_with_embeddings])
return output_df
valid_df_with_embeddings = get_df_with_embeddings(
input_df=valid_df,
field_mappings=[
{"InputFieldName": "question", "EmbeddingFieldName": "question_embedding"},
{"InputFieldName": "context", "EmbeddingFieldName": "context_embedding"},
],
embedding_endpoint_name=embedding_endpoint_name,
sagemaker_region=sagemaker_region,
batch_size=20
)
/opt/conda/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead. return bound(*args, **kwds) 100%|██████████| 222/222 [00:54<00:00, 4.07it/s]
CPU times: user 27.3 s, sys: 537 ms, total: 27.9 s Wall time: 54.6 s
実行後の DataFrame は以下の通りです。数値配列の question_embedding フィールドおよび context_embedding フィールドが追加されていることが確認できます。
valid_df_with_embeddings
id | question | context | answers | question_embedding | context_embedding | |
---|---|---|---|---|---|---|
0 | a10336p0q0 | 日本で梅雨がないのは北海道とどこか。 | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [小笠原諸島, 小笠原諸島を除く日本] | [-0.025198229, 0.011716546, -0.02442373, -0.03... | [-0.0050813365, 0.005010302, -0.053455852, -0.... |
1 | a10336p0q1 | 梅雨とは何季の一種か? | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [雨季] | [-0.03248885, 0.0052572014, -0.027507849, -0.0... | [-0.0050813365, 0.005010302, -0.053455852, -0.... |
2 | a10336p0q2 | 梅雨は、世界的にどのあたりで見られる気象ですか? | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [東アジア, 東アジアの広範囲] | [-0.01773896, 0.024719613, -0.03772002, -0.009... | [-0.0050813365, 0.005010302, -0.053455852, -0.... |
3 | a10336p0q3 | 梅雨がみられるのはどの期間? | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の... | [5月から7月, 5月から7月にかけて] | [-0.041717306, 0.019786566, -0.036074746, -0.0... | [-0.0050813365, 0.005010302, -0.053455852, -0.... |
4 | a10336p1q0 | 入梅は何の目安の時期か? | 梅雨 [SEP] 梅雨の時期が始まることを梅雨入りや入梅(にゅうばい)といい、社会通念上・気... | [春の終わりであるとともに夏の始まり(初夏), 田植えの時期, 田植えの時期の目安] | [-0.010945866, 0.039105743, -0.03545712, -1.30... | [-0.0046881377, 0.036880646, -0.02849782, 0.01... |
... | ... | ... | ... | ... | ... | ... |
4437 | a95156p5q3 | 国際銀行間通信協会ならびに国際決済機関の何と何も企業体である | 多国籍企業 [SEP] 国際銀行間通信協会ならびに国際決済機関のクリアストリームとユーロクリ... | [クリアストリームとユーロクリア] | [0.028033428, 0.015143309, -0.02286987, 0.0034... | [-0.03483603, 0.0009232421, -0.0002205119, 0.0... |
4438 | a95156p6q0 | ゼネコンはどの国特有の形態か | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [日本] | [-0.0072188377, 0.027737496, 0.0027397075, -0.... | [0.005166414, 0.04251439, -0.0018962601, -0.00... |
4439 | a95156p6q1 | 多国籍企業においてゼネコンはどこの国特有の形態であるか? | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [日本] | [-0.027483532, 0.0289756, 0.005295922, 0.00192... | [0.005166414, 0.04251439, -0.0018962601, -0.00... |
4440 | a95156p6q2 | 多国籍企業を一つ挙げよ | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [イタルチェメンティ, ラファージュホルシム] | [-0.029884888, 0.046195686, -0.057726365, -0.0... | [0.005166414, 0.04251439, -0.0018962601, -0.00... |
4441 | a95156p6q3 | ゼネコンはどの国の特有の形態か? | 多国籍企業 [SEP] ゼネコンは日本特有の形態。セメントメジャーにラファージュホルシムやイ... | [日本] | [-0.011311548, 0.027140185, 0.002658496, -0.02... | [0.005166414, 0.04251439, -0.0018962601, -0.00... |
4442 rows × 6 columns
ドキュメントのロードを行います。ドキュメントのロードは "OpenSearch の基本概念・基本操作の理解" でも解説した通り bulk API を使用することで効率よく進められますが、データ処理フレームワークを利用することでより簡単にデータを取り込むことも可能です。本ワークショップでは、AWS SDK for Pandas を使用したデータ取り込みを行います。
%%time
index_name = "jsquad-knn"
response = wr.opensearch.index_df(
client=opensearch_client,
df=valid_df_with_embeddings,
use_threads=True,
id_keys=["id"],
index=index_name,
bulk_size=200, # 200 件ずつ書き込み
refresh=False,
)
CPU times: user 37.2 s, sys: 261 ms, total: 37.4 s Wall time: 43.1 s
response["success"] の値が DataFrame の件数と一致しているかを確認します。True が表示される場合は全件登録に成功していると判断できます。
response["success"] == valid_df["id"].count()
True
本ラボではデータ登録時に意図的に Refresh オプションを無効化しているため、念のため Refresh API を実行し、登録されたドキュメントが確実に検索可能となるようにします
index_name = "jsquad-knn"
response = opensearch_client.indices.refresh(index_name)
response = opensearch_client.indices.forcemerge(index_name, max_num_segments=1)
テキスト検索とベクトル検索を実行し、結果を比較していきます。
テキスト検索のヒット率は検索キーワードとインデックスに格納されたコンテンツの内容、およびアナライザーによる正規化設定により左右されます。 テキスト検索では極力不要なキーワードは排除して検索が実行されることが好まれます。以下のような単語の組み合わせによる検索で性能を発揮します。
index_name = "jsquad-knn"
query = "日本 梅雨 ない どこ"
payload = {
"query": {
"match": {
"question": {
"query": query,
"operator": "and"
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": 10
}
response = opensearch_client.search(
index=index_name,
body=payload
)
pd.json_normalize(response["hits"]["hits"])
_index | _id | _score | fields.question | fields.answers | fields.context | |
---|---|---|---|---|---|---|
0 | jsquad-knn | a10336p0q0 | 14.434446 | [日本で梅雨がないのは北海道とどこか。] | [小笠原諸島, 小笠原諸島を除く日本] | [梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
1 | jsquad-knn | a10336p24q1 | 14.434446 | [梅雨が日本の中でない地域はどこか。] | [北海道, 東北地方] | [梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
一方、以下のような会話に近いクエリは、ノイズが増加するためうまく処理できない場合があります。
index_name = "jsquad-knn"
query = "日本で梅雨がない場所は?"
payload = {
"query": {
"match": {
"question": {
"query": query,
"operator": "and"
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": 10
}
response = opensearch_client.search(
index=index_name,
body=payload
)
response
{'took': 2, 'timed_out': False, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}, 'hits': {'total': {'value': 0, 'relation': 'eq'}, 'max_score': None, 'hits': []}}
従来は、minimum_should_match といったパラメーターによるチューニングを行ってきました。以下は検索クエリに含まれるトークンのうち 75% がマッチするドキュメントを返却するクエリです。
index_name = "jsquad-knn"
query = "日本で梅雨がない場所は?"
payload = {
"query": {
"match": {
"question": {
"query": query,
"operator": "or",
"minimum_should_match": "75%"
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": 10
}
response = opensearch_client.search(
index=index_name,
body=payload
)
pd.json_normalize(response["hits"]["hits"])
_index | _id | _score | fields.question | fields.answers | fields.context | |
---|---|---|---|---|---|---|
0 | jsquad-knn | a10336p0q0 | 14.434446 | [日本で梅雨がないのは北海道とどこか。] | [小笠原諸島, 小笠原諸島を除く日本] | [梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
1 | jsquad-knn | a10336p24q1 | 14.434446 | [梅雨が日本の中でない地域はどこか。] | [北海道, 東北地方] | [梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
2 | jsquad-knn | a10336p32q2 | 14.325790 | [気候学的には梅雨はないとされている場所は?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
テキスト検索では対応が難しい会話に近い問い合わせ分をベクトル検索で処理していきます。
OpenSearch では knn クエリを使用してベクトル検索を実行します。vector フィールドにはベクトルデータを、k には取得したい近似ベクトルの件数を指定しています。
OpenSearch は knn クエリも分散実行されるため、インデックスの構成によっては k の値と戻りの総件数の値が異なる場合があります。k の値と size の値はそろえることを推奨しています。詳細は The number of returned results を参照してください。
以下のコードでは、クエリテキストを Amazon SageMaker の推論エンドポイントに渡してベクトルデータを生成し、knn クエリの vector パラメーターに渡しています。
index_name = "jsquad-knn"
query = "日本で梅雨がない場所は?"
def text_to_embedding(text, region_name, embedding_endpoint_name):
payload = {
"inputs": [
query
]
}
body = bytes(json.dumps(payload), 'utf-8')
sagemaker_runtime_client = boto3.client("sagemaker-runtime", region_name=region_name)
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=embedding_endpoint_name,
ContentType="application/json",
Accept="application/json",
Body=body
)
embeddings = eval(response['Body'].read().decode('utf-8'))
return embeddings[0]
vector = text_to_embedding(text=query, region_name=sagemaker_region, embedding_endpoint_name=embedding_endpoint_name)
k = 10
payload = {
"query": {
"knn": {
"question_embedding": {
"vector": vector,
"k": k
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": k
}
response = opensearch_client.search(
index=index_name,
body=payload
)
pd.json_normalize(response["hits"]["hits"])
_index | _id | _score | fields.question | fields.answers | fields.context | |
---|---|---|---|---|---|---|
0 | jsquad-knn | a10336p24q1 | 0.803372 | [梅雨が日本の中でない地域はどこか。] | [北海道, 東北地方] | [梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
1 | jsquad-knn | a10336p32q3 | 0.799814 | [梅雨がないとされている都道府県はどこ?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
2 | jsquad-knn | a10336p32q2 | 0.799580 | [気候学的には梅雨はないとされている場所は?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
3 | jsquad-knn | a10336p0q0 | 0.782644 | [日本で梅雨がないのは北海道とどこか。] | [小笠原諸島, 小笠原諸島を除く日本] | [梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
4 | jsquad-knn | a10336p18q0 | 0.773163 | [日本の地域で本格的な長雨に突入しない場所はどこか。] | [北海道] | [梅雨 [SEP] 次に梅雨前線は中国の江淮(長江流域・淮河流域)に北上する。6月下旬には華... |
5 | jsquad-knn | a10336p42q4 | 0.669258 | [ほとんど雨が降らない梅雨を何という?] | [空梅雨, 空梅雨(からつゆ)] | [梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
6 | jsquad-knn | a10336p42q2 | 0.667771 | [梅雨の期間中ほとんど雨が降らない場合をなんという?] | [空梅雨, 空梅雨(からつゆ)] | [梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
7 | jsquad-knn | a10336p42q1 | 0.663624 | [梅雨の期間中ほとんど雨が降らない場合を何と呼ぶ?] | [空梅雨, 空梅雨(からつゆ)] | [梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
8 | jsquad-knn | a10336p42q3 | 0.657195 | [ほとんど雨が降らない梅雨を何と呼ぶか] | [空梅雨, 空梅雨(からつゆ)] | [梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
9 | jsquad-knn | a10336p42q0 | 0.654841 | [梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことをなんというか?] | [空梅雨, 空梅雨(からつゆ)] | [梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
ベクトル検索は _score が 1 未満となります。ベクトル検索においては、クエリと対象ドキュメントの距離が近いほど距離の値は小さくなります。 したがって、距離を 0 から 1 の間で正規化したうえで、1 から距離を引いた値をスコア(関連度)としています。
knn search では、上位 k 個のベクトルという条件以外に、こうした距離やスコアを使った絞り込みが可能です。フィルタリングには以下のオプションを使用可能です。これらのオプションは k と併用不可能です。
index_name = "jsquad-knn"
query = "日本で梅雨がない場所は?"
def text_to_embedding(text, region_name, embedding_endpoint_name):
payload = {
"inputs": [
query
]
}
body = bytes(json.dumps(payload), 'utf-8')
sagemaker_runtime_client = boto3.client("sagemaker-runtime", region_name=region_name)
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=embedding_endpoint_name,
ContentType="application/json",
Accept="application/json",
Body=body
)
embeddings = eval(response['Body'].read().decode('utf-8'))
return embeddings[0]
vector = text_to_embedding(text=query, region_name=sagemaker_region, embedding_endpoint_name=embedding_endpoint_name)
min_score に 0.7 をセットした結果は以下の通りです。0.7 以上のスコアのベクトルのみが返却されました。
k = 10
payload = {
"query": {
"knn": {
"question_embedding": {
"vector": vector,
"min_score": 0.7
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": k,
}
response = opensearch_client.search(
index=index_name,
body=payload
)
pd.json_normalize(response["hits"]["hits"])
_index | _id | _score | fields.question | fields.answers | fields.context | |
---|---|---|---|---|---|---|
0 | jsquad-knn | a10336p24q1 | 0.803372 | [梅雨が日本の中でない地域はどこか。] | [北海道, 東北地方] | [梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
1 | jsquad-knn | a10336p32q3 | 0.799814 | [梅雨がないとされている都道府県はどこ?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
2 | jsquad-knn | a10336p32q2 | 0.799580 | [気候学的には梅雨はないとされている場所は?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
3 | jsquad-knn | a10336p0q0 | 0.782644 | [日本で梅雨がないのは北海道とどこか。] | [小笠原諸島, 小笠原諸島を除く日本] | [梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
4 | jsquad-knn | a10336p18q0 | 0.773163 | [日本の地域で本格的な長雨に突入しない場所はどこか。] | [北海道] | [梅雨 [SEP] 次に梅雨前線は中国の江淮(長江流域・淮河流域)に北上する。6月下旬には華... |
max_distance に 0.3 をセットしても同様の結果が得られます。
1 - 0.3 = 0.7
に相当するスコアのベクトルが変えるためです。
k = 10
payload = {
"query": {
"knn": {
"question_embedding": {
"vector": vector,
"max_distance": 0.3
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": k,
}
response = opensearch_client.search(
index=index_name,
body=payload
)
pd.json_normalize(response["hits"]["hits"])
_index | _id | _score | fields.question | fields.answers | fields.context | |
---|---|---|---|---|---|---|
0 | jsquad-knn | a10336p24q1 | 0.803372 | [梅雨が日本の中でない地域はどこか。] | [北海道, 東北地方] | [梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
1 | jsquad-knn | a10336p32q3 | 0.799814 | [梅雨がないとされている都道府県はどこ?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
2 | jsquad-knn | a10336p32q2 | 0.799580 | [気候学的には梅雨はないとされている場所は?] | [北海道] | [梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
3 | jsquad-knn | a10336p0q0 | 0.782644 | [日本で梅雨がないのは北海道とどこか。] | [小笠原諸島, 小笠原諸島を除く日本] | [梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
4 | jsquad-knn | a10336p18q0 | 0.773163 | [日本の地域で本格的な長雨に突入しない場所はどこか。] | [北海道] | [梅雨 [SEP] 次に梅雨前線は中国の江淮(長江流域・淮河流域)に北上する。6月下旬には華... |
ラボを通して、全文検索では対応が難しいクエリをベクトル検索で処理できることが確認できました。時間がある方は、続いて以下のラボも実施してみましょう。
ダウンロードしたデータセットを削除します。./dataset ディレクトリ配下に何もない場合は、./dataset ディレクトリも合わせて削除します。
%rm -rf {dataset_dir}
%rmdir ./dataset