Gradient Descent Minibatch

fastai
Published

October 21, 2020

Gradiant descent with minibatches

From the Data Science from Scratch book.

Libraries and helper functions

import random
from typing import TypeVar, List, Iterator
Vector = List[float]
def add(vector1: Vector, vector2: Vector) -> Vector:
    assert len(vector1) == len(vector2)
    return [v1 + v2 for v1, v2 in zip(vector1, vector2)]
def vector_sum(vectors: List[Vector]) -> Vector:
    assert vectors
    
    vector_length = len(vectors[0])
    assert all(len(v) == vector_length for v in vectors)

    sums = [0] * vector_length
    for vector in vectors:
        sums = add(sums, vector)

    return sums
def scalar_multiply(c: float, vector: Vector) -> Vector:
    return [c * v for v in vector]
def vector_mean(vectors: List[Vector]) -> Vector:
    n = len(vectors)
    return scalar_multiply(1/n, vector_sum(vectors))
def gradient_step(v: Vector, gradient: Vector, step_size: float) -> Vector:
    """Return vector adjusted with step. Step is gradient times step size.
    """
    step = scalar_multiply(step_size, gradient)
    return add(v, step)
def linear_gradient(x: float, y: float, theta: Vector) -> Vector:
    slope, intercept = theta
    predicted = slope * x + intercept
    error = (predicted - y) #** 2
    # print(x, y, theta, predicted, error)
    return [2 * error * x, 2 * error]

Minibatch gradient

T = TypeVar('T')
def minibatches(dataset: List[T], batch_size=int, shuffle: bool = True) -> Iterator[List[T]]:
    batch_starts = [start for start in range(0, len(dataset), batch_size)]

    if shuffle: random.shuffle(batch_starts)

    for start in batch_starts:
        end = start + batch_size
        yield dataset[start:end]
inputs = [(x, 20 * x + 5) for x in range(-50, 50)]
inputs[:10]
[(-50, -995),
 (-49, -975),
 (-48, -955),
 (-47, -935),
 (-46, -915),
 (-45, -895),
 (-44, -875),
 (-43, -855),
 (-42, -835),
 (-41, -815)]
for batch in minibatches(inputs, batch_size=5, shuffle=False):
    print(batch)
[(-50, -995), (-49, -975), (-48, -955), (-47, -935), (-46, -915)]
[(-45, -895), (-44, -875), (-43, -855), (-42, -835), (-41, -815)]
[(-40, -795), (-39, -775), (-38, -755), (-37, -735), (-36, -715)]
[(-35, -695), (-34, -675), (-33, -655), (-32, -635), (-31, -615)]
[(-30, -595), (-29, -575), (-28, -555), (-27, -535), (-26, -515)]
[(-25, -495), (-24, -475), (-23, -455), (-22, -435), (-21, -415)]
[(-20, -395), (-19, -375), (-18, -355), (-17, -335), (-16, -315)]
[(-15, -295), (-14, -275), (-13, -255), (-12, -235), (-11, -215)]
[(-10, -195), (-9, -175), (-8, -155), (-7, -135), (-6, -115)]
[(-5, -95), (-4, -75), (-3, -55), (-2, -35), (-1, -15)]
[(0, 5), (1, 25), (2, 45), (3, 65), (4, 85)]
[(5, 105), (6, 125), (7, 145), (8, 165), (9, 185)]
[(10, 205), (11, 225), (12, 245), (13, 265), (14, 285)]
[(15, 305), (16, 325), (17, 345), (18, 365), (19, 385)]
[(20, 405), (21, 425), (22, 445), (23, 465), (24, 485)]
[(25, 505), (26, 525), (27, 545), (28, 565), (29, 585)]
[(30, 605), (31, 625), (32, 645), (33, 665), (34, 685)]
[(35, 705), (36, 725), (37, 745), (38, 765), (39, 785)]
[(40, 805), (41, 825), (42, 845), (43, 865), (44, 885)]
[(45, 905), (46, 925), (47, 945), (48, 965), (49, 985)]
batch
[(45, 905), (46, 925), (47, 945), (48, 965), (49, 985)]
theta = [random.uniform(-1, 1), random.uniform(-1, 1)]
vector_mean([linear_gradient(x, y, theta) for x, y in batch])
[-85269.18707965006, -1812.6065734463045]
inputs = [(x, 20 * x + 5) for x in range(-50, 50)]
theta = [random.uniform(-1, 1), random.uniform(-1, 1)]
learning_rate = 0.001

minibatch_results = []

for epoch in range(1000):
    for batch in minibatches(inputs, batch_size=20):
        grad = vector_mean([linear_gradient(x, y, theta) for x, y in batch])
        theta = gradient_step(theta, grad, -learning_rate)
    minibatch_results.append([epoch, theta])

Last twenty epochs

minibatch_results[-20:]
[[980, [20.00000231135919, 4.999994332476798]],
 [981, [19.99999999323285, 4.999994612954206]],
 [982, [20.000000061792846, 4.999994646372994]],
 [983, [19.99999994338701, 4.999994665516976]],
 [984, [19.99999978951187, 4.999994715352837]],
 [985, [20.000000045649184, 4.999994748354316]],
 [986, [19.99999982652112, 4.999994803275764]],
 [987, [20.000000353260692, 4.999994822918053]],
 [988, [20.000000203297585, 4.99999512614361]],
 [989, [20.000000013241355, 4.999995351690553]],
 [990, [20.00000018064206, 4.999995396803261]],
 [991, [20.00000031292562, 4.999995409193917]],
 [992, [19.99999796367368, 4.999995522343663]],
 [993, [19.9999998507973, 4.999995639004705]],
 [994, [20.00000000615185, 4.999995667269925]],
 [995, [19.999999823836678, 4.999995780888271]],
 [996, [20.000000260844455, 4.999995797933743]],
 [997, [20.00000006542973, 4.999995821702883]],
 [998, [19.999999347865106, 4.999995882305616]],
 [999, [20.000000771099824, 4.999995951108064]]]