From the Data Science from Scratch book.
from typing import List
import random
Vector = List[float]
def add(vector1: Vector, vector2: Vector) -> Vector:
assert len(vector1) == len(vector2)
return [v1 + v2 for v1, v2 in zip(vector1, vector2)]
def scalar_multiply(c: float, vector: Vector) -> Vector:
return [c * v for v in vector]
def gradient_step(v: Vector, gradient: Vector, step_size: float) -> Vector:
"""Return vector adjusted with step. Step is gradient times step size.
"""
step = scalar_multiply(step_size, gradient)
return add(v, step)
def linear_gradient(x: float, y: float, theta: Vector) -> Vector:
slope, intercept = theta
predicted = slope * x + intercept
error = (predicted - y) #** 2
# print(x, y, theta, predicted, error)
return [2 * error * x, 2 * error]
Here we use one training example at a time to calculate the gradient steps
inputs = [(x, 20 * x + 5) for x in range(-50, 50)]
theta = [random.uniform(-1, 1), random.uniform(-1, 1)]
learning_rate = 0.001
for epoch in range(100):
for x, y in inputs:
grad = linear_gradient(x, y, theta)
theta = gradient_step(theta, grad, -learning_rate)
print(epoch, theta)
0 [20.108274621088928, -0.3890550572184463] 1 [20.103628550173042, -0.15784430337372637] 2 [20.09918250047512, 0.06344662581205483] 3 [20.094927182760102, 0.2752433412342787] 4 [20.090854449810823, 0.47795318030884215] 5 [20.086956448727392, 0.6719660044610173] 6 [20.0832257045743, 0.8576549486610185] 7 [20.07965500742264, 1.0353771386943718] 8 [20.076237509713653, 1.2054743780282082] 9 [20.07296662801438, 1.3682738056445862] 10 [20.06983608174483, 1.5240885250583422] 11 [20.06683986242782, 1.6732182068542554] 12 [20.063972181799524, 1.8159496643823287] 13 [20.06122752813499, 1.9525574051756995] 14 [20.05860064833006, 2.083304159961789] 15 [20.056086446554076, 2.2084413869124764] 16 [20.05368014168373, 2.328209755991209] 17 [20.051377044125662, 2.442839611270341] 18 [20.049172772009843, 2.552551414011119] 19 [20.047063092087537, 2.6575561678668613] 20 [20.045043898683737, 2.758055822780714] 21 [20.04311134626141, 2.8542436641396707] 22 [20.04126169735246, 2.9463046849793786] 23 [20.039491436402585, 3.0344159418586236] 24 [20.037797100119768, 3.118746894557255] 25 [20.0361754521855, 3.19945973173537] 26 [20.034623390938133, 3.276709684335131] 27 [20.03313793652982, 3.3506453237233305] 28 [20.03171617359179, 3.421408845892839] 29 [20.030355435914238, 3.4891363463999814] 30 [20.02905307411413, 3.5539580824049044] 31 [20.02780658595349, 3.6159987219880825] 32 [20.026613585687915, 3.6753775847836647] 33 [20.025471758004617, 3.732208870958282] 34 [20.024378926653775, 3.7866018810682345] 35 [20.023332973880006, 3.8386612263190933] 36 [20.022331891127603, 3.8884870294569893] 37 [20.021373782454166, 3.936175118240505] 38 [20.020456760427248, 3.981817208697045] 39 [20.019579095445714, 4.0255010817522585] 40 [20.018739087243826, 4.067310752589061] 41 [20.017935091021968, 4.107326630985043] 42 [20.017165620461952, 4.14562567751338] 43 [20.016429140217177, 4.182281550851487] 44 [20.015724268710606, 4.217364749099938] 45 [20.015049636287394, 4.250942746103223] 46 [20.014403952714495, 4.283080120704653] 47 [20.01378597502875, 4.313838681185705] 48 [20.013194514460565, 4.343277583993619] 49 [20.012628414045615, 4.371453447141006] 50 [20.01208658825722, 4.398420459215999] 51 [20.011568035372246, 4.424230484791519] 52 [20.011071730850883, 4.448933163459061] 53 [20.010596714895236, 4.472576004466116] 54 [20.01014208044218, 4.495204478772731] 55 [20.009706939528638, 4.5168621063046] 56 [20.009290475161325, 4.5375905399679235] 57 [20.008891875158, 4.557429645750311] 58 [20.008510370231065, 4.576417578971538] 59 [20.00814524887334, 4.594590858346274] 60 [20.007795789585977, 4.611984435823321] 61 [20.00746133605139, 4.628631763741193] 62 [20.00714119205082, 4.644564858428533] 63 [20.00683481657478, 4.659814363112007] 64 [20.00654157694565, 4.674409606833622] 65 [20.00626093700885, 4.688378660053233] 66 [20.0059923009138, 4.701748388289354] 67 [20.005735205576336, 4.714544504461379] 68 [20.00548916033653, 4.726791619391082] 69 [20.00525365468124, 4.738513287317169] 70 [20.005028243503418, 4.749732051371429] 71 [20.004812513452155, 4.760469488056405] 72 [20.0046060409254, 4.770746248364355] 73 [20.004408419316054, 4.780582096925131] 74 [20.004219291985194, 4.789995950689194] 75 [20.004038256610244, 4.799005914654827] 76 [20.00386500754045, 4.807629317187821] 77 [20.003699183947933, 4.815882743439104] 78 [20.00354046799362, 4.823782066495483] 79 [20.003388552379008, 4.831342478409528] 80 [20.003243202425267, 4.838578520565049] 81 [20.003104049179598, 4.8455041097305935] 82 [20.002970873552083, 4.852132564899691] 83 [20.0028434121415, 4.858476634395222] 84 [20.00272141143189, 4.864548519281559] 85 [20.002604655271004, 4.870359897369896] 86 [20.002492920554257, 4.875921945826025] 87 [20.00238595286676, 4.881245361497983] 88 [20.002283589604943, 4.886340382432449] 89 [20.0021855928701, 4.8912168074015] 90 [20.002091842376988, 4.8958840153869385] 91 [20.002002090950178, 4.90035098289899] 92 [20.001916203724086, 4.904626300833264] 93 [20.001833987421463, 4.908718191643037] 94 [20.00175530458321, 4.912634524898607] 95 [20.001680007220763, 4.916382832992619] 96 [20.001607906672934, 4.919970324321837] 97 [20.001538916328176, 4.923403898248704] 98 [20.001472898762803, 4.926690159008146] 99 [20.001409710601003, 4.929835427066589]