This is heavily modeled on the Pytorch tutorial: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
We use fastai libraries extensively to make dataloading and training easier
This is a list of surnames and their ethnicities
#!wget https://download.pytorch.org/tutorial/data.zip
#!unzip -o data.zip
fastai import pandas and all sorts of other goodies
from fastai import *
from fastai.text import *
from unidecode import unidecode
import string
Reduce the ouput to 20 rows to prevent it from taking too much of the output.
pd.options.display.max_rows = 20
Read in the data; names for each language is in a separate file
path = Path('data/names')
!ls {path}
Arabic.txt English.txt Irish.txt Polish.txt Spanish.txt Chinese.txt French.txt Italian.txt Portuguese.txt Vietnamese.txt Czech.txt German.txt Japanese.txt Russian.txt Dutch.txt Greek.txt Korean.txt Scottish.txt
!head -n5 {path}/Arabic.txt
Khoury Nahas Daher Gerges Nazari
names = []
for p in path.glob('*.txt'):
lang = p.name[:-4]
with open(p) as f:
names += [(lang, l.strip()) for l in f]
df = pd.DataFrame(names, columns=['cl', 'name'])
It's always worth doing some sanity checks on your data (even supposedly clean tutorial data).
No matter how good your model is: garbage in, garbage out.
df.head()
cl | name | |
---|---|---|
0 | Korean | Ahn |
1 | Korean | Baik |
2 | Korean | Bang |
3 | Korean | Byon |
4 | Korean | Cha |
len(df)
20074
What letters outside of ASCII are in the names?
foreign_chars = Counter(_ for _ in ''.join(list(df.name)) if _ not in string.ascii_letters)
foreign_chars.most_common()
[(' ', 115), ("'", 87), ('-', 25), ('ö', 24), ('é', 22), ('í', 14), ('ó', 13), ('ä', 13), ('á', 12), ('ü', 11), ('à', 10), ('ß', 9), ('ú', 7), ('ñ', 6), ('ò', 3), ('Ś', 3), ('1', 3), (',', 3), ('è', 2), ('ã', 2), ('ù', 1), ('ì', 1), ('ż', 1), ('ń', 1), ('ł', 1), ('ą', 1), ('Ż', 1), ('/', 1), (':', 1), ('Á', 1), ('\xa0', 1), ('õ', 1), ('É', 1), ('ê', 1), ('ç', 1)]
A few of these look suspicious. (Note the use of a regular expression in contains
to check each of the characters)
suss_chars = [':', '/', '\xa0', ',', '1']
df[df.name.str.contains('|'.join(suss_chars))]
cl | name | |
---|---|---|
2494 | Czech | Maxa/B |
2590 | Czech | Rafaj1 |
2703 | Czech | Urbanek1 |
2732 | Czech | Whitmire1 |
3214 | Chinese | Lu: |
14506 | Russian | Jevolojnov, |
15347 | Russian | Lysansky, |
15366 | Russian | Lytkin, |
18052 | Russian | To The First Page |
Most of these look like legitimate names with extra junk (except 'To The First Page'). Since it's so few names it's easiest just to drop them.
df = df[~df.name.str.contains('|'.join(suss_chars))]
Single quotes and spaces are common
df[df.name.str.contains("'| ")]
cl | name | |
---|---|---|
369 | Italian | D'ambrosio |
371 | Italian | D'amore |
372 | Italian | D'angelo |
373 | Italian | D'antonio |
374 | Italian | De angelis |
375 | Italian | De campo |
376 | Italian | De felice |
377 | Italian | De filippis |
378 | Italian | De fiore |
379 | Italian | De laurentis |
... | ... | ... |
18161 | Russian | V'Yurkov |
19061 | Russian | Zasyad'Ko |
19740 | Portuguese | D'cruz |
19741 | Portuguese | D'cruze |
19743 | Portuguese | De santigo |
19858 | French | D'aramitz |
19864 | French | De la fontaine |
19869 | French | De sauveterre |
20051 | French | St martin |
20052 | French | St pierre |
150 rows × 2 columns
Since hyphens mainly join multiple last names (and are pretty rare) we won't lose heaps by dropping them.
df[df.name.str.contains('-')]
cl | name | |
---|---|---|
2982 | Chinese | Au-Yong |
3088 | Chinese | Ou-Yang |
3089 | Chinese | Ow-Yang |
10156 | Russian | Abdank-Kossovsky |
10639 | Russian | Amet-Han |
11221 | Russian | Bagai-Ool |
11757 | Russian | Bei-Bienko |
11787 | Russian | Beknazar-Yuzbashev |
11790 | Russian | Bekovich-Cherkassky |
11904 | Russian | Bestujev-Lada |
... | ... | ... |
11952 | Russian | Bim-Bad |
12209 | Russian | Chyrgal-Ool |
13071 | Russian | Galkin-Vraskoi |
13307 | Russian | Gorbunov-Posadov |
16687 | Russian | Porai-Koshits |
17222 | Russian | Shah-Nazaroff |
17430 | Russian | Shirinsky-Shikhmatov |
17748 | Russian | Tsann-Kay-Si |
17999 | Russian | Tzann-Kay-Si |
18315 | Russian | Van-Puteren |
23 rows × 2 columns
df = df[~df.name.str.contains('-')]
Let's normalise all non-ASCII characters to ASCII equivalents.
This makes our classification problem harder in practice: any names containing a ß are almost surely German, wheras "ss" could occur in many language. It also reduces the set of characters we need to represent our language.
df['ascii_name'] = df.name.apply(unidecode)
df[df.name != df.ascii_name]
cl | name | ascii_name | |
---|---|---|---|
100 | Italian | Abbà | Abba |
112 | Italian | Abelló | Abello |
160 | Italian | Airò | Airo |
195 | Italian | Alò | Alo |
238 | Italian | Azzarà | Azzara |
300 | Italian | Bovér | Bover |
445 | Italian | Giùgovaz | Giugovaz |
461 | Italian | Làconi | Laconi |
462 | Italian | Laganà | Lagana |
463 | Italian | Lagomarsìno | Lagomarsino |
... | ... | ... | ... |
19912 | French | Géroux | Geroux |
19920 | French | Guérin | Guerin |
19924 | French | Hébert | Hebert |
19949 | French | Lécuyer | Lecuyer |
19951 | French | Lefévre | Lefevre |
19955 | French | Lémieux | Lemieux |
19960 | French | Lévêque | Leveque |
19961 | French | Lévesque | Levesque |
19965 | French | Maçon | Macon |
20047 | French | Séverin | Severin |
156 rows × 3 columns
Let's check case: I expect names to be in CamelCase.
These seem to be mistakes.
df[~df.ascii_name.str.contains("^[A-Z][^A-Z]*(?:[' -][A-Z][^A-Z]*)*$")]
cl | name | ascii_name | |
---|---|---|---|
2508 | Czech | MonkoAustria | MonkoAustria |
2677 | Czech | StrakaO | StrakaO |
3266 | Vietnamese | an | an |
df = df[df.ascii_name.str.contains("^[A-Z][^A-Z]*(?:[' -][A-Z][^A-Z]*)*$")]
Let's lowercase the ascii_names
df['ascii_name'] = df.ascii_name.str.lower()
Make a check we've normalised correctly.
ascii_chars = Counter(''.join(list(df.ascii_name)))
ascii_chars.most_common()
[('a', 16511), ('o', 11120), ('e', 10768), ('i', 10416), ('n', 9943), ('r', 8245), ('s', 7980), ('h', 7673), ('k', 6902), ('l', 6704), ('v', 6301), ('t', 5939), ('u', 4725), ('m', 4343), ('d', 3894), ('b', 3641), ('y', 3604), ('g', 3209), ('c', 3068), ('z', 1928), ('f', 1774), ('p', 1707), ('j', 1346), ('w', 1125), (' ', 112), ('q', 98), ("'", 87), ('x', 72)]
In practice a surname could have multiple ethnicities, but we'd have to be really careful of how we use this in training.
If we end up with e.g. 'Michel' as French in the training dataset, but German in the validation set our model has no hope of getting it right (and we may discard an actually good model).
We could handle this by:
Without any information about frequency we can't do (2) and (1) is a harder problem, so we'll stick to (3).
name_classes = df.\
groupby('ascii_name').\
nunique().cl.sort_values(ascending=False)
name_classes.head(20)
ascii_name michel 6 adam 5 albert 5 abel 5 martin 5 simon 5 ventura 4 costa 4 jordan 4 han 4 salomon 4 samuel 4 klein 4 franco 4 wang 4 oliver 4 garcia 3 horn 3 lim 3 rose 3 Name: cl, dtype: int64
df[df.name == 'Michel']
cl | name | ascii_name | |
---|---|---|---|
872 | Polish | Michel | michel |
2077 | Dutch | Michel | michel |
3489 | Spanish | Michel | michel |
6163 | German | Michel | michel |
8709 | English | Michel | michel |
19978 | French | Michel | michel |
1 in 40 of our names have multiple classes (most of them do before normalisation too)
len(name_classes), sum(name_classes > 1) / len(name_classes)
(17380, 0.027445339470655927)
While some names like Abel do seem to occur commonly in multiple countries, for example:
It seems like Korean and Chinese have a lot of overlap, as to English and Scottish. While this makes some linguistic sense it will make it hard to make a reliable classifier.
Note that most names only occur once; so we can't pick a "most common" frequency class.
with pd.option_context('display.max_rows', 60):
print(df[df.ascii_name.isin(name_classes[name_classes > 1].index)].groupby(['ascii_name', 'cl']).count())
name ascii_name cl abel English 1 French 1 German 1 Russian 1 Spanish 1 abello Italian 1 Spanish 1 abraham English 1 French 1 abreu Portuguese 1 Spanish 1 adam English 1 French 1 German 1 Irish 1 Russian 1 adams English 1 Russian 1 adamson English 1 Russian 1 adler English 1 German 1 Russian 1 aitken English 1 Scottish 1 albert English 1 French 1 German 1 Russian 1 Spanish 1 ... ... wilson English 1 Scottish 1 winter English 1 German 1 wolf English 1 German 1 wong Chinese 1 English 1 woo Chinese 1 Korean 1 wood Czech 1 English 1 Scottish 1 wright English 1 Scottish 1 yan Chinese 2 Russian 1 yang Chinese 1 English 1 Korean 1 yim Chinese 1 Korean 1 you Chinese 1 Korean 1 young English 1 Scottish 1 yun Chinese 1 Korean 1 zambrano Italian 1 Spanish 1 [1051 rows x 1 columns]
Rather than finding the "right" ethnicity the easy thing to do is to remove all ambiguous cases.
df = df[~df.ascii_name.isin(name_classes[name_classes > 1].index)]
We need exactly one row per pair; if separate copies appear in the training and validation set we'll get a higher validation accuracy than is reasonable.
Some names occur very frequently.
counts = df.assign(n=1).groupby(['ascii_name', 'cl']).count().sort_values('n', ascending=False)
counts.head(n=20)
name | n | ||
---|---|---|---|
ascii_name | cl | ||
tahan | Arabic | 28 | 28 |
fakhoury | Arabic | 28 | 28 |
koury | Arabic | 27 | 27 |
nader | Arabic | 27 | 27 |
sarraf | Arabic | 26 | 26 |
hadad | Arabic | 26 | 26 |
kassis | Arabic | 26 | 26 |
antar | Arabic | 26 | 26 |
shadid | Arabic | 25 | 25 |
cham | Arabic | 25 | 25 |
mifsud | Arabic | 25 | 25 |
nahas | Arabic | 24 | 24 |
gerges | Arabic | 24 | 24 |
ganim | Arabic | 23 | 23 |
tuma | Arabic | 23 | 23 |
to the first page | Russian | 23 | 23 |
atiyeh | Arabic | 23 | 23 |
malouf | Arabic | 23 | 23 |
sayegh | Arabic | 22 | 22 |
naifeh | Arabic | 22 | 22 |
Let's remove the "To The First Page" junk (probably some artifact of where the data was scraped from)
df = df[df.ascii_name != 'to the first page']
There are no multiples in English, and a lot in Arabic. It seems like a data entry error rather than meaningful.
counts.assign(multiple=counts.n > 1, rows=1).groupby('cl').sum().sort_values('n', ascending=False)
name | n | multiple | rows | |
---|---|---|---|---|
cl | ||||
Russian | 9326 | 9326 | 35.0 | 9263 |
English | 3359 | 3359 | 0.0 | 3359 |
Arabic | 1892 | 1892 | 103.0 | 103 |
Japanese | 983 | 983 | 1.0 | 982 |
Italian | 665 | 665 | 5.0 | 660 |
German | 613 | 613 | 33.0 | 578 |
Czech | 480 | 480 | 16.0 | 464 |
Dutch | 255 | 255 | 10.0 | 244 |
Chinese | 219 | 219 | 19.0 | 200 |
Spanish | 214 | 214 | 2.0 | 212 |
French | 213 | 213 | 3.0 | 210 |
Greek | 195 | 195 | 2.0 | 192 |
Irish | 170 | 170 | 6.0 | 164 |
Polish | 124 | 124 | 1.0 | 123 |
Korean | 61 | 61 | 0.0 | 61 |
Vietnamese | 56 | 56 | 1.0 | 55 |
Portuguese | 32 | 32 | 0.0 | 32 |
Scottish | 1 | 1 | 0.0 | 1 |
It makes sense to drop the duplicates and only have a single row per ascii_name
and cl
.
df = df.drop_duplicates(['ascii_name', 'cl'])
len(df)
16902
It's worth checking if the shortest and longest names make sense.
They look reasonable.
df.assign(len=df.name.str.len()).sort_values('len')
cl | name | ascii_name | len | |
---|---|---|---|---|
3265 | Vietnamese | An | an | 2 |
50 | Korean | Oh | oh | 2 |
1150 | Japanese | Ii | ii | 2 |
54 | Korean | Ra | ra | 2 |
3891 | Arabic | Ba | ba | 2 |
57 | Korean | Ri | ri | 2 |
69 | Korean | Si | si | 2 |
71 | Korean | So | so | 2 |
3311 | Vietnamese | To | to | 2 |
85 | Korean | Yi | yi | 2 |
... | ... | ... | ... | ... |
11475 | Russian | Bakhtchivandzhi | bakhtchivandzhi | 15 |
10191 | Russian | Abdulladzhanoff | abdulladzhanoff | 15 |
17299 | Russian | Shakhnazaryants | shakhnazaryants | 15 |
11393 | Russian | Baistryutchenko | baistryutchenko | 15 |
14965 | Russian | Katzenellenbogen | katzenellenbogen | 16 |
2228 | Dutch | Vandroogenbroeck | vandroogenbroeck | 16 |
14947 | Russian | Katsenellenbogen | katsenellenbogen | 16 |
19552 | Greek | Chrysanthopoulos | chrysanthopoulos | 16 |
2841 | Irish | Maceachthighearna | maceachthighearna | 17 |
6380 | German | Von grimmelshausen | von grimmelshausen | 18 |
16902 rows × 4 columns
The dataset is very unbalanced.
I doubt there's enough data to tacke Portuguese (which will be close to Spanish) and Scottish (which will be close to English)
df.groupby('cl').name.count().sort_values(ascending=False)
cl Russian 9262 English 3359 Japanese 982 Italian 660 German 578 Czech 464 Dutch 244 Spanish 212 French 210 Chinese 200 Greek 192 Irish 164 Polish 123 Arabic 103 Korean 61 Vietnamese 55 Portuguese 32 Scottish 1 Name: name, dtype: int64
df[df.cl.isin(['Scottish'])]
cl | name | ascii_name | |
---|---|---|---|
3711 | Scottish | Hay | hay |
Let's remove the rarest classes; we're not likely to have enough data to guess them.
df = df[~df.cl.isin(['Scottish', 'Portuguese'])]
Note Russian contains variant transliterations to English like Abaimoff and Abaimov (which both correspond to Абаимов).
But this doesn't quite explain it's high frequency: it seems a lot more Russian data was extracted.
(Side note: Chebyshev can also be spelt e.g. Chebychev, Tchebycheff, Tschebyschef)
df[df.cl == 'Russian']
cl | name | ascii_name | |
---|---|---|---|
10112 | Russian | Ababko | ababko |
10113 | Russian | Abaev | abaev |
10114 | Russian | Abagyan | abagyan |
10115 | Russian | Abaidulin | abaidulin |
10116 | Russian | Abaidullin | abaidullin |
10117 | Russian | Abaimoff | abaimoff |
10118 | Russian | Abaimov | abaimov |
10119 | Russian | Abakeliya | abakeliya |
10120 | Russian | Abakovsky | abakovsky |
10121 | Russian | Abakshin | abakshin |
... | ... | ... | ... |
19510 | Russian | Zolotavin | zolotavin |
19511 | Russian | Zolotdinov | zolotdinov |
19512 | Russian | Zolotenkov | zolotenkov |
19513 | Russian | Zolotilin | zolotilin |
19514 | Russian | Zolotkov | zolotkov |
19515 | Russian | Zolotnitsky | zolotnitsky |
19516 | Russian | Zolotnitzky | zolotnitzky |
19517 | Russian | Zozrov | zozrov |
19518 | Russian | Zozulya | zozulya |
19519 | Russian | Zukerman | zukerman |
9262 rows × 3 columns
We want our final model to work well on any language.
But if we pick our validation set uniformly at random from the data we're likely to get many Russian names and not many Vietnamese names, which isn't a good test of this.
So instead we'll take our validation set from an equal number from each subclass.
df = df.reset_index().drop('index', 1)
df
cl | name | ascii_name | |
---|---|---|---|
0 | Korean | Ahn | ahn |
1 | Korean | Baik | baik |
2 | Korean | Bang | bang |
3 | Korean | Byon | byon |
4 | Korean | Cha | cha |
5 | Korean | Cho | cho |
6 | Korean | Choe | choe |
7 | Korean | Choi | choi |
8 | Korean | Chun | chun |
9 | Korean | Chweh | chweh |
... | ... | ... | ... |
16859 | French | Travere | travere |
16860 | French | Traverse | traverse |
16861 | French | Travert | travert |
16862 | French | Tremblay | tremblay |
16863 | French | Tremble | tremble |
16864 | French | Victors | victors |
16865 | French | Villeneuve | villeneuve |
16866 | French | Vipond | vipond |
16867 | French | Voclain | voclain |
16868 | French | Yount | yount |
16869 rows × 3 columns
counts = df.groupby('cl').name.count().sort_values(ascending=False)
counts
cl Russian 9262 English 3359 Japanese 982 Italian 660 German 578 Czech 464 Dutch 244 Spanish 212 French 210 Chinese 200 Greek 192 Irish 164 Polish 123 Arabic 103 Korean 61 Vietnamese 55 Name: name, dtype: int64
valid_size = 30 # We'll pick 30 at random from each subclass
train_size = 500 # For a balanced training set we'll pick 500 at random with replacement
np.random.seed(6011)
valid_idx = []
for cl in counts.keys():
# Random sample of size "valid_size" for each class
valid_idx += list(df[df.cl == cl].sample(valid_size).index)
df['valid'] = False
df.loc[valid_idx, 'valid'] = True
Let's also create a balanced training set as an alternative to using everything not in validation
np.random.seed(7012)
balanced_idx = []
for cl in counts.keys():
# Random sample of size "train_size" for each class from the data outside of the validation set
balanced_idx += list(df[(df.cl == cl) & ~df.valid].sample(train_size, replace=True).index)
Note the balanced index contains all 25 (= 55 - 30) Vietnamese names outside of the training set, but only contains 486 of the Russian names (because we sampled randomly with replacement there will be a couple of double ups).
df.loc[balanced_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)
cl | name | ascii_name | valid | |
---|---|---|---|---|
cl | ||||
Russian | 1 | 486 | 486 | 1 |
English | 1 | 459 | 459 | 1 |
Japanese | 1 | 383 | 383 | 1 |
Italian | 1 | 357 | 357 | 1 |
German | 1 | 330 | 330 | 1 |
Czech | 1 | 295 | 295 | 1 |
Dutch | 1 | 195 | 195 | 1 |
French | 1 | 172 | 172 | 1 |
Spanish | 1 | 170 | 170 | 1 |
Chinese | 1 | 158 | 158 | 1 |
Greek | 1 | 153 | 153 | 1 |
Irish | 1 | 129 | 129 | 1 |
Polish | 1 | 93 | 93 | 1 |
Arabic | 1 | 73 | 73 | 1 |
Korean | 1 | 31 | 31 | 1 |
Vietnamese | 1 | 25 | 25 | 1 |
Let's record our balanced set in the dataframe: this will make it easy to reload at a later point.
df['bal'] = 0
for k, v in Counter(balanced_idx).items():
df.loc[k, 'bal'] += v
df.head()
cl | name | ascii_name | valid | bal | |
---|---|---|---|---|---|
0 | Korean | Ahn | ahn | False | 13 |
1 | Korean | Baik | baik | True | 0 |
2 | Korean | Bang | bang | False | 13 |
3 | Korean | Byon | byon | False | 15 |
4 | Korean | Cha | cha | True | 0 |
We can always retrieve the indexes from the dataframe
idx = []
for k, v in zip(df.index, df.bal):
idx += [k]*v
sorted(balanced_idx) == idx
True
df.to_csv('names_clean.csv', index=False)
The first benchmark is random guessing/always guessing the same class.
The expected return is 1/(number of classes) = 1/16 ~ 6.25%
df = pd.read_csv('names_clean.csv')
valid_idx = df[df.valid].index
train_idx = df[~df.valid].index
bal_idx = []
for k, v in zip(df.index, df.bal):
bal_idx += [k]*v
Check training/balanced training data doesn't contain any names in validation set
train_intersect_valid = sum(df.iloc[train_idx].ascii_name.isin(df.iloc[valid_idx].ascii_name))
bal_interset_valid = sum(df.iloc[bal_idx].ascii_name.isin(df.iloc[valid_idx].ascii_name))
train_intersect_valid, bal_interset_valid
(0, 0)
Make sure the data looks right
df.iloc[train_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)
cl | name | ascii_name | valid | bal | |
---|---|---|---|---|---|
cl | |||||
Russian | 1 | 9232 | 9232 | 1 | 3 |
English | 1 | 3329 | 3329 | 1 | 4 |
Japanese | 1 | 952 | 952 | 1 | 5 |
Italian | 1 | 630 | 630 | 1 | 5 |
German | 1 | 548 | 548 | 1 | 5 |
Czech | 1 | 434 | 434 | 1 | 6 |
Dutch | 1 | 214 | 214 | 1 | 9 |
Spanish | 1 | 182 | 182 | 1 | 10 |
French | 1 | 180 | 180 | 1 | 9 |
Chinese | 1 | 170 | 170 | 1 | 9 |
Greek | 1 | 162 | 162 | 1 | 10 |
Irish | 1 | 134 | 134 | 1 | 10 |
Polish | 1 | 93 | 93 | 1 | 11 |
Arabic | 1 | 73 | 73 | 1 | 13 |
Korean | 1 | 31 | 31 | 1 | 13 |
Vietnamese | 1 | 25 | 25 | 1 | 16 |
df.iloc[bal_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)
cl | name | ascii_name | valid | bal | |
---|---|---|---|---|---|
cl | |||||
Russian | 1 | 486 | 486 | 1 | 2 |
English | 1 | 459 | 459 | 1 | 3 |
Japanese | 1 | 383 | 383 | 1 | 4 |
Italian | 1 | 357 | 357 | 1 | 4 |
German | 1 | 330 | 330 | 1 | 4 |
Czech | 1 | 295 | 295 | 1 | 5 |
Dutch | 1 | 195 | 195 | 1 | 8 |
French | 1 | 172 | 172 | 1 | 8 |
Spanish | 1 | 170 | 170 | 1 | 9 |
Chinese | 1 | 158 | 158 | 1 | 8 |
Greek | 1 | 153 | 153 | 1 | 9 |
Irish | 1 | 129 | 129 | 1 | 9 |
Polish | 1 | 93 | 93 | 1 | 11 |
Arabic | 1 | 73 | 73 | 1 | 13 |
Korean | 1 | 31 | 31 | 1 | 13 |
Vietnamese | 1 | 25 | 25 | 1 | 16 |
df.iloc[valid_idx].groupby('cl').nunique().sort_values('ascii_name', ascending=False)
cl | name | ascii_name | valid | bal | |
---|---|---|---|---|---|
cl | |||||
Arabic | 1 | 30 | 30 | 1 | 1 |
Chinese | 1 | 30 | 30 | 1 | 1 |
Czech | 1 | 30 | 30 | 1 | 1 |
Dutch | 1 | 30 | 30 | 1 | 1 |
English | 1 | 30 | 30 | 1 | 1 |
French | 1 | 30 | 30 | 1 | 1 |
German | 1 | 30 | 30 | 1 | 1 |
Greek | 1 | 30 | 30 | 1 | 1 |
Irish | 1 | 30 | 30 | 1 | 1 |
Italian | 1 | 30 | 30 | 1 | 1 |
Japanese | 1 | 30 | 30 | 1 | 1 |
Korean | 1 | 30 | 30 | 1 | 1 |
Polish | 1 | 30 | 30 | 1 | 1 |
Russian | 1 | 30 | 30 | 1 | 1 |
Spanish | 1 | 30 | 30 | 1 | 1 |
Vietnamese | 1 | 30 | 30 | 1 | 1 |
Picking any one class in validation will give 1/16 = 6.25%
(df[df.valid] == 'Korean').cl.sum() / df.valid.sum()
0.0625
A reasonable way to guess a language is by the frequency of characters and pairs of characters.
For example 'cz' is very rare in English, but quite common in the slavic languages.
name = 'zozrov'
A function to count the occurances of sequences of one, two or three letters (in general these sequences are called "n-grams" particularly when referring to sequences of words).
def ngrams(s,n=1):
parts = [s[i:] for i in range(n)] # e.g. ['zozrov', 'ozrov', 'zrov']
return Counter(''.join(_) for _ in zip(*parts))
ngrams(name, 1), ngrams(name, 2), ngrams(name, 3)
(Counter({'z': 2, 'o': 2, 'r': 1, 'v': 1}), Counter({'zo': 1, 'oz': 1, 'zr': 1, 'ro': 1, 'ov': 1}), Counter({'zoz': 1, 'ozr': 1, 'zro': 1, 'rov': 1}))
df = df.assign(letters=df.ascii_name.apply(ngrams))
df = df.assign(bigrams=df.ascii_name.apply(ngrams, n=2))
df = df.assign(trigrams=df.ascii_name.apply(ngrams, n=3))
df.head()
cl | name | ascii_name | valid | bal | letters | bigrams | trigrams | |
---|---|---|---|---|---|---|---|---|
0 | Korean | Ahn | ahn | False | 13 | {'a': 1, 'h': 1, 'n': 1} | {'ah': 1, 'hn': 1} | {'ahn': 1} |
1 | Korean | Baik | baik | True | 0 | {'b': 1, 'a': 1, 'i': 1, 'k': 1} | {'ba': 1, 'ai': 1, 'ik': 1} | {'bai': 1, 'aik': 1} |
2 | Korean | Bang | bang | False | 13 | {'b': 1, 'a': 1, 'n': 1, 'g': 1} | {'ba': 1, 'an': 1, 'ng': 1} | {'ban': 1, 'ang': 1} |
3 | Korean | Byon | byon | False | 15 | {'b': 1, 'y': 1, 'o': 1, 'n': 1} | {'by': 1, 'yo': 1, 'on': 1} | {'byo': 1, 'yon': 1} |
4 | Korean | Cha | cha | True | 0 | {'c': 1, 'h': 1, 'a': 1} | {'ch': 1, 'ha': 1} | {'cha': 1} |
Let's try to guess the name using Naive Bayes.
TL;DR: This is a really simple model that works quite well and will give a good benchmark.
This uses "Bayes Rule" which uses the data to answer questions like: "given the name contains the bigram 'ah' what's the probability it's Korean?".
The "Naive" part means that that we assume all these probabilities are independent (knowing it contains 'ah' doesn't tell you anything about the fact it contains 'hn'). Even though this definitely isn't true, it's often a reasonable approximation.
This makes it really fast and simple to fit a model and often works well.
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction import DictVectorizer
vd1 = DictVectorizer(sparse=False)
vd2 = DictVectorizer(sparse=False)
vd3 = DictVectorizer(sparse=False)
y = df.cl
letters = vd1.fit_transform(df.letters)
bigrams = vd2.fit_transform(df.bigrams)
trigrams = vd3.fit_transform(df.trigrams)
The letters matrix contains the number of times each of the 28 letters occurs (e.g. number of spaces, number of apostrophes, number of 'a', ...).
vd1.get_feature_names()[:10]
[' ', "'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
letters
array([[0., 0., 1., 0., ..., 0., 0., 0., 0.], [0., 0., 1., 1., ..., 0., 0., 0., 0.], [0., 0., 1., 1., ..., 0., 0., 0., 0.], [0., 0., 0., 1., ..., 0., 0., 1., 0.], ..., [0., 0., 0., 0., ..., 0., 0., 0., 0.], [0., 0., 0., 0., ..., 0., 0., 0., 0.], [0., 0., 1., 0., ..., 0., 0., 0., 0.], [0., 0., 0., 0., ..., 0., 0., 1., 0.]])
Similarly bigrams and trigrams contains the number of times each sequence of 2 or 3 letters occurs
vd2.get_feature_names()[:5], vd2.get_feature_names()[-5:]
([' a', ' b', ' c', ' e', ' f'], ['zu', 'zv', 'zw', 'zy', 'zz'])
letters.shape, bigrams.shape, trigrams.shape, y.shape
((16869, 28), (16869, 623), (16869, 5794), (16869,))
How good a model can we get looking at individual letters (e.g. saying 'z' occurs much more frequently in Chinese than in English names).
letter_nb = MultinomialNB()
letter_nb.fit(letters[train_idx],y[train_idx])
bal_letter_nb = MultinomialNB()
bal_letter_nb.fit(letters[bal_idx],y[bal_idx])
MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)
The balanced set does mut better than random; around 33%
letter_pred = letter_nb.predict(letters[valid_idx])
bal_letter_pred = bal_letter_nb.predict(letters[valid_idx])
(letter_pred == y[valid_idx]).mean(), (bal_letter_pred == y[valid_idx]).mean()
(0.14791666666666667, 0.33541666666666664)
Let's write a function to test the Naive Bayes on any dataset; fitting on the whole dataset and the balanced dataset separately.
def nb(x):
model = MultinomialNB()
model.fit(x[train_idx], y[train_idx])
preds = model.predict(x[valid_idx])
acc_train = (preds == y[valid_idx]).mean()
model = MultinomialNB()
model.fit(x[bal_idx], y[bal_idx])
preds = model.predict(x[valid_idx])
acc_bal = (preds == y[valid_idx]).mean()
return acc_train, acc_bal
nb(letters)
(0.14791666666666667, 0.33541666666666664)
Using bigrams and a balanced training set gives a much better prediction performance 53% (up from the baseline of 6.25%).
nb(bigrams)
(0.35833333333333334, 0.5291666666666667)
Adding letters doesn't make much difference (which isn't surprising
nb(np.concatenate((letters, bigrams), axis=1))
(0.3854166666666667, 0.5166666666666667)
Trigrams alone also performs worse
nb(trigrams)
(0.33958333333333335, 0.4895833333333333)
Let's try every combination with trigrams:
nb(np.concatenate((letters, trigrams), axis=1))
(0.24375, 0.5083333333333333)
nb(np.concatenate((bigrams, trigrams), axis=1))
(0.36875, 0.5416666666666666)
nb(np.concatenate((letters, bigrams, trigrams), axis=1))
(0.32916666666666666, 0.55625)
None of them significantly outperform the simple bigram model (with 623 parameters; we could probably remove some of the uncommon ones without too many problems.
Let's remove the bigrams that only occur once as they have practically no value (and there's 100 of them).
common_bigrams = (bigrams[bal_idx].sum(axis=0)) >= 2
common_bigrams.sum()
503
common_bigram_index = [i for i, t in enumerate(common_bigrams) if t]
bigrams_min = bigrams[:, common_bigram_index]
bigrams_min.shape
(16869, 503)
bigram_model = MultinomialNB()
bigram_model.fit(bigrams_min[bal_idx], y[bal_idx])
MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)
We get around 53% accuracy.
bigram_pred = bigram_model.predict(bigrams_min[valid_idx])
(bigram_pred == y[valid_idx]).mean()
0.5291666666666667
bigram_prob = bigram_model.predict_proba(bigrams_min[valid_idx])
bigram_prob.max(axis=1)
array([244, 9, 20, 18, ..., 86, 422, 143, 431])
bigram_preds = (df
.iloc[valid_idx]
.assign(pred = bigram_pred)[['name', 'cl', 'pred']]
.assign(prob = bigram_prob.max(axis=1)))
bigram_preds.sort_values('prob', ascending=False).head(15)
name | cl | pred | prob | |
---|---|---|---|---|
16557 | Kotsiopoulos | Greek | Greek | 1.000000 |
2012 | Rooijakker | Dutch | Dutch | 1.000000 |
16470 | Akrivopoulos | Greek | Greek | 0.999999 |
826 | Warszawski | Polish | Polish | 0.999998 |
1997 | Romeijnders | Dutch | Dutch | 0.999997 |
16478 | Antonopoulos | Greek | Greek | 0.999995 |
813 | Sokolowski | Polish | Polish | 0.999994 |
839 | Zdunowski | Polish | Polish | 0.999950 |
1996 | Romeijn | Dutch | Dutch | 0.999917 |
16497 | Chrysanthopoulos | Greek | Greek | 0.999895 |
2053 | Sneijers | Dutch | Dutch | 0.999792 |
2031 | Schwarzenberg | Dutch | German | 0.999774 |
795 | Rudawski | Polish | Polish | 0.999751 |
16715 | De sauveterre | French | French | 0.999604 |
1160 | Kawagichi | Japanese | Japanese | 0.999600 |
The names it's least confident with: they typically seem to be quite short
bigram_preds.sort_values('prob', ascending=True).head(15)
name | cl | pred | prob | |
---|---|---|---|---|
2907 | Do | Vietnamese | Irish | 0.176679 |
24 | Mo | Korean | Japanese | 0.179534 |
47 | So | Korean | Korean | 0.188088 |
45 | Si | Korean | Greek | 0.190236 |
13775 | Prigojin | Russian | Italian | 0.191639 |
41 | Seok | Korean | French | 0.197154 |
5 | Cho | Korean | German | 0.202991 |
1091 | Isobe | Japanese | English | 0.206442 |
46 | Sin | Korean | Italian | 0.218300 |
5332 | Ingram | English | Spanish | 0.220205 |
2935 | Ta | Vietnamese | Japanese | 0.226875 |
2700 | Ban | Chinese | Vietnamese | 0.228022 |
1697 | Togo | Japanese | Japanese | 0.236658 |
4 | Cha | Korean | Irish | 0.239172 |
3445 | Graner | German | Spanish | 0.240844 |
The names it's most confidently wrong with:
bigram_preds[bigram_preds.cl != bigram_preds.pred].sort_values('prob', ascending=False).head(15)
name | cl | pred | prob | |
---|---|---|---|---|
2031 | Schwarzenberg | Dutch | German | 0.999774 |
16578 | Malihoudis | Greek | Arabic | 0.992311 |
16576 | Louverdis | Greek | French | 0.990256 |
4758 | Fairbrace | English | Irish | 0.987143 |
3743 | Spellmeyer | German | English | 0.976530 |
16468 | Adamou | Greek | Arabic | 0.973496 |
3009 | De la fuente | Spanish | French | 0.969431 |
3011 | De leon | Spanish | French | 0.964697 |
3263 | Boulos | Arabic | Greek | 0.962321 |
16513 | Egonidis | Greek | Italian | 0.954264 |
2478 | Suchanka | Czech | Japanese | 0.949000 |
2515 | Weichert | Czech | German | 0.946457 |
5476 | Keene | English | Dutch | 0.944270 |
3511 | Jaeger | German | Dutch | 0.938891 |
3174 | Attia | Arabic | Italian | 0.935905 |
Our very simple system does great on Japanese and Russian, but relatively poorly on Vietnamese where our data is most sparse (but still much better than random).
(bigram_preds
.assign(yes=bigram_preds.cl == bigram_preds.pred)
.groupby('cl')
.yes
.mean()
.sort_values(ascending=False)
cl Japanese 0.866667 Russian 0.733333 Polish 0.666667 Irish 0.666667 Dutch 0.633333 Italian 0.600000 Greek 0.533333 German 0.500000 English 0.500000 Spanish 0.466667 French 0.466667 Arabic 0.433333 Czech 0.400000 Chinese 0.400000 Korean 0.366667 Vietnamese 0.233333 Name: yes, dtype: float64
from sklearn.metrics import confusion_matrix
bigram_pred
array(['Arabic', 'Irish', 'German', 'Dutch', ..., 'Russian', 'Polish', 'Irish', 'Italian'], dtype='<U10')
cm = confusion_matrix(y[valid_idx], bigram_pred, labels=y.unique())
cm
array([[11, 1, 0, 6, ..., 0, 0, 2, 1], [ 0, 18, 0, 1, ..., 2, 1, 2, 1], [ 0, 1, 20, 1, ..., 0, 0, 0, 0], [ 1, 0, 0, 26, ..., 1, 0, 0, 0], ..., [ 1, 2, 0, 1, ..., 15, 0, 1, 1], [ 0, 2, 1, 1, ..., 1, 22, 0, 0], [ 0, 2, 1, 1, ..., 0, 1, 16, 3], [ 0, 3, 1, 0, ..., 1, 1, 4, 14]])
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
Vietnamese is often confused for Chinese (which makes sense) and Irish (which doesn't). Korean is often confused for Japanese. Spanish is often confused for Italian.
plt.figure(figsize=(12,12))
plot_confusion_matrix(cm, y.unique())
Confusion matrix, without normalization
bigram_preds[bigram_preds.cl == 'Vietnamese'].sort_values('prob').head(20)
name | cl | pred | prob | |
---|---|---|---|---|
2907 | Do | Vietnamese | Irish | 0.176679 |
2935 | Ta | Vietnamese | Japanese | 0.226875 |
2924 | Luc | Vietnamese | Vietnamese | 0.253900 |
2944 | Ton | Vietnamese | Vietnamese | 0.282859 |
2930 | Pho | Vietnamese | Dutch | 0.296166 |
2910 | Ly | Vietnamese | Russian | 0.298872 |
2915 | Doan | Vietnamese | Chinese | 0.307318 |
2916 | Dam | Vietnamese | Arabic | 0.325199 |
2906 | Dang | Vietnamese | Chinese | 0.388393 |
2900 | Pham | Vietnamese | Arabic | 0.396346 |
2928 | Nghiem | Vietnamese | English | 0.428147 |
2932 | Quach | Vietnamese | Vietnamese | 0.450314 |
2922 | Lac | Vietnamese | Irish | 0.450851 |
2938 | Thi | Vietnamese | Chinese | 0.455845 |
2902 | Hoang | Vietnamese | Korean | 0.487704 |
2927 | Mach | Vietnamese | Irish | 0.505350 |
2917 | Dao | Vietnamese | Irish | 0.512267 |
2923 | Lieu | Vietnamese | French | 0.515336 |
2939 | Than | Vietnamese | Chinese | 0.530805 |
2937 | Thai | Vietnamese | Irish | 0.580569 |
So our baseline is 53%. Let's see if we can do better with deep learning
Load in the dataframe and extract indexes for training, validation and balanced trainings.
df = pd.read_csv('names_clean.csv')
valid_idx = df[df.valid].index
train_idx = df[~df.valid].index
bal_idx = []
for k, v in zip(df.index, df.bal):
bal_idx += [k]*v
As of December 2018 Fastai only has Word level tokenizers; we'll have to create our own letter tokenizer.
The fastai library injects BOS
markers (xxbos
) at the start of every string; we'll have to parse them separately.
class LetterTokenizer(BaseTokenizer):
"Character level tokenizer function."
def __init__(self, lang): pass
def tokenizer(self, t:str) -> List[str]:
out = []
i = 0
while i < len(t):
if t[i:].startswith(BOS):
out.append(BOS)
i += len(BOS)
else:
out.append(t[i])
i += 1
return out
def add_special_cases(self, toks:Collection[str]): pass
We create a vocab of all ASCII letters, and a character tokenizer that doesn't do any specific processing.
itos = [UNK, BOS] + list(string.ascii_lowercase + " -'")
vocab=Vocab(itos)
tokenizer=Tokenizer(LetterTokenizer, pre_rules=[], post_rules=[])
We can create a data pipeline using the TextDataBunch.from_df
constructor.
mark_fields
puts and extra xxfld
marker between each field of text. Since we only have 1 field this is unnecessary.
train_df = df.iloc[train_idx, [0,2]]
valid_df = df.iloc[valid_idx, [0,2]]
train_df.head()
cl | ascii_name | |
---|---|---|
0 | Korean | ahn |
2 | Korean | bang |
3 | Korean | byon |
10 | Korean | gil |
11 | Korean | gu |
data = TextClasDataBunch.from_df(path='.', train_df=train_df, valid_df=valid_df,
tokenizer=tokenizer, vocab=vocab,
mark_fields=False)
data.show_batch()
text | target |
---|---|
v o n g r i m m e l s h a u s e n | German |
m a c e a c h t h i g h e a r n a | Irish |
c h k h a r t i s h v i l i | Russian |
t z e h m i s t r e n k o | Russian |
c h e p t y g m a s h e v | Russian |
Or we can create it using data block API.
This uses the processors
to tokenize and numericalize the input.
processors = [TokenizeProcessor(tokenizer=tokenizer, mark_fields=False),
NumericalizeProcessor(vocab=vocab)]
data = (TextList
.from_df(df,
cols=[2],
processor=processors)
.split_by_idxs(train_idx=train_idx, valid_idx=valid_idx)
.label_from_df(cols=0)
.databunch(bs=32))
data.show_batch()
text | target |
---|---|
v o n g r i m m e l s h a u s e n | German |
p a r a s k e v o p o u l o s | Greek |
d z h a v a h i s h v i l i | Russian |
s h a h n a z a r y a n t s | Russian |
m o g i l n i c h e n k o | Russian |
Counter(_.obj for _ in data.valid_ds.y)
Counter({'Korean': 30, 'Italian': 30, 'Polish': 30, 'Japanese': 30, 'Dutch': 30, 'Czech': 30, 'Irish': 30, 'Chinese': 30, 'Vietnamese': 30, 'Spanish': 30, 'Arabic': 30, 'German': 30, 'English': 30, 'Russian': 30, 'Greek': 30, 'French': 30})
Counter(_.obj for _ in data.train_ds.y).most_common()
[('Russian', 9232), ('English', 3329), ('Japanese', 952), ('Italian', 630), ('German', 548), ('Czech', 434), ('Dutch', 214), ('Spanish', 182), ('French', 180), ('Chinese', 170), ('Greek', 162), ('Irish', 134), ('Polish', 93), ('Arabic', 73), ('Korean', 31), ('Vietnamese', 25)]
Check no text is both in Validation and Training
valid_set = set(_.text for _ in data.valid_ds.x)
for _ in data.train_ds.x:
assert _.text not in valid_set, _.text
trainiter = iter(data.train_dl)
batch, cl = next(trainiter)
batch2, cl2 = next(trainiter)
cl, len(cl)
(tensor([ 6, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13], device='cuda:0'), 16)
batch.shape
torch.Size([20, 16])
The first 22 letters run down the batch backpadded by BOS
; we have 16 names across.
Somehow it looks like we also have an extra space at the beginning of each name that wasn't in the input data.
(Note this is different to what the fastai wrappers will give you; they concatenate the data and split it into 16 chunks).
pd.options.display.max_columns = 100
(pd
.DataFrame([[vocab.itos[y] for y in x] for x in batch])
.T
.assign(category=[data.classes[_] for _ in cl])
.T)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos |
1 | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | |
2 | v | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos |
3 | o | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos |
4 | n | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos |
5 | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | |||
6 | g | t | l | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | |||||||
7 | r | c | e | t | m | b | c | z | b | p | ||||||
8 | i | h | i | c | i | a | h | h | a | a | g | s | a | b | v | v |
9 | m | a | h | h | n | k | a | e | h | t | r | h | w | a | y | i |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
11 | e | t | e | h | a | h | t | o | h | i | s | n | o | h | c | c |
12 | l | o | n | l | z | t | o | k | i | o | h | d | r | t | h | h |
13 | s | r | b | a | e | a | r | h | v | r | e | e | k | i | e | e |
14 | h | i | e | k | t | n | i | o | a | k | l | r | h | g | s | p |
15 | a | z | r | o | d | o | z | v | n | o | e | o | a | a | l | o |
16 | u | h | g | v | i | w | h | t | d | v | v | v | n | r | a | l |
17 | s | s | s | s | n | s | s | s | z | s | s | i | o | e | v | s |
18 | e | k | k | k | o | k | k | e | h | k | k | c | f | e | o | k |
19 | n | y | y | y | v | i | y | v | i | y | y | h | f | v | v | y |
category | German | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian | Russian |
21 rows × 16 columns
[vocab.itos[_] for _ in data.train_ds[0][0].data]
['xxbos', ' ', 'a', 'h', 'n']
list(df.iloc[0,1])
['A', 'h', 'n']
Note the length of strings varies between batches.
(pd
.DataFrame([[vocab.itos[y] for y in x] for x in batch2])
.T
.assign(category=[data.classes[_] for _ in cl2])
.T)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos | xxbos |
1 | ||||||||||||||||
2 | b | m | r | f | m | b | d | j | m | b | k | u | m | m | a | b |
3 | e | c | i | e | a | a | e | e | o | a | i | f | a | e | n | a |
4 | n | g | d | r | s | j | m | f | r | l | n | i | k | a | s | b |
5 | i | r | g | r | a | e | a | f | a | b | c | m | i | d | e | u |
6 | t | o | w | a | o | n | k | e | n | o | h | k | o | h | l | r |
7 | e | r | a | r | k | o | i | r | d | n | i | i | k | r | m | i |
8 | z | y | y | o | a | v | s | s | i | i | n | n | a | a | i | n |
category | Spanish | English | English | Italian | Japanese | Russian | Greek | English | Italian | Italian | English | Russian | Japanese | Irish | Italian | Russian |
vocab.textify(batch2[:,0])
'xxbos b e n i t e z'
data.show_batch(ds_type=DatasetType.Valid)
text | target |
---|---|
c h r y s a n t h o p o u l o s | Greek |
v o n i n g e r s l e b e n | German |
s c h w a r z e n b e r g | Dutch |
d e s a u v e t e r r e | French |
a r e c h a v a l e t a | Spanish |
The torch nn.RNN expects the data to be one hot encoded
one_hot = torch.eye(len(vocab.itos))
one_hot[batch][:2]
tensor([[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
one_hot[batch].shape
torch.Size([20, 16, 31])
Here's how we could do it without storing the one_hot matrix in memory.
def one_hot_fly(y, length=len(vocab.itos)):
length = len(vocab.itos)
shape = list(y.shape)
assert len(shape) == 2
tensor = torch.zeros(shape + [length])
for i,row in enumerate(y):
for j, val in enumerate(row):
tensor[i][j][val] = 1.
return tensor
(one_hot[batch] == one_hot_fly(batch)).all()
tensor(1, dtype=torch.uint8)
Using matrix operations is ~250 times faster at this size than the double for loop.
%timeit one_hot[batch]
%timeit one_hot_fly(batch)
None
36.1 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 8.91 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
n_letters = len(vocab.itos)
n_hidden = 128
n_output = df.cl.nunique()
n_letters, n_output
(31, 16)
We use an RNN to take our sequence of letters in and calculate the hidden state
rnn = nn.RNN(input_size=n_letters,
hidden_size=n_hidden,
num_layers=1,
nonlinearity='relu',
dropout=0.)
output, hidden = rnn(one_hot[batch])
output.shape, hidden.shape
(torch.Size([20, 16, 128]), torch.Size([1, 16, 128]))
lo = nn.Linear(n_hidden, n_output)
preds = lo(output)
preds.shape
torch.Size([20, 16, 16])
cl
tensor([ 6, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13], device='cuda:0')
nn.functional.softmax(preds[-1], dim=1).argmax(dim=1)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
one_hot = torch.eye(len(vocab.itos))
class MyLetterRNN(nn.Module):
def __init__(self, dropout=0., n_layers=1, n_input=n_letters, n_hidden=n_hidden, n_output=n_output):
super().__init__()
self.one_hot = torch.eye(n_letters).cuda()
self.rnn = nn.RNN(input_size=n_letters,
hidden_size=n_hidden,
num_layers=n_layers,
nonlinearity='relu',
dropout=dropout)
self.lo = nn.Linear(n_hidden, n_output)
def forward(self, input):
rnn, _ = self.rnn(self.one_hot[input])
out = self.lo(rnn)
return out[-1]
rnn = MyLetterRNN().cuda()
out = rnn(batch)
out.argmax(dim=1), cl
(tensor([6, 6, 6, 6, 1, 6, 6, 1, 6, 6, 6, 1, 6, 1, 1, 6], device='cuda:0'), tensor([ 6, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13], device='cuda:0'))
Fit the model
F.cross_entropy(out, cl)
tensor(2.7832, device='cuda:0', grad_fn=<NllLossBackward>)
learn = Learner(data, rnn, loss_func=F.cross_entropy, metrics=[accuracy])
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
learn.fit_one_cycle(10, max_lr=3e-2)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.811252 | 2.653636 | 0.260417 |
2 | 0.928508 | 3.329767 | 0.216667 |
3 | 0.830531 | 3.436638 | 0.218750 |
4 | 0.947136 | 3.056552 | 0.202083 |
5 | 0.878935 | 3.361734 | 0.210417 |
6 | 0.818984 | 3.208372 | 0.214583 |
7 | 0.811538 | 2.896590 | 0.252083 |
8 | 0.745542 | 3.237130 | 0.283333 |
9 | 0.753505 | 2.819807 | 0.302083 |
10 | 0.763112 | 2.878011 | 0.297917 |
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.save('char_rnn_1')
learn.fit_one_cycle(5, 3e-3)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.696038 | 2.910773 | 0.304167 |
2 | 0.734545 | 2.814250 | 0.306250 |
3 | 0.634951 | 2.827829 | 0.295833 |
4 | 0.636780 | 2.758662 | 0.312500 |
5 | 0.696148 | 2.838843 | 0.312500 |
learn.save('char_rnn_1_final')
This is abysmal; 31% is much worse than 52% from the simple Naive Bayes bigram model.
Does it improve if we add another layer?
learn = Learner(data, MyLetterRNN(n_layers=2), loss_func=F.cross_entropy, metrics=[accuracy])
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(20, max_lr=1e-2)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 0.929248 | 3.101529 | 0.189583 |
2 | 0.695901 | 2.869615 | 0.250000 |
3 | 0.745567 | 2.520683 | 0.316667 |
4 | 0.620927 | 3.530135 | 0.262500 |
5 | 0.742575 | 2.512531 | 0.318750 |
6 | 0.723677 | 2.616584 | 0.343750 |
7 | 0.839355 | 2.454891 | 0.335417 |
8 | 0.817186 | 2.794391 | 0.291667 |
9 | 0.653000 | 2.695168 | 0.302083 |
10 | 0.683367 | 2.637764 | 0.358333 |
11 | 0.611877 | 2.308675 | 0.333333 |
12 | 0.586979 | 2.296229 | 0.352083 |
13 | 0.611386 | 2.224956 | 0.381250 |
14 | 0.580687 | 2.247524 | 0.383333 |
15 | 0.512482 | 2.244857 | 0.387500 |
16 | 0.516693 | 2.303736 | 0.412500 |
17 | 0.409016 | 2.413911 | 0.412500 |
18 | 0.435291 | 2.442951 | 0.422917 |
19 | 0.386392 | 2.507006 | 0.425000 |
20 | 0.352908 | 2.518786 | 0.433333 |
It looks like the fit has converged, again at a much worse result than our Naive Bayes bigrams.
But that was trained using a balanced dataset; maybe that will help with RNNs too.
learn.recorder.plot_losses()
learn.save('char_rnn_2_p0')
prob, targ = learn.get_preds()
Counter(data.classes[_.item()] for _ in prob.argmax(dim=1)).most_common()
[('English', 141), ('Russian', 97), ('Chinese', 54), ('Italian', 43), ('Japanese', 38), ('Greek', 25), ('German', 22), ('Czech', 12), ('French', 11), ('Dutch', 11), ('Spanish', 10), ('Polish', 10), ('Korean', 4), ('Vietnamese', 2)]
Even though the balanced set is a subset of the training set (and throws away a lot of data), the model performs much better on the balanced validation set with it.
This is because on the whole training set heuristics like "when in doubt, guess Russian/English" and "it's almost never Vietnamese" are good, but are terrible on our validation set.
data = (TextList
.from_df(df,
cols=[2],
processor=processors)
.split_by_idxs(train_idx=bal_idx, valid_idx=valid_idx)
.label_from_df(cols=0)
.databunch(bs=1024))
Counter(_.obj for _ in data.valid_ds.y)
Counter({'Korean': 30, 'Italian': 30, 'Polish': 30, 'Japanese': 30, 'Dutch': 30, 'Czech': 30, 'Irish': 30, 'Chinese': 30, 'Vietnamese': 30, 'Spanish': 30, 'Arabic': 30, 'German': 30, 'English': 30, 'Russian': 30, 'Greek': 30, 'French': 30})
Counter(_.obj for _ in data.train_ds.y).most_common()
[('Korean', 500), ('Italian', 500), ('Polish', 500), ('Japanese', 500), ('Dutch', 500), ('Czech', 500), ('Irish', 500), ('Chinese', 500), ('Vietnamese', 500), ('Spanish', 500), ('Arabic', 500), ('German', 500), ('English', 500), ('Russian', 500), ('Greek', 500), ('French', 500)]
(pd.DataFrame({'x': [_.text for _ in data