pip install -r requirements.txt python main.py """ CLI assistant that uses Databricks MCP Vector Search and UC Functions via the OpenAI Agents SDK. """ import asyncio import os import httpx from typing import Dict, Any from agents import Agent, Runner, function_tool, gen_trace_id, trace from agents.exceptions import ( InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered, ) from agents.model_settings import ModelSettings from databricks_mcp import DatabricksOAuthClientProvider from databricks.sdk import WorkspaceClient from supply_chain_guardrails import supply_chain_guardrail CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main") # override catalog, schema, functions_path name if your data assets sit in a different location SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db") FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db") DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT") # override if using a different profile name HTTP_TIMEOUT = 30.0 # seconds async def _databricks_ctx(): """Return (workspace, PAT token, base_url).""" ws = WorkspaceClient(profile=DATABRICKS_PROFILE) token = DatabricksOAuthClientProvider(ws).get_token() return ws, token, ws.config.host @function_tool async def vector_search(query: str) -> Dict[str, Any]: """Query Databricks MCP Vector Search index.""" ws, token, base_url = await _databricks_ctx() url = f"{base_url}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}" headers = {"Authorization": f"Bearer {token}"} async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: resp = await client.post(url, json={"query": query}, headers=headers) resp.raise_for_status() return resp.json() @function_tool async def uc_function(function_name: str, params: Dict[str, Any]) -> Dict[str, Any]: """Invoke a Databricks Unity Catalog function with parameters.""" ws, token, base_url = await _databricks_ctx() url = f"{base_url}/api/2.0/mcp/functions/{FUNCTIONS_PATH}" headers = {"Authorization": f"Bearer {token}"} payload = {"function": function_name, "params": params} async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: resp = await client.post(url, json=payload, headers=headers) resp.raise_for_status() return resp.json() async def run_agent(): agent = Agent( name="Assistant", instructions="You are a supply-chain assistant for Databricks MCP; you must answer **only** questions that are **strictly** about supply-chain data, logistics, inventory, procurement, demand forecasting, etc; for every answer you must call one of the registered tools; if the user asks anything not related to supply chain, reply **exactly** with 'Sorry, I can only help with supply-chain questions'.", tools=[vector_search, uc_function], model_settings=ModelSettings(model="gpt-4o", tool_choice="required"), output_guardrails=[supply_chain_guardrail], ) print("Databricks MCP assistant ready. Type a question or 'exit' to quit.") while True: user_input = input("You: ").strip() if user_input.lower() in {"exit", "quit"}: break trace_id = gen_trace_id() with trace(workflow_name="Databricks MCP Agent", trace_id=trace_id): try: result = await Runner.run(starting_agent=agent, input=user_input) print("Assistant:", result.final_output) except InputGuardrailTripwireTriggered: print("Assistant: Sorry, I can only help with supply-chain questions.") except OutputGuardrailTripwireTriggered: print("Assistant: Sorry, I can only help with supply-chain questions.") def main(): asyncio.run(run_agent()) if __name__ == "__main__": main() """ Databricks OAuth client provider for MCP servers. """ class DatabricksOAuthClientProvider: def __init__(self, ws): self.ws = ws def get_token(self): # For Databricks SDK >=0.57.0, token is available as ws.config.token return self.ws.config.token """ Output guardrail that blocks answers not related to supply-chain topics. """ from __future__ import annotations from pydantic import BaseModel from agents import Agent, Runner, GuardrailFunctionOutput from agents import output_guardrail from agents.run_context import RunContextWrapper class SupplyChainCheckOutput(BaseModel): reasoning: str is_supply_chain: bool guardrail_agent = Agent( name="Supply-chain check", instructions=( "Check if the text is within the domain of supply-chain analytics and operations " "Return JSON strictly matching the SupplyChainCheckOutput schema" ), output_type=SupplyChainCheckOutput, ) @output_guardrail async def supply_chain_guardrail( ctx: RunContextWrapper, agent: Agent, output ) -> GuardrailFunctionOutput: """Output guardrail that blocks non-supply-chain answers""" text = output if isinstance(output, str) else getattr(output, "response", str(output)) result = await Runner.run(guardrail_agent, text, context=ctx.context) return GuardrailFunctionOutput( output_info=result.final_output, tripwire_triggered=not result.final_output.is_supply_chain, ) python -m uvicorn api_server:app --reload --port 8000 """ FastAPI wrapper that exposes the agent as a streaming `/chat` endpoint. """ import os import asyncio import logging from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from agents.exceptions import ( InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered, ) from agents import Agent, Runner, gen_trace_id, trace from agents.mcp import MCPServerStreamableHttp, MCPServerStreamableHttpParams from agents.model_settings import ModelSettings from databricks_mcp import DatabricksOAuthClientProvider from databricks.sdk import WorkspaceClient from supply_chain_guardrails import supply_chain_guardrail CATALOG = os.getenv("MCP_VECTOR_CATALOG", "main") SCHEMA = os.getenv("MCP_VECTOR_SCHEMA", "supply_chain_db") FUNCTIONS_PATH = os.getenv("MCP_FUNCTIONS_PATH", "main/supply_chain_db") DATABRICKS_PROFILE = os.getenv("DATABRICKS_PROFILE", "DEFAULT") HTTP_TIMEOUT = 30.0 # seconds app = FastAPI() # Allow local dev front‑end app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:5173"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ChatRequest(BaseModel): message: str async def build_mcp_servers(): """Initialise Databricks MCP vector & UC‑function servers.""" ws = WorkspaceClient(profile=DATABRICKS_PROFILE) token = DatabricksOAuthClientProvider(ws).get_token() base = ws.config.host vector_url = f"{base}/api/2.0/mcp/vector-search/{CATALOG}/{SCHEMA}" fn_url = f"{base}/api/2.0/mcp/functions/{FUNCTIONS_PATH}" async def _proxy_tool(request_json: dict, url: str): import httpx headers = {"Authorization": f"Bearer {token}"} async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: resp = await client.post(url, json=request_json, headers=headers) resp.raise_for_status() return resp.json() headers = {"Authorization": f"Bearer {token}"} servers = [ MCPServerStreamableHttp( MCPServerStreamableHttpParams( url=vector_url, headers=headers, timeout=HTTP_TIMEOUT, ), name="vector_search", client_session_timeout_seconds=60, ), MCPServerStreamableHttp( MCPServerStreamableHttpParams( url=fn_url, headers=headers, timeout=HTTP_TIMEOUT, ), name="uc_functions", client_session_timeout_seconds=60, ), ] # Ensure servers are initialized before use await asyncio.gather(*(s.connect() for s in servers)) return servers @app.post("/chat") async def chat_endpoint(req: ChatRequest): try: servers = await build_mcp_servers() agent = Agent( name="Assistant", instructions="Use the tools to answer the questions.", mcp_servers=servers, model_settings=ModelSettings(tool_choice="required"), output_guardrails=[supply_chain_guardrail], ) trace_id = gen_trace_id() async def agent_stream(): logging.info(f"[AGENT_STREAM] Input message: {req.message}") try: with trace(workflow_name="Databricks MCP Example", trace_id=trace_id): result = await Runner.run(starting_agent=agent, input=req.message) logging.info(f"[AGENT_STREAM] Raw agent result: {result}") try: logging.info( f"[AGENT_STREAM] RunResult __dict__: {getattr(result, '__dict__', str(result))}" ) raw_responses = getattr(result, "raw_responses", None) logging.info(f"[AGENT_STREAM] RunResult raw_responses: {raw_responses}") except Exception as log_exc: logging.warning(f"[AGENT_STREAM] Could not log RunResult details: {log_exc}") yield result.final_output except InputGuardrailTripwireTriggered: # Off-topic question denied by guardrail yield "Sorry, I can only help with supply-chain questions." except OutputGuardrailTripwireTriggered: # Out-of-scope answer blocked by guardrail yield "Sorry, I can only help with supply-chain questions." except Exception: logging.exception("[AGENT_STREAM] Exception during agent run") yield "[ERROR] Exception during agent run. Check backend logs for details." return StreamingResponse(agent_stream(), media_type="text/plain") except Exception: logging.exception("chat_endpoint failed") return StreamingResponse( (line.encode() for line in ["Internal server error 🙈"]), media_type="text/plain", status_code=500, ) cd ui npm install npm run dev """ Code snippet handling the token stream coming from the FastAPI /chat endpoint. """ const reader = response.body.getReader(); while (true) { const { done, value } = await reader.read(); if (done) break; assistantMsg.text += new TextDecoder().decode(value); setMessages(m => { const copy = [...m]; copy[copy.length - 1] = { ...assistantMsg }; return copy; }); }