DC2 object Run1.1p Apache Spark tutorial -- Part I: Apache Spark access

Author: Julien Peloton @JulienPeloton
Based on the work of: Francois Lanusse @EiffL, Javier Sanchez @fjaviersanchez using GCR.
Last Verifed to Run: 2018-10-22

This notebook will illustrate the basics of accessing the merged object catalogs through Apache Spark as well as how to select useful samples of stars/galaxies from the dpdd catalogs (from DM outputs). It follows the same steps (and sometimes same documentation) as in this notebook which uses the Generic Catalog Reader (GCR).

Learning objectives:

After going through this notebook, you should be able to:

  1. Load and efficiently access a DC2 object catalog with Apache Spark
  2. Understand and have references for the catalog schema
  3. Apply cuts to the catalog using Spark SQL functionalities
  4. Have an example of quality cuts and simple star/galaxy separation cut
  5. Distribute the computation and the routine to plot to be faster!

Logistics: This notebook is intended to be run through the JupyterHub NERSC interface with the desc-pyspark kernel. The kernel is automatically installed in your environment when you use the kernel setup script:

source /global/common/software/lsst/common/miniconda/kernels/setup.sh

For more information see LSSTDESC/jupyter-kernels. Note that a general introduction and tutorials for Apache Spark can be found at astrolabsoftware/spark-tutorials (under construction).

Note concerning resources

The large-memory login node used by https://jupyter-dev.nersc.gov/
is a shared resource, so please be careful not to use too many CPUs
or too much memory.

That means avoid using `--master local[*]` in your kernel, but limit
the resources to a few core. Typically `--master local[4]` is enough for
prototyping a program.

Then to scale the analysis, the best is to switch to batch mode! 
There, no limit!

This is already taken care for you in the Spark+DESC kernel setup script (from desc-pyspark), but keep this in mind if you use a custom kernel.

In [1]:
%matplotlib inline

import os

import numpy as np
import matplotlib.pyplot as plt

Accessing the object catalog with Apache Spark

In this section, we illustrate how to use Apache Spark to access the object catalogs from DC2 Run1.1p. Let's initialise Spark and load the data into a DataFrame. We will focus on data stored in the parquet data format.

In [2]:
# Where the dpdd data is stored
base_dir = '/global/projecta/projectdirs/lsst/global/in2p3/Run1.1/summary'

# Load one patch, all tracts
datafile = os.path.join(base_dir, 'dpdd_object.parquet')
print("Data will be read from: \n", datafile)
Data will be read from: 
 /global/projecta/projectdirs/lsst/global/in2p3/Run1.1/summary/dpdd_object.parquet
In [3]:
from pyspark.sql import SparkSession

# Initialise our Spark session
spark = SparkSession.builder.getOrCreate()

# Read the data as DataFrame
df = spark.read.format("parquet").load(datafile)

DC2 Object catalog Schema

To see the quantities available in the catalog, you can use the following:

