2018년 10월 15일 월요일

gradient descent

Single Variable
import matplotlib.pyplot as plt

x = [1, 2, 3]
y = [1, 2, 3]

# y = x  : True function# h(x) = wx : Hyperthesisdef cost(x, y, w):
   sum = 0   for a, b in zip(x, y):
       sum += (w*a - b)**2   return sum/len(x)

def show_cost():
    print(cost(x, y, 0))
    print(cost(x, y, 1))
    print(cost(x, y, 2))

    c = []
    for w in range(-10, 11):
      c.append(cost(x, y, w))

    print(c)
    plt.plot(range(-10, 11), c)
    plt.show()

    # 미분 : 기울기. 순간 변화량    #        x축으로 1만큼 움직일 때, y축으로 움직인 거리    #  y = x  1 = 1, 2 = 2, 3 = 3    #  y = 2x 2 = 1, 4 = 2, 6 = 3

def gradient_descent(x, y, w):
   gd = 0   for a, b in zip(x, y):
       gd += (w*a - b)*a # MSE cost w에 대해 미분시 x, 여기서는 a를 한번 더 곱해 주는 식이 된다.   return gd/len(x)

def update_gradient_descent(init=3, learning_rate=0.2, epoch=20):
   x = [1, 2, 3]
   y = [1, 2, 3]

   w = init
   for i in range(epoch):
       c = cost(x, y, w)
       g = gradient_descent(x, y, w)
       w -= learning_rate * g
       print(i, c)
   return w

w = update_gradient_descent()
print (5*w)
print (7*w)

# 문제 w 1.0으로 만드는 방법. Hyper parameter# 1. step을 올린다.# 2. learning late# 3. 초기값
# x 5 6일 대의 결과를 예측


Multi Variables
x1, x2, x3는 각각 피처. w1, w2, w3는 각각의 비중.
y는 실측

def loss(x1, x2, x3, y, w1, w2, w3):
    c = 0    for x_1, x_2, x_3, y_1 in zip(x1, x2, x3, y):
        c += (x_1*w1 + x_2*w2 + x_3*w3 - y_1) ** 2    
    return c / len(y)

def update_gradient_decent_one(x, y, w):
    d = 0    for x_1, y_1 in zip(x, y):
        d += (x_1*w - y_1)*x_1
    return d

def update_gradient_decent(x1, x2, x3, y, w1, w2, w3):
    w1 = w1 - 0.2 * update_gradient_decent_one(x1, y, w1)
    w2 = w2 - 0.2 * update_gradient_decent_one(x2, y, w2)
    w3 = w2 - 0.2 * update_gradient_decent_one(x3, y, w3)
    return w1, w2, w3

x1 = [1, 4, 5, 8] # 공부 시간
x2 = [3, 4, 7, 9] # 출석 일 수
x3 = [4, 6, 1, 3] # 학습 열정
y = [3, 5, 4, 7] # 성적

w1 = 3
w2 = 4
w3 = 5

for i in range(50):
    w1, w2, w3 = update_gradient_decent(x1, x2, x3, y, w1, w2, w3)
    print(loss(x1, x2, x3, y, w1, w2, w3))

댓글 없음:

댓글 쓰기