#!/usr/bin/env python # coding: utf-8 # In[18]: import numpy as np import pandas as pd import polars as pl import multiprocess as mp from sklearn.neighbors import BallTree # Create synthetic data of the location of people (latitude/longitude of India, approx) # For 40 M people, they will have on average ~15 connections, for 10 M people, 4 connections np.random.seed(0) N = 40_000_000 lat = np.random.uniform(low=10, high=35, size=N) lon = np.random.uniform(low=70, high=95, size=N) # Make it a pandas dataframe df = pd.DataFrame({'id': range(N), 'lat': lat, 'lon': lon}) # In[19]: # Fit BallTree for fast queries bt = BallTree(df[["lat", "lon"]].values, metric="euclidean") # Approximate distance to match euclidean to geographical distance (1 decimal degree ~ 111 km), # should create only a small error over small distances (but double check) df["neighbor_id"] = bt.query_radius(df[["lat", "lon"]].values, r=1/111) ## This would be more precise but much much slower # bt = BallTree(df[["lat", "lon"]].values, metric="haversine") # radius = 1 / 6371.0 # neighboors = bt.query_radius(df[["lat", "lon"]].values, r=radius) print(df.shape) df.head() # In[ ]: sum([len(x) for x in df["neighbor_id"]])/1E6 # In[ ]: # Convert to polars for efficiency df = pl.LazyFrame(df[["id", "neighbor_id"]]) # Remove from neighor_id the id of column id # Save to disk to allow for larger than memory operations (df .explode("neighbor_id") .filter(pl.col("id") != pl.col("neighbor_id")) .sink_parquet("~/Downloads/neighbors.parquet") ) # In[ ]: # Set sorted makes the join faster, but ensure df is sorted by ID!!!!! (otherwise bad things happen) df = pl.scan_parquet("~/Downloads/neighbors.parquet").set_sorted("id") # In[ ]: df.head().collect() # In[ ]: # Join with itself to find peers of peers and save to parquet (lazy operations, does not need to fit in memory) (df.join(df, left_on="neighbor_id", right_on="id", how="inner") .sink_parquet("~/Downloads/neighbors_of_neighbors.parquet") ) # In[ ]: data = pl.scan_parquet("~/Downloads/neighbors_of_neighbors.parquet") # In[ ]: # Head of file data.head().collect() # In[ ]: # Number of rows data.select(pl.len()).collect() # In[ ]: