Searching stars by stellar spectroscopy - stellar parameters pairing using contrastive objective
import h5py
import warnings
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from stellarperceptron.model import StellarPerceptron
from astropy.io import fits
from astroNN.apogee import allstar
allstar_f = fits.getdata(allstar(dr=17))
# ================== hardware-related settings ==================
device = "cpu" # "cpu" for CPU or "cuda:x" for a NVIDIA GPU
mixed_precision = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# ================== hardware-related settings ==================
# need to load the trained main model first since we need the trained encoder and embeddings
nn_model = StellarPerceptron.load(
"./model_torch/", mixed_precision=mixed_precision, device=device
)
def find_topk_matches(source_id, spec_embeddings, queries_embedding, k=10):
"""
Function to lookup stars in the embedding space
"""
spec_embeddings = torch.nn.functional.normalize(spec_embeddings, p=2, dim=1)
queries_embedding = torch.nn.functional.normalize(queries_embedding, p=2, dim=1)
dot_similarity = torch.matmul(
queries_embedding, torch.transpose(spec_embeddings, 0, 1)
)
results = torch.topk(dot_similarity, k).indices.cpu().numpy()
return [[source_id[idx] for idx in indices] for indices in results]
class SpecEncoder(nn.Module):
def __init__(
self,
trained_model: StellarPerceptron = nn_model,
projection_dims: int = 32,
context_length: int = 64,
dropout_rate: float = 0.1,
device: str = "cpu",
dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.factory_kwargs = {"device": device, "dtype": dtype}
self.base_trained_model = trained_model
self.trained_encoder = trained_model.torch_encoder.eval()
self.trained_nonlinear_embedding = nn_model.embedding_layer
self.embedding_dim = trained_model.embedding_dim
self.dropout_rate = dropout_rate
self.projection_dims = projection_dims
self.context_length = context_length
self.dense_base = torch.nn.Linear(
self.embedding_dim * self.context_length,
self.projection_dims,
**self.factory_kwargs,
)
self.dropout_1 = torch.nn.Dropout(self.dropout_rate)
self.dense_1 = torch.nn.Linear(
self.projection_dims, self.projection_dims, **self.factory_kwargs
)
self.layernorm_1 = torch.nn.LayerNorm(
self.projection_dims, **self.factory_kwargs
)
self.dense_2 = torch.nn.Linear(
self.projection_dims, self.projection_dims, **self.factory_kwargs
)
self.layernorm_2 = torch.nn.LayerNorm(
self.projection_dims, **self.factory_kwargs
)
self.dense_3 = torch.nn.Linear(
self.projection_dims, self.projection_dims, **self.factory_kwargs
)
self.layernorm_3 = torch.nn.LayerNorm(
self.projection_dims, **self.factory_kwargs
)
def forward(self, inputs, inputs_names):
with torch.no_grad(): # non-trainable trained encoder
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.base_trained_model.perceive(
inputs, inputs_names, inference_mode=False
)
embeddings = (
torch.flatten(nn_model._perception_memory, start_dim=1, end_dim=2)
* 1.0
)
projected_embeddings = self.dense_base(embeddings)
# ================== #
x = F.gelu(projected_embeddings)
x = self.dense_1(x)
x = self.dropout_1(x)
projected_embeddings = self.layernorm_1(projected_embeddings + x)
# ================== #
x = F.gelu(projected_embeddings)
x = self.dense_2(x)
x = self.dropout_1(x)
projected_embeddings = self.layernorm_2(projected_embeddings + x)
# ================== #
x = F.gelu(projected_embeddings)
x = self.dense_3(x)
x = self.dropout_1(x)
projected_embeddings = self.layernorm_3(projected_embeddings + x)
# ================== #
return projected_embeddings
def predict(self, inputs, inputs_names):
with torch.inference_mode():
return self(inputs, inputs_names)
class StellarEncoder(nn.Module):
def __init__(
self,
trained_model: StellarPerceptron = nn_model,
projection_dims: int = 128,
context_length: int = 64,
dropout_rate: float = 0.1,
device: str = "cpu",
dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.factory_kwargs = {"device": device, "dtype": dtype}
self.base_trained_model = trained_model
self.trained_encoder = trained_model.torch_encoder.eval()
self.embedding_dim = trained_model.embedding_dim
self.dropout_rate = dropout_rate
self.projection_dims = projection_dims
self.context_length = context_length
self.dense_base = torch.nn.Linear(
self.embedding_dim * self.context_length,
self.projection_dims,
**self.factory_kwargs,
)
self.dropout_1 = torch.nn.Dropout(self.dropout_rate)
self.dense_1 = torch.nn.Linear(
self.projection_dims, self.projection_dims, **self.factory_kwargs
)
self.layernorm_1 = torch.nn.LayerNorm(
self.projection_dims, **self.factory_kwargs
)
self.dense_2 = torch.nn.Linear(
self.projection_dims, self.projection_dims, **self.factory_kwargs
)
self.layernorm_2 = torch.nn.LayerNorm(
self.projection_dims, **self.factory_kwargs
)
self.dense_3 = torch.nn.Linear(
self.projection_dims, self.projection_dims, **self.factory_kwargs
)
self.layernorm_3 = torch.nn.LayerNorm(
self.projection_dims, **self.factory_kwargs
)
def forward(self, inputs, inputs_names):
with torch.no_grad(): # non-trainable trained encoder
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.base_trained_model.perceive(
inputs, inputs_names, inference_mode=False
)
embeddings = (
torch.flatten(nn_model._perception_memory, start_dim=1, end_dim=2)
* 1.0
)
projected_embeddings = self.dense_base(embeddings)
# ================== #
x = F.gelu(projected_embeddings)
x = self.dense_1(x)
x = self.dropout_1(x)
projected_embeddings = self.layernorm_1(projected_embeddings + x)
# ================== #
x = F.gelu(projected_embeddings)
x = self.dense_2(x)
x = self.dropout_1(x)
projected_embeddings = self.layernorm_2(projected_embeddings + x)
# ================== #
x = F.gelu(projected_embeddings)
x = self.dense_3(x)
x = self.dropout_1(x)
projected_embeddings = self.layernorm_3(projected_embeddings + x)
# ================== #
return projected_embeddings
def predict(self, inputs, inputs_names):
with torch.inference_mode():
return self(inputs, inputs_names)
spec_nn = SpecEncoder(device=device, projection_dims=64)
stars_nn = StellarEncoder(device=device, projection_dims=64)
# load the trained model
modelsearch = torch.load(
f"./model_torch_search/model_torch_search.pt", weights_only=True
)
spec_nn.load_state_dict(
modelsearch["specmodel_state_dict"],
strict=True,
)
stars_nn.load_state_dict(
modelsearch["starmodel_state_dict"],
strict=True,
)
spec_nn.eval()
stars_nn.eval()
# load database
stars_database = h5py.File("./model_torch_search/gaia_small_db.h5", "r")
# calculage embeddings from XP spectra only, only using the first 32 bp and rp
inputs_names = [*[f"bp{i}" for i in range(32)], *[f"rp{i}" for i in range(32)]]
spec_embedding = torch.zeros(
(len(stars_database["source_id"][()]), spec_nn.projection_dims)
)
batch_size = 1024
for i in range(len(stars_database["source_id"][()]) // batch_size):
spec_embedding[i * batch_size : (i + 1) * batch_size] = spec_nn.predict(
stars_database["rp32bp32"][()][i * batch_size : (i + 1) * batch_size],
inputs_names,
)
spec_embedding[(i + 1) * batch_size :] = spec_nn.predict(
stars_database["rp32bp32"][()][(i + 1) * batch_size :], inputs_names
)
# chagne the parameters here to setup a query star, similar to the usage of percieve() in our main model
q_embedding = stars_nn.predict([[4700.0, 2.5, 0.0]], ["teff", "logg", "m_h"])
# find the stars with source id in the database that are most similar to the query star
source_id = find_topk_matches(
stars_database["source_id"][()], spec_embedding, q_embedding.cpu(), k=10
)
# cross-match APOGEE allstar catalog to find the T_eff and log(g) of the most similar stars
allstar_idx = np.intersect1d(
np.array(source_id[0], dtype=np.int64),
allstar_f["GAIAEDR3_SOURCE_ID"],
return_indices=True,
)[2]
print("Most similar star in database (Gaia DR3 Source ID):\n", source_id[0])
print("Their T_eff:\n", allstar_f["TEFF"][allstar_idx])
print("Their log(g):\n", allstar_f["LOGG"][allstar_idx])
Most similar star in database (Gaia DR3 Source ID): [704126323012532864, 2608184212654581888, 4761510940323057792, 1632894240354190720, 2266653189281248384, 1422406139514522368, 2133086543966514432, 3933250136788680064, 2105722860647339776, 1223271389585049728] Their T_eff: [4701.993 4706.58 4705.3115 4589.256 4802.4175 4737.822 4703.627 4613.678 4365.014 4674.269 ] Their log(g): [2.4414918 2.4398525 2.3873532 2.4079142 2.50701 2.3648117 2.3617556 2.6147337 3.3534677 2.690805 ]
# chagne the parameters here to setup a query star, similar to the usage of percieve() in our main model
q_embedding = stars_nn([[3900.0, 4.65, 0.0]], ["teff", "logg", "m_h"])
# find the stars with source id in the database that are most similar to the query star
source_id = find_topk_matches(
stars_database["source_id"][()], spec_embedding, q_embedding.cpu(), k=10
)
# cross-match APOGEE allstar catalog to find the T_eff and log(g) of the most similar stars
allstar_idx = np.intersect1d(
np.array(source_id[0], dtype=np.int64),
allstar_f["GAIAEDR3_SOURCE_ID"],
return_indices=True,
)[2]
print("Most similar star in database (Gaia DR3 Source ID):\n", source_id[0])
print("Their T_eff:\n", allstar_f["TEFF"][allstar_idx])
print("Their log(g):\n", allstar_f["LOGG"][allstar_idx])
Most similar star in database (Gaia DR3 Source ID): [1634950739415310848, 3251625804573815552, 3336303180759424640, 2635281985958371584, 1577014963485955968, 2828370624526504064, 41383388583719168, 2255693154297886976, 1273850642449726464, 3337904997402262016] Their T_eff: [3819.9282 3651.9966 4044.0454 3769.9504 3624.199 3859.134 3741.5781 3765.8489 4074.1426 3713.9307] Their log(g): [4.658435 4.6930223 4.6133075 4.652622 4.6738906 4.7005982 4.6687317 4.6484175 3.5922148 4.6665177]
# chagne the parameters here to setup a query star, similar to the usage of percieve() in our main model
q_embedding = stars_nn.predict([[0.6]], ["jh"])
# find the stars with source id in the database that are most similar to the query star
source_id = find_topk_matches(
stars_database["source_id"][()], spec_embedding, q_embedding.cpu(), k=10
)
# cross-match APOGEE allstar catalog to find the T_eff and log(g) of the most similar stars
allstar_idx = np.intersect1d(
np.array(source_id[0], dtype=np.int64),
allstar_f["GAIAEDR3_SOURCE_ID"],
return_indices=True,
)[2]
print("Most similar star in database (Gaia DR3 Source ID):\n", source_id[0])
print("Their J-H:\n", allstar_f["J"][allstar_idx] - allstar_f["H"][allstar_idx])
Most similar star in database (Gaia DR3 Source ID): [2653942966024313472, 4474962404650333184, 1031840951989590272, 438295266462409216, 1248344858902045568, 1876562612822977664, 4522365924339728512, 3379761206049876352, 1393180845569802880, 65161976802151680] Their J-H: [0.625 0.6040001 0.54100037 0.7119999 0.599 0.5550003 0.5959997 0.5770006 0.5830002 0.60200024]