Pyro Showcase


Website: https://pyro.ai

'Pyro is a universal probabilistic programming language (PPL) written in Python and supported by PyTorch on the backend. Pyro enables flexible and expressive deep probabilistic modeling, unifying the best of modern deep learning and Bayesian modeling'

If you are looking to use deep probabilistic modelling (VAE, Deep Markov Model,...) or look for a framework to develop your own probabilistic model, you might be interested in this.

Applications: Deep Generative Models, Time Series, Gaussian Process, Epidemiology.

In [ ]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/My Drive/Colab Notebooks

!pip3 install pyro-ppl
!pip3 install scanpy
!pip3 install scvi-tools

Example 1: Single Cell RNA Sequencing Analysis with VAEs

'We use a semi-supervised deep generative model of transcriptomics data to propagate labels from a small set of labeled cells to a larger set of unlabeled cells. In particular we use a dataset of peripheral blood mononuclear cells (PBMC) from 10x Genomics and (approximately) reproduce Figure 6 in reference [1].'bold text

References: [1] "Harmonization and Annotation of Single-cell Transcriptomics data with Deep Generative Models," Chenling Xu, Romain Lopez, Edouard Mehlman, Jeffrey Regier, Michael I. Jordan, Nir Yosef.

In [25]:
! python scanvi.py --num-epochs=5 --plot
#instead of 60 epochs
/usr/local/lib/python3.7/dist-packages/numba/np/ufunc/parallel.py:363: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.
  warnings.warn(problem)
INFO     File data/PurifiedPBMCDataset.h5ad already downloaded                  
INFO     Using batches from adata.obs["batch"]                                  
INFO     Using labels from adata.obs["labels"]                                  
INFO     Using data from adata.X                                                
INFO     Computing library size prior per batch                                 
INFO     Successfully registered anndata object containing 42919 cells, 21932   
         vars, 4 batches, 4 labels, and 0 proteins. Also registered 0 extra     
         categorical covariates and 0 extra continuous covariates.              
