Skip to content

Commit 55dd5f8

Browse files
authored
Merge pull request #360 from wedesoft/ppo-draft-7
PPO draft 7
2 parents d183486 + 925482b commit 55dd5f8

1 file changed

Lines changed: 89 additions & 31 deletions

File tree

src/ppo/main.clj

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,16 @@
107107
;; A simulation step of the pendulum is implemented as follows.
108108
(defn update-state
109109
"Perform simulation step of pendulum"
110-
([{:keys [angle velocity t]} {:keys [control]} {:keys [dt motor gravitation length max-speed]}]
110+
([{:keys [angle velocity t]}
111+
{:keys [control]}
112+
{:keys [dt motor gravitation length max-speed]}]
111113
(let [gravity (pendulum-gravity gravitation length angle)
112114
motor (motor-acceleration control motor)
113115
t (+ t dt)
114116
acceleration (+ motor gravity)
115-
velocity (max (- max-speed) (min max-speed (+ velocity (* acceleration dt))))
117+
velocity (max (- max-speed)
118+
(min max-speed
119+
(+ velocity (* acceleration dt))))
116120
angle (+ angle (* velocity dt))]
117121
{:angle angle
118122
:velocity velocity
@@ -191,7 +195,9 @@
191195
;; Note that it is important that the reward function is continuous because machine learning uses gradient descent.
192196
(defn reward
193197
"Reward function"
194-
[{:keys [angle velocity]} {:keys [angle-weight velocity-weight control-weight]} {:keys [control]}]
198+
[{:keys [angle velocity]}
199+
{:keys [angle-weight velocity-weight control-weight]}
200+
{:keys [control]}]
195201
(- (+ (* angle-weight (sqr (normalize-angle angle)))
196202
(* velocity-weight (sqr velocity))
197203
(* control-weight (sqr control)))))
@@ -242,8 +248,11 @@
242248
(q/stroke-weight 1)
243249
(q/ellipse pendulum-x pendulum-y size size)
244250
(q/no-fill)
245-
(q/arc origin-x origin-y (* 2 arc-radius) (* 2 arc-radius) (to-radians -45) (to-radians 225))
246-
(q/with-translation [(+ origin-x (* (cos (to-radians tip-angle)) arc-radius)) (+ origin-y (* (sin (to-radians tip-angle)) arc-radius))]
251+
(q/arc origin-x origin-y
252+
(* 2 arc-radius) (* 2 arc-radius)
253+
(to-radians -45) (to-radians 225))
254+
(q/with-translation [(+ origin-x (* (cos (to-radians tip-angle)) arc-radius))
255+
(+ origin-y (* (sin (to-radians tip-angle)) arc-radius))]
247256
(q/with-rotation [(to-radians (if positive 225 -45))]
248257
(q/triangle 0 (if positive 10 -10) -5 0 5 0)))
249258
(when (:save config)
@@ -260,7 +269,10 @@
260269
:size [854 480]
261270
:setup #(setup PI 0.0)
262271
:update (fn [state]
263-
(let [action {:control (min 1.0 (max -1.0 (- 1.0 (/ (q/mouse-x) (/ (q/width) 2.0)))))}
272+
(let [action {:control (min 1.0
273+
(max -1.0
274+
(- 1.0 (/ (q/mouse-x)
275+
(/ (q/width) 2.0)))))}
264276
state (update-state state action config)]
265277
(when (done? state config) (async/close! done-chan))
266278
(reset! last-action action)
@@ -605,7 +617,9 @@
605617
reward (environment-reward state action)
606618
done (environment-done? state)
607619
truncate (environment-truncate? state)
608-
next-state (if (or done truncate) (environment-factory) (environment-update state action))
620+
next-state (if (or done truncate)
621+
(environment-factory)
622+
(environment-update state action))
609623
next-observation (environment-observation next-state)]
610624
(recur next-state
611625
(conj observations observation)
@@ -677,25 +691,37 @@
677691
"Compute difference between actual reward plus discounted estimate of next state and estimated value of current state"
678692
[{:keys [observations next-observations rewards dones]} critic gamma]
679693
(mapv (fn [observation next-observation reward done]
680-
(- (+ reward (if done 0.0 (* gamma (critic next-observation)))) (critic observation)))
694+
(- (+ reward
695+
(if done 0.0 (* gamma (critic next-observation))))
696+
(critic observation)))
681697
observations next-observations rewards dones))
682698

683699
;; If the reward is zero and the critic outputs constant zero, there is no difference between the expected and received reward.
684-
(deltas {:observations [[4]] :next-observations [[3]] :rewards [0] :dones [false]} (constantly 0) 1.0)
700+
(deltas {:observations [[4]] :next-observations [[3]] :rewards [0] :dones [false]}
701+
(constantly 0)
702+
1.0)
685703

686704
;; If the reward is 1.0 and the critic outputs zero for both observations, the difference is 1.0.
687-
(deltas {:observations [[4]] :next-observations [[3]] :rewards [1] :dones [false]} (constantly 0) 1.0)
705+
(deltas {:observations [[4]] :next-observations [[3]] :rewards [1] :dones [false]}
706+
(constantly 0)
707+
1.0)
688708

689709
;; If the reward is 1.0 and the difference of critic outputs is also 1.0 then there is no difference between the expected and received reward (when $\gamma=1$).
690710
(defn linear-critic [observation] (first observation))
691-
(deltas {:observations [[4]] :next-observations [[3]] :rewards [1] :dones [false]} linear-critic 1.0)
711+
(deltas {:observations [[4]] :next-observations [[3]] :rewards [1] :dones [false]}
712+
linear-critic
713+
1.0)
692714

693715
;; If the next critic value is 1.0 and discounted with 0.5 and the current critic value is 2.0, we expect a reward of 1.5.
694716
;; If we only get a reward of 1.0, the difference is -0.5.
695-
(deltas {:observations [[2]] :next-observations [[1]] :rewards [1] :dones [false]} linear-critic 0.5)
717+
(deltas {:observations [[2]] :next-observations [[1]] :rewards [1] :dones [false]}
718+
linear-critic
719+
0.5)
696720

697721
;; If the run is terminated, the current critic value is compared with the reward which in this case is the last reward received in this run.
698-
(deltas {:observations [[4]] :next-observations [[3]] :rewards [4] :dones [true]} linear-critic 1.0)
722+
(deltas {:observations [[4]] :next-observations [[3]] :rewards [4] :dones [true]}
723+
linear-critic
724+
1.0)
699725

700726
;; #### Implementation of Advantages
701727
;;
@@ -713,13 +739,18 @@
713739
(reverse (map vector deltas dones truncates)))))))
714740

715741
;; For example if using an discount factor of 0.5, the advantages approach 2.0 assymptotically when going backwards in time.
716-
(advantages {:dones [false false false] :truncates [false false false]} [1.0 1.0 1.0] 0.5 1.0)
742+
(advantages {:dones [false false false] :truncates [false false false]}
743+
[1.0 1.0 1.0]
744+
0.5
745+
1.0)
717746

718747
;; When an episode is terminated (or truncated), the accumulation of advantages starts again when going backwards in time.
719748
;; I.e. the computation of advantages does not distinguish between terminated and truncated episodes (unlike the deltas).
720749
(advantages {:dones [false false true false false true]
721750
:truncates [false false false false false false]}
722-
[1.0 1.0 1.0 1.0 1.0 1.0] 0.5 1.0)
751+
[1.0 1.0 1.0 1.0 1.0 1.0]
752+
0.5
753+
1.0)
723754

724755
;; We add the advantages to the batch of samples with the following function.
725756
(defn assoc-advantages
@@ -786,13 +817,18 @@
786817
(torch/neg
787818
(torch/min
788819
(torch/mul probability-ratios advantages)
789-
(torch/mul (torch/clamp probability-ratios (- 1.0 epsilon) (+ 1.0 epsilon)) advantages)))))
820+
(torch/mul (torch/clamp probability-ratios (- 1.0 epsilon) (+ 1.0 epsilon))
821+
advantages)))))
790822

