import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline
column_names = ['year', *('average' + str(i) for i in range(12)), *('covariance' + str(i) for i in range(78))]
data = pd.read_csv("YearPredictionMSD.txt", sep = ",", header=None, names = column_names)
data.head()
year | average0 | average1 | average2 | average3 | average4 | average5 | average6 | average7 | average8 | ... | covariance68 | covariance69 | covariance70 | covariance71 | covariance72 | covariance73 | covariance74 | covariance75 | covariance76 | covariance77 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2001 | 49.94357 | 21.47114 | 73.07750 | 8.74861 | -17.40628 | -13.09905 | -25.01202 | -12.23257 | 7.83089 | ... | 13.01620 | -54.40548 | 58.99367 | 15.37344 | 1.11144 | -23.08793 | 68.40795 | -1.82223 | -27.46348 | 2.26327 |
1 | 2001 | 48.73215 | 18.42930 | 70.32679 | 12.94636 | -10.32437 | -24.83777 | 8.76630 | -0.92019 | 18.76548 | ... | 5.66812 | -19.68073 | 33.04964 | 42.87836 | -9.90378 | -32.22788 | 70.49388 | 12.04941 | 58.43453 | 26.92061 |
2 | 2001 | 50.95714 | 31.85602 | 55.81851 | 13.41693 | -6.57898 | -18.54940 | -3.27872 | -2.35035 | 16.07017 | ... | 3.03800 | 26.05866 | -50.92779 | 10.93792 | -0.07568 | 43.20130 | -115.00698 | -0.05859 | 39.67068 | -0.66345 |
3 | 2001 | 48.24750 | -1.89837 | 36.29772 | 2.58776 | 0.97170 | -26.21683 | 5.05097 | -10.34124 | 3.55005 | ... | 34.57337 | -171.70734 | -16.96705 | -46.67617 | -12.51516 | 82.58061 | -72.08993 | 9.90558 | 199.62971 | 18.85382 |
4 | 2001 | 50.97020 | 42.20998 | 67.09964 | 8.46791 | -15.85279 | -16.81409 | -12.48207 | -9.37636 | 12.63699 | ... | 9.92661 | -55.95724 | 64.92712 | -17.72522 | -1.49237 | -7.50035 | 51.76631 | 7.88713 | 55.66926 | 28.74903 |
5 rows × 91 columns
train = data.iloc[:463715, :]
test = data.iloc[463715:, :]
len(train), len(test)
(463715, 51630)
first_year, last_year = 1922, 2011
num_years = last_year-first_year+1
from torch import nn
model = nn.Sequential()
model.add_module('l1', nn.Linear(90, num_years))
#model.add_module('activ', nn.ReLU())
model.add_module('smax', nn.Softmax(0))
import torch
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
#check that answers dffer every time
a = np.random.randint(0, len(train), 5)
b = np.random.randint(0, len(train), 5)
a, b
(array([249807, 398891, 389318, 385061, 425867]), array([297735, 435139, 453997, 407096, 194297]))
X = train.iloc[:, 1:].values
Y = train.iloc[:, 0].values
list(Y)
[2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2007, 2008, 2002, 2004, 2003, 1999, 2003, 2002, 1992, 1997, 1987, 2000, 2000, 2005, 2000, 1997, 1997, 1996, 1997, 1997, 1997, 1997, 1997, 1997, 1997, 1997, 1998, 2000, 2000, 2001, 2000, 2000, 2000, 2000, 2000, 1998, 2000, 2000, 2000, 2000, 2003, 2003, 2001, 2003, 2001, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2007, 2007, 2007, 2007, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2008, 2009, 2009, 2009, 2009, 2008, 2008, 2009, 2009, 2008, 2008, 2009, 2009, 2009, 2008, 2008, 2007, 2008, 2007, 2008, 2008, 2009, 2009, 2009, 2008, 2008, 2008, 2008, 2008, 2008, 2007, 2004, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2009, 2006, 2006, 2009, 2009, 2009, 2009, 2009, 2009, 1993, 1993, 1993, 1993, 1993, 1993, 1993, 1993, 1993, 1993, 1996, 2005, 1991, 1991, 1933, 1991, 1991, 1991, 1991, 1930, 1935, 2005, 2000, 1991, 1991, 1995, 1999, 1999, 1941, 1991, 2005, 1995, 1990, 1999, 2000, 1991, 1991, 1999, 1997, 2000, 1930, 1991, 1930, 1995, 1991, 1991, 1991, 1991, 1991, 1995, 1999, 1999, 1991, 1991, 1991, 1997, 1991, 1999, 1990, 1941, 1941, 1997, 1997, 1997, 1943, 1991, 1996, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2008, 2007, 2001, 2009, 2006, 2002, 2002, 2002, 2008, 2008, 2002, 2002, 2002, 2002, 2002, 2002, 2002, 2006, 2006, 1999, 1999, 1999, 1999, 1999, 1999, 1999, 2002, 2000, 2004, 2006, 2005, 2006, 2006, 2006, 2005, 2006, 2006, 2006, 2005, 2005, 2005, 2005, 2005, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 1998, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2000, 2000, 1999, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 1994, 1994, 1994, 2002, 2002, 1993, 1995, 2002, 1995, 1995, 1994, 2002, 2008, 2008, 2008, 2008, 1996, 1996, 1996, 2003, 1993, 2002, 1993, 2002, 1994, 2002, 1995, 1995, 1995, 1995, 1995, 2002, 2002, 2002, 2002, 2000, 2000, 2000, 2009, 2009, 2009, 1994, 1994, 1994, 1994, 1994, 1994, 2006, 2006, 1999, 1999, 1999, 1999, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2002, 2002, 2002, 1999, 2002, 2002, 2002, 2002, 2004, 2006, 2006, 2006, 2006, 2006, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2008, 1996, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2004, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 1990, 2008, 1999, 1974, 1974, 1974, 1974, 1974, 1974, 1974, 1976, 1976, 1976, 1976, 1976, 1976, 1999, 1975, 1975, 1999, 1975, 1975, 1975, 1975, 1970, 1970, 1970, 1970, 1970, 1971, 1971, 1971, 1971, 1971, 1971, 1970, 1981, 1981, 1981, 1981, 1981, 1981, 1981, 1981, 2009, 1989, 1989, 1989, 2002, 2002, 2002, 2002, 1999, 1999, 2002, 1969, 1969, 1969, 1969, 1972, 1972, 1972, 1972, 1972, 1972, 1972, 1972, 1972, 1971, 1971, 1971, 1971, 1971, 1971, 1971, 1973, 1973, 1973, 1973, 1973, 1983, 1983, 1983, 1983, 1983, 1983, 1983, 1983, 1983, 2009, 2009, 2009, 2009, 2009, 2010, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2001, 2008, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2006, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 1985, 1988, 1989, 1989, 1989, 1989, 1989, 1989, 1993, 1985, 1987, 2004, 2004, 1999, 1993, 2004, 1991, 1987, 1987, 1987, 2001, 2001, 2001, 1979, 1980, 1980, 2001, 2001, 1979, 1980, 2000, 1980, 1980, 1980, 2001, 2000, 2001, 1979, 2000, 1980, 2000, 1980, 1980, 1979, 1980, 1999, 1989, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1986, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1970, 1971, 1971, 1971, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 1958, 2005, 1958, 1992, 2006, 1993, 1992, 1970, 1970, 1958, 2005, 1958, 1958, 1958, 1958, 1970, 1997, 1997, 2005, 2009, 1992, 1970, 1998, 1994, 2006, 2008, 1992, 1972, 1992, 1992, 2008, 1994, 2005, 1993, 1993, 1993, 2000, 1991, 1997, 1991, 1978, 1992, 1995, 1972, 1995, 1991, 1987, 1987, 1968, 2003, 1962, 1995, 1995, 1995, 1995, 1958, 1980, 1980, 2005, 2005, 2005, 2005, 2005, 2005, 2005, 2005, 2005, 2005, 1981, 2005, 2005, 1981, 2005, 1981, 2005, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2003, 2002, 2003, 1987, 1987, 1987, 1987, 2002, 1992, 1997, 1997, 1997, 1997, 1997, 1997, 1997, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 1997, 1992, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2008, 2010, 2010, 2010, 2010, 2010, 2010, 2010, 2010, 2010, 2010, 2010, 2006, 2005, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2005, 2005, 2005, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2004, 2006, 2006, 2006, 2005, 2004, 2005, 2004, 2004, 2004, 2004, 2005, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 2004, 1995, 1994, 1993, 1994, 1996, 1967, 2006, 1967, 2006, 2006, 1972, 2002, 1972, 2002, 1972, 1972, 1972, 1999, 2006, 2006, 2006, 2006, 2006, 2006, 2006, 2004, 2006, 2006, 1972, 1972, 1972, 1972, 1972, ...]
Y = list(Y)
for idx in range(len(Y)):
year=Y[idx]
Y[idx] = [int(i+first_year == year) for i in range(num_years)]
Y = torch.tensor(Y, dtype=torch.float32)
history = []
batch_size = int(len(train)/50)
for i in range(10):#around a thousand samples
# sample batch_size random data
ix = np.random.randint(0, len(train), batch_size)
x_batch = torch.tensor(X[ix], dtype=torch.float32)
y_batch = torch.tensor(Y[ix], dtype=torch.float32)
# predict probabilities
y_predicted = model(x_batch)
#assert y_predicted.dim() == 1, "did you forget to select first column with [:, 0]"
# compute loss, just like before
loss = torch.mean( (y_predicted - y_batch)**2 )
loss.backward() # add new gradients
opt.step() # change weights
opt.zero_grad() # clear gradients
history.append(loss.data.numpy())
if i % 1 == 0:
print("step #%i | mean loss = %.3f" % (i, np.mean(history[-10:])))
step #0 | mean loss = 0.011 step #1 | mean loss = 0.011 step #2 | mean loss = 0.011 step #3 | mean loss = 0.011 step #4 | mean loss = 0.011 step #5 | mean loss = 0.011 step #6 | mean loss = 0.011 step #7 | mean loss = 0.011 step #8 | mean loss = 0.011 step #9 | mean loss = 0.011