|
579 | 579 |
|
580 | 580 | (without-gradient (entropy-of-distribution actor (tensor [-1 0 0]))) |
581 | 581 |
|
| 582 | +;; ## Proximal Policy Optimization |
| 583 | +;; |
| 584 | +;; ### Sampling data |
| 585 | +;; |
| 586 | +;; In order to perform optimization, we sample the environment using the current policy (indeterministic action using actor). |
| 587 | +(defn sample-environment |
| 588 | + "Collect trajectory data from environment" |
| 589 | + [environment-factory policy size] |
| 590 | + (loop [state (environment-factory) |
| 591 | + observations [] |
| 592 | + actions [] |
| 593 | + logprobs [] |
| 594 | + next-observations [] |
| 595 | + rewards [] |
| 596 | + dones [] |
| 597 | + truncates [] |
| 598 | + i size] |
| 599 | + (if (pos? i) |
| 600 | + (let [observation (environment-observation state) |
| 601 | + sample (policy observation) |
| 602 | + action (:action sample) |
| 603 | + logprob (:logprob sample) |
| 604 | + reward (environment-reward state action) |
| 605 | + done (environment-done? state) |
| 606 | + truncate (environment-truncate? state) |
| 607 | + next-state (if (or done truncate) (environment-factory) (environment-update state action)) |
| 608 | + next-observation (environment-observation next-state)] |
| 609 | + (recur next-state |
| 610 | + (conj observations observation) |
| 611 | + (conj actions action) |
| 612 | + (conj logprobs logprob) |
| 613 | + (conj next-observations next-observation) |
| 614 | + (conj rewards reward) |
| 615 | + (conj dones done) |
| 616 | + (conj truncates truncate) |
| 617 | + (dec i))) |
| 618 | + {:observations observations |
| 619 | + :actions actions |
| 620 | + :logprobs logprobs |
| 621 | + :next-observations next-observations |
| 622 | + :rewards rewards |
| 623 | + :dones dones |
| 624 | + :truncates truncates}))) |
| 625 | + |
| 626 | +;; Here for example we are sampling 3 consecutives states of the pendulum. |
| 627 | +(sample-environment pendulum-factory (indeterministic-act actor) 3) |
| 628 | + |
582 | 629 | ;; $\hat{A}_{T-1} = -V(S_{T-1}) + r_{T-1} + \gamma V(S_T)$ |
583 | 630 | ;; |
584 | 631 | ;; $\hat{A}_{T-2} = -V(S_{T-2}) + r_{T-2} + \gamma r_{T-1} + \gamma^2 V(S_T)$ |
|
0 commit comments