|
546 | 546 | ;; c) sample a random action from the distribution and get the associated log-probability. |
547 | 547 | ((indeterministic-act actor) [-1 0 0]) |
548 | 548 |
|
| 549 | +;; We can also query the current log-probability of a previously sampled action. |
| 550 | +(defn logprob-of-action |
| 551 | + "Get log probability of action" |
| 552 | + [actor] |
| 553 | + (fn [observation action] |
| 554 | + (let [dist (py. actor get_dist observation)] |
| 555 | + (py. dist log_prob action)))) |
549 | 556 |
|
550 | | -(let [samples (repeatedly 256 #((indeterministic-act actor) [-1 0 0])) |
551 | | - scatter (tc/dataset {:x (map (fn [sample] (first (:action sample))) samples) |
552 | | - :y (map (fn [sample] (exp (first (:logprob sample)))) samples)})] |
553 | | - (-> scatter |
554 | | - (plotly/base {:=title "Actor output for a single observation"}) |
555 | | - (plotly/layer-point {:=x :x :=y :y}))) |
556 | | - |
| 557 | +;; Here is a plot of the probability density function (PDF) actor output for a single observation. |
| 558 | +(without-gradient |
| 559 | + (let [actions (range 0.0 1.01 0.01) |
| 560 | + scatter (tc/dataset {:x actions |
| 561 | + :y (map (fn [action] |
| 562 | + (exp (first (tolist ((logprob-of-action actor) (tensor [-1 0 0]) (tensor [action])))))) |
| 563 | + actions)})] |
| 564 | + (-> scatter |
| 565 | + (plotly/base {:=title "Actor output for a single observation" :=mode :lines}) |
| 566 | + (plotly/layer-point {:=x :x :=y :y})))) |
| 567 | + |
| 568 | +;; Finally we also can also query the entropy of the distribution. |
| 569 | +;; By incorporating the entropy into the loss function later on, we can encourage exploration and prevent the probability density function from collapsing. |
| 570 | +(defn entropy-of-distribution |
| 571 | + "Get entropy of distribution" |
| 572 | + [actor observation] |
| 573 | + (let [dist (py. actor get_dist observation)] |
| 574 | + (py. dist entropy))) |
| 575 | + |
| 576 | +(without-gradient (entropy-of-distribution actor (tensor [-1 0 0]))) |
557 | 577 |
|
558 | | -;; # TODO |
559 | | -;; |
560 | | -;; * neural networks |
561 | | -;; * ppo |
562 | | -;; |
563 | 578 | ;; $\hat{A}_{T-1} = -V(S_{T-1}) + r_{T-1} + \gamma V(S_T)$ |
564 | 579 | ;; |
565 | 580 | ;; $\hat{A}_{T-2} = -V(S_{T-2}) + r_{T-2} + \gamma r_{T-1} + \gamma^2 V(S_T)$ |
|
0 commit comments