Skip to content

Commit

Permalink
Bump version to 1.8.2 (#3135)
Browse files Browse the repository at this point in the history
* Bump version to 1.8.2

* Run scripts/update_version.py

* Update to newer seaborn.kdeplot args

* Fix sns.scatterplot
  • Loading branch information
fritzo authored Sep 6, 2022
1 parent 7102cf5 commit ad53c72
Show file tree
Hide file tree
Showing 67 changed files with 95 additions and 95 deletions.
2 changes: 1 addition & 1 deletion examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def per_param_optim_args(param_name):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="Pyro AIR example", argument_default=argparse.SUPPRESS
)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--num-chains", nargs="?", default=4, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=200, type=int)
parser.add_argument("--jit", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/scoping_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=200, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/autoname/tree_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=100, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/cevae/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="Causal Effect Variational Autoencoder"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="Regional compartmental epidemiology modeling using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="Compartmental epidemiology modeling using HMC"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/forecast/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def transform(pred, truth):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example")
parser.add_argument("--train-window", default=2160, type=int)
parser.add_argument("--test-window", default=336, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="MAP Baum-Welch learning Bach Chorales"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/gp/sv-dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Pyro GP MNIST Example")
parser.add_argument(
"--data-dir",
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/oed/ab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def main(num_vi_steps, num_bo_steps, seed):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="A/B test experiment design using VI")
parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int)
parser.add_argument("--num-bo-steps", nargs="?", default=5, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/timeseries/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="contrib.timeseries example usage")
parser.add_argument("-n", "--num-steps", default=300, type=int)
parser.add_argument("-s", "--seed", default=0, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/cvae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def do_evaluation():

# parse command-line arguments and execute the main method
if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")

parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", type=int, default=5000)
Expand Down
2 changes: 1 addition & 1 deletion examples/eight_schools/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Eight Schools MCMC")
parser.add_argument(
"--num-samples",
Expand Down
2 changes: 1 addition & 1 deletion examples/eight_schools/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Eight Schools SVI")
parser.add_argument(
"--lr", type=float, default=0.01, help="learning rate (default: 0.01)"
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="MAP Baum-Welch learning Bach Chorales"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/inclined_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=500, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="Amortized Latent Dirichlet Allocation"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Demonstrate the use of an LKJ Prior")
parser.add_argument("--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--n", nargs="?", default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def guide(data):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-b", "--backend", default="minipyro")
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
16 changes: 8 additions & 8 deletions examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main(args):
ylim=ylim,
title="Posterior \n(vanilla HMC)",
)
sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)
sns.kdeplot(x=vanilla_samples[:, 0], y=vanilla_samples[:, 1], ax=ax2)

# 3(a). Fit a diagonal normal autoguide
logging.info("\nFitting a DiagNormal autoguide ...")
Expand All @@ -157,15 +157,15 @@ def main(args):
ylim=ylim,
title="Posterior \n(DiagNormal autoguide)",
)
sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)
sns.kdeplot(x=guide_samples[:, 0], y=guide_samples[:, 1], ax=ax3)

# 3(b). Draw samples using NeuTra HMC
logging.info("\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...")
neutra = NeuTraReparam(guide.requires_grad_(False))
neutra_model = poutine.reparam(model, config=lambda _: neutra)
mcmc = run_hmc(args, neutra_model)
zs = mcmc.get_samples()["x_shared_latent"]
sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
sns.scatterplot(x=zs[:, 0], y=zs[:, 1], alpha=0.2, ax=ax4)
ax4.set(
xlabel="x0",
ylabel="x1",
Expand All @@ -182,7 +182,7 @@ def main(args):
ylim=ylim,
title="Posterior (transformed) \n(DiagNormal + NeuTra HMC)",
)
sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)
sns.kdeplot(x=samples[:, 0], y=samples[:, 1], ax=ax5)

# 4(a). Fit a BNAF autoguide
logging.info("\nFitting a BNAF autoguide ...")
Expand All @@ -201,15 +201,15 @@ def main(args):
ylim=ylim,
title="Posterior \n(BNAF autoguide)",
)
sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)
sns.kdeplot(x=guide_samples[:, 0], y=guide_samples[:, 1], ax=ax6)

# 4(b). Draw samples using NeuTra HMC
logging.info("\nDrawing samples using BNAF autoguide + NeuTra HMC ...")
neutra = NeuTraReparam(guide.requires_grad_(False))
neutra_model = poutine.reparam(model, config=lambda _: neutra)
mcmc = run_hmc(args, neutra_model)
zs = mcmc.get_samples()["x_shared_latent"]
sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
sns.scatterplot(x=zs[:, 0], y=zs[:, 1], alpha=0.2, ax=ax7)
ax7.set(
xlabel="x0",
ylabel="x1",
Expand All @@ -226,13 +226,13 @@ def main(args):
ylim=ylim,
title="Posterior (transformed) \n(BNAF + NeuTra HMC)",
)
sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)
sns.kdeplot(x=samples[:, 0], y=samples[:, 1], ax=ax8)

plt.savefig(os.path.join(os.path.dirname(__file__), "neutra.pdf"))


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(
description="Example illustrating NeuTra Reparametrizer"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=10, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/hyperbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=10, type=int)
parser.add_argument("--price", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/schelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=10, type=int)
parser.add_argument("--depth", default=2, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/schelling_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=10, type=int)
parser.add_argument("--depth", default=3, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/rsa/semantic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def is_all_qud(world):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-samples", default=10, type=int)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion examples/scanvi/scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
# Parse command line arguments
parser = argparse.ArgumentParser(
description="single-cell ANnotation using Variational Inference"
Expand Down
2 changes: 1 addition & 1 deletion examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="SIR epidemiology modeling using HMC")
parser.add_argument("-p", "--population", default=10, type=int)
parser.add_argument("-m", "--min-observations", default=3, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_gamma_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Krylov KIT")
parser.add_argument("--num-data", type=int, default=750)
parser.add_argument("--num-steps", type=int, default=1000)
Expand Down
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Distributed training via Horovod")
parser.add_argument("-o", "--outfile")
parser.add_argument("-s", "--size", default=1000000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/toy_mixture_model_discrete_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_true_pred_CPDs(CPD, posterior_param):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="Toy mixture model")
parser.add_argument("-n", "--num-steps", default=4000, type=int)
parser.add_argument("-o", "--num-obs", default=10000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def main(args):
)

if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")

parser = argparse.ArgumentParser(description="SS-VAE\n{}".format(EXAMPLE_RUN))

Expand Down
2 changes: 1 addition & 1 deletion examples/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/vae/vae_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def main(args):


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.1")
assert pyro.__version__.startswith("1.8.2")
parser = argparse.ArgumentParser(description="VAE using MNIST dataset")
parser.add_argument("-n", "--num-epochs", nargs="?", default=10, type=int)
parser.add_argument("--batch_size", nargs="?", default=128, type=int)
Expand Down
Loading

0 comments on commit ad53c72

Please sign in to comment.