791823
;; We can plot the objective function for a single action and a positive advantage.
792824
(without-gradient
793825
(let [ratios (range 0.0 2.01 0.01)
794826
loss (fn [ratio advantage epsilon]
795-
(toitem (torch/neg (clipped-surrogate-loss (tensor ratio) (tensor advantage) epsilon))))
827+
(toitem
828+
(torch/neg
829+
(clipped-surrogate-loss (tensor ratio)
830+
(tensor advantage)
831+
epsilon))))
796832
scatter (tc/dataset
797833
{:x ratios
798834
:y (map (fn [ratio] (loss ratio 0.5 0.2)) ratios)})]
@@ -804,7 +840,11 @@
804840
(without-gradient
805841
(let [ratios (range 0.0 2.01 0.01)
806842
loss (fn [ratio advantage epsilon]
807-
(toitem (torch/neg (clipped-surrogate-loss (tensor ratio) (tensor advantage) epsilon))))
843+
(toitem
844+
(torch/neg
845+
(clipped-surrogate-loss (tensor ratio)
846+
(tensor advantage)
847+
epsilon))))
808848
scatter (tc/dataset
809849
{:x ratios
810850
:y (map (fn [ratio] (loss ratio -0.5 0.2)) ratios)})]
@@ -819,7 +859,11 @@
819859
"Compute loss value for batch of samples and actor"
820860
[samples actor epsilon entropy-factor]
821861
(let [ratios (probability-ratios samples (logprob-of-action actor))
822-
entropy (torch/mul entropy-factor (torch/neg (torch/mean (entropy-of-distribution actor (:observations samples)))))
862+
entropy (torch/mul
863+
entropy-factor
864+
(torch/neg
865+
(torch/mean
866+
(entropy-of-distribution actor (:observations samples)))))
823867
surrogate-loss (clipped-surrogate-loss ratios (:advantages samples) epsilon)]
824868
(torch/add surrogate-loss entropy)))
825869

@@ -828,7 +872,8 @@
828872
"Normalize advantages"
829873
[batch]
830874
(let [advantages (:advantages batch)]
831-
(assoc batch :advantages (torch/div (torch/sub advantages (torch/mean advantages)) (torch/std advantages)))))
875+
(assoc batch :advantages (torch/div (torch/sub advantages (torch/mean advantages))
876+
(torch/std advantages)))))
832877

