#!/usr/bin/env python # coding: utf-8 # In[ ]: import pyarrow as pa import pyarrow.parquet as pq import pyarrow.flight as flight import numpy as np import pandas as pd import time import threading from pyarrow.util import find_free_port # # Implement a Flight server in Python # # This server has a few goals # # * Clients can send ("put") datasets, to be kept in memory by the server # * Clients can request a list of cached datasets ("list-tables") # * Clients can request ("get") a cached table # # Note that this server is very simple and does not show some of the more sophisticated "query planning" capabilities of Arrow Flight, nor does it show parallel or multi-part access. My goal is to show you that # # * It's easy to write a Flight service in Python # * The performance of Flight is **very, very good** # In[ ]: class DemoServer(flight.FlightServerBase): def __init__(self, location): self._cache = {} super().__init__(location) def list_actions(self, context): return [flight.ActionType('list-tables', 'List stored tables'), flight.ActionType('drop-table', 'Drop a stored table')] # ----------------------------------------------------------------- # Implement actions def do_action(self, context, action): handlers = { 'list-tables': self._list_tables, 'drop-table': self._drop_table } handler = handlers.get(action.type) if not handler: raise NotImplementedError return handlers[action.type](action) def _drop_table(self, action): del self._cache[action.body] def _list_tables(self, action): return iter([flight.Result(cache_key) for cache_key in sorted(self._cache.keys())]) # ----------------------------------------------------------------- # Implement puts def do_put(self, context, descriptor, reader, writer): self._cache[descriptor.command] = reader.read_all() # ----------------------------------------------------------------- # Implement gets def do_get(self, context, ticket): table = self._cache[ticket.ticket] return flight.RecordBatchStream(table) # Some helper utilities, you can ignore this part # ## Start server in background, connect client # In[ ]: get_ipython().run_line_magic('pinfo', 'pa.ipc.IpcWriteOptions') # In[117]: port = 1337 location = flight.Location.for_grpc_tcp("localhost", find_free_port()) location server = DemoServer(location) thread = threading.Thread(target=lambda: server.serve(), daemon=True) thread.start() class DemoClient: def __init__(self, location, options=None): self.con = flight.connect(location) self.con.wait_for_available() self.options = options # Call "list-tables" RPC and return results as Python list def list_tables(self): action = flight.Action('list-tables', b'') return [x.body.to_pybytes().decode('utf8') for x in self.con.do_action(action)] # Send a pyarrow.Table to the server to be cached def cache_table_in_server(self, name, table): desc = flight.FlightDescriptor.for_command(name.encode('utf8')) put_writer, put_meta_reader = self.con.do_put(desc, table.schema, options=self.options) put_writer.write(table) put_writer.close() # Request a pyarrow.Table by name def get_table(self, name): reader = self.con.do_get(flight.Ticket(name.encode('utf8')), options=self.options) return reader.read_all() def list_actions(self): return self.con.list_actions() ipc_options = pa.ipc.IpcWriteOptions(compression='zstd') options = flight.FlightCallOptions(write_options=ipc_options) client = DemoClient(location, options=options) # ### Ask server for supported actions # In[118]: table = pa.table([pa.array([1,2,3,4,5])], names=['f0']) client.cache_table_in_server('table1', table) # In[119]: client.list_tables() # In[120]: client.cache_table_in_server('table2', table) client.cache_table_in_server('table3', table) client.cache_table_in_server('table4', table) # In[121]: client.list_tables() # In[122]: client.get_table('table1') # ### Now let's make a much bigger table and test performance # In[ ]: # fec = pd.read_csv('/home/wesm/code/pydata-book/datasets/fec/P00000001-ALL.csv', # low_memory=False) # table = pa.table(fec) # pq.write_table(table, 'fec-2012.parquet') # In[123]: fec_table = pq.read_table('fec-2012.parquet') # In[124]: fec_table = pa.concat_tables([fec_table] * 10) # In[125]: # How big is it? out = pa.BufferOutputStream() with pa.ipc.RecordBatchStreamWriter(out, fec_table.schema, options=ipc_options) as writer: writer.write(fec_table) num_bytes = len(out.getvalue()) # In[126]: print(f'Table is {num_bytes / (1 << 30)} gigabytes') # In[127]: get_ipython().run_cell_magic('time', '', "client.cache_table_in_server('fec_table', fec_table)\n") # In[128]: client.list_tables() # In[129]: get_ipython().run_cell_magic('time', '', "fec_table_received = client.get_table('fec_table')\n") # In[ ]: