#!/usr/bin/env python # coding: utf-8 # # Searching for machine learning models using semantic search # # > Finding models on the Hugging Face hub using semantic search # # - toc: true # - badges: false # - comments: true # - categories: [huggingface, huggingface-datasets, semantic-search] # - search_exclude: false # - badges: true # - image: https://github.com/davanstrien/blog/blob/master/images/hub_model_search.webp?raw=true # The [Hugging Face model hub](https://huggingface.co/models) has (at the time of the last checking) 60,509 models publicly available. Some of these models are useful as base models for further fine-tuning; these include your classics like `bert-base-uncased`. # # The hub also has more obscure indie hits that might already do a good job on your desired downstream task or be a closer start. For example, if one wanted to classify the genre of 18th Century books, it might make sense to start with [a model for classifying 19th Century books](https://huggingface.co/BritishLibraryLabs/bl-books-genre). # # ## Finding candidate models # # Ideally, we'd like a quick way to identify if a model might already do close to what we want. From there, we would likely want to review a bunch of other info about the model before deciding if it might be helpful for us or not. # # Unfortunately, finding suitable models on the hub isn't always that easy. Even knowing that models for genre classification exist on the hub, we don't find any results. # # ![](../images/hub_model_search.webp) # # It's not documented exactly how the search on the hub works, but it seems to be based mainly on the model's name rather than the README or other information. In this blog post, I will continue some [previous experiments with embeddings](https://danielvanstrien.xyz/metadata/deployment/huggingface/ethics/huggingface-datasets/faiss/2022/01/13/image_search.html) to see if there might be different ways in which we could identify potential models. # # This will be a very rough experiment and is more about establishing whether this is an avenue worth exploring rather than a fully fleshed-out approach. # First install some libraries we'll use: # In[1]: import torch # In[2]: deps = ["datasets" ,"sentence-transformers", "rich['jupyter']", "requests"] if torch.cuda.is_available(): deps.append("faiss-gpu") else: deps.append("faise-cpu") # In[3]: get_ipython().run_cell_magic('capture', '', '!pip install {" ".join(deps)} --upgrade\n') # In[4]: get_ipython().system('git config --global credential.helper store') # These days I almost always have the rich extension loaded! # In[5]: get_ipython().run_line_magic('load_ext', 'rich') # ## Using the huggingface_hub API to download some model metadata # # Our goal is to see if we might be able to find suitable models more efficiently using some form of semantic search (i.e. using embeddings). To do this, we should grab some model data from the hub. The easiest way to do this is using the hub API. # In[6]: from huggingface_hub import hf_api import re from rich import print # In[7]: api = hf_api.HfApi() # In[8]: api # We can take a look at some example models # In[9]: all_models = api.list_models() all_models[:3] # For a particular model we can also see what files there are. # In[10]: files = api.list_repo_files(all_models[0].modelId) # In[11]: files # ### Filtering # # To limit the scope of this blog post, we'll focus only on Pytorch models and 'text classification' models. The metadata about the model type is likely usually pretty reliable. The model task metadata, on the other hand, is not always reliable in my experience. This means we probably have some models that aren't text-classification models and don't include some actual text classification models in our dataset. For now, we won't worry too much about this. # In[12]: from huggingface_hub import ModelSearchArguments # In[13]: model_args = ModelSearchArguments() # In[14]: from huggingface_hub import ModelFilter model_filter = ModelFilter( task=model_args.pipeline_tag.TextClassification, library=model_args.library.PyTorch ) api.list_models(filter=model_filter)[0] # Now we have a filter we'll use that to grab all the models that match this filter. # In[15]: all_models = api.list_models(filter=model_filter) # In[16]: all_models[0] # Let's see how many models that gives us. # In[17]: len(all_models) # Later on, in this blog, we'll want to work with the `config.json` files (we'll get back to why later!), so we'll quickly check that all our models have this. # In[18]: def has_config(model): has_config = False files = model.siblings for file in files: if "config.json" in file.rfilename: has_config = True return has_config else: continue # In[19]: has_config(all_models[0]) # In[20]: has_config = [model for model in all_models if has_config(model)] # Let's check how many we have now # In[21]: len(has_config) # We can also download a particular file from the hub # In[22]: from huggingface_hub import hf_hub_download file = hf_hub_download(repo_id=all_models[0].modelId, filename="config.json") # In[23]: file # In[24]: import json with open(file) as f: data = json.load(f) # In[25]: data # We can also check if the model has a `README.md` # In[26]: def has_file_in_repo(model,file_name): has_file = False files = model.siblings for file in files: if file_name in file.rfilename: has_file = True return has_file else: continue # In[27]: has_file_in_repo(has_config[0],'README.md') # In[28]: has_readme = [model for model in has_config if has_file_in_repo(model,"README.md")] # We can see that there are more configs than READMEs # In[29]: len(has_readme) # In[30]: len(has_config) # We now write some functions to grab both the `README.md` and `config.json` files from the hub. # In[ ]: from requests.exceptions import JSONDecodeError import concurrent.futures # In[ ]: @lru_cache(maxsize=None) def get_model_labels(model): try: url = hf_hub_url(repo_id=model.modelId, filename="config.json") return model.modelId, list(requests.get(url).json()['label2id'].keys()) except (KeyError, JSONDecodeError, AttributeError): return model.modelId, None # In[ ]: get_model_labels(has_config[0]) # In[ ]: def get_model_readme(model): url = hf_hub_url(repo_id=model.modelId, filename="README.md") return requests.get(url).text # In[ ]: def get_data(model): readme = get_model_readme(model) _, labels = get_model_labels(model) return model.modelId, labels, readme # Since this takes a little while we make a progress bar and do this using multiple threads # In[ ]: from tqdm.auto import tqdm # In[ ]: with tqdm(total=len(has_config)) as progress: with concurrent.futures.ThreadPoolExecutor() as e: tasks = [] for model in has_config: future = e.submit(get_data, model) future.add_done_callback(lambda p: progress.update()) tasks.append(future) results = [task.result() for task in tasks] # Load our data using Pandas. # In[ ]: import pandas as pd # In[ ]: df = pd.DataFrame(results,columns=['modelId','label','readme']) # In[ ]: df # You can see we now have a DataFrame containing the modelID, the model labels and the `README.md` for each model (where it exists). # # Since the `README.md` (the model card) is the obvious source of information about a model we'll start here. One question we may have is how long our the `README.md` is. Some models have very detailed model cards whilst others have very little information in the model card. We can get a bit of a sense of this by looking at the range of `README.md` lenghts: # In[ ]: df['readme'].apply(len).describe() # We might want to filter on the length of the README so we'll store that info in a new column. # In[ ]: df['readme_len'] = df['readme'].apply(len) # Since we might want to work with this data again, let's load it into a `datasets` Dataset and use `push_to_hub` to store a copy. # In[8]: from datasets import Dataset # In[ ]: ds = Dataset.from_pandas(df) ds # In[9]: from huggingface_hub import notebook_login # In[10]: notebook_login() # In[ ]: ds.push_to_hub('davanstrien/hf_model_metadata') # We can now load it again using `load_dataset`. # In[11]: from datasets import load_dataset # In[12]: ds = load_dataset('davanstrien/hf_model_metadata', split='train') # Clean up some memory... # In[ ]: del df # ## Semantic search of model cards # # We now get to the main point of all of this. Can we use semantic search to try and find models of interest? For this, we'll use the sentence-transformers library. This blog won't cover all the background of this library. The [docs](https://www.sbert.net/index.html) give a helpful overview and some tutorials. # # To start, we'll see if we can search using the information in the `README.md`. This should, in theory, contain data that might be similar to the kinds of things we want to search for when finding candidate models. We might prefer to use semantic search over an exact match because the terms we use might be different, or there is a related concept/model that might be close enough to make it worthwhile for fine-tuning. # # # First, we import the `SentenceTransformer` class and some util functions. # In[13]: from sentence_transformers import SentenceTransformer, util # We'll now download an embedding model. There are many we could choose from but since we're just trying things out at the moment we won't stress about the particular model we use here. # In[14]: model = SentenceTransformer('all-MiniLM-L6-v2') # Let's start on longer README's, here i mean a long readme that is just not super short... # In[15]: ds_longer_readmes = ds.filter(lambda x: x['readme_len']>100) # We now create embeddings for the `readme` column and store this in a new `embedding` column # In[16]: def encode_readme(readme): return model.encode(readme,device='cuda') # In[17]: ds_with_embeddings = ds_longer_readmes.map(lambda example: {"embedding":encode_readme(example['readme'])},batched=True, batch_size=16) # In[18]: ds_with_embeddings # We can now use the `add_fais_index` to create an index which allows us to efficiently query these embeddings # In[19]: ds_with_embeddings.add_faiss_index(column='embedding') # ### Similar models # # To start, we'll take a readme for a model and see how well the model performs on finding similar models. # In[31]: query_readme = ds_with_embeddings[35]['readme'] # In[32]: print(query_readme) # We pass this README into the model we used to create our embedding. This creates a query embedding for this README. # In[33]: q = model.encode(query_readme) # We can use `get_nearest_examples` to look for the most similar results to this query. # In[34]: scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embedding', q, k=10) # Let's take a look at the first result # In[38]: print(retrieved_examples['modelId'][0]) # In[39]: print(retrieved_examples["readme"][0]) # and a lower similarity result # In[42]: print(retrieved_examples["readme"][9]) # The results seem pretty reasonable; the first result appears to be a duplicate. The lower result is for a slightly different task using social media data. # ### Searching # # Being able to find similar model cards is a start but we actually wanted to be able to search directly. Let's see how our results do if we instead search for some terms we might use to try and find suitable models. # In[43]: q = model.encode("fake news") # In[44]: scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embedding', q, k=10) # In[45]: print(retrieved_examples["readme"][0]) # In[46]: print(retrieved_examples["readme"][1]) # In[47]: print(retrieved_examples["readme"][2]) # Not a bad start. Let's try another one # In[58]: q = model.encode("financial sentiment") scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embedding', q, k=10) print(retrieved_examples["readme"][0]) # In[59]: print(retrieved_examples["readme"][1]) # In[61]: print(retrieved_examples["readme"][9]) # These seem like a good starting point. However, we have a few issues relying on model cards alone. Firstly a lot of models don't include them and the quality of them can be mixed. It's maybe a question if we want to use a model that has no model card at all but it is possible that despite a good model card we don't capture everything we'd need for searching in the README. # ## Can we search using model labels? # # We're only working with classification models in this case. For most Pytorch models on the hub, we have a config file. This config usually contains the model's labels. For example, 'positive', 'negative'. # # Maybe instead of relying only on the metadata, we can search 'inside' the model. The labels will often be a helpful reflection of what we're looking for. For example, we want to find a sentiment classification model that roughly puts text into positive or negative sentiment. Again, relying on exact label matches may not work well, but maybe embeddings get around this problem. Let's try it out! # # Let's look at an example label. # In[62]: ds[0]['label'] # Since we're expecting labels to match this format lets filter out any that don't fit this structure. # In[63]: ds = ds.filter(lambda example: isinstance(example['label'],list)) # ### How to create embeddings for our labels? # # How should we encode our labels? At the moment, we have a list of labels. One option would be to create an embedding for every single label, which will require us to query multiple embeddings to check for a match. We may also prefer intuatively to have an embedding for the combination of labels. This is because we probably know more about the model type from all its labels rather than looking at one label at a time. We'll deal with the labels very crudely by joining them on `,` and creating a single string out of all the labels. I'm sure this isn't the best possible approach, but it might be a good place to start testing this idea. # In[64]: ds = ds.map(lambda example: {"string_label": ",".join(example['label'])}) # In[65]: ds # In[66]: ds_with_embeddings = ds.map(lambda example: {"label_embedding":encode_readme(example['string_label'])},batched=True, batch_size=16) # In[67]: ds_with_embeddings # ### Searching with labels # # Now we have some embeddings for the labels, let's try searching. Let's start with an existing set of labels to see how well we can match those. # In[68]: ds_with_embeddings[0]['string_label'] # In[69]: q = model.encode("negative") # In[70]: ds_with_embeddings.add_faiss_index(column='label_embedding') # In[71]: scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10) # In[72]: retrieved_examples['label'][:10] # So far, these results look pretty good, although we haven't done anything we couldn't do with simple string matching. Let's see what happens if we use a slightly more abstract search. # In[73]: q = model.encode("music") # In[74]: scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10) # In[75]: retrieved_examples['label'][:10] # We can see that we get back labels related to music genre: `['Dance', 'Heavy Metal', 'Hip Hop', 'Indie', 'Pop', 'Rock']`, for our first four results. After that, we get back `['business', 'entertainment', 'sports'],` which might not be too far off what we want if we searched for music. # # How about another search term # In[76]: q = model.encode("hateful") # In[77]: scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10) # In[78]: retrieved_examples['label'][:10] # Again here we have something quite close to what we'd get with string matching, but we have a bit more flexibility in how we spell/define our labels which might help surface more possible results. # # We'll try a bunch more things... # In[79]: def query_labels(query:str): q = model.encode(query) scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10) print(f"results for: {query}") print(list(zip(retrieved_examples['label'][:10],retrieved_examples['modelId'][:10]))) # In[80]: query_labels("politics") # In[84]: query_labels("fiction, non_fiction") # Let's try the set of emotions one should feel everyday. # In[87]: query_labels("worry, disgust, anxiety, fear") # This example of searching for a set of labels might be a better approach in general since the query will better match the format of the intitial search. # ## Conclusion # # It seems like there is some merit in exploring some of these ideas further. There are a lot of improvements that could be made: # - how the embeddings are created # - removing some 'noise' from the README, for example, by first parsing the Markdown # - improving how the embeddings are created for the labels # - combining the embeddings in some way either upfront or when queryig # - a bunch of other things... # # If I find some spare time, I plan to dig into these topics a bit further. This is also a nice excuse to play with one of the new open source embedding databases that have popped up in the last couple of years. #