There are 5 color types in MTG. Each card can be of multiple color types. Each color has a mechanical identity and flavor, refelcted in the art. For example, blue cards tend to be blue and feature wizards and spells. Black cards on the other hand tend to be grey and depict evil monsters.
Our self imposed challenge is to train a classifier predicting a card's color from it's art.
To simplify things, let's only consider monotype cards, restricting each sample to a single class.
I've used the Scryfall API to fetch all card art. In total we have ~26K training samples (2.3 GB). I uploaded the dataset to kaggle if you want to take a stab at it yourself.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
tf.get_logger().setLevel('ERROR')
%matplotlib inline
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 128
VAL_SPLIT = 0.2
WUBRG = list("WUBRG")
N_CLASSES = len(WUBRG)
def wubrg2name(cls: str):
return {
"W": "white",
"U": "blue",
"B": "black",
"R": "red",
"G": "green"
}[cls]
# load the images
train, test = keras.utils.image_dataset_from_directory(
"data/art_by_color",
label_mode="categorical",
image_size=IMAGE_SIZE,
shuffle=True,
batch_size=BATCH_SIZE,
validation_split=VAL_SPLIT,
subset="both",
seed=1234,
class_names=WUBRG
)
imgs, labels = next(iter(train.take(1)))
fig, axs = plt.subplots(3, 3, figsize=(12, 8))
plt.subplots_adjust(wspace=0.05, hspace=0.3)
for i, ax in enumerate(axs.flat):
ax.imshow(imgs[i].numpy().astype("uint8"), aspect="auto")
color = np.argwhere(labels[i])[0][0]
ax.set_title(wubrg2name(WUBRG[color]))
ax.axis('off')
Found 26360 files belonging to 5 classes. Using 21088 files for training. Using 5272 files for validation.
Looks like Scryfall did a good job of cropping the images. I don't see any borders, which would have lead to leakage in our training data.
Let's start with a relatively straightforward CNN architecture:
def make_basic_cnn(include_top: bool = True):
return keras.Sequential([
keras.layers.Rescaling(1./255, input_shape=(*IMAGE_SIZE, 3)),
keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Flatten(),
keras.layers.Dense(
units=128,
activation='relu',
kernel_regularizer=keras.regularizers.L2(l2=1e-2),
bias_regularizer=keras.regularizers.L2(l2=1e-2)
),
keras.layers.Dropout(0.5),
keras.layers.Dense(5) if include_top else keras.layers.Identity(),
], name="basic_cnn")
basic_cnn = make_basic_cnn()
basic_cnn.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=["categorical_accuracy"]
)
basic_cnn.summary()
Model: "basic_cnn" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling_7 (Rescaling) (None, 224, 224, 3) 0 conv2d_21 (Conv2D) (None, 224, 224, 16) 448 max_pooling2d_21 (MaxPooli (None, 112, 112, 16) 0 ng2D) conv2d_22 (Conv2D) (None, 112, 112, 32) 4640 max_pooling2d_22 (MaxPooli (None, 56, 56, 32) 0 ng2D) conv2d_23 (Conv2D) (None, 56, 56, 64) 18496 max_pooling2d_23 (MaxPooli (None, 28, 28, 64) 0 ng2D) flatten_7 (Flatten) (None, 50176) 0 dense_16 (Dense) (None, 128) 6422656 dropout_9 (Dropout) (None, 128) 0 dense_17 (Dense) (None, 5) 645 ================================================================= Total params: 6446885 (24.59 MB) Trainable params: 6446885 (24.59 MB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
mc = tf.keras.callbacks.ModelCheckpoint("models/cnn.keras", save_best_only=True, save_weights_only=True)
basic_cnn_history = basic_cnn.fit(train, validation_data=test, epochs=16, callbacks=[mc])
Epoch 1/16 330/330 [==============================] - 24s 68ms/step - loss: 1.6867 - categorical_accuracy: 0.3515 - val_loss: 1.4909 - val_categorical_accuracy: 0.4346 Epoch 2/16 330/330 [==============================] - 23s 67ms/step - loss: 1.5116 - categorical_accuracy: 0.4183 - val_loss: 1.4684 - val_categorical_accuracy: 0.4558 Epoch 3/16 330/330 [==============================] - 23s 68ms/step - loss: 1.4910 - categorical_accuracy: 0.4320 - val_loss: 1.4323 - val_categorical_accuracy: 0.4693 Epoch 4/16 330/330 [==============================] - 23s 67ms/step - loss: 1.4786 - categorical_accuracy: 0.4421 - val_loss: 1.4644 - val_categorical_accuracy: 0.4461 Epoch 5/16 330/330 [==============================] - 23s 69ms/step - loss: 1.4714 - categorical_accuracy: 0.4407 - val_loss: 1.4449 - val_categorical_accuracy: 0.4623 Epoch 6/16 330/330 [==============================] - 23s 69ms/step - loss: 1.4669 - categorical_accuracy: 0.4454 - val_loss: 1.4276 - val_categorical_accuracy: 0.4651 Epoch 7/16 330/330 [==============================] - 23s 69ms/step - loss: 1.4589 - categorical_accuracy: 0.4457 - val_loss: 1.4289 - val_categorical_accuracy: 0.4678 Epoch 8/16 330/330 [==============================] - 22s 67ms/step - loss: 1.4683 - categorical_accuracy: 0.4487 - val_loss: 1.4486 - val_categorical_accuracy: 0.4685 Epoch 9/16 330/330 [==============================] - 22s 67ms/step - loss: 1.4625 - categorical_accuracy: 0.4527 - val_loss: 1.4299 - val_categorical_accuracy: 0.4820 Epoch 10/16 330/330 [==============================] - 23s 67ms/step - loss: 1.4661 - categorical_accuracy: 0.4559 - val_loss: 1.4360 - val_categorical_accuracy: 0.4788 Epoch 11/16 330/330 [==============================] - 22s 67ms/step - loss: 1.4654 - categorical_accuracy: 0.4579 - val_loss: 1.4607 - val_categorical_accuracy: 0.4569 Epoch 12/16 330/330 [==============================] - 23s 67ms/step - loss: 1.4592 - categorical_accuracy: 0.4613 - val_loss: 1.4076 - val_categorical_accuracy: 0.4881 Epoch 13/16 330/330 [==============================] - 22s 67ms/step - loss: 1.4683 - categorical_accuracy: 0.4645 - val_loss: 1.4412 - val_categorical_accuracy: 0.4837 Epoch 14/16 330/330 [==============================] - 22s 67ms/step - loss: 1.4616 - categorical_accuracy: 0.4656 - val_loss: 1.4545 - val_categorical_accuracy: 0.4729 Epoch 15/16 330/330 [==============================] - 23s 67ms/step - loss: 1.4640 - categorical_accuracy: 0.4668 - val_loss: 1.4519 - val_categorical_accuracy: 0.4795 Epoch 16/16 330/330 [==============================] - 22s 67ms/step - loss: 1.4630 - categorical_accuracy: 0.4699 - val_loss: 1.4691 - val_categorical_accuracy: 0.4695
basic_cnn.load_weights("models/cnn.keras")
def plt_model(history):
plt.plot(history.history['loss'], label="loss")
plt.plot(history.history['val_loss'], label="val_loss")
plt.xlabel("Epoch")
plt.ylabel("CategoricalCrossentropy")
plt.legend()
plt.show()
n_epochs = len(history.history['categorical_accuracy'])
plt.plot(history.history['categorical_accuracy'], label="accuracy")
plt.plot(history.history['val_categorical_accuracy'], label="val_accuracy")
plt.plot([0, n_epochs], [0.2, 0.2], linestyle='dashed', label="baseline")
plt.xlabel("Epoch")
plt.xlabel("CategoricalAccuracy")
plt.legend()
plt_model(basic_cnn_history)
Looks like got a good regression plateuing at around 48% accuracy. Not terrible but not great. There is a fair amount of irreducible error in our dataset. Many cards have an ambiguous classification, even to the human eye.
Let's plot a confusion matrix to see where the model is struggling:
from sklearn.metrics import ConfusionMatrixDisplay
def plt_confusion(model, test_data):
y_test_pred_labels = tf.argmax(tf.nn.softmax(model.predict(test_data)), axis=1).numpy()
y_test_true = np.argmax(np.vstack([y for _, y in test_data]), axis=1)
return ConfusionMatrixDisplay.from_predictions(
y_test_true,
y_test_pred_labels,
display_labels=["white", "blue", "black", "red", "green"]
)
plt_confusion(basic_cnn, test)
83/83 [==============================] - 3s 38ms/step
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7f2fb14f4d30>
Looks like our model struggles with mislabelling cards as white. This isn't surprising. White card art is far less thematically consistant than the other color types. A quick scan of the dataset yields these cards all labelled as white:
Our model correctly recognizes that the literal color of the art correlates with its label. The model isn't complex enough to recognize objects and associate them with colors (e.g. black cads are more likely to have snakes).
Since we don't have enough data to support such a model, we can fine tune an existing image classifier on our dataset. I'll use ResNet50 as a base, freeze it's layers, add an extra dense layer, and concatenate it's output with our prior model.
def freeze_model(m: keras.Model):
for layer in m.layers:
layer.trainable = False
return m
input_layer = keras.layers.Input((*IMAGE_SIZE, 3))
resnet = freeze_model(keras.applications.ResNet50(include_top=False, weights="imagenet"))
rnx = resnet(input_layer)
rnx = keras.layers.GlobalAveragePooling2D()(rnx)
rnx = keras.layers.Dense(
units=256,
kernel_regularizer=keras.regularizers.L2(1e-5),
bias_regularizer=keras.regularizers.L2(1e-5),
activation="relu",
)(rnx)
rnx = keras.layers.Dropout(0.5)(rnx)
mtg_net = make_basic_cnn(include_top=False)
mtgx = mtg_net(input_layer)
xs = keras.layers.concatenate([rnx, mtgx])
output_layer = keras.layers.Dense(5)(xs)
fine_tuned_resnet = keras.Model(input_layer, output_layer)
fine_tuned_resnet.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=["categorical_accuracy"]
)
fine_tuned_resnet.summary()
tf.keras.utils.plot_model(fine_tuned_resnet)
Model: "model_3" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_7 (InputLayer) [(None, 224, 224, 3)] 0 [] resnet50 (Functional) (None, None, None, 2048) 2358771 ['input_7[0][0]'] 2 global_average_pooling2d_3 (None, 2048) 0 ['resnet50[0][0]'] (GlobalAveragePooling2D) dense_21 (Dense) (None, 256) 524544 ['global_average_pooling2d_3[0 ][0]'] dropout_12 (Dropout) (None, 256) 0 ['dense_21[0][0]'] basic_cnn (Sequential) (None, 128) 6446240 ['input_7[0][0]'] concatenate_3 (Concatenate (None, 384) 0 ['dropout_12[0][0]', ) 'basic_cnn[0][0]'] dense_23 (Dense) (None, 5) 1925 ['concatenate_3[0][0]'] ================================================================================================== Total params: 30560421 (116.58 MB) Trainable params: 6972709 (26.60 MB) Non-trainable params: 23587712 (89.98 MB) __________________________________________________________________________________________________
fine_tuned_resnet_history = fine_tuned_resnet.fit(
train,
validation_data=test,
epochs=12,
callbacks=[
keras.callbacks.ModelCheckpoint("models/ftrn.keras", save_best_only=True, save_weights_only=True)
]
)
fine_tuned_resnet.load_weights("models/ftrn.keras")
Epoch 1/12 330/330 [==============================] - 78s 226ms/step - loss: 1.4985 - categorical_accuracy: 0.4300 - val_loss: 1.3031 - val_categorical_accuracy: 0.5061 Epoch 2/12 330/330 [==============================] - 74s 225ms/step - loss: 1.3243 - categorical_accuracy: 0.4825 - val_loss: 1.2732 - val_categorical_accuracy: 0.5070 Epoch 3/12 330/330 [==============================] - 74s 224ms/step - loss: 1.2929 - categorical_accuracy: 0.4951 - val_loss: 1.2568 - val_categorical_accuracy: 0.5220 Epoch 4/12 330/330 [==============================] - 74s 224ms/step - loss: 1.2689 - categorical_accuracy: 0.5070 - val_loss: 1.2493 - val_categorical_accuracy: 0.5176 Epoch 5/12 330/330 [==============================] - 74s 224ms/step - loss: 1.2494 - categorical_accuracy: 0.5186 - val_loss: 1.2493 - val_categorical_accuracy: 0.5220 Epoch 6/12 330/330 [==============================] - 74s 225ms/step - loss: 1.2324 - categorical_accuracy: 0.5189 - val_loss: 1.2332 - val_categorical_accuracy: 0.5285 Epoch 7/12 330/330 [==============================] - 74s 223ms/step - loss: 1.2128 - categorical_accuracy: 0.5285 - val_loss: 1.2389 - val_categorical_accuracy: 0.5239 Epoch 8/12 330/330 [==============================] - 74s 224ms/step - loss: 1.2013 - categorical_accuracy: 0.5360 - val_loss: 1.2293 - val_categorical_accuracy: 0.5328 Epoch 9/12 330/330 [==============================] - 74s 224ms/step - loss: 1.1940 - categorical_accuracy: 0.5412 - val_loss: 1.2444 - val_categorical_accuracy: 0.5190 Epoch 10/12 330/330 [==============================] - 74s 224ms/step - loss: 1.1824 - categorical_accuracy: 0.5441 - val_loss: 1.2274 - val_categorical_accuracy: 0.5338 Epoch 11/12 330/330 [==============================] - 74s 223ms/step - loss: 1.1709 - categorical_accuracy: 0.5481 - val_loss: 1.2328 - val_categorical_accuracy: 0.5351 Epoch 12/12 330/330 [==============================] - 74s 223ms/step - loss: 1.1633 - categorical_accuracy: 0.5549 - val_loss: 1.2310 - val_categorical_accuracy: 0.5281
plt_model(fine_tuned_resnet_history)
plt_confusion(fine_tuned_resnet, test)
83/83 [==============================] - 13s 156ms/step
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7f2fb3e52520>
Looks like the improvement was small but noticeable--squeezing out an extra 5% accuracy. There might be some hypterparameter optimizations I've left on the table but without a beefier GPU I just don't have the patience to tune them. Number of epochs, unfreezing the final ResNet layers, strengthening regularizers, experimenting with alternate base models all might yield a better model.
We are approaching the upper bound on model accuracy. Due to the irreducible error present in the data, I suspect a "perfect" classifier will cap out around 60%. For a lot of cards, the classification is just a judgement call. With that in mind, our model is pretty good!
Now for the fun part: let's try the models on some classic works of art!
import os
image_dir = "data/test_images"
def classify(model, image):
prediction = tf.nn.softmax(model.predict(np.array([image])))
index = tf.argmax(prediction, axis=1).numpy()[0]
return wubrg2name(WUBRG[index]), prediction[0][index]
filenames = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith(".jpg")]
fig, axs = plt.subplots(3, 3, figsize=(12, 8))
plt.subplots_adjust(wspace=0.05, hspace=0.5)
for filename, ax in zip(filenames, axs.flat):
img = tf.keras.utils.img_to_array(
tf.keras.utils.load_img(filename, target_size=IMAGE_SIZE)
).astype("uint8")
ax.imshow(img)
bcnn_class, bcnn_conf = classify(basic_cnn, img)
ftrn_class, frtn_conf = classify(fine_tuned_resnet, img)
ax.set_title(f"basic_cnn: {bcnn_class} ({(bcnn_conf*100):.0f})%\nft_resnet: {ftrn_class} ({(frtn_conf*100):.0f})%")
ax.axis('off')
1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 15ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 15ms/step 1/1 [==============================] - 0s 12ms/step 1/1 [==============================] - 0s 15ms/step 1/1 [==============================] - 0s 12ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 12ms/step 1/1 [==============================] - 0s 15ms/step 1/1 [==============================] - 0s 12ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 12ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 12ms/step 1/1 [==============================] - 0s 15ms/step
The discrepancies in the model predictions are very interesting. The basic CNN tends to classify based on the literal color of the images. The ResNet based model has a more semantic interpretation of the images. Super cool!
Overall I'm happy with the results and I'd like to do more MTG ML work on my blog. I'd like to train a "card2vec" model รก la word2vec. Using decklists as a corpus, cards which synergize well together would have a high cosine similarity.