INFO     Please do not further modify adata until model is trained.             
tcmalloc: large alloc 3765198848 bytes == 0x55597ea40000 @  0x7ff484d99001 0x7ff4817e354f 0x7ff481833b58 0x7ff481837b17 0x7ff4818d6203 0x55593dc4c544 0x55593dc4c240 0x55593dcc0627 0x55593dc4dafa 0x55593dcbfd00 0x55593dcba9ee 0x55593dc4dbda 0x55593dcbc737 0x55593dcba9ee 0x55593dc4dbda 0x55593dcbfd00 0x55593dcba9ee 0x55593dc4dbda 0x55593dcbc737 0x55593dc4dafa 0x55593dcbb915 0x55593dcba9ee 0x55593dcba6f3 0x55593dd844c2 0x55593dd8483d 0x55593dd846e6 0x55593dd5c163 0x55593dd5be0c 0x7ff483b81bf7 0x55593dd5bcea
tcmalloc: large alloc 1754562560 bytes == 0x555a6961e000 @  0x7ff484d79b6b 0x7ff484d99379 0x7ff415bee26e 0x7ff415bef9e2 0x7ff459459b49 0x7ff45945a897 0x7ff459836d89 0x7ff459f9bb9a 0x7ff459f7ecbe 0x7ff459b83a05 0x7ff459442e2e 0x7ff4594447fb 0x7ff459445c33 0x7ff45981f6d6 0x7ff459823084 0x7ff459f9c79f 0x7ff459e5655e 0x7ff45b5ad5f5 0x7ff45b5adaa2 0x7ff45a2b5565 0x7ff46c7949e2 0x55593dd34409 0x55593dcbbe7a 0x55593dcba9ee 0x55593dc4dbda 0x55593dcbc737 0x55593dc4dafa 0x55593dcbb915 0x55593dcba9ee 0x55593dcba6f3 0x55593dd844c2
[Epoch 0000]  Loss: 0.37968
[Epoch 0001]  Loss: 0.25020
[Epoch 0002]  Loss: 0.17182
[Epoch 0003]  Loss: 0.13909
[Epoch 0004]  Loss: 0.12325
tcmalloc: large alloc 1754562560 bytes == 0x55597f6ca000 @  0x7ff484d79b6b 0x7ff484d99379 0x7ff415bee26e 0x7ff415bef9e2 0x7ff459459b49 0x7ff45945a897 0x7ff459836d89 0x7ff459fd09f6 0x7ff4594447fb 0x7ff459445c33 0x7ff459446b0d 0x7ff4595c0b0f 0x7ff459fec5c3 0x7ff459fec672 0x7ff459e5b0ed 0x7ff459e5b261 0x7ff45b64ba3a 0x7ff45b64c065 0x7ff45a27f21e 0x7ff46c383a35 0x7ff46c3844c6 0x55593dd618d4 0x55593dc4c6f2 0x55593dc917f0 0x55593dc93168 0x55593dd34db7 0x55593dcbc320 0x55593dc4e039 0x55593dc4e698 0x55593dcbcfe4 0x55593dcba9ee
tcmalloc: large alloc 1754562560 bytes == 0x5559e8012000 @  0x7ff484d79b6b 0x7ff484d99379 0x7ff415bee26e 0x7ff415bef9e2 0x7ff459459b49 0x7ff45945a897 0x7ff459836d89 0x7ff459fd6566 0x7ff45944282f 0x7ff459445aa0 0x7ff459446770 0x7ff459fea1b5 0x7ff459fea23c 0x7ff459db77f1 0x7ff459dbb40f 0x7ff45b6486b2 0x7ff45b648a7f 0x7ff45a255fd1 0x7ff46c46c84b 0x55593dc4c4b0 0x55593dc4c240 0x55593dcc00f3 0x55593dc4e039 0x55593dc4e698 0x55593dcbcfe4 0x55593dcba9ee 0x55593dc4e271 0x55593dc91409 0x55593dc4cc52 0x55593dcc04d9 0x55593dc4dafa

scanvi (1)-1.jpg

Fig 1: Seed cells are colored by their annotation (using known marker genes)

Fig 2-5: The posterior probability of each cell being one of the four T cell subtype obtained

Example 2: Epidemiological inference via Hamiltonian Monte Carlo

'Generate outbreak data for a population of 10000 observed for 60 days, then infer infection parameters and forecast new infections for another 30 days.' Data simulated using a Discrete SIR Model (population = Susceptible + Infected + Recovered).

Reference: https://github.com/pyro-ppl/pyro/blob/dev/examples/sir_hmc.py

