let us not love with words or tongue but actions and truth.

IT/강화학습

Twin Delayed DDPG (TD3), Soft Actor-Critic (SAC)

sarah0518 2022. 12. 6. 19:10

Twin Delayed DDPG (TD3)

TD3와 DDPG의 차이점

1. Clipped double Q learning

기존의 DDPG는 4개의 layer를 사용했다면,

TD3는 main/target critic network를 하나씩 더 사용하여 총 6개를 사용함

2. Delayed policy updates

the actor network parameter is delayed and updated only after two steps of the episode.

Crtic을 많이 업데이트하고, actor를 중간중간 가끔식 update하자라는 뜻임

(최소 2번에 1번씩 actor를 update해주자)

3. Target policy smoothing

target action에도 noise를 추가해주자.

 

 

 

[Clipped double Q learning] RL에서의 overestimation

실제 Q값보다 너무 높게 예측하는 것

원인: 아래와 같은 수식에서 maximum값을 가져오기 때문에

해결책: 2개의 network를 추가하여 그 중에 작은 값을 쓰자.

We compute the Q value of the next state-action pair using two separate target critic networks

we use the minimum of the two in computing the target value.

 

 

 

[Delayed policy updates] 

통상적으로 2번에 1번씩 actor network를 업데이트 하면 됨.

 θ1,  θ2중에 아무거나 선택해서 loss를 계산함

 

 

[Target policy smoothing] Target policy smoothing이 필요한 이유

target값 계산할 때, noise를 추가하는 것임

 

continuous action space에서 평균이 95일 때,

유사한 action(95, 95.1, 94.9)이 뽑힐 때마다 action value의 값은 차이가 크다는 단점이 있음. 

해결책: target에 noise를 추가해 줌으로써, “similar actions have similar value”를 얻을 수 있음

 

 

 

TD3: 6개의 network

 

 

Soft Actor Critic (SAC)

SAC 특징1.

SAC의 취지는 DDPG와 달리 stochastic policy를 쓰면서 off-policy 빙밥을 쓰고 싶은 것임

(stochastic policy optimization과 DDPG-style 을 합친 것임)

** off-policy : training하는 것과 target하는 policy가 다른 것

 

 

SAC특징2: Entropy를 object에 사용함

Entropy: 무질서함(불확실성)

무질서함을 감소시키기 위해서는 에너지가 필요함

Entropy 수식: 

 

 

SAC의 critic network 의 2가지 network

1. value network(value function) - main(ψ) & targt(ψ') value network

2. Q network (Q function) - only main Q network

** 참고로 Actor-critic methods에서는 actor/critic network가 있었음 

 

 

 

SAC 목적함수

SAC 목적함수는 오른쪽과 같이 기존 목적함수에서 entropy를 추가한 것임

불확실성이 클 수록 return값이 커진다는 뜻임

alpha(temperature)가 값이 커질 수록 더 많은 exploration을 취함 (=Entropy사용목적)

 

 

 

SAC-Critic: value function

St라고 하는 state에서의 stochastic policy의 entropy임

아래에서는 확률1이 더 entropy가 커지므로, return값이 커지게 됨

 

 

 

V(s)와 Q(s,a)와의 차이

v(s): s라는 state에서 시작해서 끝까지 갔을 때의 return

Q(s,a): action이 추가되었다는 것이 v(s)와의 차이점임

 

 

 

SAC Critic - Q function

Q function도 오른쪽과 같이 기존 목적함수에서 entropy를 추가한 것임

단, t=0일 때는 entropy가 없음

왜냐하면 t=0일 때의 초기 action은 정해져있음

그러므로 t=1일 때만 entroy항을 추가하면 됨

 

 

 

The Soft State-Value Function 

From the equation of V and Q, we can get the relation between the two.

위의 식에 아래식을 적용하여

아래식과 같이 표현할 수 있음 ( = Soft State-Value Function ) 

 

 

 

SAC에서 사용하는 5개의 network

main: prediction을 위한 network임.

target network와 같은 network를 사용하므로 main과 target을 분리함

 

 

 

State의 value 계산방법

아래의 수식을 여러 episode에 대해 계산하여 평균값을 취함

또, Q(s,a)는 Q network θ에서 계산되고,

ㅠ(s,a)는 actor network에서 계산 된 값임

 

** 추가로, TD3의 clipped double Q learning을 적용하여

target state value 를 아래와 같이 계산함 (main Q network 2개 : j=1,2)

Clipped double Q learning을 활용하여 min값을 취함

 

 

 

Components of SAC: Critic

1. value network - loss function

위의 loss를 최소화 하는 방향으로 training 함

 

2. Q network (2개의 main network를 사용)

위의 loss를 최소화 하는 방향으로 training 함

 

 

 

* Q network에서 target network가 없는 이유: 

보통 training하는 main network와 target network를 분리해야 하지만,

SAC에서는 value network의 value값(빨간색밑줄)을 가져오기 때문에

분리할 필요가 없음

 

 

 

Components of SAC: Actor

Q값이 높아지는 방향으로 training 함 (파란색 증가)

Entropy도 증가하는 방향으로 training 함 (빨강색 증가)

 

 

 

코드 설명

 

Hard replacement vs. Soft replacement

 

 

Q network (j=1, 2)을 구성한 부분

 

 

아래 코드를 보면 continuous한 state를 알 수있음 

(mean과 std를 사용하여 Gaussian distribution에서의 policy를 취함)

log std 값이 너무 작아지거나 커지는 것을 방지하기 위해 min, max를 제한함

 

 

위에서 생성된 mean과 std로 분포를 만들고 그 안에서 action을 추출하는 코드임

추가로 노란색 줄은, 모든 액션값이 -1~1까지의 값이 되도록 tanh함수를 씌워준 것임

(왜냐하면 action의 범위가 -2~2까지이므로

파란색 줄은 아래 수식의 파란색 영역을 계산하기 위해 따로 기록해두려고 하는 것임

 

각 network를 표시해 보면 아래와 같음

 

 

 

 

 

reparameterization trick

Back propagation시, 미분이 가능하게 하기위해서 reparameterization form을 사용함

state를 넣어서 나온 mean과 std로 분포를 정하고 

그 분포에서 action을 sampling 함

x_t는 action을 의미함

log_prob는 x_t가 sampling 된 확률을 위에서 구한 분포를 통해서 알 수 있음

 

 

위에서 return된 값을 사용하여 policy loss를 아래와 같이 구함

 

'IT > 강화학습' 카테고리의 다른 글

Proximal Policy Optimization (PPO)  (0) 2022.12.13
Multi-Armed Bandit (MAB)  (0) 2022.12.04
Actor Critic methods - DDPG  (0) 2022.11.29
Policy Gradient Methods  (0) 2022.11.15
Deep Q Network  (0) 2022.11.01