Loading [MathJax]/jax/output/CommonHTML/jax.js

Linear Regression을 위한 Backpropagation은 다음과 같은 모델들을 사용하여 순서대로 진행될 것이다.

  1. y=θx
  2. y=θ1x+θ0
  3. y=θ2x2+θ1x1+θ0

그리고 각 단계에서도 첫 번째로는 하나의 sample에 대한 backpropagation, 두 번째로는 2개의 sample에 대한 backpropagation, 마지막으로 Vectorized form의 backpropagation으로 진행될 것이다.

 

그러면 이번 포스트에서는 첫 번째 model부터 살펴보도록하자.

 

Backpropagation for One Sample


먼저 forward-backward propagation을 위한 node들을 먼저 만들어보면 다음과 같다.

 

 

즉, z1 node는 prediction value가 될 것이고, L는 loss를 구하게 된다. 그리고 각 node에서의 partial derivative를 구하면 다음과 같이 표시할 수 있다.

 

 

그리고 Chain rule에 의해 θ까지의 backpropagation을 계산하면 다음과 같다.

 

Lz1=Lz2z2z1=2z2(1)=2z2=2(yz1)

 

Lθ=Lz2z2z1z1θ=Lz1z1θ=2x(yz1)=2x(yθx)

 

따라서 Backpropagation은 다음과 같이 일어난다.

 

θ=θlrLθ=θlr(2x(yθx))=θ+lr2x(yθx)

 

Backpropagation for Two Samples


 

우리가 Learning 단계에서 mini-batch를 사용한다면 loss가 아닌 cost를 parameter update에 사용하게 된다. 

 

즉, loss들의 평균을 구하는 node가 추가된다. 그리고 mini-batch size가 2인 linear regression model은 다음과 같다.

 

그리고 각 node에서의 partial derivative를 표시하면

이 되고, Chain rule을 이용하여 backpropagation을 계산하면 다음과 같다.

 

Jz(1)2=JL(1)L(1)z(1)2=122z(1)2

 

Jz(2)2=JL(2)L(2)z(2)2=122z(2)2

 

Jz(1)1=JL(1)L(1)z(1)2z(1)2z(1)1=Jz(1)2z(1)2z(1)1=122z(1)2=122(y(1)z(1)1)

 

Jz(2)1=JL(2)L(2)z(2)2z(2)2z(2)1=Jz(2)2z(2)2z(2)1=122z(2)2=122(y(2)z(2)1)

 

Jθ=JL(1)L(1)z(1)2z(1)2z(1)1z(1)1θ=Jz(1)1z(1)1θ=122x(1)(y(1)z(1)1)=122x(1)(y(1)θx(1))

 

Jθ=JL(2)L(2)z(2)2z(2)2z(2)1z(2)1θ=Jz(2)1z(2)1θ=122x(2)(y(2)z(2)1)=122x(2)(y(2)θx(2))

 

이를 visualization하면 다음과 같다.

 

따라서 parameter update는 다음과 같이 된다.

 

θ=θlr(122x(1)(y(1)θx(1))122x(2)(y(2)θx(2)))=θ+lr122i=12x(i)(y(i)θx(i))

 

 

Backpropagation Vectorized form


위의 두 번째 단계는 실제 프로그래밍을 하기에도 힘들고, 같은 연산이 반복될 때 사용되는 vectorization을 이용하지 못한 모양이다. 따라서 mini-batch 사이즈가 임의의 n개일 때는 vectorization form을 이용하게 되고, 이 포스트에서는 3개의 mini-batch에 대한 backpropagation을 다룬다.

 

먼저 forward propgation을 포함한 model은 다음과 같다.

 

 

그리고 각 node에서 Jacobian을 이용한 partial derivative를 표시하면 다음과 같다.

 

 

그리고 각 node에서 실제로 backpropagationdmf 계산하면 다음과 같다.

따라서 Backpropagation은 다음과 같이 전체적으로 표시할 수 있다.

 

+ Recent posts