833878
;; ### Preparing Samples
834879
;;
@@ -858,7 +903,8 @@
858903
([samples]
859904
(shuffle-samples samples (random-order (python/len (first (vals samples))))))
860905
([samples indices]
861-
(zipmap (keys samples) (map #(torch/index_select % 0 (torch/tensor indices)) (vals samples)))))
906+
(zipmap (keys samples)
907+
(map #(torch/index_select % 0 (torch/tensor indices)) (vals samples)))))
862908

863909
;; Here is an example of shuffling observations:
864910
(shuffle-samples {:observations (tensor [[1] [2] [3] [4] [5] [6] [7] [8] [9] [10]])})
@@ -869,7 +915,9 @@
869915
(defn create-batches
870916
"Create mini batches from environment samples"
871917
[batch-size samples]
872-
(apply mapv (fn [& args] (zipmap (keys samples) args)) (map #(py. % split batch-size) (vals samples))))
918+
(apply mapv
919+
(fn [& args] (zipmap (keys samples) args))
920+
(map #(py. % split batch-size) (vals samples))))
873921

874922
(create-batches 5 {:observations (tensor [[1] [2] [3] [4] [5] [6] [7] [8] [9] [10]])})
875923

@@ -930,31 +978,36 @@
930978
actor-optimizer (adam-optimizer actor lr weight-decay)
931979
critic-optimizer (adam-optimizer critic lr weight-decay)]
932980
(doseq [epoch (range n-epochs)]
933-
(let [samples (sample-with-advantage-and-critic-target factory actor critic (* batch-size n-batches)
934-
batch-size gamma lambda)]
981+
(let [samples (sample-with-advantage-and-critic-target factory actor critic
982+
(* batch-size n-batches)
983+
batch-size
984+
gamma lambda)]
935985
(doseq [k (range n-updates)]
936986
(doseq [batch samples]
937987
(let [loss (actor-loss batch actor epsilon @entropy-factor)]
938988
(py. actor-optimizer zero_grad)
939989
(py. loss backward)
940990
(utils/clip_grad_norm_(py. actor parameters) 0.5)
941991
(py. actor-optimizer step)
942-
(swap! smooth-actor-loss (fn [x] (+ (* 0.999 x) (* 0.001 (toitem loss))))) ))
992+
(swap! smooth-actor-loss
993+
(fn [x] (+ (* 0.999 x) (* 0.001 (toitem loss))))) ))
943994
(doseq [batch samples]
944995
(let [loss (critic-loss batch critic)]
945996
(py. critic-optimizer zero_grad)
946997
(py. loss backward)
947998
(py. critic-optimizer step)
948-
(swap! smooth-critic-loss (fn [x] (+ (* 0.999 x) (* 0.001 (toitem loss))))))))
999+
(swap! smooth-critic-loss
1000+
(fn [x] (+ (* 0.999 x) (* 0.001 (toitem loss))))))))
9491001
(println "Epoch:" epoch
9501002
"Actor Loss:" @smooth-actor-loss
9511003
"Critic Loss:" @smooth-critic-loss
9521004
"Entropy Factor:" @entropy-factor))
9531005
(without-gradient
9541006
(doseq [input [[1 0 -1.0] [1 0 1.0] [0 -1 -1.0] [0 -1 1.0] [0 1 -1.0] [0 1 1.0] [-1 0 -1.0] [-1 0 1.0]]]
955-
(println input
956-
"->" (action (tolist (py. actor deterministic_act (tensor input))))
957-
"entropy" (toitem (entropy-of-distribution actor (tensor input))))))
1007+
(println
1008+
input
1009+
"->" (action (tolist (py. actor deterministic_act (tensor input))))
1010+
"entropy" (toitem (entropy-of-distribution actor (tensor input))))))
9581011
(swap! entropy-factor * entropy-decay)
9591012
(when (= (mod epoch checkpoint) (dec checkpoint))
9601013
(println "Saving models")
@@ -980,8 +1033,13 @@
9801033
:update (fn [state]
9811034
(let [observation (observation state config)
9821035
action (if (q/mouse-pressed?)
983-
(action (tolist (py. actor deterministic_act (tensor observation))))
984-
{:control (min 1.0 (max -1.0 (- 1.0 (/ (q/mouse-x) (/ (q/width) 2.0)))))})
1036+
(action (tolist (py. actor
1037+
deterministic_act
1038+
(tensor observation))))
1039+
{:control (min 1.0
1040+
(max -1.0
1041+
(- 1.0 (/ (q/mouse-x)
1042+
(/ (q/width) 2.0)))))})
9851043
state (update-state state action)]
9861044
(when (done? state) (async/close! done-chan))
9871045
(reset! last-action action)

0 commit comments

Comments
 (0)