#!/usr/bin/env python # coding: utf-8 # # Magic the Gathering Card Art Color Classifier # ### *Fine tuning ResNet50 to classify MTG cards into mana colors* # There are 5 color types in MTG. Each card can be of multiple color types. Each color has a mechanical identity and flavor, refelcted in the art. For example, blue cards tend to be blue and feature wizards and spells. Black cards on the other hand tend to be grey and depict evil monsters. # # Our self imposed challenge is to __train a classifier predicting a card's color from it's art__. # # To simplify things, let's only consider monotype cards, restricting each sample to a single class. # # I've used the [Scryfall API](https://scryfall.com/docs/api) to fetch all card art. In total we have ~26K training samples (2.3 GB). I [uploaded the dataset to kaggle](https://www.kaggle.com/datasets/kassouni/mtg-monotype-card-art) if you want to take a stab at it yourself. # In[1]: import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf from tensorflow import keras import numpy as np import matplotlib.pyplot as plt tf.get_logger().setLevel('ERROR') get_ipython().run_line_magic('matplotlib', 'inline') # In[2]: IMAGE_SIZE = (224, 224) BATCH_SIZE = 128 VAL_SPLIT = 0.2 WUBRG = list("WUBRG") N_CLASSES = len(WUBRG) # In[3]: def wubrg2name(cls: str): return { "W": "white", "U": "blue", "B": "black", "R": "red", "G": "green" }[cls] # In[72]: # load the images train, test = keras.utils.image_dataset_from_directory( "data/art_by_color", label_mode="categorical", image_size=IMAGE_SIZE, shuffle=True, batch_size=BATCH_SIZE, validation_split=VAL_SPLIT, subset="both", seed=1234, class_names=WUBRG ) imgs, labels = next(iter(train.take(1))) fig, axs = plt.subplots(3, 3, figsize=(12, 8)) plt.subplots_adjust(wspace=0.05, hspace=0.3) for i, ax in enumerate(axs.flat): ax.imshow(imgs[i].numpy().astype("uint8"), aspect="auto") color = np.argwhere(labels[i])[0][0] ax.set_title(wubrg2name(WUBRG[color])) ax.axis('off') # Looks like Scryfall did a good job of cropping the images. I don't see any borders, which would have lead to [leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)) in our training data. # Let's start with a relatively straightforward CNN architecture: # In[30]: def make_basic_cnn(include_top: bool = True): return keras.Sequential([ keras.layers.Rescaling(1./255, input_shape=(*IMAGE_SIZE, 3)), keras.layers.Conv2D(16, 3, padding='same', activation='relu'), keras.layers.MaxPooling2D(), keras.layers.Conv2D(32, 3, padding='same', activation='relu'), keras.layers.MaxPooling2D(), keras.layers.Conv2D(64, 3, padding='same', activation='relu'), keras.layers.MaxPooling2D(), keras.layers.Flatten(), keras.layers.Dense( units=128, activation='relu', kernel_regularizer=keras.regularizers.L2(l2=1e-2), bias_regularizer=keras.regularizers.L2(l2=1e-2) ), keras.layers.Dropout(0.5), keras.layers.Dense(5) if include_top else keras.layers.Identity(), ], name="basic_cnn") basic_cnn = make_basic_cnn() basic_cnn.compile( optimizer=keras.optimizers.Adam(), loss=keras.losses.CategoricalCrossentropy(from_logits=True), metrics=["categorical_accuracy"] ) basic_cnn.summary() # In[31]: mc = tf.keras.callbacks.ModelCheckpoint("models/cnn.keras", save_best_only=True, save_weights_only=True) basic_cnn_history = basic_cnn.fit(train, validation_data=test, epochs=16, callbacks=[mc]) # In[32]: basic_cnn.load_weights("models/cnn.keras") # In[41]: def plt_model(history): plt.plot(history.history['loss'], label="loss") plt.plot(history.history['val_loss'], label="val_loss") plt.xlabel("Epoch") plt.ylabel("CategoricalCrossentropy") plt.legend() plt.show() n_epochs = len(history.history['categorical_accuracy']) plt.plot(history.history['categorical_accuracy'], label="accuracy") plt.plot(history.history['val_categorical_accuracy'], label="val_accuracy") plt.plot([0, n_epochs], [0.2, 0.2], linestyle='dashed', label="baseline") plt.xlabel("Epoch") plt.xlabel("CategoricalAccuracy") plt.legend() plt_model(basic_cnn_history) # Looks like got a good regression plateuing at around 48% accuracy. Not terrible but not great. There is a fair amount of irreducible error in our dataset. *Many* cards have an ambiguous classification, even to the human eye. # # Let's plot a confusion matrix to see where the model is struggling: # In[35]: from sklearn.metrics import ConfusionMatrixDisplay def plt_confusion(model, test_data): y_test_pred_labels = tf.argmax(tf.nn.softmax(model.predict(test_data)), axis=1).numpy() y_test_true = np.argmax(np.vstack([y for _, y in test_data]), axis=1) return ConfusionMatrixDisplay.from_predictions( y_test_true, y_test_pred_labels, display_labels=["white", "blue", "black", "red", "green"] ) plt_confusion(basic_cnn, test) # Looks like our model struggles with mislabelling cards as white. This isn't surprising. White card art is far less thematically consistant than the other color types. A quick scan of the dataset yields these cards all labelled as white: # # ![Safeguard](https://cards.scryfall.io/art_crop/front/1/f/1f170fcf-5a29-4d4f-ae19-7bb1262ebcbf.jpg?1673158895) # ![Unruly Mob](https://cards.scryfall.io/art_crop/front/2/c/2c8e174c-7abb-4a93-aa1d-8c2a2e815ba6.jpg?1562053270) # ![Sammite Blessing](https://cards.scryfall.io/art_crop/front/3/b/3b21aa2e-df36-4101-89b6-515858f2ab88.jpg?1562596391) # Our model correctly recognizes that the literal color of the art correlates with its label. The model isn't complex enough to recognize objects and associate them with colors (e.g. black cads are more likely to have snakes). # # Since we don't have enough data to support such a model, we can fine tune an existing image classifier on our dataset. I'll use ResNet50 as a base, freeze it's layers, add an extra dense layer, and concatenate it's output with our prior model. # In[37]: def freeze_model(m: keras.Model): for layer in m.layers: layer.trainable = False return m input_layer = keras.layers.Input((*IMAGE_SIZE, 3)) resnet = freeze_model(keras.applications.ResNet50(include_top=False, weights="imagenet")) rnx = resnet(input_layer) rnx = keras.layers.GlobalAveragePooling2D()(rnx) rnx = keras.layers.Dense( units=256, kernel_regularizer=keras.regularizers.L2(1e-5), bias_regularizer=keras.regularizers.L2(1e-5), activation="relu", )(rnx) rnx = keras.layers.Dropout(0.5)(rnx) mtg_net = make_basic_cnn(include_top=False) mtgx = mtg_net(input_layer) xs = keras.layers.concatenate([rnx, mtgx]) output_layer = keras.layers.Dense(5)(xs) fine_tuned_resnet = keras.Model(input_layer, output_layer) fine_tuned_resnet.compile( optimizer=keras.optimizers.Adam(), loss=keras.losses.CategoricalCrossentropy(from_logits=True), metrics=["categorical_accuracy"] ) fine_tuned_resnet.summary() tf.keras.utils.plot_model(fine_tuned_resnet) # In[38]: fine_tuned_resnet_history = fine_tuned_resnet.fit( train, validation_data=test, epochs=12, callbacks=[ keras.callbacks.ModelCheckpoint("models/ftrn.keras", save_best_only=True, save_weights_only=True) ] ) fine_tuned_resnet.load_weights("models/ftrn.keras") # In[42]: plt_model(fine_tuned_resnet_history) plt_confusion(fine_tuned_resnet, test) # Looks like the improvement was small but noticeable--squeezing out an extra 5% accuracy. There might be some hypterparameter optimizations I've left on the table but without a beefier GPU I just don't have the patience to tune them. Number of epochs, unfreezing the final ResNet layers, strengthening regularizers, experimenting with alternate base models all might yield a better model. # # We are approaching the upper bound on model accuracy. Due to the irreducible error present in the data, I suspect a "perfect" classifier will cap out around 60%. For a lot of cards, the classification is just a judgement call. With that in mind, our model is pretty good! # Now for the fun part: let's try the models on some classic works of art! # In[76]: import os image_dir = "data/test_images" def classify(model, image): prediction = tf.nn.softmax(model.predict(np.array([image]))) index = tf.argmax(prediction, axis=1).numpy()[0] return wubrg2name(WUBRG[index]), prediction[0][index] filenames = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith(".jpg")] fig, axs = plt.subplots(3, 3, figsize=(12, 8)) plt.subplots_adjust(wspace=0.05, hspace=0.5) for filename, ax in zip(filenames, axs.flat): img = tf.keras.utils.img_to_array( tf.keras.utils.load_img(filename, target_size=IMAGE_SIZE) ).astype("uint8") ax.imshow(img) bcnn_class, bcnn_conf = classify(basic_cnn, img) ftrn_class, frtn_conf = classify(fine_tuned_resnet, img) ax.set_title(f"basic_cnn: {bcnn_class} ({(bcnn_conf*100):.0f})%\nft_resnet: {ftrn_class} ({(frtn_conf*100):.0f})%") ax.axis('off') # The discrepancies in the model predictions are very interesting. The basic CNN tends to classify based on the literal color of the images. The ResNet based model has a more semantic interpretation of the images. Super cool! # # Overall I'm happy with the results and I'd like to do more MTG ML work on my blog. I'd like to train a "card2vec" model รก la word2vec. Using decklists as a corpus, cards which synergize well together would have a high cosine similarity.