import numpy as np
import pandas as pd
import polars as pl
import multiprocess as mp
from sklearn.neighbors import BallTree
# Create synthetic data of the location of people (latitude/longitude of India, approx)
# For 40 M people, they will have on average ~15 connections, for 10 M people, 4 connections
np.random.seed(0)
N = 40_000_000
lat = np.random.uniform(low=10, high=35, size=N)
lon = np.random.uniform(low=70, high=95, size=N)
# Make it a pandas dataframe
df = pd.DataFrame({'id': range(N), 'lat': lat, 'lon': lon})
# Fit BallTree for fast queries
bt = BallTree(df[["lat", "lon"]].values, metric="euclidean")
# Approximate distance to match euclidean to geographical distance (1 decimal degree ~ 111 km),
# should create only a small error over small distances (but double check)
df["neighbor_id"] = bt.query_radius(df[["lat", "lon"]].values, r=1/111)
## This would be more precise but much much slower
# bt = BallTree(df[["lat", "lon"]].values, metric="haversine")
# radius = 1 / 6371.0
# neighboors = bt.query_radius(df[["lat", "lon"]].values, r=radius)
print(df.shape)
df.head()
sum([len(x) for x in df["neighbor_id"]])/1E6
10520530
# Convert to polars for efficiency
df = pl.LazyFrame(df[["id", "neighbor_id"]])
# Remove from neighor_id the id of column id
# Save to disk to allow for larger than memory operations
(df
.explode("neighbor_id")
.filter(pl.col("id") != pl.col("neighbor_id"))
.sink_parquet("~/Downloads/neighbors.parquet")
)
# Set sorted makes the join faster, but ensure df is sorted by ID!!!!! (otherwise bad things happen)
df = pl.scan_parquet("~/Downloads/neighbors.parquet").set_sorted("id")
df.head().collect()
id | neighbor_id |
---|---|
i64 | i64 |
0 | 1861004 |
0 | 1413696 |
1 | 678927 |
2 | 1056336 |
3 | 3011696 |
# Join with itself to find peers of peers and save to parquet (lazy operations, does not need to fit in memory)
(df.join(df,
left_on="neighbor_id",
right_on="id",
how="inner")
.sink_parquet("~/Downloads/neighbors_of_neighbors.parquet")
)
data = pl.scan_parquet("~/Downloads/neighbors_of_neighbors.parquet")
# Head of file
data.head().collect()
id | neighbor_id | neighbor_id_right |
---|---|---|
i64 | i64 | i64 |
34810 | 3 | 34810 |
16481 | 37 | 16481 |
20681 | 46 | 20681 |
16268 | 73 | 16268 |
32768 | 5040 | 32768 |
# Number of rows
data.select(pl.len()).collect()
len |
---|
u32 |
630 |