Skip to content

Commit b040d09

Browse files
committed
Sampling environment
1 parent 677d070 commit b040d09

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

src/ppo/main.clj

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,53 @@
579579

580580
(without-gradient (entropy-of-distribution actor (tensor [-1 0 0])))
581581

582+
;; ## Proximal Policy Optimization
583+
;;
584+
;; ### Sampling data
585+
;;
586+
;; In order to perform optimization, we sample the environment using the current policy (indeterministic action using actor).
587+
(defn sample-environment
588+
"Collect trajectory data from environment"
589+
[environment-factory policy size]
590+
(loop [state (environment-factory)
591+
observations []
592+
actions []
593+
logprobs []
594+
next-observations []
595+
rewards []
596+
dones []
597+
truncates []
598+
i size]
599+
(if (pos? i)
600+
(let [observation (environment-observation state)
601+
sample (policy observation)
602+
action (:action sample)
603+
logprob (:logprob sample)
604+
reward (environment-reward state action)
605+
done (environment-done? state)
606+
truncate (environment-truncate? state)
607+
next-state (if (or done truncate) (environment-factory) (environment-update state action))
608+
next-observation (environment-observation next-state)]
609+
(recur next-state
610+
(conj observations observation)
611+
(conj actions action)
612+
(conj logprobs logprob)
613+
(conj next-observations next-observation)
614+
(conj rewards reward)
615+
(conj dones done)
616+
(conj truncates truncate)
617+
(dec i)))
618+
{:observations observations
619+
:actions actions
620+
:logprobs logprobs
621+
:next-observations next-observations
622+
:rewards rewards
623+
:dones dones
624+
:truncates truncates})))
625+
626+
;; Here for example we are sampling 3 consecutives states of the pendulum.
627+
(sample-environment pendulum-factory (indeterministic-act actor) 3)
628+
582629
;; $\hat{A}_{T-1} = -V(S_{T-1}) + r_{T-1} + \gamma V(S_T)$
583630
;;
584631
;; $\hat{A}_{T-2} = -V(S_{T-2}) + r_{T-2} + \gamma r_{T-1} + \gamma^2 V(S_T)$

0 commit comments

Comments
 (0)