In [24]:
! python sir_hmc.py -p 10000 -d 60 -f 30 --plot
Generating data...
Observed 708/1471 infections:
1 1 0 0 0 0 1 1 0 0 1 1 2 0 1 1 0 1 1 1 2 1 1 3 3 7 3 5 5 5 2 10 4 10 6 8 8 13 11 17 11 13 16 17 12 21 25 31 26 10 34 23 37 37 42 38 42 42 45 49
Running inference...
Sample: 100%|██| 300/300 [01:31,  3.29it/s, step size=6.41e-02, acc. prob=0.885]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
  I_aux[0]      2.48      0.75      2.32      1.39      3.51      9.98      1.00
  I_aux[1]      4.07      1.03      4.13      2.55      5.77     28.41      1.08
  I_aux[2]      5.01      0.82      5.00      3.57      6.16     22.15      1.07
  I_aux[3]      4.95      1.17      5.09      3.01      6.50     21.79      1.16
  I_aux[4]      4.67      1.07      4.68      2.85      6.13      7.16      1.14
  I_aux[5]      4.38      1.10      4.32      2.50      6.07     29.30      1.00
  I_aux[6]      5.06      0.90      5.02      3.71      6.52     24.36      1.05
  I_aux[7]      5.67      0.94      5.48      4.54      7.39      9.50      1.10
  I_aux[8]      4.55      1.25      4.45      2.86      6.75     25.57      1.03
  I_aux[9]      4.08      1.19      3.87      2.20      6.10     22.67      1.03
 I_aux[10]      4.81      1.10      4.63      3.25      6.68     11.05      1.04
 I_aux[11]      5.77      1.44      5.61      3.22      7.86     17.06      1.07
 I_aux[12]      7.62      1.25      7.45      5.92     10.02     28.46      1.00
 I_aux[13]      7.74      1.22      7.68      5.95      9.82     30.23      1.00
 I_aux[14]      8.79      1.09      8.68      7.29     10.57     43.36      1.00
 I_aux[15]      9.79      1.57      9.85      6.72     11.87     11.07      1.26
 I_aux[16]      9.61      1.91      9.63      6.42     12.65      6.20      1.49
 I_aux[17]     10.37      2.28     10.45      6.26     13.48      4.98      1.61
 I_aux[18]     10.97      1.99     10.69      8.08     14.46      5.08      1.50
 I_aux[19]     11.94      2.06     12.08      7.96     14.65      5.12      1.34
 I_aux[20]     13.92      2.65     13.95      9.56     18.18      9.62      1.10
 I_aux[21]     15.29      2.91     14.94     10.38     19.79     10.39      1.05
 I_aux[22]     16.70      2.85     16.51     11.84     20.19     10.26      1.04
 I_aux[23]     20.52      2.74     20.17     16.97     25.22     14.64      1.01
 I_aux[24]     23.91      2.59     23.71     19.39     27.22     15.90      1.00
 I_aux[25]     30.66      2.95     30.31     26.63     35.81     15.16      1.12
 I_aux[26]     32.97      3.00     32.97     27.85     37.55     15.36      1.09
 I_aux[27]     37.56      3.12     37.01     33.65     44.14     17.93      1.04
 I_aux[28]     41.95      3.57     41.79     37.09     47.52     19.97      1.02
 I_aux[29]     46.44      3.92     46.14     38.94     51.50     15.85      1.01
 I_aux[30]     48.21      4.62     48.07     41.64     56.17     12.61      1.00
 I_aux[31]     57.15      4.94     57.56     49.66     64.97     10.96      1.04
 I_aux[32]     59.84      5.35     59.58     51.21     67.82     12.36      1.07
 I_aux[33]     67.60      6.23     67.34     57.41     77.22     12.82      1.08
 I_aux[34]     70.93      6.03     70.73     61.16     80.49     16.76      1.07
 I_aux[35]     76.40      5.47     75.68     66.88     84.43     23.79      1.04
 I_aux[36]     82.92      5.03     82.24     75.99     91.12     21.68      1.02
 I_aux[37]     94.27      5.02     93.95     86.79    102.92     19.12      1.00
 I_aux[38]    102.94      5.79    103.41     94.07    112.32     15.35      1.00
 I_aux[39]    116.25      6.24    116.38    105.12    125.54     12.45      1.00
 I_aux[40]    124.01      6.50    124.66    111.46    133.53     12.56      1.01
 I_aux[41]    134.06      7.33    134.02    122.37    146.24     12.44      1.06
 I_aux[42]    146.18      7.76    146.51    132.72    157.70     13.18      1.07
 I_aux[43]    159.32      8.46    159.20    146.05    173.46     12.92      1.06
 I_aux[44]    168.71      9.67    167.68    154.83    186.64     12.04      1.08
 I_aux[45]    185.78     10.41    184.46    170.41    203.78      9.17      1.15
 I_aux[46]    205.49     11.73    204.05    187.21    225.64      8.71      1.19
 I_aux[47]    228.54     13.30    227.40    212.30    256.50      7.25      1.28
 I_aux[48]    246.12     13.91    244.64    220.90    268.35      6.90      1.32
 I_aux[49]    250.80     13.97    248.92    229.14    274.00      5.89      1.37
 I_aux[50]    277.93     14.61    276.39    256.11    303.55      5.82      1.37
 I_aux[51]    294.99     16.38    293.10    269.14    319.73      5.54      1.37
 I_aux[52]    324.63     17.46    323.57    297.04    350.89      5.29      1.35
 I_aux[53]    353.03     19.44    350.69    321.82    382.07      4.54      1.43
 I_aux[54]    384.58     20.78    382.88    356.25    418.31      4.56      1.40
 I_aux[55]    411.26     22.23    411.33    375.48    441.19      4.23      1.46
 I_aux[56]    439.62     24.43    438.69    398.34    477.19      4.17      1.50
 I_aux[57]    468.47     26.30    466.51    428.35    509.93      3.90      1.60
 I_aux[58]    498.41     29.22    496.96    459.90    549.66      3.97      1.58
 I_aux[59]    530.38     32.84    529.58    485.21    587.78      4.02      1.63
        R0      1.56      0.06      1.55      1.47      1.64     54.25      1.10
  S_aux[0]   9997.40      0.76   9997.40   9996.17   9998.64     15.28      1.00
  S_aux[1]   9995.71      0.94   9995.66   9994.24   9997.23     22.73      1.01
  S_aux[2]   9994.43      0.87   9994.34   9993.08   9995.73     14.47      1.08
  S_aux[3]   9993.92      1.00   9994.10   9992.38   9995.58     12.83      1.00
  S_aux[4]   9993.60      1.01   9993.75   9992.18   9995.45     12.67      1.01
  S_aux[5]   9993.58      0.93   9993.71   9992.04   9994.90     12.93      1.03
  S_aux[6]   9992.37      0.91   9992.47   9990.68   9993.60     16.76      1.01
  S_aux[7]   9991.22      0.94   9991.26   9989.84   9992.80     14.10      1.00
  S_aux[8]   9991.04      1.00   9991.09   9989.53   9992.66     11.52      1.00
  S_aux[9]   9990.91      1.02   9990.86   9989.15   9992.44     11.06      1.00
 S_aux[10]   9989.63      1.09   9989.63   9987.92   9991.34     10.84      1.00
 S_aux[11]   9988.24      1.37   9988.30   9986.05   9990.56      7.01      1.08
 S_aux[12]   9985.75      1.38   9985.72   9983.59   9987.94      7.09      1.06
 S_aux[13]   9984.83      1.48   9984.89   9982.08   9986.80      7.01      1.06
 S_aux[14]   9983.01      1.51   9983.13   9980.39   9985.23      7.63      1.05
 S_aux[15]   9980.82      1.47   9980.70   9978.60   9983.14      9.43      1.00
 S_aux[16]   9979.83      1.65   9979.86   9977.54   9982.70      8.15      1.03
 S_aux[17]   9978.02      1.76   9978.04   9975.41   9980.83      7.17      1.13
 S_aux[18]   9976.14      1.77   9976.03   9973.02   9978.88     12.79      1.11
 S_aux[19]   9974.10      1.77   9973.83   9971.39   9977.26     16.95      1.16
 S_aux[20]   9970.70      1.35   9970.76   9968.49   9972.94     29.13      1.07
 S_aux[21]   9968.20      1.60   9968.16   9966.17   9971.13     28.43      1.05
 S_aux[22]   9965.52      1.63   9965.73   9962.63   9967.90     24.28      1.04
 S_aux[23]   9960.10      1.75   9960.10   9957.32   9962.79     52.79      1.02
 S_aux[24]   9954.74      1.63   9954.80   9952.79   9958.21     45.85      1.01
 S_aux[25]   9945.28      1.78   9945.39   9942.51   9948.36     33.70      1.04
 S_aux[26]   9939.44      1.79   9939.56   9936.71   9942.34     34.02      1.04
 S_aux[27]   9931.07      2.35   9931.42   9927.16   9934.75     21.29      1.05
 S_aux[28]   9922.29      2.75   9922.42   9918.13   9926.59     20.42      1.05
 S_aux[29]   9912.74      3.22   9912.84   9907.37   9917.62     18.73      1.04
 S_aux[30]   9905.50      3.64   9905.34   9899.06   9910.70     16.24      1.01
 S_aux[31]   9890.52      4.19   9890.01   9882.97   9896.66     12.72      1.00
 S_aux[32]   9880.39      4.61   9879.97   9874.15   9888.51     11.60      1.01
 S_aux[33]   9864.63      5.31   9864.07   9856.76   9873.34     10.42      1.01
 S_aux[34]   9852.21      5.72   9852.23   9842.25   9860.11     10.42      1.00
 S_aux[35]   9837.20      5.77   9838.14   9827.68   9845.91     10.80      1.00
 S_aux[36]   9821.31      5.99   9821.98   9811.17   9830.51     11.28      1.00
 S_aux[37]   9799.70      6.28   9800.29   9789.93   9809.13     10.87      1.02
 S_aux[38]   9779.01      6.94   9779.19   9768.69   9792.88     10.32      1.04
 S_aux[39]   9751.99      7.45   9752.49   9740.92   9766.44      9.40      1.08
 S_aux[40]   9729.62      7.77   9730.23   9718.26   9744.69      8.15      1.16
 S_aux[41]   9704.29      8.64   9704.81   9687.33   9716.32      7.07      1.23
 S_aux[42]   9675.34      9.50   9675.74   9658.27   9688.11      6.16      1.32
 S_aux[43]   9644.10     10.06   9644.10   9629.96   9661.24      5.86      1.40
 S_aux[44]   9615.95     10.55   9615.83   9599.18   9632.25      4.99      1.57
 S_aux[45]   9578.00     11.45   9578.79   9558.27   9595.19      4.60      1.72
 S_aux[46]   9534.93     13.36   9535.45   9510.82   9553.77      4.47      1.86
 S_aux[47]   9484.82     14.72   9484.94   9461.15   9507.63      4.13      1.99
 S_aux[48]   9437.22     16.19   9437.81   9414.63   9465.15      3.85      2.11
 S_aux[49]   9402.43     17.59   9403.89   9374.93   9427.60      3.66      2.19
 S_aux[50]   9344.05     19.72   9346.50   9314.32   9371.84      3.59      2.15
 S_aux[51]   9293.30     22.22   9296.66   9258.55   9322.21      3.49      2.22
 S_aux[52]   9227.66     25.07   9233.32   9187.92   9259.82      3.54      2.08
 S_aux[53]   9158.50     28.52   9163.67   9117.09   9201.48      3.45      2.06
 S_aux[54]   9082.78     32.03   9090.60   9041.17   9134.18      3.50      1.98
 S_aux[55]   9007.95     35.26   9015.33   8957.47   9062.69      3.41      2.02
 S_aux[56]   8927.66     40.01   8934.04   8866.73   8986.92      3.47      1.97
 S_aux[57]   8844.19     44.97   8851.60   8775.30   8909.05      3.43      1.98
 S_aux[58]   8755.67     50.73   8762.55   8681.01   8836.91      3.45      1.94
 S_aux[59]   8661.00     56.82   8668.57   8574.77   8746.93      3.42      1.95
       rho      0.53      0.03      0.53      0.49      0.58      4.70      1.57

Number of divergences: 0
R0: truth = 1.5, estimate = 1.56 ± 0.0567
rho: truth = 0.5, estimate = 0.531 ± 0.0281
/usr/local/lib/python3.7/dist-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
NumExpr defaulting to 2 threads.
/usr/local/lib/python3.7/dist-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
Forecasting 30 steps ahead...
Median prediction of new infections (starting on day 0):
1 2 1 0 0 0 1 1 0 0 1 1 2 1 2 2 1 2 2 2 3 2 3 5 5 9 6 8 9 10 7 15 10 15 12 15 16 21 20 27 22 25 29 31 28 38 43 50 47 35 58 51 65 69 75 75 80 83 88 94 100 106 113 118 123 124 132 134 140 144 145 146 150 151 151 153 153 151 150 147 145 142 139 133 131 127 124 116 112 107

sir_hmc-1.jpg