In [4]:
# Check what we have in the file
df.printSchema()
root
 |-- magerr_i: double (nullable = true)
 |-- psFlux_i: double (nullable = true)
 |-- Ixx_r: double (nullable = true)
 |-- mag_i_cModel: double (nullable = true)
 |-- IxxPSF_u: double (nullable = true)
 |-- magerr_r: double (nullable = true)
 |-- psf_fwhm_i: double (nullable = true)
 |-- psf_fwhm_r: double (nullable = true)
 |-- Ixx: double (nullable = true)
 |-- magerr_g_cModel: double (nullable = true)
 |-- I_flag_y: boolean (nullable = true)
 |-- Iyy_z: double (nullable = true)
 |-- IxyPSF_i: double (nullable = true)
 |-- Ixx_z: double (nullable = true)
 |-- magerr_u_cModel: double (nullable = true)
 |-- IxyPSF: double (nullable = true)
 |-- snr_u_cModel: double (nullable = true)
 |-- IxxPSF_y: double (nullable = true)
 |-- psFlux_flag_i: boolean (nullable = true)
 |-- IyyPSF_g: double (nullable = true)
 |-- Ixy: double (nullable = true)
 |-- magerr_y: double (nullable = true)
 |-- psFlux_g: double (nullable = true)
 |-- snr_y_cModel: double (nullable = true)
 |-- Ixy_z: double (nullable = true)
 |-- psFlux_flag_r: boolean (nullable = true)
 |-- Iyy_g: double (nullable = true)
 |-- psFluxErr_r: double (nullable = true)
 |-- Ixx_i: double (nullable = true)
 |-- snr_z_cModel: double (nullable = true)
 |-- psf_fwhm_g: double (nullable = true)
 |-- Ixx_y: double (nullable = true)
 |-- I_flag: boolean (nullable = true)
 |-- magerr_z: double (nullable = true)
 |-- I_flag_i: boolean (nullable = true)
 |-- IyyPSF_i: double (nullable = true)
 |-- yErr: float (nullable = true)
 |-- magerr_r_cModel: double (nullable = true)
 |-- magerr_i_cModel: double (nullable = true)
 |-- clean: boolean (nullable = true)
 |-- IxyPSF_y: double (nullable = true)
 |-- mag_g: double (nullable = true)
 |-- mag_r: double (nullable = true)
 |-- psf_fwhm_y: double (nullable = true)
 |-- IxyPSF_g: double (nullable = true)
 |-- ra: double (nullable = true)
 |-- extendedness: double (nullable = true)
 |-- IxxPSF: double (nullable = true)
 |-- x: double (nullable = true)
 |-- Ixx_u: double (nullable = true)
 |-- mag_g_cModel: double (nullable = true)
 |-- psFluxErr_u: double (nullable = true)
 |-- I_flag_r: boolean (nullable = true)
 |-- IyyPSF_r: double (nullable = true)
 |-- psFluxErr_y: double (nullable = true)
 |-- psNdata: float (nullable = true)
 |-- psFlux_y: double (nullable = true)
 |-- psFlux_u: double (nullable = true)
 |-- Iyy: double (nullable = true)
 |-- IxxPSF_r: double (nullable = true)
 |-- mag_u: double (nullable = true)
 |-- dec: double (nullable = true)
 |-- IxyPSF_z: double (nullable = true)
 |-- mag_y: double (nullable = true)
 |-- Ixx_g: double (nullable = true)
 |-- Ixy_y: double (nullable = true)
 |-- IxxPSF_i: double (nullable = true)
 |-- blendedness: double (nullable = true)
 |-- Iyy_y: double (nullable = true)
 |-- IxxPSF_g: double (nullable = true)
 |-- psFlux_flag_u: boolean (nullable = true)
 |-- psFluxErr_g: double (nullable = true)
 |-- Iyy_r: double (nullable = true)
 |-- magerr_u: double (nullable = true)
 |-- I_flag_g: boolean (nullable = true)
 |-- snr_i_cModel: double (nullable = true)
 |-- psFluxErr_i: double (nullable = true)
 |-- IxxPSF_z: double (nullable = true)
 |-- IyyPSF_y: double (nullable = true)
 |-- I_flag_z: boolean (nullable = true)
 |-- snr_r_cModel: double (nullable = true)
 |-- Ixy_g: double (nullable = true)
 |-- mag_z: double (nullable = true)
 |-- IyyPSF_z: double (nullable = true)
 |-- psFlux_r: double (nullable = true)
 |-- IxyPSF_r: double (nullable = true)
 |-- psFlux_flag_g: boolean (nullable = true)
 |-- Iyy_u: double (nullable = true)
 |-- psf_fwhm_u: double (nullable = true)
 |-- objectId: long (nullable = true)
 |-- magerr_z_cModel: double (nullable = true)
 |-- snr_g_cModel: double (nullable = true)
 |-- psFlux_flag_y: boolean (nullable = true)
 |-- I_flag_u: boolean (nullable = true)
 |-- Ixy_u: double (nullable = true)
 |-- mag_i: double (nullable = true)
 |-- psFluxErr_z: double (nullable = true)
 |-- good: boolean (nullable = true)
 |-- Ixy_r: double (nullable = true)
 |-- parentObjectId: long (nullable = true)
 |-- Ixy_i: double (nullable = true)
 |-- IyyPSF: double (nullable = true)
 |-- xErr: float (nullable = true)
 |-- mag_u_cModel: double (nullable = true)
 |-- xy_flag: boolean (nullable = true)
 |-- IyyPSF_u: double (nullable = true)
 |-- mag_r_cModel: double (nullable = true)
 |-- mag_y_cModel: double (nullable = true)
 |-- magerr_g: double (nullable = true)
 |-- mag_z_cModel: double (nullable = true)
 |-- psf_fwhm_z: double (nullable = true)
 |-- psFlux_flag_z: boolean (nullable = true)
 |-- IxyPSF_u: double (nullable = true)
 |-- Iyy_i: double (nullable = true)
 |-- psFlux_z: double (nullable = true)
 |-- y: double (nullable = true)
 |-- magerr_y_cModel: double (nullable = true)
 |-- tract: integer (nullable = true)
 |-- patch: string (nullable = true)

