---
title: "๐ฒ MCMC + Diffusion Sampling"
date: "2024-04-15"
lightbox: auto
toc: true
image: "./assets/qdist.svg"
open-graph:
title: "๐ฒ MCMC + Diffusion Sampling"
image: "./assets/qdist.svg"
twitter-card:
site: "saforem2"
creator: "saforem2"
title: "๐ฒ MCMC + Diffusion Sampling"
image: "./assets/qdist.svg"
# format:
# html:
# output-file: "diffusion.html"
---
## Denoising Diffusion Probabilistic Models
## Imports / Setup {#imports--setup}
::: {#f924fcb4-d11e-44c4-acc5-caf9f50ea803 .cell .code execution_count="1" execution="{\"iopub.execute_input\":\"2023-09-21T17:23: 25.674067Z\",\"iopub.status.busy\":\"2023-09-21T17:23: 25.673461Z\",\"iopub.status.idle\":\"2023-09-21T17:23: 26.976641Z\",\"shell.execute_reply\":\"2023-09-21T17:23: 26.976187Z\",\"shell.execute_reply.started\":\"2023-09-21T17:23: 25.674029Z\"}" jukit_cell_id="ZlqRKELWeh"}
``` python
from __future__ import absolute_import, print_function, annotations, division
from dataclasses import dataclass
import sys
import os
import math
import numpy as np
import scipy
import time
from random import randrange
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
```
:::
::: {#707c54d9-1cee-4452-a9e8-beab67f6bb5e .cell .code execution_count="2" execution="{'iopub.execute_input':'2023-09-21T17:23: 27.398552Z','iopub.status.busy':'2023-09-21T17:23: 27.398037Z','iopub.status.idle':'2023-09-21T17:23: 29.370420Z','shell.execute_reply':'2023-09-21T17:23: 29.370048Z','shell.execute_reply.started':'2023-09-21T17:23: 27.398512Z'}"}
``` python
from ezpz.dist import setup_torch
port = np.random.randint(5000 , 6000 )
print (f"Using port: { port} " )
RANK = setup_torch(
backend= "DDP" ,
port= f" { port} "
)
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
Using port: 5561
Using DDP for distributed training
Global Rank: <span style="color: #008080; text-decoration-color: #008080; font-weight: bold">0</span> <span style="color: #800080; text-decoration-color: #800080">/</span> <span style="color: #008080; text-decoration-color: #008080; font-weight: bold">0</span>
</pre>
```
:::
::: {#b55dfeab-27a1-4322-af80-44380adbaadf .cell .code execution_count="3" execution="{\"iopub.execute_input\":\"2023-09-21T17:23: 29.371628Z\",\"iopub.status.busy\":\"2023-09-21T17:23: 29.371422Z\",\"iopub.status.idle\":\"2023-09-21T17:23: 30.756041Z\",\"shell.execute_reply\":\"2023-09-21T17:23: 30.755862Z\",\"shell.execute_reply.started\":\"2023-09-21T17:23: 29.371616Z\"}" jukit_cell_id="ZlqRKELWeh"}
``` python
% matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg' )
from l2hmc.main import build_experiment
from l2hmc.utils.rich import get_console
from l2hmc.utils.plot_helpers import set_plot_style
import opinionated
from l2hmc.diffusion.diffusion import PureMH, MH_Diffusion
from l2hmc.utils.plot_helpers import set_plot_style
from pandas.io.formats import style
import scipy
import time
from random import randrange
from l2hmc.diffusion.diffusion import PureMH, MH_Diffusion
set_plot_style()
console = get_console()
print (console.is_jupyter)
if console.is_jupyter:
console.is_jupyter = False
print (console.is_jupyter)
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
Using device: cpu
Failed to download font: Source Sans Pro, skipping!
Failed to download font: Titillium WebRoboto Condensed, skipping!
True
False
</pre>
```
:::
::: {#a6523472-488a-49e9-bfd6-6ca8fb953650 .cell .code execution_count="5" execution="{\"iopub.execute_input\":\"2023-09-21T17:23: 30.954621Z\",\"iopub.status.busy\":\"2023-09-21T17:23: 30.954289Z\",\"iopub.status.idle\":\"2023-09-21T17:23: 31.061599Z\",\"shell.execute_reply\":\"2023-09-21T17:23: 31.058688Z\",\"shell.execute_reply.started\":\"2023-09-21T17:23: 30.954598Z\"}"}
``` python
plt.style.use(opinionated.STYLES['opinionated_min' ])
sns.set_context('notebook' )
```
:::
::: {#3c521572 .cell .markdown}
## 2D $U(1)$ {#2d-u1}
:::
::: {#b853ba57 .cell .code execution_count="6" execution="{\"iopub.execute_input\":\"2023-09-21T17:23: 46.575842Z\",\"iopub.status.busy\":\"2023-09-21T17:23: 46.575002Z\",\"iopub.status.idle\":\"2023-09-21T17:23: 46.644400Z\",\"shell.execute_reply\":\"2023-09-21T17:23: 46.644215Z\",\"shell.execute_reply.started\":\"2023-09-21T17:23: 46.575810Z\"}" jukit_cell_id="x8zCEYjYIg"}
``` python
from l2hmc.configs import dict_to_list_of_overrides
seed = np.random.randint(0 , 2 ** 32 )
console.print (f"seed = { seed} " )
overrides = {
"seed" : f" { seed} " ,
"precision" : "float32" ,
"init_wandb" : False ,
"init_aim" : False ,
"use_wandb" : False ,
"dynamics" : {
"latvolume" : [32 , 32 ],
"nleapfrog" : 10 ,
"nchains" : 16 ,
"eps" : 0.05 ,
},
"network" : {
"use_batch_norm" : False ,
},
'annealing_schedule' : {
'beta_init' : 6.0 ,
'beta_final' : 6.0 ,
},
}
OVERRIDES = dict_to_list_of_overrides(overrides)
```
:::
::: {.output .cell-output .cell-output-display}
seed = 1675333995
:::
::: {#5ae7bc5d .cell .code execution_count="7" execution="{\"iopub.execute_input\":\"2023-09-21T17:23: 53.284515Z\",\"iopub.status.busy\":\"2023-09-21T17:23: 53.284077Z\",\"iopub.status.idle\":\"2023-09-21T17:23: 53.362789Z\",\"shell.execute_reply\":\"2023-09-21T17:23: 53.362643Z\",\"shell.execute_reply.started\":\"2023-09-21T17:23: 53.284485Z\"}" jukit_cell_id="iKzpP0xANE"}
``` python
from pathlib import Path
from l2hmc.common import get_timestamp
from enrich.console import get_theme, Console
console = Console(theme= get_theme())
OUTDIR = Path(
'l2hmc-diffusion-2dU1'
).joinpath(get_timestamp("%Y-%m- %d " ))
OUTDIR.mkdir(exist_ok= True , parents= True )
console.print (f"OUTDIR: { OUTDIR} " )
date = get_timestamp('%Y-%m- %d ' )
PLOTS_DIR = OUTDIR.joinpath('plots' )
PLOTS_DIR.mkdir(exist_ok= True , parents= True )
console.print (f"Saving figures to: { PLOTS_DIR} " )
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
OUTDIR: l2hmc-diffusion-2dU1/<span style="color: #2094f3; text-decoration-color: #2094f3">2023</span>-<span style="color: #2094f3; text-decoration-color: #2094f3">09</span>-<span style="color: #2094f3; text-decoration-color: #2094f3">21</span>
Saving figures to: l2hmc-diffusion-2dU1/<span style="color: #2094f3; text-decoration-color: #2094f3">2023</span>-<span style="color: #2094f3; text-decoration-color: #2094f3">09</span>-<span style="color: #2094f3; text-decoration-color: #2094f3">21</span>/plots
</pre>
```
:::
::: {#4ca95dea .cell .code execution_count="8" execution="{\"iopub.execute_input\":\"2023-09-21T17:23: 55.013350Z\",\"iopub.status.busy\":\"2023-09-21T17:23: 55.012634Z\",\"iopub.status.idle\":\"2023-09-21T17:24: 09.645580Z\",\"shell.execute_reply\":\"2023-09-21T17:24: 09.645178Z\",\"shell.execute_reply.started\":\"2023-09-21T17:23: 55.013316Z\"}" jukit_cell_id="lHICETmP7j"}
``` python
#os.environ['MASTER_PORT'] = '5436'
exp = build_experiment(
overrides= [
* OVERRIDES,
'framework=pytorch' ,
'backend=DDP'
]
)
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal; margin-bottom:0;">
[09/21/23 12:23:55][INFO][dist.py:226] - Caught MASTER_PORT:5561 from environment!
[09/21/23 12:23:55][INFO][dist.py:338] - Global Rank: 0 / 0
[09/21/23 12:23:58][INFO][experiment.py:251] - Creating outputs/2023-09-21-122358/pytorch/train
[09/21/23 12:23:58][INFO][experiment.py:251] - Creating outputs/2023-09-21-122358/pytorch/eval
[09/21/23 12:23:58][INFO][experiment.py:251] - Creating outputs/2023-09-21-122358/pytorch/hmc
[09/21/23 12:23:58][INFO][dist.py:226] - Caught MASTER_PORT:5561 from environment!
[09/21/23 12:23:58][INFO][dist.py:226] - Caught MASTER_PORT:5561 from environment!
[09/21/23 12:24:06][INFO][trainer.py:441] - Looking for checkpoints in:
/Users/samforeman/projects/saforem2/l2hmc-qcd/src/l2hmc/checkpoints/U1/2-32-32/nlf-10/xsplit-True/sepnets-True/merge-True/conv-8-16-32-64-128_5-3-3-3-2_2-2-2-2-2/net-16-16-16-16_dp-0.2_bn-False/pytorch
[09/21/23 12:24:06][WARNING][trainer.py:437] - No checkpoints found to load from
[09/21/23 12:24:06][WARNING][trainer.py:437] - Restoring global step from ckpt! self._gstep: 0
[09/21/23 12:24:06][WARNING][trainer.py:437] - Using `torch.optim.Adam` optimizer
[09/21/23 12:24:06][INFO][trainer.py:284] - num_params in model: 958628260
[09/21/23 12:24:09][WARNING][trainer.py:250] - logging with freq 50 for wandb.watch
</pre>
```
:::
::: {#de33e6c7 .cell .code execution_count="10" execution="{\"iopub.execute_input\":\"2023-09-21T17:24: 11.069316Z\",\"iopub.status.busy\":\"2023-09-21T17:24: 11.067417Z\",\"iopub.status.idle\":\"2023-09-21T17:24: 11.148435Z\",\"shell.execute_reply\":\"2023-09-21T17:24: 11.148240Z\",\"shell.execute_reply.started\":\"2023-09-21T17:24: 11.069247Z\"}" jukit_cell_id="9XP3KbsEEH"}
``` python
state = exp.trainer.dynamics.random_state(6.0 )
xdim = state.x.flatten().shape[0 ]
dim = xdim
low_bound = (- np.pi) * np.ones(dim)
high_bound = (np.pi) * np.ones(dim)
sigma = 0.15
retrains = 10
samples_per_retrain = 100
diffusion_prob = 0.1
```
:::
::: {#292fced2 .cell .code execution_count="11" execution="{\"iopub.execute_input\":\"2023-09-21T17:24: 21.526369Z\",\"iopub.status.busy\":\"2023-09-21T17:24: 21.525718Z\",\"iopub.status.idle\":\"2023-09-21T17:25: 01.413937Z\",\"shell.execute_reply\":\"2023-09-21T17:25: 01.413545Z\",\"shell.execute_reply.started\":\"2023-09-21T17:24: 21.526320Z\"}" jukit_cell_id="KCwuqXXGds" scrolled="true"}
``` python
sns.set_context('notebook' )
outputs = {}
outputs['hmc' ] = exp.trainer.eval (
job_type= 'hmc' ,
beta= 6.0 ,
nprint= 100 ,
nchains= 16 ,
eval_steps= 1000
)
#hdset = exp.save_dataset(job_type='hmc', nchains=1)
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;">
[09/21/23 12:24:21][WARNING][trainer.py:437] - Step size `eps` not specified for HMC! Using default: 0.1000 for generic HMC
[09/21/23 12:24:21][WARNING][trainer.py:437] - x.shape (original): torch.Size([16, 2, 32, 32])
[09/21/23 12:24:21][WARNING][trainer.py:437] - x[:nchains].shape: torch.Size([16, 2, 32, 32])
[09/21/23 12:24:21][INFO][trainer.py:1058] - eps=0.1
beta=6.0
nlog=10
table=<rich.table.Table object at 0x2e1b98520>
nprint=100
eval_steps=1000
nleapfrog=20
[09/21/23 12:24:24][INFO][trainer.py:1188] - hstep=0 dt=0.024 beta=6.000 loss=3.410 dQsin=0.125 dQint=0.000 energy=1586.502 logprob=1586.502 logdet=0.000 acc=0.472 sumlogdet=0.000 acc_mask=0.500 plaqs=0.909 intQ=0.000 sinQ=0.051
[09/21/23 12:24:27][INFO][trainer.py:1188] - hstep=100 dt=0.026 beta=6.000 loss=2.876 dQsin=0.163 dQint=0.000 energy=1555.800 logprob=1555.800 logdet=0.000 acc=0.593 sumlogdet=0.000 acc_mask=0.688 plaqs=0.912 intQ=-0.125 sinQ=-0.159
[09/21/23 12:24:31][INFO][trainer.py:1188] - hstep=200 dt=0.025 beta=6.000 loss=4.678 dQsin=0.088 dQint=0.063 energy=1569.994 logprob=1569.994 logdet=0.000 acc=0.451 sumlogdet=0.000 acc_mask=0.250 plaqs=0.912 intQ=-0.187 sinQ=-0.149
[09/21/23 12:24:34][INFO][trainer.py:1188] - hstep=300 dt=0.024 beta=6.000 loss=14.041 dQsin=0.094 dQint=0.000 energy=1554.118 logprob=1554.118 logdet=0.000 acc=0.438 sumlogdet=0.000 acc_mask=0.438 plaqs=0.914 intQ=-0.125 sinQ=-0.114
[09/21/23 12:24:38][INFO][trainer.py:1188] - hstep=400 dt=0.024 beta=6.000 loss=-0.739 dQsin=0.199 dQint=0.000 energy=1566.516 logprob=1566.516 logdet=0.000 acc=0.509 sumlogdet=0.000 acc_mask=0.562 plaqs=0.912 intQ=-0.437 sinQ=-0.452
[09/21/23 12:24:41][INFO][trainer.py:1188] - hstep=500 dt=0.045 beta=6.000 loss=1.545 dQsin=0.100 dQint=0.000 energy=1570.837 logprob=1570.837 logdet=0.000 acc=0.448 sumlogdet=0.000 acc_mask=0.562 plaqs=0.911 intQ=0.125 sinQ=0.189
[09/21/23 12:24:45][INFO][trainer.py:1188] - hstep=600 dt=0.025 beta=6.000 loss=3.780 dQsin=0.094 dQint=0.000 energy=1568.012 logprob=1568.012 logdet=0.000 acc=0.463 sumlogdet=0.000 acc_mask=0.500 plaqs=0.913 intQ=0.438 sinQ=0.466
[09/21/23 12:24:50][INFO][trainer.py:1188] - hstep=700 dt=0.023 beta=6.000 loss=-0.902 dQsin=0.113 dQint=0.000 energy=1563.778 logprob=1563.778 logdet=0.000 acc=0.475 sumlogdet=0.000 acc_mask=0.375 plaqs=0.913 intQ=0.688 sinQ=0.628
[09/21/23 12:24:53][INFO][trainer.py:1188] - hstep=800 dt=0.024 beta=6.000 loss=11.416 dQsin=0.061 dQint=0.000 energy=1561.427 logprob=1561.427 logdet=0.000 acc=0.339 sumlogdet=0.000 acc_mask=0.438 plaqs=0.913 intQ=0.813 sinQ=0.755
[09/21/23 12:24:57][INFO][trainer.py:1188] - hstep=900 dt=0.028 beta=6.000 loss=1.114 dQsin=0.127 dQint=0.000 energy=1564.465 logprob=1564.465 logdet=0.000 acc=0.699 sumlogdet=0.000 acc_mask=0.625 plaqs=0.913 intQ=0.938 sinQ=0.893
</pre>
```
![](assets/diffusion/4291d2280ebb3b7d7fd410b6fbd5a7999b28d1e6.svg)
:::
::: {#a62b5515-1fbe-4a78-9330-a4e675adc55d .cell .code execution_count="12" execution="{\"iopub.execute_input\":\"2023-09-21T17:25: 05.409303Z\",\"iopub.status.busy\":\"2023-09-21T17:25: 05.408268Z\",\"iopub.status.idle\":\"2023-09-21T17:25: 33.378971Z\",\"shell.execute_reply\":\"2023-09-21T17:25: 33.378687Z\",\"shell.execute_reply.started\":\"2023-09-21T17:25: 05.409256Z\"}" scrolled="true"}
``` python
# %matplotlib inline
from l2hmc.common import plot_dataset
sns.set_context('notebook' )
hdataset = outputs['hmc' ]['history' ].get_dataset()
plot_dataset(hdataset, outdir= PLOTS_DIR, job_type= 'HMC' )
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal; margin-bottom:0; margin-bottom: 0;">
[09/21/23 12:25:06][INFO][plot_helpers.py:1049] - Saving figure to: l2hmc-diffusion-2dU1/2023-09-21/plots/ridgeplots/svgs/energy_ridgeplot.svg
[09/21/23 12:25:09][INFO][plot_helpers.py:1049] - Saving figure to: l2hmc-diffusion-2dU1/2023-09-21/plots/ridgeplots/svgs/logprob_ridgeplot.svg
[09/21/23 12:25:11][INFO][plot_helpers.py:1049] - Saving figure to: l2hmc-diffusion-2dU1/2023-09-21/plots/ridgeplots/svgs/logdet_ridgeplot.svg
</pre>
```
![](assets/diffusion/68eb7a87933506e4ea9f017913c8f1eb3d474f82.svg)
![](assets/diffusion/91f9822f7846afea800b646ef6d7f00adbe10a41.svg)
![](assets/diffusion/c51d8b2dcc6ec0e2de63ff069d3730bb571457b0.svg)
![](assets/diffusion/f4712af7b26ced3d77075d986993d9ed41e01438.svg)
![](assets/diffusion/52754b8f1d67eb90a526f91bf820f0ab6f00f3e6.svg)
![](assets/diffusion/150b885c11a3eb085ac620b73b278b467bf2accf.svg)
![](assets/diffusion/eee837b584211a155df0693c80ab0f74ef3df2d9.svg)
![](assets/diffusion/d7f0c56ea62b31af4a43a53d6eb9d0e46a0d0979.svg)
![](assets/diffusion/79d781dfd4250c92dd76a8d8d3185ef5f4bf0edb.svg)
![](assets/diffusion/ba6c79eb08e2df2807d207201f4260fcfb41a55d.svg)
![](assets/diffusion/be9f0e6e22b1f8d41ba35ef1c2e7c53026cdd606.svg)
![](assets/diffusion/acbf03bb04c361e167eefed84a13de13ef8ee37c.svg)
![](assets/diffusion/c919c93d92e6aec85e840c2ad8cb47eb30a8b015.svg)
![](assets/diffusion/7e5ffe849dbd6a7b8314042e953d762ab9e4286b.svg)
![](assets/diffusion/713d68e5d1fdd22fb92855cc3e25255035f0770c.svg)
![](assets/diffusion/368f5610bb37915389764b0b7392edd577eaf415.svg)
![](assets/diffusion/86fc3f47b8b4020e2b94ff74e79fdb3558c43f80.svg)
![](assets/diffusion/835139501cc9fcf6442be81e0033045b8f33ae74.svg)
:::
::: {#80a4cca0-19fd-43aa-83a5-8bd9c18da011 .cell .code execution_count="13" execution="{\"iopub.execute_input\":\"2023-09-21T17:26: 40.599785Z\",\"iopub.status.busy\":\"2023-09-21T17:26: 40.598293Z\",\"iopub.status.idle\":\"2023-09-21T17:27: 07.102561Z\",\"shell.execute_reply\":\"2023-09-21T17:27: 07.101951Z\",\"shell.execute_reply.started\":\"2023-09-21T17:26: 40.599694Z\"}" scrolled="true"}
``` python
import torch
initial_states = []
state_init = exp.trainer.dynamics.random_state(6.0 )
x = state_init.x
beta = state_init.beta
NSAMPLES = 1000
for idx in range (NSAMPLES + int (0.1 * NSAMPLES)):
if idx % 100 == 0 :
console.print (f"step: { idx} " )
x, metrics = exp.trainer.hmc_step((x, beta))
if idx > int ((0.1 * NSAMPLES)):
initial_states.append(x)
initial_states = torch.stack(initial_states).squeeze()
initial_states_np = initial_states.detach().cpu().numpy()
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
step: <span style="color: #2094f3; text-decoration-color: #2094f3">0</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">100</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">200</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">300</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">400</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">500</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">600</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">700</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">800</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">900</span>
step: <span style="color: #2094f3; text-decoration-color: #2094f3">1000</span>
</pre>
```
:::
::: {#b3874832-0782-4b49-a025-742be370675d .cell .code execution_count="14" execution="{\"iopub.execute_input\":\"2023-09-21T17:27: 37.023220Z\",\"iopub.status.busy\":\"2023-09-21T17:27: 37.022568Z\",\"iopub.status.idle\":\"2023-09-21T17:27: 37.170197Z\",\"shell.execute_reply\":\"2023-09-21T17:27: 37.169976Z\",\"shell.execute_reply.started\":\"2023-09-21T17:27: 37.023184Z\"}"}
``` python
initial_states_np.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="14"}
(999, 16, 2048)
:::
::: {#46eea215-ac24-46ed-b334-648ab0f082b0 .cell .code execution_count="15" execution="{\"iopub.execute_input\":\"2023-09-21T17:27: 45.524060Z\",\"iopub.status.busy\":\"2023-09-21T17:27: 45.523874Z\",\"iopub.status.idle\":\"2023-09-21T17:27: 45.610750Z\",\"shell.execute_reply\":\"2023-09-21T17:27: 45.610562Z\",\"shell.execute_reply.started\":\"2023-09-21T17:27: 45.524045Z\"}"}
``` python
x_ = initial_states_np.reshape(- 1 , 16 , 2 , 32 , 32 )
tmp_ = x_[:, 0 , ...]
console.print (f' { x_. shape} ' )
console.print (f' { tmp_. shape} ' )
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
<span style="font-weight: bold">(</span><span style="color: #2094f3; text-decoration-color: #2094f3">999</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span><span style="font-weight: bold">)</span>
<span style="font-weight: bold">(</span><span style="color: #2094f3; text-decoration-color: #2094f3">999</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span><span style="font-weight: bold">)</span>
</pre>
```
:::
::: {#00862566-256a-4e1b-b96f-b289e146ed65 .cell .code execution_count="16" execution="{\"iopub.execute_input\":\"2023-09-21T17:28: 23.005041Z\",\"iopub.status.busy\":\"2023-09-21T17:28: 23.004835Z\",\"iopub.status.idle\":\"2023-09-21T17:28: 41.019060Z\",\"shell.execute_reply\":\"2023-09-21T17:28: 41.018864Z\",\"shell.execute_reply.started\":\"2023-09-21T17:28: 23.005025Z\"}"}
``` python
from l2hmc.common import savefig
#x_ = initial_states_np[:100].reshape(-1, 2, 32, 32)
tmp_ = x_[:, 0 , ...]
fig, ax = plt.subplots()
sns.kdeplot(
x= tmp_[- 100 :, 0 ].flatten(),
y= tmp_[- 100 :, 1 ].flatten(),
# ax=ax,
cmap= 'viridis' ,
# ax=axes[0],
# cmap="Blues",
shade= False ,
# bw_adjust=0.5,
thresh= 0
)
ax.set_xlim((- 4 , 4 ))
ax.set_ylim((- 4 , 4 ))
savefig(
f'hmc_samples- { NSAMPLES} ' ,
Path(PLOTS_DIR),
tstamp= True ,
)
```
:::
::: {.output .cell-output .cell-output-display}
Saving hmc_samples-1000-2023-09-21-122840 to l2hmc-diffusion-2dU1/2023-09-21/plots
:::
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/70234c6bf6df3cbd32f9de96e36e740eda0ed20d.svg)
:::
::: {#544a477a-05cd-4e1f-b11f-5f9d29c61918 .cell .code execution_count="28" execution="{\"iopub.execute_input\":\"2023-09-21T18:03: 31.911850Z\",\"iopub.status.busy\":\"2023-09-21T18:03: 31.910702Z\",\"iopub.status.idle\":\"2023-09-21T18:03: 32.065942Z\",\"shell.execute_reply\":\"2023-09-21T18:03: 32.065721Z\",\"shell.execute_reply.started\":\"2023-09-21T18:03: 31.911790Z\"}"}
``` python
class Diffusion:
def __init__ (
self ,
noise_steps: int = 1000 ,
beta_start: float = 1e-4 ,
beta_end: float = 0.02 ,
nchannels: int = 2 ,
img_size: int = 256 ,
device: str = "cuda"
):
self .noise_steps = noise_steps
self .beta_start = beta_start
self .beta_end = beta_end
self .img_size = img_size
self .device = device
self .nchannels = nchannels
self .beta = self .prepare_noise_schedule().to(device)
self .alpha = 1. - self .beta
self .alpha_hat = torch.cumprod(self .alpha, dim= 0 )
def prepare_noise_schedule(self ):
return torch.linspace(self .beta_start, self .beta_end, self .noise_steps)
def noise_images(self , x, t):
sqrt_alpha_hat = torch.sqrt(self .alpha_hat[t])[:, None , None , None ]
sqrt_one_minus_alpha_hat = torch.sqrt(
1 - self .alpha_hat[t]
)[:, None , None , None ]
eps = torch.randn_like(x)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * eps, eps
def sample_timesteps(self , n):
return torch.randint(low= 1 , high= self .noise_steps, size= (n,))
def sample(self , model, n):
# console.print(f"Sampling {n} new images....")
model.eval ()
with torch.no_grad():
x = torch.randn(
(n, self .nchannels, self .img_size, self .img_size)
).to(self .device)
sample_bar = tqdm(
reversed (range (1 , self .noise_steps)),
position= 0 ,
total= self .noise_steps - 1 ,
dynamic_ncols= True ,
)
for i in sample_bar:
t = (torch.ones(n) * i).long ().to(self .device)
predicted_noise = model(x, t)
alpha = self .alpha[t][:, None , None , None ]
alpha_hat = self .alpha_hat[t][:, None , None , None ]
beta = self .beta[t][:, None , None , None ]
if i > 1 :
noise = torch.randn_like(x)
else :
noise = torch.zeros_like(x)
x = (
(1 / torch.sqrt(alpha))
* (
x
- ((1 - alpha) / (torch.sqrt(1 - alpha_hat)))
* predicted_noise
)
+ (torch.sqrt(beta) * noise)
)
model.train()
x = (x + np.pi) % (2 * np.pi) - np.pi
return x
```
:::
::: {#f49c6213-2274-45e5-bc10-01a46e19d7c6 .cell .code execution_count="29" execution="{\"iopub.execute_input\":\"2023-09-21T18:03: 32.936448Z\",\"iopub.status.busy\":\"2023-09-21T18:03: 32.936214Z\",\"iopub.status.idle\":\"2023-09-21T18:03: 33.006979Z\",\"shell.execute_reply\":\"2023-09-21T18:03: 33.006767Z\",\"shell.execute_reply.started\":\"2023-09-21T18:03: 32.936420Z\"}"}
``` python
initial_states.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="29"}
torch.Size([ 999, 16, 2048 ] )
:::
::: {#dcfb6322-b8d5-4f4a-8c74-6a9e1f9eb2c3 .cell .markdown}
## Train Diffusion Model
:::
::: {#24658ef0-6421-465a-aee3-45cbaa84fe16 .cell .code execution_count="19" execution="{\"iopub.execute_input\":\"2023-09-21T17:29: 05.270957Z\",\"iopub.status.busy\":\"2023-09-21T17:29: 05.270198Z\",\"iopub.status.idle\":\"2023-09-21T17:29: 05.643233Z\",\"shell.execute_reply\":\"2023-09-21T17:29: 05.643039Z\",\"shell.execute_reply.started\":\"2023-09-21T17:29: 05.270924Z\"}"}
``` python
import torchvision
import os
import random
from pathlib import Path
import torch
import torchvision
import torchvision.transforms as T
import numpy as np
from PIL import Image
#from fastdownload import FastDownload
from torch.utils.data import DataLoader
def save_images(images, path, ** kwargs):
grid = torchvision.utils.make_grid(images, ** kwargs)
ndarr = grid.permute(1 , 2 , 0 ).to('cpu' ).numpy()
im = Image.fromarray(ndarr)
im.save(path)
```
:::
::: {#bf6c6419-6d67-4e82-800f-19b896dc6a1b .cell .markdown}
## Build Diffusion Model with UNet Architecure
:::
::: {#d7adbbf1-9a17-459c-ba2b-144605576790 .cell .code execution_count="20" execution="{\"iopub.execute_input\":\"2023-09-21T17:29: 33.246111Z\",\"iopub.status.busy\":\"2023-09-21T17:29: 33.245275Z\",\"iopub.status.idle\":\"2023-09-21T17:29: 33.647190Z\",\"shell.execute_reply\":\"2023-09-21T17:29: 33.646950Z\",\"shell.execute_reply.started\":\"2023-09-21T17:29: 33.246094Z\"}"}
``` python
from torch import nn
from torch import optim
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from l2hmc.common import savefig
from l2hmc.diffusion.modules import NoiseScheduler, UNet
from l2hmc.diffusion import ddpm
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {
'channels_in' : 2 ,
'channels_out' : 2 ,
'train_batch_size' : 5 ,
'learning_rate' : 0.001 ,
'num_epochs' : 1 ,
'noise_steps' : 100 ,
'beta' : 6.0 ,
'img_size' : 32 ,
'retrains' : 10 ,
'samples_per_retrain' : 500 ,
'diffusion_prob' : 0.1 ,
}
model = UNet(c_in= 2 , c_out= 2 )
dataset = TensorDataset(initial_states.reshape(- 1 , 2 , 32 , 32 ))
dataloader = DataLoader(
dataset,
batch_size= config["train_batch_size" ],
shuffle= False ,
drop_last= True
)
optimizer = optim.AdamW(model.parameters(), lr= config['learning_rate' ])
mse = nn.MSELoss()
diffusion = Diffusion(
noise_steps= 100 ,
img_size= 32 ,
device= DEVICE,
nchannels= 2 ,
)
#logger = SummaryWriter(os.path.join("runs", args.run_name))
l = len (dataloader)
run_name = 'diffusion2dU1'
```
:::
::: {#5aed3b53-6261-4e0c-938d-07b0616d0a44 .cell .markdown}
### Perform initial training on HMC samples
:::
::: {#27fcd372-259b-4be5-830d-167a396aeac1 .cell .code execution_count="21" execution="{\"iopub.execute_input\":\"2023-09-21T17:29: 54.637896Z\",\"iopub.status.busy\":\"2023-09-21T17:29: 54.634898Z\",\"iopub.status.idle\":\"2023-09-21T17:45: 06.940844Z\",\"shell.execute_reply\":\"2023-09-21T17:45: 06.940627Z\",\"shell.execute_reply.started\":\"2023-09-21T17:29: 54.637847Z\"}" scrolled="true"}
```python
from torch import optim
device = 'cpu'
#dataloader = get_data(args)
#model = UNet().to(device)
sampled_images_history = []
for epoch in range (config['num_epochs' ]):
console.print (f"Starting epoch { epoch} :" )
pbar = tqdm(dataloader)
for i, images in enumerate (pbar):
if isinstance (images, (tuple , list )) and len (images) == 1 :
images = images[0 ]
t = diffusion.sample_timesteps(images.shape[0 ]).to(device)
x_t, noise = diffusion.noise_images(images, t)
predicted_noise = model(x_t, t)
loss = mse(noise, predicted_noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix({'epoch' : epoch, 'batch' : i, 'MSE' : loss.item()})
console.print (f'epoch: { epoch} , loss: { loss. item()} ' )
sampled_images = diffusion.sample(model, n= images.shape[0 ])
sampled_images_history.append(sampled_images)
sns.set_context('notebook' )
#tmp = initial_states.reshape(-1, 2, 32, 32)
fig, ax = plt.subplots(ncols= 2 )
_ = ax[0 ].imshow(sampled_images[0 , 0 , :, :])
_ = ax[1 ].imshow(sampled_images[0 , 1 , :, :])
_ = ax[0 ].set_xticklabels([])
_ = ax[1 ].set_xticklabels([])
_ = ax[0 ].set_yticklabels([])
_ = ax[1 ].set_yticklabels([])
_ = ax[0 ].set_title(r"$U_ {0} $" , loc= 'center' )
_ = ax[1 ].set_title(r"$U_ {1} $" , loc= 'center' )
_ = fig.suptitle('Diffusion Samples' , y= 0.8 )
plt.show()
savefig(fname= f'sampled_image_epoch { epoch} ' , outdir= PLOTS_DIR, tstamp= True )
MODEL_FILE = OUTDIR.joinpath("models" , f"unet-diffusion-epoch { epoch} .pt" )
MODEL_FILE.parent.mkdir(exist_ok= True , parents= True )
console.print (f"Saving model checkpoint to: { MODEL_FILE} " )
torch.save(model.state_dict(), MODEL_FILE)
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
Starting epoch <span style="color: #2094f3; text-decoration-color: #2094f3">0</span>:
</pre>
```
:::
::: {.output .cell-output .cell-output-display}
``` json
{"model_id" : "19b415c346b24bef8b60336d7f7bc355" , "version_major" : 2 , "version_minor" : 0 }
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
epoch: <span style="color: #2094f3; text-decoration-color: #2094f3">0</span>, loss: <span style="color: #2094f3; text-decoration-color: #2094f3">0.6023472547531128</span>
</pre>
```
:::
::: {.output .cell-output .cell-output-display}
``` json
{"model_id" : "eea24504754f4cb9ab4d9925a6225c10" , "version_major" : 2 , "version_minor" : 0 }
```
:::
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/e90001ab4bbbe38f50cca429a69bc2b212ed8a4d.svg)
:::
::: {.output .cell-output .cell-output-display}
Saving sampled_image_epoch0-2023-09-21-124506 to l2hmc-diffusion-2dU1/2023-09-21/plots
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
Saving model checkpoint to: l2hmc-diffusion-2dU1/<span style="color: #2094f3; text-decoration-color: #2094f3">2023</span>-<span style="color: #2094f3; text-decoration-color: #2094f3">09</span>-<span style="color: #2094f3; text-decoration-color: #2094f3">21</span>/models/unet-diffusion-epoch0.pt
</pre>
```
:::
::: {.output .cell-output .cell-output-display}
< Figure size 640x480 with 0 Axes >
:::
::: {#c4ae942b-4617-4f1e-a837-68e95d6ddabc .cell .code execution_count="22" execution="{\"iopub.execute_input\":\"2023-09-21T18:01: 57.860377Z\",\"iopub.status.busy\":\"2023-09-21T18:01: 57.859019Z\",\"iopub.status.idle\":\"2023-09-21T18:01: 58.250584Z\",\"shell.execute_reply\":\"2023-09-21T18:01: 58.250409Z\",\"shell.execute_reply.started\":\"2023-09-21T18:01: 57.860309Z\"}"}
``` python
sns.set_context('notebook' )
tmp = initial_states.reshape(- 1 , 2 , 32 , 32 )
fig, ax = plt.subplots(ncols= 2 )
_ = ax[0 ].imshow(tmp[0 , 0 , :, :])
_ = ax[1 ].imshow(tmp[0 , 1 , :, :])
_ = ax[0 ].set_title(r"$U_ {0} $" , loc= 'center' )
_ = ax[0 ].set_xticklabels([])
_ = ax[1 ].set_xticklabels([])
_ = ax[0 ].set_yticklabels([])
_ = ax[1 ].set_yticklabels([])
_ = ax[1 ].set_title(r"$U_ {1} $" , loc= 'center' )
_ = fig.suptitle('HMC Samples' , y= 0.8 )
```
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/85261feb8c9ba70c38f9f40eb69b7b3bc4c670b0.svg)
:::
:::
::: {#2637733d-1da9-42e9-b00f-630b6b5d8269 .cell .code execution_count="23" execution="{\"iopub.execute_input\":\"2023-09-21T18:02: 10.910463Z\",\"iopub.status.busy\":\"2023-09-21T18:02: 10.910215Z\",\"iopub.status.idle\":\"2023-09-21T18:02: 10.982520Z\",\"shell.execute_reply\":\"2023-09-21T18:02: 10.982237Z\",\"shell.execute_reply.started\":\"2023-09-21T18:02: 10.910447Z\"}"}
``` python
sampled_images_history_ = torch.stack(sampled_images_history)
```
:::
::: {#2afaff5c-a946-4080-90d7-bbb1af973ce0 .cell .code execution_count="24" execution="{\"iopub.execute_input\":\"2023-09-21T18:02: 11.478836Z\",\"iopub.status.busy\":\"2023-09-21T18:02: 11.478478Z\",\"iopub.status.idle\":\"2023-09-21T18:02: 11.546438Z\",\"shell.execute_reply\":\"2023-09-21T18:02: 11.546257Z\",\"shell.execute_reply.started\":\"2023-09-21T18:02: 11.478818Z\"}"}
``` python
sampled_images_history_.shape
```
::: {.output .cell-output .cell-output-display .execute_result execution_count="24"}
torch.Size([ 1, 5, 2, 32, 32 ] )
:::
:::
::: {#c9397f4d-fbad-4332-b909-c671469b1f7d .cell .code execution_count="25" execution="{\"iopub.execute_input\":\"2023-09-21T18:02: 13.627514Z\",\"iopub.status.busy\":\"2023-09-21T18:02: 13.627140Z\",\"iopub.status.idle\":\"2023-09-21T18:02: 13.825037Z\",\"shell.execute_reply\":\"2023-09-21T18:02: 13.824814Z\",\"shell.execute_reply.started\":\"2023-09-21T18:02: 13.627490Z\"}"}
``` python
sns.set_context('notebook' )
fig, ax = plt.subplots(ncols= 2 )
_ = ax[0 ].imshow(sampled_images_history_[0 ][0 ][0 ])
_ = ax[1 ].imshow(sampled_images_history_[0 ][0 ][1 ])
_ = ax[0 ].set_xticklabels([])
_ = ax[1 ].set_xticklabels([])
_ = ax[0 ].set_yticklabels([])
_ = ax[1 ].set_yticklabels([])
_ = ax[0 ].set_title(r"$U_ {0} $" , loc= 'center' )
_ = ax[1 ].set_title(r"$U_ {1} $" , loc= 'center' )
_ = fig.suptitle('Diffusion Samples' , y= 0.85 )
```
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/b2807643bf74b33689a17754209e56554bfe1f52.svg)
:::
:::
::: {#70324363-7315-4c3f-9304-24c0509bdd6c .cell .code execution_count="27" execution="{\"iopub.execute_input\":\"2023-09-21T18:02: 25.159709Z\",\"iopub.status.busy\":\"2023-09-21T18:02: 25.158545Z\",\"iopub.status.idle\":\"2023-09-21T18:02: 25.258988Z\",\"shell.execute_reply\":\"2023-09-21T18:02: 25.258835Z\",\"shell.execute_reply.started\":\"2023-09-21T18:02: 25.159655Z\"}"}
``` python
for idx in range (sampled_images_history_.shape[0 ]):
q = exp.trainer.lattice.charges(x= sampled_images_history_[idx])
console.print (f' { idx} : { q} ' )
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
<span style="color: #2094f3; text-decoration-color: #2094f3">0</span>: <span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">Charges</span><span style="font-weight: bold">(</span><span style="color: #7d8697; text-decoration-color: #7d8697">intQ</span>=<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">tensor</span><span style="font-weight: bold">([</span> <span style="color: #2094f3; text-decoration-color: #2094f3">5.0000e+00</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">-4.0000e+00</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">-6.0000e+00</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">-4.5535e-07</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">1.0000e+00</span><span style="font-weight: bold">])</span>, <span style="color: #7d8697; text-decoration-color: #7d8697">sinQ</span>=<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">tensor</span><span style="font-weight: bold">([</span> <span style="color: #2094f3; text-decoration-color: #2094f3">1.6426</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">-1.7244</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">-4.4651</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">0.5680</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">0.7046</span><span style="font-weight: bold">]))</span>
</pre>
```
:::
::: {#e8f0fae1-63eb-4522-8342-2d051f4705c4 .cell .markdown}
## HMC Sampling with Diffusion
:::
::: {#1b571b62-b713-4481-a328-fe3d1bcc0030 .cell .code execution_count="32" execution="{\"iopub.execute_input\":\"2023-09-21T18:08: 36.186462Z\",\"iopub.status.busy\":\"2023-09-21T18:08: 36.185269Z\",\"iopub.status.idle\":\"2023-09-21T18:08: 36.391707Z\",\"shell.execute_reply\":\"2023-09-21T18:08: 36.391500Z\",\"shell.execute_reply.started\":\"2023-09-21T18:08: 36.186399Z\"}" scrolled="true"}
``` python
#for retrain_iter in range(config['retrains']):
state = exp.trainer.dynamics.random_state(config['beta' ])
x = state.x
histories = {}
samples = []
hmc_samples = []
diffusion_samples = []
global_step = 0
watcher = {}
update_types = []
combined_samples = {}
```
:::
::: {#4deb9af0-c9ca-4314-a393-7a206c15d1e6 .cell .code execution_count="33" execution="{\"iopub.execute_input\":\"2023-09-21T18:08: 40.518616Z\",\"iopub.status.busy\":\"2023-09-21T18:08: 40.518438Z\",\"iopub.status.idle\":\"2023-09-21T18:08: 40.627283Z\",\"shell.execute_reply\":\"2023-09-21T18:08: 40.625908Z\",\"shell.execute_reply.started\":\"2023-09-21T18:08: 40.518604Z\"}"}
``` python
global_step
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="33"}
0
:::
::: {#d3d35fc8-5067-4d51-8cc0-8136b34669a3 .cell .code execution_count="35" execution="{\"iopub.execute_input\":\"2023-09-21T18:09: 58.481960Z\",\"iopub.status.busy\":\"2023-09-21T18:09: 58.480609Z\",\"iopub.status.idle\":\"2023-09-21T18:12: 28.387917Z\",\"shell.execute_reply\":\"2023-09-21T18:12: 28.387487Z\",\"shell.execute_reply.started\":\"2023-09-21T18:09: 58.481905Z\"}" scrolled="true"}
``` python
for retrain_iter in range (2 ):
console.print (f'retrain_iter: { retrain_iter} ' )
ndiff_acc = 0
ndiff_proposed = 0
histories[retrain_iter] = {
'diffusion' : [],
'hmc' : [],
}
#for idx in range(config['samples_per_retrain']):
sbar = tqdm(range (10 ))
for idx in sbar:
t0_ = time.perf_counter()
if idx % 100 == 0 :
console.print (f'sample idx: { idx} ' )
rand = np.random.uniform()
if (retrain_iter >= 1 ) and rand < diffusion_prob:
console.print (f'rand: { rand} < { diffusion_prob} ' )
# Sample from diffusion model
x_ = diffusion.sample(model, n= x.shape[0 ])
ll_ = exp.trainer.dynamics.potential_energy(x_, config['beta' ])
ll = exp.trainer.dynamics.potential_energy(x, config['beta' ])
ratio = ll_ / ll
a = torch.min (torch.ones_like(ratio), ratio)
u = torch.rand(a.shape)
#u = np.random.uniform()
#for jdx in range(u.shape[0]):
# if u[jdx] < a[jdx]:
# samples.append(x_[jdx])
# diffusion_samples.append(x_[jdx])
#x = torch.where((u < a), x_, x.reshape_as(x_)).reshape_as(x)
x = torch.where((u < a)[:, None , None , None ], x_, x.reshape_as(x_))
samples.append(x)
diffusion_samples.append(x)
combined_samples[global_step] = x
watcher[global_step] = 'diffusion'
#diffusion_samples.extend(x)
#samples.extend(x)
#ndiff_acc +=
#if u < a:
# console.print('Accepted diffusion sample!')
# console.print(f'{ndiff_acc} / {ndiff_proposed}')
# ndiff_acc += 1
# x = x_
# diffusion_samples.append(x)
# samples.append(x)
else :
# Oherwise, HMC
x, metrics = exp.trainer.hmc_step((x, config['beta' ]))
hmc_samples.append(x)
samples.append(x)
combined_samples[global_step] = x
watcher[global_step] = 'HMC'
smetrics = {
'idx' : idx,
'global_step' : global_step,
'dt' : time.perf_counter() - t0_,
}
global_step += 1
#smetrics |= {
# f'{k}': {torch.tensor(v).mean().item()} for k, v in metrics.items()
#}
sbar.set_postfix(smetrics)
# Train loop
dataset = TensorDataset(
torch.stack(hmc_samples).reshape(- 1 , 2 , 32 , 32 )
)
dataloader = DataLoader(
dataset,
shuffle= False ,
drop_last= True ,
batch_size= config["train_batch_size" ],
)
pbar = tqdm(dataloader)
for i, batch in enumerate (pbar):
if i == 0 :
console.print ('Retraining...' )
if isinstance (batch, (tuple , list )) and len (batch) == 1 :
batch, = batch
batch = batch.reshape(- 1 , 2 , 32 , 32 )
t0 = time.time()
t = diffusion.sample_timesteps(batch.shape[0 ]).to(device)
x_t, noise = diffusion.noise_images(batch, t)
predicted_noise = model(x_t, t)
loss = mse(noise, predicted_noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
t1 = time.time()
pbar.set_postfix(
{
'global_step' : global_step,
'retrain_iter' : retrain_iter,
'batch' : i,
'dt' : t1 - t0,
'MSE' : loss.item()
}
)
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
retrain_iter: <span style="color: #2094f3; text-decoration-color: #2094f3">0</span>
```
``` json
{"model_id" : "17132d7ca8624fa387ee9467e4f1fa4d" , "version_major" : 2 , "version_minor" : 0 }
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
sample idx: <span style="color: #2094f3; text-decoration-color: #2094f3">0</span>
</pre>
```
``` json
{"model_id" : "0ed1080fdebd4f7b9aae80db0d36b96b" , "version_major" : 2 , "version_minor" : 0 }
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
Retraining<span style="color: #808000; text-decoration-color: #808000">...</span>
</pre>
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
retrain_iter: <span style="color: #2094f3; text-decoration-color: #2094f3">1</span>
</pre>
```
``` json
{"model_id" : "d0346019e21b4d2a9b624dc59e84015b" , "version_major" : 2 , "version_minor" : 0 }
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
sample idx: <span style="color: #2094f3; text-decoration-color: #2094f3">0</span>
</pre>
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
rand: <span style="color: #2094f3; text-decoration-color: #2094f3">0.05506106760134255</span> < <span style="color: #2094f3; text-decoration-color: #2094f3">0.1</span>
</pre>
```
``` json
{"model_id" : "c02b09d53ada46a194a47921f0ab3cba" , "version_major" : 2 , "version_minor" : 0 }
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
rand: <span style="color: #2094f3; text-decoration-color: #2094f3">0.07860283644524213</span> < <span style="color: #2094f3; text-decoration-color: #2094f3">0.1</span>
</pre>
```
``` json
{"model_id" : "184df3f1c9714ece9756866b2617ed02" , "version_major" : 2 , "version_minor" : 0 }
```
``` json
{"model_id" : "eaa0d84229c04618b7a2bffe2a4b1739" , "version_major" : 2 , "version_minor" : 0 }
```
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
Retraining<span style="color: #808000; text-decoration-color: #808000">...</span>
</pre>
```
:::
::: {#a83cf9fb-9c2c-430e-957a-d42eaed79401 .cell .code execution_count="36" execution="{\"iopub.execute_input\":\"2023-09-21T18:12: 32.389697Z\",\"iopub.status.busy\":\"2023-09-21T18:12: 32.389127Z\",\"iopub.status.idle\":\"2023-09-21T18:12: 32.627601Z\",\"shell.execute_reply\":\"2023-09-21T18:12: 32.627412Z\",\"shell.execute_reply.started\":\"2023-09-21T18:12: 32.389644Z\"}" scrolled="true"}
``` python
console.print (' \n ' .join([f" { i. shape} " for i in samples[:100 ]]))
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2048</span><span style="font-weight: bold">])</span>
<span style="color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold">torch.Size</span><span style="font-weight: bold">([</span><span style="color: #2094f3; text-decoration-color: #2094f3">16</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">2</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span>, <span style="color: #2094f3; text-decoration-color: #2094f3">32</span><span style="font-weight: bold">])</span>
</pre>
```
:::
::: {#4b0e9e7c-9585-4f50-81c7-4c9dd37e18ce .cell .code execution_count="37" execution="{\"iopub.execute_input\":\"2023-09-21T18:12: 38.701814Z\",\"iopub.status.busy\":\"2023-09-21T18:12: 38.699745Z\",\"iopub.status.idle\":\"2023-09-21T18:12: 38.789521Z\",\"shell.execute_reply\":\"2023-09-21T18:12: 38.789359Z\",\"shell.execute_reply.started\":\"2023-09-21T18:12: 38.701745Z\"}"}
``` python
samples_ = torch.stack([i.reshape(- 1 , 2 , 32 , 32 ) for i in samples])
samples_.shape
```
::: {.output .cell-output .cell-output-display .execute_result execution_count="37"}
torch.Size([ 30, 16, 2, 32, 32 ] )
:::
:::
::: {#bef6ff41-1561-40ce-980c-29014f7efad5 .cell .code execution_count="38" execution="{\"iopub.execute_input\":\"2023-09-21T18:12: 40.804170Z\",\"iopub.status.busy\":\"2023-09-21T18:12: 40.803208Z\",\"iopub.status.idle\":\"2023-09-21T18:12: 40.896475Z\",\"shell.execute_reply\":\"2023-09-21T18:12: 40.896319Z\",\"shell.execute_reply.started\":\"2023-09-21T18:12: 40.804119Z\"}"}
``` python
len (hmc_samples)
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="38"}
28
:::
::: {#a2279e49-e463-4a3a-a08e-14b1a636222f .cell .code execution_count="41" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 00.841115Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 00.840869Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 00.912440Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 00.912247Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 00.841101Z\"}"}
``` python
len (diffusion_samples)
```
::: {.output .cell-output .cell-output-display .execute_result execution_count="41"}
2
:::
:::
::: {#13d3d5cd-df10-40e0-8dff-b1625a3ead57 .cell .code execution_count="43" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 04.231527Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 04.231202Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 04.328804Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 04.328067Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 04.231509Z\"}"}
``` python
hmc_samples_ = torch.stack([i.reshape(- 1 , 2 , 32 , 32 ) for i in hmc_samples])
diffusion_samples_ = torch.stack(
[i.reshape(- 1 , 2 , 32 , 32 ) for i in diffusion_samples]
)
```
:::
::: {#7b58ba9e-4a8d-4cdb-a14e-87695db13ec3 .cell .code execution_count="44" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 04.966106Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 04.965159Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 05.068570Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 05.067781Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 04.966070Z\"}"}
``` python
hmc_samples_.shape
```
::: {.output .cell-output .cell-output-display .execute_result execution_count="44"}
torch.Size([ 28, 16, 2, 32, 32 ] )
:::
:::
::: {#21c844ca-b9dd-4dd4-96cb-c5bb83ad93b5 .cell .code execution_count="45" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 12.549057Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 12.548060Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 12.726595Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 12.726410Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 12.549036Z\"}"}
``` python
diffusion_samples_.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="45"}
torch.Size([ 2, 16, 2, 32, 32 ] )
:::
::: {#bee61425-74d1-4faa-9447-ef4c6c71a8ba .cell .code execution_count="46" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 15.683862Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 15.683665Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 15.753128Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 15.752924Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 15.683847Z\"}"}
``` python
samples_.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="46"}
torch.Size([ 30, 16, 2, 32, 32 ] )
:::
::: {#c637a290-4dc3-42f6-abeb-ee807b991ffd .cell .code execution_count="47" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 29.661939Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 29.661388Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 29.751394Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 29.751170Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 29.661915Z\"}"}
``` python
def calc_plaqs(x):
return torch.stack([
exp.trainer.lattice.plaqs(
x[:, idx]
) for idx in range (x.shape[1 ])
], - 1 )
def calc_intQ(x):
return torch.stack([
exp.trainer.lattice.int_charges(
x[:, idx]
) for idx in range (x.shape[1 ])
], - 1 )
def calc_sinQ(x):
return torch.stack([
exp.trainer.lattice.sin_charges(
x[:, idx]
) for idx in range (x.shape[1 ])
], - 1 )
```
:::
::: {#be21453d-5bd3-4f17-b9eb-294adcdf2e6b .cell .code execution_count="64" execution="{\"iopub.execute_input\":\"2023-09-21T18:17: 29.078038Z\",\"iopub.status.busy\":\"2023-09-21T18:17: 29.076956Z\",\"iopub.status.idle\":\"2023-09-21T18:17: 29.288381Z\",\"shell.execute_reply\":\"2023-09-21T18:17: 29.288242Z\",\"shell.execute_reply.started\":\"2023-09-21T18:17: 29.077980Z\"}"}
``` python
samples_init_ = initial_states.reshape(- 1 , initial_states.shape[1 ], 2 , 32 , 32 )
samples_init_.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="64"}
torch.Size([ 999, 16, 2, 32, 32 ] )
:::
::: {#8542582c-65ae-40b5-b5b9-70292acc2c81 .cell .code execution_count="65" execution="{\"iopub.execute_input\":\"2023-09-21T18:17: 43.351050Z\",\"iopub.status.busy\":\"2023-09-21T18:17: 43.347957Z\",\"iopub.status.idle\":\"2023-09-21T18:17: 43.699383Z\",\"shell.execute_reply\":\"2023-09-21T18:17: 43.699177Z\",\"shell.execute_reply.started\":\"2023-09-21T18:17: 43.350913Z\"}"}
``` python
metrics_init_ = {
'plaqs' : calc_plaqs(samples_init_),
'intQ' : calc_intQ(samples_init_),
'sinQ' : calc_sinQ(samples_init_)
}
metrics_ = {
'plaqs' : calc_plaqs(samples_),
'intQ' : calc_intQ(samples_),
'sinQ' : calc_sinQ(samples_)
}
metrics_hmc_ = {
'plaqs' : calc_plaqs(hmc_samples_),
'intQ' : calc_intQ(hmc_samples_),
'sinQ' : calc_sinQ(hmc_samples_)
}
metrics_diffusion_ = {
'plaqs' : calc_plaqs(diffusion_samples_),
'intQ' : calc_intQ(diffusion_samples_),
'sinQ' : calc_sinQ(diffusion_samples_)
}
```
:::
::: {#21bfb626-a2c1-46db-8f1a-7eb86d81eafc .cell .code execution_count="49" execution="{\"iopub.execute_input\":\"2023-09-21T18:13: 32.364043Z\",\"iopub.status.busy\":\"2023-09-21T18:13: 32.363212Z\",\"iopub.status.idle\":\"2023-09-21T18:13: 32.470577Z\",\"shell.execute_reply\":\"2023-09-21T18:13: 32.470322Z\",\"shell.execute_reply.started\":\"2023-09-21T18:13: 32.364002Z\"}"}
``` python
metrics_['plaqs' ].shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="49"}
torch.Size([ 30, 16 ] )
:::
::: {#59f33ef8-cfad-41d5-bfdd-0717c8700376 .cell .code execution_count="52" execution="{\"iopub.execute_input\":\"2023-09-21T18:14: 16.986650Z\",\"iopub.status.busy\":\"2023-09-21T18:14: 16.985673Z\",\"iopub.status.idle\":\"2023-09-21T18:14: 17.106591Z\",\"shell.execute_reply\":\"2023-09-21T18:14: 17.105464Z\",\"shell.execute_reply.started\":\"2023-09-21T18:14: 16.986586Z\"}"}
``` python
console.print (' \n ' .join([f" { k} : { v} " for k, v in watcher.items()]))
```
:::
::: {.output .cell-output .cell-output-display}
```{=html}
<pre style="white-space:pre;overflow-x:auto;line-height:normal;">
<span style="color: #2094f3; text-decoration-color: #2094f3">0</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">1</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">2</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">3</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">4</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">5</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">6</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">7</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">8</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">9</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">10</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">11</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">12</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">13</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">14</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">15</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">16</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">17</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">18</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">19</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">20</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">21</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">22</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">23</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">24</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">25</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">26</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">27</span>: diffusion
<span style="color: #2094f3; text-decoration-color: #2094f3">28</span>: HMC
<span style="color: #2094f3; text-decoration-color: #2094f3">29</span>: diffusion
</pre>
```
:::
::: {#4e5d7c22-e454-4368-ba7e-599f1d68fa19 .cell .code execution_count="54" execution="{\"iopub.execute_input\":\"2023-09-21T18:14: 41.427638Z\",\"iopub.status.busy\":\"2023-09-21T18:14: 41.427237Z\",\"iopub.status.idle\":\"2023-09-21T18:14: 41.629492Z\",\"shell.execute_reply\":\"2023-09-21T18:14: 41.629312Z\",\"shell.execute_reply.started\":\"2023-09-21T18:14: 41.427623Z\"}"}
``` python
fig, ax = plt.subplots()
_ = ax.plot(metrics_['plaqs' ][:, 0 ], label= 'Combined' )
_ = ax.plot(metrics_hmc_['plaqs' ][:, 0 ], label= 'HMC' )
_ = ax.plot(metrics_diffusion_['plaqs' ][:, 0 ], label= 'Diffusion' )
#_ = ax.plot(metrics_hmc1['plaqs'], label='HMC 1')
#_ = ax.plot(metrics_diff_['plaqs'], label='Diffusion')
_ = ax.legend(loc= 'upper left' , bbox_to_anchor= (1.05 , 1.00 ))
```
:::
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/4509936f2dbb7e69226f3d2851bc462cf8046e8e.svg)
:::
::: {#ce6ca162-689d-4673-b093-261e465eefc7 .cell .code execution_count="66" execution="{\"iopub.execute_input\":\"2023-09-21T18:18: 08.961955Z\",\"iopub.status.busy\":\"2023-09-21T18:18: 08.961699Z\",\"iopub.status.idle\":\"2023-09-21T18:18: 10.065123Z\",\"shell.execute_reply\":\"2023-09-21T18:18: 10.064912Z\",\"shell.execute_reply.started\":\"2023-09-21T18:18: 08.961923Z\"}"}
``` python
fig, ax = plt.subplots(ncols= 3 , figsize= (14 , 4 ))
for idx, (key, val) in enumerate (metrics_init_.items()):
_ = ax[idx].plot(val[:, 0 ], label= 'HMC (Initial Samples)' )
_ = ax[idx].set_ylabel(key, loc= 'center' )
_ = ax[idx].set_xlabel('Draw' , loc= 'center' )
#_ = ax[idx].legend(loc='best', frameon=True, edgecolor="#838383")
_ = fig.suptitle(f"Initial HMC States" , y= 0.92 )
```
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/9d36e3726eb2a63bcae157590b44c3710847442f.svg)
:::
:::
::: {#333d6fce-2124-46c7-809b-e584e2efc057 .cell .code execution_count="68" execution="{\"iopub.execute_input\":\"2023-09-21T18:18: 30.709489Z\",\"iopub.status.busy\":\"2023-09-21T18:18: 30.708639Z\",\"iopub.status.idle\":\"2023-09-21T18:18: 31.422559Z\",\"shell.execute_reply\":\"2023-09-21T18:18: 31.422402Z\",\"shell.execute_reply.started\":\"2023-09-21T18:18: 30.709448Z\"}"}
``` python
fig, ax = plt.subplots(ncols= 3 , figsize= (14 , 4 ))
for idx, (key, val) in enumerate (metrics_.items()):
_ = ax[idx].plot(val[:, 0 ], label= 'Combined' )
_ = ax[idx].set_ylabel(key, loc= 'center' )
_ = ax[idx].set_xlabel('Draw' , loc= 'center' )
#_ = ax[idx].legend(loc='best', frameon=True, edgecolor="#838383")
_ = fig.suptitle(f"Combined Samples" , y= 0.92 )
```
:::
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/6ba854aa0d42fd2eae98f6d82c0d6566005c421b.svg)
:::
::: {#c348510d-9ec3-41c7-b04e-654b5ba0bd3d .cell .code execution_count="71" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 01.490764Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 01.490310Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 01.844808Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 01.844643Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 01.490733Z\"}"}
``` python
fig, ax = plt.subplots(ncols= 3 , figsize= (14 , 4 ))
for idx, (key, val) in enumerate (metrics_hmc_.items()):
_ = ax[idx].plot(val[:, 0 ], label= 'HMC' )
_ = ax[idx].set_ylabel(key, loc= 'center' )
_ = ax[idx].set_xlabel('Draw' , loc= 'center' )
_ = fig.suptitle(f"Generated HMC States" , y= 0.92 )
```
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/9d3a4133b6d0c2bf60da7558af21452baab7d5c3.svg)
:::
:::
::: {#20a16129-4648-4c62-9ed4-d6753d425f89 .cell .code execution_count="73" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 03.884288Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 03.884107Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 04.146335Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 04.146166Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 03.884273Z\"}"}
``` python
fig, ax = plt.subplots(ncols= 3 , figsize= (14 , 4 ))
for idx, (key, val) in enumerate (metrics_diffusion_.items()):
_ = ax[idx].plot(val[:, 0 ], label= 'Diffusion' )
_ = ax[idx].set_ylabel(key, loc= 'center' )
_ = ax[idx].set_xlabel('Draw' , loc= 'center' )
_ = fig.suptitle(f"Generated Diffusion States" , y= 0.92 )
```
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/02d2b941bf240ac2a6659c37519a6e2ebec63b41.svg)
:::
:::
::: {#9f329885-fa42-4e13-bd2c-e6fff4418cdb .cell .code execution_count="74" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 05.363396Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 05.363028Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 05.494802Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 05.494503Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 05.363383Z\"}"}
``` python
from l2hmc.lattice.u1.pytorch.lattice import plaq_exact
plaq_exact(torch.tensor(6.0 ))
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="74"}
tensor(0.9124)
:::
::: {#0646c637-546c-49df-876b-4dfb7abb9497 .cell .code execution_count="78" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 41.898242Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 41.898063Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 42.092250Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 42.092113Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 41.898226Z\"}"}
``` python
fig, ax = plt.subplots()
#_ = plt.hist(metrics_['intQ'].flatten(), color='C0', alpha=0.6, label='Combined', edgecolor='none')
_ = ax.hist(
metrics_diffusion_['intQ' ].flatten(),
color= 'C0' ,
alpha= 0.6 ,
edgecolor= 'none' ,
label= 'Diffusion' ,
density= True ,
)
_ = ax.hist(
metrics_hmc_['intQ' ].flatten(),
color= 'C1' ,
alpha= 0.6 ,
edgecolor= 'none' ,
label= 'HMC' ,
density= True ,
)
_ = ax.legend(loc= 'best' , frameon= True , edgecolor= '#666666' )
_ = ax.set_xlabel(r"$Q$" , loc= 'center' )
_ = ax.set_title('Topological Charge ($Q$) Distribution' , loc= 'center' )
```
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/3f9d006e6a51dac62b80daab5d4bb4ac70596db2.svg)
:::
:::
::: {#f44d12ff-37f7-452e-94d7-6c3f0b28205e .cell .code execution_count="79" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 47.593614Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 47.593330Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 47.769152Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 47.768975Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 47.593598Z\"}"}
``` python
fig, ax = plt.subplots()
_ = plt.plot(metrics_['plaqs' ][:, 0 ], color= 'C0' , label= 'Diffusion' )
_ = plt.plot(metrics_hmc_['plaqs' ][:, 0 ], color= 'C1' , label= 'HMC' )
_ = ax.legend(loc= 'best' , frameon= True , edgecolor= '#666666' , ncols= 2 )
_ = ax.set_ylabel(r"$\left\langle U_{\mu\nu}\right\rangle $" , loc= 'center' )
_ = ax.set_xlabel(f"Draw" , loc= 'center' )
```
:::
::: {.output .cell-output .cell-output-display}
![](assets/diffusion/88b5366a0686ecb675259d3e1459d08d6c39e750.svg)
:::
::: {#c0574a41-5d7e-45d9-97e9-0d81a0a6fa84 .cell .code execution_count="80" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 53.009094Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 53.008705Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 53.084368Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 53.084175Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 53.009076Z\"}"}
``` python
wloops = {
'hmc' : [
exp.trainer.lattice.wilson_loops(i) for i in hmc_samples_
],
'diffusion' : [
exp.trainer.lattice.wilson_loops(i) for i in diffusion_samples_
],
}
plaqs = {
'hmc' : [
exp.trainer.lattice.plaqs(i) for i in hmc_samples_
],
'diffusion' : [
exp.trainer.lattice.plaqs(i) for i in diffusion_samples_
],
}
```
:::
::: {#a7659697-40bd-489f-9445-1b27a5c79d8e .cell .code execution_count="81" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 53.670229Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 53.669768Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 53.757525Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 53.757339Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 53.670188Z\"}"}
``` python
wlhmc = torch.stack(wloops['hmc' ]).squeeze()
wldiff = torch.stack(wloops['diffusion' ]).squeeze()
```
:::
::: {#63818acf-ed67-4f87-a2c2-065edebd94ba .cell .code execution_count="82" execution="{\"iopub.execute_input\":\"2023-09-21T18:19: 54.068148Z\",\"iopub.status.busy\":\"2023-09-21T18:19: 54.067480Z\",\"iopub.status.idle\":\"2023-09-21T18:19: 54.181682Z\",\"shell.execute_reply\":\"2023-09-21T18:19: 54.181151Z\",\"shell.execute_reply.started\":\"2023-09-21T18:19: 54.068106Z\"}"}
``` python
wlhmc.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="82"}
torch.Size([ 28, 16, 32, 32 ] )
:::
::: {#3e408fed-abb8-486c-a560-3af19f121179 .cell .code execution_count="85" execution="{\"iopub.execute_input\":\"2023-09-21T18:20: 23.993243Z\",\"iopub.status.busy\":\"2023-09-21T18:20: 23.992854Z\",\"iopub.status.idle\":\"2023-09-21T18:20: 24.334498Z\",\"shell.execute_reply\":\"2023-09-21T18:20: 24.334319Z\",\"shell.execute_reply.started\":\"2023-09-21T18:20: 23.993216Z\"}"}
``` python
_ = plt.tight_layout()
for idx in range (2 ):
fig, ax = plt.subplots(ncols= 2 )
_ = ax[0 ].imshow(wlhmc[idx, 0 ])
_ = ax[0 ].set_title("HMC" , loc= 'center' )
_ = ax[1 ].imshow(wldiff[idx, 0 ])
_ = ax[1 ].set_title("Diffusion" , loc= 'center' )
_ = fig.suptitle(r"$U_{\mu\nu}$" , y= 0.8 )
for ax_ in ax:
_ = ax_.set_xticklabels([])
_ = ax_.set_yticklabels([])
```
:::
::: {.flex-container .output .cell-output .cell-output-display}
![](assets/diffusion/292acc374ca58b3045c07221e75d5ad20d83a646.svg)
![](assets/diffusion/3b6a02bc8d6489b5a0147b59009b8a3be270ba14.svg)
:::
::: {#ff7b6cca-0327-4be2-aab6-aaa8df6ce2bc .cell .code execution_count="87" execution="{\"iopub.execute_input\":\"2023-09-21T18:20: 36.056148Z\",\"iopub.status.busy\":\"2023-09-21T18:20: 36.055543Z\",\"iopub.status.idle\":\"2023-09-21T18:20: 36.149377Z\",\"shell.execute_reply\":\"2023-09-21T18:20: 36.149190Z\",\"shell.execute_reply.started\":\"2023-09-21T18:20: 36.056103Z\"}"}
``` python
qhmc = metrics_hmc_['intQ' ]
qdiff = metrics_diffusion_['intQ' ]
qhmc.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="87"}
torch.Size([ 28, 16 ] )
:::
::: {#5230a441-1e1b-4901-b4d0-aacf5413ff35 .cell .code execution_count="88" execution="{\"iopub.execute_input\":\"2023-09-21T18:20: 39.115417Z\",\"iopub.status.busy\":\"2023-09-21T18:20: 39.114730Z\",\"iopub.status.idle\":\"2023-09-21T18:20: 39.197664Z\",\"shell.execute_reply\":\"2023-09-21T18:20: 39.197492Z\",\"shell.execute_reply.started\":\"2023-09-21T18:20: 39.115373Z\"}"}
``` python
phmc = torch.stack(plaqs['hmc' ]).squeeze()
pdiff = torch.stack(plaqs['diffusion' ]).squeeze()
phmc.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="88"}
torch.Size([ 28, 16 ] )
:::
::: {#1af88db9-f706-4b58-8e81-a54cf92f144d .cell .code execution_count="89" execution="{\"iopub.execute_input\":\"2023-09-21T18:20: 39.947370Z\",\"iopub.status.busy\":\"2023-09-21T18:20: 39.946997Z\",\"iopub.status.idle\":\"2023-09-21T18:20: 40.026391Z\",\"shell.execute_reply\":\"2023-09-21T18:20: 40.026205Z\",\"shell.execute_reply.started\":\"2023-09-21T18:20: 39.947336Z\"}"}
``` python
pdiff.shape
```
:::
::: {.output .cell-output .cell-output-display .execute_result execution_count="89"}
torch.Size([ 2, 16 ] )
:::
::: {#1eeaeb30-abd1-47e9-8752-42d26873629b .cell .code execution_count="90" execution="{\"iopub.execute_input\":\"2023-09-21T18:20: 40.855087Z\",\"iopub.status.busy\":\"2023-09-21T18:20: 40.854732Z\",\"iopub.status.idle\":\"2023-09-21T18:20: 41.041979Z\",\"shell.execute_reply\":\"2023-09-21T18:20: 41.041831Z\",\"shell.execute_reply.started\":\"2023-09-21T18:20: 40.855068Z\"}"}
``` python
fig, ax = plt.subplots()
_ = ax.hist(
metrics_['plaqs' ].flatten(),
color= 'C1' ,
histtype= 'step' ,
stacked= True ,
density= True ,
label= 'HMC' ,
linewidth= 1.5
)
_ = ax.hist(
metrics_diffusion_['plaqs' ].flatten(),
color= 'C0' ,
histtype= 'step' ,
stacked= True ,
density= True ,
label= 'Diffusion' ,
linewidth= 1.5
)
_ = ax.hist(
metrics_hmc_['plaqs' ].flatten(),
color= 'C2' ,
histtype= 'step' ,
stacked= True ,
density= True ,
label= 'Combined' ,
linewidth= 1.5
)
_ = ax.set_xlabel(r"$U_{\mu\nu}$" , loc= 'center' )
_ = ax.legend(
loc= 'upper left' ,
frameon= True ,
#ncols=2,
bbox_to_anchor= (0.55 , 1.00 ),
edgecolor= "#838383" ,
)
_ = ax.set_title('Plaquette Distribution' , loc= 'center' )
```
:::
::: {.output .cell-output .cell-output-display .cell-output .cell-output-display}
![](assets/diffusion/b8adc3f97c471460069de45c953d36d2a16fc858.svg)
:::
::: {#b40d5264-8818-40d6-9b6f-1a8e5c62f8a0 .cell .code execution_count="92" execution="{\"iopub.execute_input\":\"2023-09-21T18:21: 00.286255Z\",\"iopub.status.busy\":\"2023-09-21T18:21: 00.285869Z\",\"iopub.status.idle\":\"2023-09-21T18:21: 00.480910Z\",\"shell.execute_reply\":\"2023-09-21T18:21: 00.480726Z\",\"shell.execute_reply.started\":\"2023-09-21T18:21: 00.286225Z\"}"}
``` python
fig, ax = plt.subplots()
_ = ax.hist(
metrics_['intQ' ].flatten(),
color= 'C1' ,
histtype= 'step' ,
stacked= True ,
density= True ,
label= 'HMC' ,
linewidth= 1.5
)
_ = ax.hist(
metrics_diffusion_['intQ' ].flatten(),
color= 'C0' ,
histtype= 'step' ,
stacked= True ,
density= True ,
label= 'Diffusion' ,
linewidth= 1.5
)
_ = ax.hist(
metrics_hmc_['intQ' ].flatten(),
color= 'C2' ,
histtype= 'step' ,
stacked= True ,
density= True ,
label= 'Combined' ,
linewidth= 1.5
)
_ = ax.set_xlabel('$Q_{\mathbb {Z} }$' , loc= 'center' )
_ = ax.legend(
loc= 'upper left' ,
frameon= True ,
#ncols=2,
bbox_to_anchor= (0.55 , 1.00 ),
edgecolor= "#838383" ,
)
_ = ax.set_title('Charge Distribution' , loc= 'center' )
```
:::
::: {#667c0f58-fbcf-4268-9193-4d475fbbe01b .cell .code}
``` python
global_step = 0
frames = []
losses = []
print ("Training model..." )
for epoch in range (config["num_epochs" ]):
model.train()
progress_bar = tqdm(total= len (dataloader))
progress_bar.set_description(f"Epoch { epoch} " )
for step, batch in enumerate (dataloader):
t = diffusion.sample_timesteps(images.shape[0 ]).to(device)
noise = torch.randn(batch.shape)
timesteps = torch.randint(
0 , noise_scheduler.num_timesteps, (batch.shape[0 ],)
).long ()
#noisy = noise_scheduler.add_noise(batch, noise, timesteps)
noisy = noise_scheduler.noise_images(batch, timesteps)
noise_pred = model(noisy, timesteps)
loss = F.mse_loss(noise_pred, noise)
loss.backward(loss)
nn.utils.clip_grad_norm_(model.parameters(), 1.0 )
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1 )
logs = {"loss" : loss.detach().item(), "step" : global_step}
losses.append(loss.detach().item())
progress_bar.set_postfix(** logs)
global_step += 1
progress_bar.close()
if epoch % config["save_images_step" ] == 0 or epoch == config["num_epochs" ] - 1 :
# generate data with the model to later visualize the learning process
model.eval ()
sample = torch.randn(config["eval_batch_size" ], 2 )
timesteps = list (range (len (noise_scheduler)))[::- 1 ]
for i, t in enumerate (tqdm(timesteps)):
t = torch.from_numpy(np.repeat(t, config["eval_batch_size" ])).long ()
with torch.no_grad():
residual = model(sample, t)
sample = noise_scheduler.step(residual, t[0 ], sample)
frames.append(sample.numpy())
```
:::
::: {#bca44d4c-aecd-4348-8ad8-be9a404c8946 .cell .code}
``` python
dataset[6 ]
```
:::
::: {#169c00d5-8ef2-42df-8f21-7ce08837ede4 .cell .code}
``` python
len (dataloader)
```
:::
::: {#8d47bdc2-41a1-4236-8674-090e98f2cd49 .cell .code}
``` python
eval_batch_size = 10
num_timesteps = 50
plot_step = 5
noise_scheduler = ddpm.NoiseScheduler(num_timesteps= num_timesteps)
sample = torch.randn(eval_batch_size, 2 )
timesteps = list (range (num_timesteps))[::- 1 ]
samples = []
steps = []
retrains = 10
diffusion_prob = 0.3
samples_per_retrain = 100
```
:::
::: {#11ce1f12-beb8-4d8b-9475-d27c1b91beee .cell .code}
``` python
eval_batch_size = 10
t = torch.from_numpy(np.repeat(timesteps[0 ], eval_batch_size)).long ()
with torch.no_grad():
residual = model(sample, t)
sample_ = noise_scheduler.step(residual, t[0 ], sample)
```
:::
::: {#7db39b1f-807d-4656-a473-fcc818f2d157 .cell .code}
``` python
sample.shape
```
:::
::: {#b7118647-55d6-4605-87cf-c43fbe8d735c .cell .code}
``` python
residual.shape
```
:::
::: {#77a6735b-f21f-4bc3-adc5-35c075f61bdd .cell .code}
``` python
sample_.shape
```
:::
::: {#2648c025-75c8-4555-9d75-3c1c6426f3a1 .cell .code}
``` python
diffusion_samples = []
hmc_samples = []
beta = 1.
for retrain_iter in range (retrains):
console.print (f'retrain_iter: { retrain_iter} ' )
ndiff_acc = 0
ndiff_proposed = 0
for idx in range (samples_per_retrain):
console.print (f'sample idx: { idx} ' )
rand = np.random.uniform()
if rand < diffusion_prob:
ndiff_proposed += 1
rand_pick = randrange(len (dataloader))
#theta_prime = dataset[rand_pick]
t = torch.from_numpy(np.repeat(t, eval_batch_size)).long ()
with torch.no_grad():
residual = model(sample, t)
sample_ = noise_scheduler.step(residual, t[0 ], sample)
ratio = (
log_likelihood_2dU1(sample_, 2 )
/ log_likelihood_2dU1(sample, 2 )
)
a = min (1 , ratio)
u = np.random.uniform()
if u < a:
ndiff_acc += 1
sample = sample_
diffusion_samples.append(sample)
else :
sample_, metrics = exp.trainer.hmc_step((sample_, beta))
hmc_samples.append(sample)
```
:::
::: {#8333621a-2fee-4736-a66c-bc42650d72aa .cell .code}
``` python
for i, t in enumerate (tqdm(timesteps)):
t = torch.from_numpy(np.repeat(t, eval_batch_size)).long ()
with torch.no_grad():
residual = model(sample, t)
sample = noise_scheduler.step(residual, t[0 ], sample)
if (i + 1 ) % plot_step == 0 :
samples.append(sample.numpy())
steps.append(i + 1 )
```
:::
::: {#014d60bf-7f4b-48f6-a0e0-20b73392446d .cell .markdown jp-MarkdownHeadingCollapsed="true"}
## Alternate
:::
::: {#385670d3-4c8a-40e6-abd8-7adf803d0042 .cell .code}
``` python
diffusion_ = DiffusionAlt(img_size= 64 , device= 'cpu' )
```
:::
::: {#88d31f53-1808-4a91-b26d-61b6328b824f .cell .code}
``` python
image = torch.rand(1 , 2 , 64 , 64 )
t = diffusion_.sample_timesteps(image.shape[0 ]).to('cpu' )
unet(image, t)
```
:::