#!/usr/bin/env python # coding: utf-8 # # セマンティックリランキング (Amazon SageMaker 編) # ## 概要 # リランキングは、検索結果の並べ替えを行うアプローチです。リランキングにはいくつかの手法が存在します。 # # 本ラボでは、クロスエンコーダーモデルによるセマンティックリランキングを活用した検索の改善効果を確認していきます。 # # ### 前提事項 # 本ラボの実施にあたっては、以下のラボを事前に完了している必要があります。これらのラボで作成したインデックスを元にリランキングを行っていきます。 # - [ベクトル検索の実装 (Amazon SageMaker 編)](../vector-search/vector-search-with-sagemaker.ipynb) # - [ニューラル検索の実装 (Amazon SageMaker 編)](../vector-search/neural-search-with-sagemaker.ipynb) # - [ハイブリッド検索 (Amazon SageMaker 編)](../hybrid-search/hybrid-search-with-sagemaker.ipynb) # # ### 使用するモデル # 本ラボでは、Apache license 2.0 ライセンスで公開されている [BAAI/bge-reranker-v2-m3][bge-reranker-v2-m3] を Amazon SageMaker 上にデプロイし、リランキングに使用します。モデルの詳細については Hugging Face 上の解説を参照してください。 # # [bge-reranker-v2-m3]: https://huggingface.co/BAAI/bge-reranker-v2-m3 # ## 事前作業 # ### パッケージインストール # In[1]: get_ipython().system('pip install opensearch-py requests-aws4auth --quiet') # ### インポート # In[2]: from IPython.core.magic import register_cell_magic from IPython import get_ipython import ipywidgets as widgets import boto3 import json import time import logging from tqdm import tqdm from datetime import datetime, timedelta from functools import lru_cache import pandas as pd import numpy as np from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth import sagemaker from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri # ### ヘルパー関数の定義 # 以降の処理を実行する際に必要なヘルパー関数を定義しておきます。 # In[3]: def search_cloudformation_output(stackname, key): cloudformation_client = boto3.client("cloudformation", region_name=default_region) for output in cloudformation_client.describe_stacks(StackName=stackname)["Stacks"][0]["Outputs"]: if output["OutputKey"] == key: return output["OutputValue"] raise ValueError(f"{key} is not found in outputs of {stackname}.") def get_huggingface_tei_image_uri(instance_type, region): key = "huggingface-tei" if instance_type.startswith("ml.g") or instance_type.startswith("ml.p") else "huggingface-tei-cpu" return get_huggingface_llm_image_uri(key, version="1.2.3", region=region) @lru_cache(maxsize=None) def list_instance_quotas_for_realtime_inference(instance_family, region): service_quotas_client = boto3.client("service-quotas", region_name=region) quotas = [] paginator = service_quotas_client.get_paginator('list_service_quotas') page_iterator = paginator.paginate(ServiceCode='sagemaker', PaginationConfig={'MaxResults': 100}) for page in page_iterator: for quota in page['Quotas']: if quota["QuotaName"].endswith("endpoint usage") and quota["Value"] > 0 and quota["QuotaName"].startswith("ml."+instance_family): quotas.append(quota) return quotas def list_instance_usages(region): sagemaker_client = boto3.client("sagemaker", region_name=region) response = sagemaker_client.list_endpoints( ) instances = [] for endpoint in response["Endpoints"]: response = sagemaker_client.describe_endpoint(EndpointName=endpoint["EndpointName"]) response = sagemaker_client.describe_endpoint_config(EndpointConfigName=response["EndpointConfigName"]) if "InstanceType" in response["ProductionVariants"][0]: instances.append(response["ProductionVariants"][0]["InstanceType"]) values, counts = np.unique(instances, return_counts=True) return values,counts def list_instance_attributes_realtime_inference(instance_family, region): pricing = boto3.client("pricing", region_name="us-east-1") instance_types = [] paginator = pricing.get_paginator("get_products") page_iterator = paginator.paginate( ServiceCode="AmazonSageMaker", Filters=[ { "Type": "TERM_MATCH", "Field": "productFamily", "Value": "ML Instance" }, { "Type": "TERM_MATCH", "Field": "regionCode", "Value": region }, { "Type": "TERM_MATCH", "Field": "platoinstancetype", "Value": "Hosting" }, { "Type": "TERM_MATCH", "Field": "platoinstancename", "Value": instance_family }, ], ) products = [] for page in page_iterator: for product in page["PriceList"]: products.append(json.loads(product)["product"]["attributes"]) return products def list_available_instance_types_for_realtime_inference(instance_family, region): quotas = list_instance_quotas_for_realtime_inference(instance_family=instance_family, region=region) quotas_df = pd.json_normalize(quotas).loc[:,["QuotaName","Value"]] quotas_df["InstanceType"] = quotas_df["QuotaName"].str.removesuffix(" for endpoint usage") quotas_df = quotas_df.drop(columns=["QuotaName"]).rename(columns={"Value":"Limit"}) quotas_df["Limit"] = quotas_df["Limit"].astype(int) usage_values,usage_counts = list_instance_usages(region=region) usages_df = pd.DataFrame({"InstanceType": usage_values, "Usage": usage_counts}) attributes = list_instance_attributes_realtime_inference(instance_family=instance_family, region=region) attributes_df = pd.json_normalize(attributes) attributes_df = attributes_df.loc[:,["instanceName","vCpu"]].rename(columns={"instanceName": "InstanceType"}) merged_df = pd.merge(pd.merge(quotas_df, usages_df, how="left", on="InstanceType"), attributes_df, on='InstanceType').fillna(value=0) filtered_df = merged_df.query("Usage 0: instance_type = available_instance_types[0] else: print("There is no available gpu instance type. Trying to use cpu instance type.") available_instance_types = list_available_instance_types_for_realtime_inference(instance_family="c5", region=sagemaker_region) if (len(available_instance_types)) > 0: instance_type = available_instance_types[0] else: raise ValueException("There is no available instance types. Please change deployment option to on serverless endpoint, and try again.") print(f"Found an eligable instance type. {instance_type} will be used for a realtime {realtime_inference_endpoint_type} endpoint.") reranking_model = HuggingFaceModel( name=sagemaker.utils.name_from_base(reranking_model_name), env=hub, # configuration for loading model from Hub role=role, # iam role with permissions to create an Endpoint image_uri=get_huggingface_tei_image_uri(instance_type=instance_type, region=sagemaker_region) ) print(f"start deploy {reranking_model_id_on_hf} on a realtime {realtime_inference_endpoint_type} endpoint.") reranking_model.deploy( endpoint_name=sagemaker.utils.name_from_base(reranking_endpoint_name_prefix), initial_instance_count=1, instance_type=instance_type ) reranking_endpoint_name = reranking_model.endpoint_name else: print(f"Model {reranking_model_id_on_hf} is already deployed on a realtime {realtime_inference_endpoint_type} endpoint.") reranking_endpoint = response[0]["Endpoint"] reranking_endpoint_name = reranking_endpoint["EndpointName"] reranking_endpoint_url = f"https://runtime.sagemaker.{default_region}.amazonaws.com/endpoints/{reranking_endpoint_name}/invocations" print(f"\nendpoint name: {reranking_endpoint_name}") print(f"endpoint url: {reranking_endpoint_url}") except Exception as e: print(e) # ##### 推論エンドポイントのテスト呼び出し # 推論エンドポイントに対してテスト呼び出しを実行します。推論エンドポイントからは、ドキュメントリストのインデックス番号とスコアが返却されます。結果はスコアの高い順番に返却されます。 # In[6]: get_ipython().run_cell_magic('time', '', '\npayload = {\n "query":"What is the capital city of America?",\n "texts":[\n "Carson City is the capital city of the American state of Nevada.",\n "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",\n "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",\n "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."\n ]\n}\n\nbody = bytes(json.dumps(payload), \'utf-8\')\n\nsagemaker_runtime_client = boto3.client("sagemaker-runtime", region_name=sagemaker_region)\nresponse = sagemaker_runtime_client.invoke_endpoint(\n EndpointName=reranking_endpoint_name,\n ContentType="application/json",\n Accept="application/json",\n Body=body\n)\n\nresult = eval(response[\'Body\'].read().decode(\'utf-8\'))\nprint(json.dumps(result, indent=2))\n') # 送信したドキュメントリストと順番を揃える場合は、index でソートする必要があります。 # In[7]: print(json.dumps(sorted(result, key=lambda x: x['index']),indent=2)) # ### OpenSearch 関連リソースの作成 # #### OpenSearch クライアントの作成 # ドメイン(クラスター)に接続するためのエンドポイント情報を CloudFormation スタックの出力から取得し、OpenSearch クライアントを作成します。 # In[8]: cloudformation_stack_name = "search-lab-jp" opensearch_cluster_endpoint = search_cloudformation_output(cloudformation_stack_name, "OpenSearchDomainEndpoint") credentials = boto3.Session().get_credentials() service_code = "es" auth = AWSV4SignerAuth(credentials=credentials, region=default_region, service=service_code) opensearch_client = OpenSearch( hosts=[{"host": opensearch_cluster_endpoint, "port": 443}], http_compress=True, http_auth=auth, use_ssl=True, verify_certs=True, connection_class = RequestsHttpConnection ) opensearch_client # OpenSearch クラスターへのネットワーク接続性が確保されており、OpenSearch の Security 機能により API リクエストが許可されているかを確認します。 # レスポンスに cluster_name や cluster_uuid が含まれていれば、接続確認が無事完了したと判断できます # In[9]: opensearch_client.info() # #### OpenSearch へのモデル登録 # SageMaker 上にデプロイしたモデルを呼び出すためのコンポーネントを作成します。 # # モデルは、コネクタと呼ばれる外部接続を定義したコンポーネントで構成されています。 # 今回の構成では、モデルは Rerank Processor から呼び出されます。Rerank processor は、リランキングモデルと連携し、クエリと検索結果の組み合わせを元に検索結果の並べ替えを行います。 # # # # ##### Amazon SageMaker のモデル情報・エンドポイント情報の確認 # 本セルの実行でエラーが発生する場合は、再度 SageMaker 上でのモデルデプロイをお試しください。 # In[10]: print(f"model name: {reranking_model_name}") print(f"endpoint name: {reranking_endpoint_name}") print(f"endpoint url: {reranking_endpoint_url}") # ##### コネクタ用 IAM Role ARN の確認 # OpenSearch コネクタから AWS サービスに接続する際、任意の IAM ロールの権限を引き受ける必要があります。引受対象の IAM ロールを CloudFormation スタックの出力から取得します。 # In[11]: cloudformation_stack_name = "search-lab-jp" opensearch_connector_role_arn = search_cloudformation_output(cloudformation_stack_name, 'OpenSearchConnectorRoleArn') opensearch_connector_role_arn # ##### コネクタの作成 # Amazon SageMaker 上のモデルを呼び出す定義を記載したコネクタを作成します。 # コネクタは、OpenSearch におけるモデルの一要素です。 # # コネクタの処理の流れは以下の通りです。 # # 1. pre_process_function の定義を元に、OpenSearch の Search pipline 内の Rerank processor から与えられた入力から、推論エンドポイントに与えるパラメーターを作成 # 1. pre_process_function によって変換されたパラメーターを元に、request_body の定義に沿ってペイロードを組み立て、推論エンドポイントの呼び出しを行う # 1. post_process_function の定義を元に、推論エンドポイントから返却された推論結果を加工し、Rerank processor に返却 # # In[12]: payload = { "name": reranking_model_name, "description": "Remote connector for "+ reranking_model_name, "version": 1, "protocol": "aws_sigv4", "credential": { "roleArn": opensearch_connector_role_arn }, "parameters": { "region": default_region, "service_name": "sagemaker" }, "actions": [ { "action_type": "predict", "method": "POST", "url": reranking_endpoint_url, "headers": { "content-type": "application/json" }, "pre_process_function": """ def query_text = params.query_text; def text_docs = params.text_docs; def textDocsBuilder = new StringBuilder('['); for (int i=0; i