The meaning of these fields follows the standard nomenclature of the LSST Data Products Definition Document DPDD.

The DPDD is an effort made by the LSST project to standardize the format of the official Data Release Products (DRP). While the native outputs of the DM stack are succeptible to change, the DPDD will be more stable. An early adoption of these conventions by the DESC will save time and energy down the road.

We can see that the catalog includes:

  • Positions
  • Fluxes and magnitudes (PSF and CModel)
  • Shapes (using GalSim's HSM)
  • Quality flags: e.g, does the source have any interpolated pixels? Has any of the measurement algorithms returned an error?
  • Other useful quantities: blendedness, measure of how flux is affected by neighbors: (1 - flux.child/flux.parent) (see 4.9.11 of Bosch et al. 2018); extendedness, classifies sources in extended and psf-like.

Accessing the data (taken from the original GCR notebook)

While Run1.1p is still of manageable size, full DC2 will be much larger, accessing the whole data can be challenging. In order to access the data efficiently, it is important to understand how it is physically stored and how to access it, one piece at the time.

The coadds produced by the DM stack are structured in terms of large tracts and smaller patches, illustrated here for DC2: Here the tracts have large blue numbers, and the patches are denoted with an (x,y) format. For DC2, each tract has 8x8 patches.

You can learn more about how to make such a plot of the tract and patches here

Obviously Spark preserves the structure of the data so that any particular quantity can be accessed on a tract/patch bases. The tracts available in the catalog can be listed using the following command:

In [5]:
# Show all available tracts
df.select('tract').distinct().show()
+-----+
|tract|
+-----+
| 4637|
| 4848|
| 5066|
| 5062|
| 4851|
| 4639|
| 4852|
| 4849|
| 4638|
| 5064|
| 5065|
| 5063|
| 4431|
| 4433|
| 4432|
| 4640|
| 4636|
| 4850|
| 4430|
+-----+

The DM stack includes functionality to get the tract and patch number corresponding to a certain position (RA,DEC). However, it is out of the scope of this tutorial.

Apache Spark provides filter mechanisms, which you can use to speed up data access if you only need a certain chunks of the dataset. For the object catalog, the chunks are broken into tract and patch, and hence those are the filters you can use:

In [6]:
# Retrieve the ra,dec coordinates of all sources within tract number 4430
data = df.select('ra', 'dec').where('tract == 4430').collect()

# `collect` returns list of list[ra, dec], so for 
# plotting purpose we tranpose the output:
ra, dec = np.transpose(data)

# Plot a 2d histogram of sources
plt.figure(figsize=(10,7))
plt.hist2d(ra, dec, 100)
plt.gca().set_aspect('equal')
plt.colorbar(label='Number of objects')
plt.xlabel('RA [deg]')
plt.ylabel('dec [deg]');

It is interesting to note that there are several ways in Spark to use those filtering mechanisms

Pure SQL

In [7]:
# Pure SQL
cols = "ra, dec"

# SQL - register first the DataFrame
df.createOrReplaceTempView("full_tract")

# Keeps only columns with 0.0 < magerr_g < 0.3
sql_command = """
    SELECT {}
    FROM full_tract 
    WHERE 
        tract == 4430
""".format(cols)

# Execute the expression - return a DataFrame
df_sub = spark.sql(sql_command)
data = df_sub.collect()

Spark DataFrame built-in methods

In [8]:
# Using select/where
data = df.select('ra', 'dec').where('tract == 4430').collect()

# Or using select/filter
data = df.select('ra', 'dec').filter('tract == 4430').collect()

Data type

The data returned by collecting a DataFrame (collect) in Spark is structured as a native Python list of Row:

In [9]:
print("Data type is {}".format(type(data[0])))
print("Example: {}".format(data[0]))
Data type is <class 'pyspark.sql.types.Row'>
Example: Row(ra=54.404930306668895, dec=-31.365531039028056)

But you can easily go back to standard list or numpy array (numpy methods do that for you most of the time) is needed:

In [10]:
arow = data[0]

# Explicit conversion
mylist = list(arow)
print("input type: {} / output type {}".format(type(arow), type(mylist)))

# Implicit conversion
cols = np.transpose(arow)
print("input type: {} / output type {}".format(type(arow), type(cols)))
input type: <class 'pyspark.sql.types.Row'> / output type <class 'list'>
input type: <class 'pyspark.sql.types.Row'> / output type <class 'numpy.ndarray'>

Spark to Pandas DataFrame

A Spark DataFrame can also easily be converted into a Pandas DataFrame:

In [11]:
pdata = df.select('ra', 'dec').where('tract == 4430').toPandas()
pdata
Out[11]:
ra dec
0 54.404930 -31.365531
1 54.410349 -31.365466
2 54.487789 -31.364878
3 54.457632 -31.364972
4 54.414210 -31.364857
5 54.471451 -31.364524
6 54.483118 -31.364427
7 54.494475 -31.364112
8 54.498064 -31.364036
9 54.461496 -31.363817
10 54.444731 -31.363971
11 54.407138 -31.363816
12 54.479225 -31.363226
13 54.480464 -31.363059
14 54.405230 -31.363196
15 54.415338 -31.362927
16 54.439836 -31.362678
17 54.490157 -31.362344
18 54.484377 -31.362356
19 54.495764 -31.362086
20 54.489013 -31.361885
21 54.481877 -31.361311
22 54.431351 -31.361542
23 54.465088 -31.361079
24 54.471405 -31.361044
25 54.449442 -31.360700
26 54.499241 -31.360447
27 54.502800 -31.359996
28 54.464624 -31.359868
29 54.492899 -31.359802
... ... ...
103857 54.139486 -31.989889
103858 54.135182 -31.989360
103859 54.143822 -31.988384
103860 54.140683 -31.986849
103861 54.129652 -31.986672
103862 54.146430 -31.986221
103863 54.132536 -31.984499
103864 54.139372 -31.996074
103865 54.138669 -31.996317
103866 54.137469 -31.996449
103867 54.138623 -31.996733
103868 54.137008 -31.996776
103869 54.137856 -31.996862
103870 54.142427 -31.994695
103871 54.143435 -31.995051
103872 54.149626 -31.989562
103873 54.149592 -31.989129
103874 54.140357 -31.988662
103875 54.140206 -31.989310
103876 54.148349 -31.986927
103877 54.147498 -31.986709
103878 54.405782 -32.406046
103879 54.402353 -32.406134
103880 54.400020 -32.399785
103881 54.404683 -32.397523
103882 54.410731 -32.395946
103883 54.413689 -32.395284
103884 54.429948 -32.394602
103885 54.394665 -32.403016
103886 54.391354 -32.398250

103887 rows × 2 columns

Access time

As a simple test, you can show the advantage of loading one tract at a time compared to the entire catalog:

In [12]:
df_radec = df.select('ra', 'dec')
%time data = df_radec.where('tract == 4430').collect()
CPU times: user 447 ms, sys: 16.3 ms, total: 463 ms
Wall time: 634 ms
In [13]:
%time data = df_radec.collect()
CPU times: user 26.4 s, sys: 854 ms, total: 27.3 s
Wall time: 41 s

Note that we timed the collect action which is very specific (collecting data from the executors to the driver). In practice, we do not perform this action often (only at the very end of the pipeline, because it implies communication btw the machines). In Spark, we do most of the computation (including plot!) in the executors (distributed computation), and we collect the data once it is sufficiently reduced. Therefore, what matters more is the time to load the data and perform an action inside executors. The simplest one (but relevant though!) is to count the elements (O(n) complexity):

In [14]:
%time data = df_radec.where('tract == 4430').count()
CPU times: user 659 ms, sys: 132 ms, total: 791 ms
Wall time: 1.46 s
In [15]:
%time data = df_radec.count()
CPU times: user 2.49 ms, sys: 496 µs, total: 2.99 ms
Wall time: 4.92 s

Here we are, super fast! Note that we loaded the data AND perform a simple action. So these benchmarks give you the IO overhead for this kind of catalogs. More about Apache Spark benchmarks can be found here.

Applying filters and cuts

For more than one cut, this is all the same:

In [16]:
# Simple cut to remove unreliable detections
# More cuts can be added, as a logical AND, by appending GCRQuery objects to this list
# good: 
#   The source has no flagged pixels (interpolated, saturated, edge, clipped...) 
#   and was not skipped by the deblender
# tract == 4849:
#   Data only for tract 4849

# Data after cut (DataFrame)
df_cut = df_radec.where("tract == 4849 AND good")

# Data without cuts (DataFrame)
df_full = df_radec.where("tract == 4849")
In [17]:
print("Number of sources before cut : {}".format(df_full.count()))
print("Number of sources after cut  : {}".format(df_cut.count()))
Number of sources before cut : 785795
Number of sources after cut  : 780845

Plotting the result - the standard way

The standard way means filtering data, collecting data, and plotting (e.g. you would do that in GCR). This can be written as:

In [18]:
# Plot a 2d histogram of sources
plt.figure(figsize=(15, 7))
for index, dataframe in enumerate([df_full, df_cut]):
    ra, dec = np.transpose(dataframe.collect())
    plt.subplot(121 + index)
    (counts, xe, ye, Image) = plt.hist2d(ra, dec, 256); 
    plt.gca().set_aspect('equal'); 
    plt.xlabel('RA [deg]');
    plt.ylabel('dec [deg]');
    if index == 0:
        plt.title('Full sample');
    else:
        plt.title('Clean objects');
    plt.colorbar(label='Number of objects');

Now all of that works because you have only a small fraction of data. Let's imagine if you have TB of data, even after cuts. What would you do? GCR makes use of iterators. This is a workaround, but still not satisfactory as things are done serially. For TB of data it will work, but it will take forever.

Spark point of view: distribute the computation (and plot!)

The way to be faster is to distribute the plot (or the computation which leads to the data to be plotted). Histograms are particularly easy to distribute. Let's write a method to be apply on each Spark partition (each would contain only a fraction of the data):

In [19]:
def hist2d(partition, nbins=256, xyrange=None):
    """ Produce 2D histograms from (x, y) data
    
    Parameters
    ----------
    partition : Iterator
        Iterator containing partition data *[x, y].
        
    Returns
    ----------
    Generator yielding counts for each partition. 
    Counts is an array of dimension nbins x nbins.
    """
    # Unwrap the iterator
    radec = [*partition]
    ra, dec = np.transpose(radec)
    
    (counts, xedges, yedges, Image) = plt.hist2d(
        ra, dec, nbins, 
        range=xyrange)
    
    yield counts

# Min/Max values - just to make a nice plot
xyrange = [[np.min(xe), np.max(xe)], [np.min(ye), np.max(ye)]]

plt.figure(figsize=(15, 7))
for index, dataframe in enumerate([df_full, df_cut]):
    plt.subplot(121 + index)
    # This is the crucial part - build the plot data in parallel!
    im = dataframe\
        .rdd\
        .mapPartitions(lambda partition: hist2d(partition, 256, xyrange))\
        .reduce(lambda x, y: x+y)
    
    plt.imshow(im.T, origin='bottom', aspect='equal');
    plt.xlabel('RA [deg]');
    plt.ylabel('dec [deg]');
    if index == 0:
        plt.title('Full sample');
    else:
        plt.title('Clean objects');
    plt.colorbar(label='Number of objects');