๐ก 'Deep Learning from Scratch'์ 'CS231N'์ ์ฐธ๊ณ ํ์ฌ ์์ฑ
์ ๊ฒฝ๋ง(neural network)์ ํ์ต ๋ชฉ์ ์ ์์ค ํจ์(loss function)์ ๊ฐ์ ์ต๋ํ ๋ฎ์ถ๋ ๋งค๊ฐ๋ณ์(parameter)๋ฅผ ์ฐพ๋ ๊ฒ์ด์์ต๋๋ค. ์ด๋ ๊ณง ๋งค๊ฐ๋ณ์์ ์ต์ ๊ฐ์ ์ฐพ๋ ๋ฌธ์ ์ด๋ฉฐ, ์ด๋ฅผ ์ต์ ํ ๋ฌธ์ (optimization)๋ผ ํฉ๋๋ค.
์ต์ ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ฐพ๊ธฐ ์ํด์๋ ํ์ต ๋ฐ์ดํฐ๋ค์ ์ด์ฉํด ๊ธฐ์ธ๊ธฐ(gradient)์ ๊ฐ์ ๊ตฌํ๊ณ , ๊ทธ ๊ฐ์ ๊ธฐ์ค์ผ๋ก ๋์๊ฐ ๋ฐฉํฅ์ ๊ฒฐ์ ํด์ผ ํฉ๋๋ค. ์ด๋ฒ ๊ฒ์๋ฌผ์์๋ ์ต์ ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ฐพ๋ ๋ฐฉ๋ฒ์ธ ์ตํฐ๋ง์ด์ ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค.
1. ๊ธฐ์ธ๊ธฐ
ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ์ฌ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ตฌํ๊ธฐ ์ํด์๋ ๊ฐ ๋ณ์๋ค์ ๋ํ ํธ๋ฏธ๋ถ์ ๋์์ ๊ณ์ฐํด์ผ ํฉ๋๋ค.
$$ f(x_0, x_1) = x_0^2 + x_1^2 $$
์์ ๊ณต์์์ $ x_0 = 3, x_1 = 4 $์ผ ๋, ($ x_0, x_1 $)์ ์์ชฝ์ ํธ๋ฏธ๋ถ์ ๋ฌถ์ด์ ($ \frac{\partial{f}}{\partial{x_0}}, \frac{\partial{f}}{\partial{x_1}} $)๋ก ๊ณ์ฐํฉ๋๋ค. ($ \frac{\partial{f}}{\partial{x_0}}, \frac{\partial{f}}{\partial{x_1}} $)์ฒ๋ผ ๋ชจ๋ ๋ณ์์ ํธ๋ฏธ๋ถ์ ๋ฒกํฐ๋ก ์ ๋ฆฌํ ๊ฒ์ ๊ธฐ์ธ๊ธฐ๋ผ๊ณ ํฉ๋๋ค. ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ตฌํ๋ ํจ์๋ ์๋์ ๊ฐ์ด ๊ตฌํํ ์ ์์ต๋๋ค.
def numerical_gradient_no_batch(f: Callable, x: np.array) -> np.array:
h = 1e-4
# x์ ๊ฐ์ ํ์์ 0์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐฐ์ด ์์ฑ
grad = np.zeros_like(x)
for idx in range(x.size):
tmp_val = x[idx]
# f(x+h) ๊ณ์ฐ
x[idx] = float(tmp_val) + h
fxh1 = f(x)
# f(x-h) ๊ณ์ฐ
x[idx] = tmp_val - h
fxh2 = f(x)
grad[idx] = (fxh1 - fxh2) / (2 * h)
x[idx] = tmp_val # ๊ฐ ๋ณต์
return grad
๊ธฐ์ธ๊ธฐ์ ๊ฒฐ๊ณผ์ ๋ง์ด๋์ค๋ฅผ ๋ถ์ธ ๋ฒกํฐ๋ฅผ ๊ทธ๋ ค๋ณด๋ฉด, ๊ธฐ์ธ๊ธฐ๊ฐ ์๋ฏธํ๋ ๊ฒ์ ๋ ์ ์ดํดํ ์ ์์ต๋๋ค. ์๋์ ๊ทธ๋ฆผ 1์ ๋ฐฉํฅ์ ๊ฐ์ง ๋ฒกํฐ(ํ์ดํ)๋ก ๊ทธ๋ ค์ก์ต๋๋ค. ์ด๋ ๋ง์น ํจ์์ '๊ฐ์ฅ ๋ฎ์ ์ฅ์(์ต์๊ฐ)'๋ฅผ ๊ฐ๋ฆฌํค๋ ๊ฒ ๊ฐ์ต๋๋ค. ๋ํ, '๊ฐ์ฅ ๋ฎ์ ๊ณณ'์์ ๋ฉ์ด์ง์๋ก ํ์ดํ์ ํฌ๊ธฐ๊ฐ ์ปค์ง๋ค๋ ๊ฒ ์ญ์ ํ์ธํ ์ ์์ต๋๋ค.
๊ธฐ์ธ๊ธฐ๋ ๊ฐ ์ง์ ์์ ๋ฎ์์ง๋ ๋ฐฉํฅ์ ๊ฐ๋ฆฌํต๋๋ค. ์ ํํ ํํํ๋ฉด, ๊ธฐ์ธ๊ธฐ๊ฐ ๊ฐ๋ฆฌํค๋ ์ชฝ์ ๊ฐ ์ฅ์์์ ํจ์์ ์ถ๋ ฅ ๊ฐ์ ๊ฐ์ฅ ํฌ๊ฒ ์ค์ฌ์ฃผ๋ ๋ฐฉํฅ์ด๋ผ๊ณ ํ ์ ์์ต๋๋ค.
2. ์ตํฐ๋ง์ด์
2.1 ์ตํฐ๋ง์ด์ ๋?
์ตํฐ๋ง์ด์ ๋ '์ต์ ์ ํํ๋ ์'๋ผ๋ ๋ป์ ๋จ์ด์ ๋๋ค. ๊ธฐ๊ณํ์ต(machine learning)์ ํ์ต ๋จ๊ณ์์ ์ต์ ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ฐพ์๋ด๋ ๊ฒ์ด ์ตํฐ๋ง์ด์ ์ ์ญํ ์ ๋๋ค. ์ฌ๊ธฐ์ ์ต์ ์ด๋, ์์ค ํจ์๊ฐ ์ต์๊ฐ์ด ๋ ๋์ ๋งค๊ฐ๋ณ์ ๊ฐ์ ์๋ฏธํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ผ๋ฐ์ ์ธ ๋ฌธ์ ์ ์์ค ํจ์๋ ๋งค์ฐ ๋ณต์กํฉ๋๋ค. ์ด๋ฐ ์ํฉ์์ ๊ธฐ์ธ๊ธฐ๋ฅผ ์ ์ด์ฉํ์ฌ ํจ์์ ์ต์๊ฐ์ ์ฐพ๋ ๊ฒ์ด ๋ฐ๋ก ์ตํฐ๋ง์ด์ ์ ๋๋ค.
์ฌ๊ธฐ์ ์ฃผ์ํ ์ ์ ๊ฐ ์ง์ ์์ ํจ์์ ๊ฐ์ ๋ฎ์ถ๋ ๋ฐฉ์์ ์ ์ํ๋ ์งํ๊ฐ ๊ธฐ์ธ๊ธฐ๋ผ๋ ๊ฒ์ ๋๋ค. ๋ณต์กํ ํจ์์์ ๊ธฐ์ธ๊ธฐ๊ฐ ๊ฐ๋ฆฌํค๋ ๊ณณ์ด ์ ๋ง๋ก ๋์๊ฐ์ผ ํ ๋ฐฉํฅ์ธ์ง๋ฅผ ๋ณด์ฅํ ์ ์์ต๋๋ค. ์ค์ ๋ก ํจ์๊ฐ ๋ณต์กํ ์๋ก ๊ธฐ์ธ๊ธฐ๊ฐ ๊ฐ๋ฆฌํค๋ ๋ฐฉํฅ์ ์ต์๊ฐ์ด ์๋ ๊ฒฝ์ฐ๊ฐ ๋๋ถ๋ถ์ ๋๋ค. ์์ฅ์ (saddle point)๊ณผ ์ง์ญ ์ต์ ํด(local minima)๊ฐ ๊ทธ ๋ํ์ ์ธ ์ฌ๋ก์ ๋๋ค.
์์ฅ์
์ด๋ ๋ฐฉํฅ์์ ๋ณด๋ฉด ๊ทน๋๊ฐ์ด๊ณ ๋ค๋ฅธ ๋ฐฉํฅ์์ ๋ณด๋ฉด ๊ทน์๊ฐ์ด ๋๋ ์ .
์ง์ญ ์ต์ ํด
์ฃผ์์ ๋ชจ๋ ์ ์ ํจ์๊ฐ๋ณด๋ค๋ ์ดํ์ ๊ฐ์ ๊ฐ๋ ์ ์ด์ง๋ง, ๋ชจ๋ ์คํ ๊ฐ๋ฅํ ํจ์ซ๊ฐ๋ณด๋ค๋ ๋ ํด ์ ์๋ ์ .
์ด์ฒ๋ผ ๊ธฐ์ธ์ด์ง ๋ฐฉํฅ์ด ๊ผญ ์ต์๊ฐ์ ๊ฐ๋ฆฌํค๋ ๊ฒ์ ์๋์ง๋ง, ๊ทธ ๋ฐฉํฅ์ผ๋ก ๊ฐ์ผ ํจ์์ ๊ฐ์ ์ค์ผ ์ ์์ต๋๋ค. ๊ทธ๋์ ์ ์ญ ์ต์๊ฐ(global minimum)์ด ๋๋ ์ฅ์๋ฅผ ์ฐพ๋ ๋ฌธ์ ์์๋ ๊ธฐ์ธ๊ธฐ์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๋์๊ฐ์ผํ ๋ฐฉํฅ์ ์ ํด์ผ ํฉ๋๋ค. ์ฌ๊ธฐ์ ๋ฑ์ฅํ๋ ๋ฐฉ๋ฒ์ด ์ตํฐ๋ง์ด์ ์ ์์ด์ธ ๊ฒฝ์ฌํ๊ฐ๋ฒ(gradient descent)์ ๋๋ค.
2.2 Stochastic Gradient Descent (SGD)
๊ฒฝ์ฌํ๊ฐ๋ฒ์ ํ ์์น์์ ์ต์๊ฐ์ ํฅํด ์ผ์ ๊ฑฐ๋ฆฌ๋งํผ ์ด๋ํฉ๋๋ค. ์ด๋ํ ๊ณณ์์ ๋ค์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ตฌํ๊ณ ๋ ์ต์๊ฐ์ ํฅํด ๋์๊ฐ๊ธฐ๋ฅผ ๋ฐ๋ณตํฉ๋๋ค. ์ต์๊ฐ ๋๋ฌ์ด๋ผ๋ ๋ชฉํ๋ฅผ ๋ฌ์ฑํ๊ธฐ ์ํด ๊ฒฝ์ฌํ๊ฐ๋ฒ์ ๊ณ์ํด์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐฑ์ ํฉ๋๋ค. ์ด๋, ๋ฐ์ดํฐ๋ฅผ ๋ฏธ๋๋ฐฐ์น(mini-batch)๋ก ๋๋คํ๊ฒ ์ ์ ํ๋ ๊ฒฝ์ฌํ๊ฐ๋ฒ์ SGD๋ผ๊ณ ๋ถ๋ฆ ๋๋ค. ์ด๋ 'ํ๋ฅ ์ (stochastic)์ผ๋ก ๋ฌด์์ํ๊ฒ ๊ณจ๋ผ๋ธ ๋ฐ์ดํฐ'์ ๋ํด ์ํํ๋ ๊ฒฝ์ฌํ๊ฐ๋ฒ์ด๋ผ๋ ์๋ฏธ๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. SGD์ ์์์ ์๋์ ๊ฐ์ด ์์ฑํ ์ ์์ต๋๋ค.
$$ \theta_{t} = \theta_{t-1} - \eta \triangledown J(\theta_{t-1}) $$
- $ \theta $ : ๋งค๊ฐ๋ณ์
- $ \eta $ : ํ์ต๋ฅ (learning rate)
- $ \triangledown J(\theta) $ : ์์ค ํจ์์ ๊ธฐ์ธ๊ธฐ
ํ์ต๋ฅ
๊ฐฑ์ ํ๋ ์. ํ ๋ฒ์ ํ์ต์ผ๋ก ์ผ๋ง๋งํผ ํ์ตํด์ผ ํ ์ง, ์ฆ ๋งค๊ฐ๋ณ์ ๊ฐ์ ์ผ๋ง๋ ๊ฐฑ์ ํ๋ ์ง๋ฅผ ์ ํจ
์ด๋ฅผ ๊ตฌํํ๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
class SGD:
def __init__(self, lr: float=0.01) -> None:
self.lr = lr
def update(self, params: dict, grads: dict) -> None:
for key in params.keys():
params[key] -= self.lr * grads[key]
SGD๋ ๋จ์ํ๊ณ ๊ตฌํํ๊ธฐ ์ฝ์ง๋ง, ๋ฌธ์ ์ ๋ฐ๋ผ์ ๋นํจ์จ์ ์ธ ๊ฒฝ์ฐ๊ฐ ์์ต๋๋ค. ์๋์ ์์์ SGD๋ฅผ ์ ์ฉํ์ฌ ์ต์๊ฐ์ ๊ตฌํ๋ ๋ฌธ์ ๋ฅผ ์๊ฐํด ๋ณด๊ฒ ์ต๋๋ค.
$$ f(x,y)=\frac{1}{20}x^2+y^2 $$
ํ์์ ์์ํ๋ ์ด๊ธฐ๊ฐ์ $ (x, y) = (-8.0, 2.0) $์ผ๋ก ํ๊ฒ ์ต๋๋ค. ์ด๋ฅผ ์๊ฐํํ๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
SGD๋ ๊ทธ๋ฆผ 3๊ณผ ๊ฐ์ด ์ฌํ๊ฒ ๊ตฝ์ด์ ธ ๋นํจ์จ์ ์ธ ์์ง์์ ๋ณด์ฌ์ค๋๋ค. SGD์ ๋จ์ ์ ๋น๋ฑ๋ฐฉ์ฑ(anisotropy) ํจ์์์๋ ํ์ ๊ฒฝ๋ก๊ฐ ๋นํจ์จ์ ์ด๋ผ๋ ๊ฒ์ ๋๋ค. SGD ๊ฐ์ด ๋ฌด์์ ๊ธฐ์ธ์ด์ง ๋ฐฉํฅ์ผ๋ก ์งํํ๋ ๋จ์ํ ๋ฐฉ์์ ์ฌ์ฉํ๊ธฐ๋ณด๋ค ์ด๋ฅผ ๊ฐ์ ํ ๋ฐฉ๋ฒ์ด ํ์ํ์ต๋๋ค.
๋น๋ฑ๋ฐฉ์ฑ ํจ์
๋ฐฉํฅ์ ๋ฐ๋ผ ์ฑ์ง(์ฌ๊ธฐ์๋ ๊ธฐ์ธ๊ธฐ)์ด ๋ฌ๋ผ์ง๋ ํจ์
์ด์ธ์๋ SGD๋ ์ง์ญ ์ต์ ํด์ ์์ฅ์ ์์ ๋ฒ์ด๋์ง ๋ชปํ๋ ๋ฌธ์ ๋ฅผ ๊ฐ๊ณ ์์ต๋๋ค. ๋ํ, ๋ชจ๋ ํ๋ผ๋ฏธํฐ์์ ๋์ผํ ํ์ต ๋ณดํญ(step size)์ด ์ ์ฉ๋๊ธฐ ๋๋ฌธ์ ํ๋ผ๋ฏธํฐ๊ฐ ๋ณํ ๊ฒฐ๊ณผ๋ฅผ ํ์ต์ ๋ฐ์ํ์ง ๋ชปํ๋ค๋ ๋ฌธ์ ์ ์ด ์์ต๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ๊ฐ์ ํ๊ธฐ ์ํด ์๋ก์ด ์ตํฐ๋ง์ด์ ๋ค์ด ๋ฑ์ฅํ์ต๋๋ค.
2.3 SGD + Momentum
๋ชจ๋ฉํ (Momentum)์ '์ด๋๋'์ ๋ปํ๋ ๋จ์ด๋ก ๋ฌผ๋ฆฌ์ ๊ด๊ณ๊ฐ ์์ต๋๋ค. SGD๊ฐ ์ง์ญ ์ต์ ํด์ ์์ฅ์ ์์ ๋ฒ์ด๋์ง ๋ชปํ๋ ๋ฌธ์ ๋ฅผ ๊ฐ์ ํ๊ธฐ ์ํด SGD์ ๋ชจ๋ฉํ ์ ๊ฐ๋ ์ ์ถ๊ฐํ์์ต๋๋ค. ๋ชจ๋ฉํ ์ ๋ค์๊ณผ ๊ฐ์ด ์์์ ์์ฑํ ์ ์์ต๋๋ค.
$$ v_t = \gamma v_{t-1}-\eta\triangledown J(\theta_{t-1}) $$
$$ \theta_{t} = \theta_{t-1} + v_t $$
- $ v_t $ : ์๋(velocity)
์ ๊ณต์์์ $ \gamma v_t $๋ ๋ฌผ์ฒด๊ฐ ์๋ฌด๋ฐ ํ์ ๋ฐ์ง ์์ ๋ ์์ํ ํ๊ฐ์ํค๋ ์ญํ ์ ํฉ๋๋ค($ \gamma $๋ 0.9 ๋ฑ์ ๊ฐ์ผ๋ก ์ค์ ). $ \gamma $๋ ๋ฌผ๋ฆฌ์์ ์ง๋ฉด ๋ง์ฐฐ์ด๋ ๊ณต๊ธฐ ์ ํญ์ ํด๋นํฉ๋๋ค. ๋ชจ๋ฉํ ์ ์๋ ๊ฐ๋ ์ ํตํด ์ต์๊ฐ์์ ๋ฐ๋ก ๋ฉ์ถ์ง ์๊ณ '์ค๋ฒ์ํ (Overshooting)'ํ์ฌ ์ง์ญ ์ต์ ํด์ ์์ฅ์ ์ ํต๊ณผํ ์ ์์์ต๋๋ค. ๋ชจ๋ฉํ ์ ๊ตฌํํ๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
class momentum:
def __init__(self, lr: float=0.01, momentum: float=0.9) -> None:
self.lr = lr
self.m = momentum
self.v = None
def update(self, params: dict, grads: dict) -> None:
if self.v is None:
self.v = {}
for key, val in params.items():
self.v[key] = np.zeros_like(val)
for key in params.keys():
self.v[key] = self.m * self.v[key] - self.lr * grads[key]
params[key] += self.v[key]
์ฝ๋์์ ๋ณ์ v๋ ๋ฌผ์ฒด์ ์๋๋ฅผ ๋ปํฉ๋๋ค. ์ด๋ฅผ ์๊ฐํํ๋ฉด ์๋ ๊ทธ๋ฆผ 5์ ๊ฐ์ต๋๋ค.
๊ทธ๋ฆผ 5์์ ๊ฐฑ์ ๊ฒฝ๋ก๊ฐ SGD์ ๋นํ์ฌ ๋ถ๋๋ฌ์์ง ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ์ด๋ $ x $์ถ์ ํ์ ์์ฃผ ์์ง๋ง ๋ฐฉํฅ์ ๋ณํ์ง ์์์ ํ ๋ฐฉํฅ์ผ๋ก ์ผ์ ํ๊ฒ ๊ฐ์ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๋ฐ๋๋ก $ y $์ถ์ ํ์ ํฌ์ง๋ง ์์๋๋ก ๋ฒ๊ฐ์ ๋ฐ์์ ์์ถฉํ์ฌ $ y $์ถ ๋ฐฉํฅ์ ์๋๋ ์์ ์ ์ด์ง ์์ต๋๋ค.
๋ํ, ๊ทธ๋ฆผ 5์์ ์ ์ญ ์ต์๊ฐ์์๋ ๋ฐ๋ก ๋ฉ์ถ๊ฑฐ๋ ๋ฆ์ถฐ์ง์ง ์๊ณ ์ง๋์ณ๊ฐ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ์ด๋ ์ง์ญ ์ต์ ํด๋ ์์ฅ์ ์ ํต๊ณผํ๋๋ฐ๋ ์ ํจํ์์ง๋ง, ์ ์ญ ์ต์๊ฐ์์ ๋น ๋ฅด๊ฒ ์๋ ดํ์ง ์์ ํ์ต ์๊ฐ์ด ๋์ด๋๋ค๋ ๋ฌธ์ ๋ก ์ฐ๊ฒฐ๋ฉ๋๋ค.
2.4 Nesterov Accelerated Gradient (NAG)
๋ชจ๋ฉํ ์ ๊ฐฑ์ ๊ณผ์ ์์ ํ์ฌ์ ๊ธฐ์ธ๊ธฐ ๊ฐ์ ๊ธฐ๋ฐ์ผ๋ก ๋ค์ ๋จ๊ณ์ ๊ฐ์ ๋์ถํ์์ต๋๋ค. ์ด๋ ๊ด์ฑ์ ์ํด ์ต์ ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ง๋์น ์ ์๋ค๋ ๋ฌธ์ ๋ก ์ด์ด์ง๋๋ค. NAG์์๋ ๋ชจ๋ฉํ ์ผ๋ก ์ด๋ํ ์ง์ ์์์ ๊ธฐ์ธ๊ธฐ๋ฅผ ํ์ฉํ์ฌ ๊ฐฑ์ ํ๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์์ต๋๋ค. ์ด๋ฅผ ์์์ผ๋ก ํํํ๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
$$ v_t = \gamma v_{t-1}-\eta\triangledown J(\theta_{t-1}-\gamma v_{t-1}) $$
$$ \theta_{t} = \theta_{t-1} + v_t $$
๋ชจ๋ฉํ ๊ณผ ๋ฌ๋ฆฌ $ v_t $๋ฅผ ๊ตฌํ ๋, $ \triangledown J(\theta-\gamma v_{t-1}) $๊ณผ ๊ฐ์ด ๊ด์ฑ์ ์ํ์ฌ ์ด๋ํ ๊ณณ์ ๊ธฐ์ธ๊ธฐ๋ฅผ ์ ์ฉํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ์ด๋ ๊ด์ฑ์ ์ํ์ฌ ๋น ๋ฅด๊ฒ ์ด๋ํ๋ ์ด์ ์ ์ ์ฉํ๋ฉด์๋ ์๋ ดํด์ผ ํ๋ ๊ณณ์์ ํจ๊ณผ์ ์ผ๋ก ๋ฉ์ถ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ง๋ค์์ต๋๋ค. NAG์ ๋ชจ๋ฉํ ์ ์์ง์์ ๋น๊ตํ๋ฉด ์๋์ ๊ทธ๋ฆผ 6๊ณผ ๊ฐ์ต๋๋ค.
NAG๋ ์๋์ ๊ฐ์ด ๊ตฌํํ ์ ์์ต๋๋ค.
class momentum:
def __init__(self, lr: float=0.01, momentum: float=0.9,
nesterov : bool=False) -> None:
self.lr = lr
self.m = momentum
self.v = None
self.nesterov = nesterov
def update(self, params: dict, grads: dict) -> None:
if self.v is None:
self.v = {}
for key, val in params.items():
self.v[key] = np.zeros_like(val)
for key in params.keys():
if self.nesterov: # NAG
self.v[key] = self.m * self.v[key] - self.lr * grads[key]
params[key] += (self.m * self.v[key] - self.lr * grads[key])
else: # vanila momentum
self.v[key] = self.m * self.v[key] - self.lr * grads[key]
params[key] += self.v[key]
๊ธฐ์กด์ ๋ชจ๋ฉํ ํด๋์ค์ nesterov ์ฌ๋ถ๋ฅผ ์ถ๊ฐํ์ฌ ๊ตฌํํ์์ต๋๋ค. ์ด๋ฅผ ์๊ฐํํ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ์ต๋๋ค.
๋ชจ๋ฉํ ๊ณผ ๋์ผํ๊ฒ ์ค๋ฒ์ํ ์ ํตํด ์ ์ญ ์ต์๊ฐ์ ์ง๋์น์ง๋ง, ๋ชจ๋ฉํ ์ ๋นํด ๋ ๋น ๋ฅด๊ฒ ์๋ ดํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
(์คํ ์ฌ์ด์ฆ๋ฅผ ๊ฐ์ ํ ์ตํฐ๋ง์ด์ ๋ ๋ค์ ๊ธ์์ ์ด์ด์ง๋๋ค)
์ฐธ๊ณ ์๋ฃ ์ถ์ฒ
- ํ์ฉํธ, "์์ตํด๋ ๋ชจ๋ฅด๊ฒ ๋ ๋ฅ๋ฌ๋, ๋จธ๋ฆฌ์์ ์ธ์คํจ ์์ผ๋๋ฆฝ๋๋ค.", https://www.slideshare.net/yongho/ss-79607172, 2017
- CS213N, "Convolutional Neural Networks for Visual Recognition", https://cs231n.github.io/neural-networks-3/, 2021
'๊ธฐ์ด > ์ธ๊ณต์ง๋ฅ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
ํ์ต๊ณผ ๊ด๋ จ๋ ๊ธฐ์ ๋ค (0) | 2021.09.10 |
---|---|
์ตํฐ๋ง์ด์ (Optimizer) (2/2) (0) | 2021.09.02 |
์ค์ฐจ์ญ์ ํ(Back-Propagation) (0) | 2021.08.29 |
์์ค ํจ์(Loss function) (0) | 2021.08.28 |
ํ์ฑํ ํจ์(Activation function) (0) | 2021.08.26 |