---
title: '🔳 `l2hmc-qcd` Example: 4D SU(3)'
date: "2025-07-24"
date-modified: last-modified
categories:
- LQCD
- MCMC
- HMC
- SU3
- l2hmc
jupyter:
jupytext:
formats: ipynb,py:percent,qmd:quarto
text_representation:
extension: .qmd
format_name: quarto
format_version: '1.0'
jupytext_version: 1.17.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
# %matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import os
os.environ['COLORTERM'] = 'truecolor'
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(color=False)
# automatically detect and reload local changes to modules
%load_ext autoreload
%autoreload 2
import ezpz
import numpy as np
import matplotlib.pyplot as plt
from l2hmc.utils.plot_helpers import FigAxes
import ambivalent
plt.style.use(ambivalent.STYLES['ambivalent'])
#set_plot_style()
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
import ezpz
from pathlib import Path
from typing import Optional
from rich import print
import lovely_tensors as lt
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
# from l2hmc.utils.dist import setup_torch
seed = np.random.randint(2 ** 32)
print(f"seed: {seed}")
_ = ezpz.setup_torch(seed=seed)
torch.set_default_dtype(torch.float64)
# _ = setup_torch(precision='float64', backend='DDP', seed=seed, port='1234')
logger = ezpz.get_logger()
import l2hmc.group.su3.pytorch.group as g
# from l2hmc.utils.rich import get_console
from l2hmc.common import grab_tensor, print_dict
from l2hmc.configs import dict_to_list_of_overrides, get_experiment
from l2hmc.experiment.pytorch.experiment import Experiment, evaluate # noqa
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
from l2hmc.utils.plot_helpers import ( # noqa
plot_scalar,
plot_chains,
plot_leapfrogs
)
def savefig(fig: plt.Figure, fname: str, outdir: os.PathLike):
pngfile = Path(outdir).joinpath(f"pngs/{fname}.png")
svgfile = Path(outdir).joinpath(f"svgs/{fname}.svg")
pngfile.parent.mkdir(exist_ok=True, parents=True)
svgfile.parent.mkdir(exist_ok=True, parents=True)
fig.savefig(svgfile, transparent=True, bbox_inches='tight')
fig.savefig(pngfile, transparent=True, bbox_inches='tight', dpi=300)
def plot_metrics(metrics: dict, title: Optional[str] = None, **kwargs):
outdir = Path(f"./plots-4dSU3/{title}")
outdir.mkdir(exist_ok=True, parents=True)
for key, val in metrics.items():
fig, ax = plot_metric(val, name=key, **kwargs)
if title is not None:
ax.set_title(title)
console.log(f"Saving {key} to {outdir}")
savefig(fig, f"{key}", outdir=outdir)
plt.show()
def plot_metric(
metric: torch.Tensor,
name: Optional[str] = None,
**kwargs,
):
assert len(metric) > 0
if isinstance(metric[0], (int, float, bool, np.floating)):
y = np.stack(metric)
return plot_scalar(y, ylabel=name, **kwargs)
element_shape = metric[0].shape
if len(element_shape) == 2:
y = grab_tensor(torch.stack(metric))
return plot_leapfrogs(y, ylabel=name)
if len(element_shape) == 1:
y = grab_tensor(torch.stack(metric))
return plot_chains(y, ylabel=name, **kwargs)
if len(element_shape) == 0:
y = grab_tensor(torch.stack(metric))
return plot_scalar(y, ylabel=name, **kwargs)
raise ValueError
```
## Load config + build Experiment
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
from rich import print
from l2hmc.configs import CONF_DIR
su3conf = Path(f"{CONF_DIR}/su3test.yaml")
with su3conf.open('r') as stream:
conf = dict(yaml.safe_load(stream))
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
import json
from rich import print_json
print_json(json.dumps(conf, indent=4, sort_keys=True))
overrides = dict_to_list_of_overrides(conf)
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
ptExpSU3 = get_experiment(overrides=[*overrides], build_networks=True)
# console.print(ptExpSU3.config)
state = ptExpSU3.trainer.dynamics.random_state(6.0)
logger.info(f"checkSU(state.x): {g.checkSU(state.x)}")
logger.info(f"checkSU(state.x): {g.checkSU(g.projectSU(state.x))}")
assert isinstance(state.x, torch.Tensor)
assert isinstance(state.beta, torch.Tensor)
assert isinstance(ptExpSU3, Experiment)
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
# from l2hmc.utils.plot_helpers import set_plot_style
# set_plot_style()
from l2hmc.common import get_timestamp
TSTAMP = get_timestamp()
OUTPUT_DIR = Path(f"./outputs/pt4dSU3/{TSTAMP}")
HMC_DIR = OUTPUT_DIR.joinpath('hmc')
EVAL_DIR = OUTPUT_DIR.joinpath('eval')
TRAIN_DIR = OUTPUT_DIR.joinpath('train')
HMC_DIR.mkdir(exist_ok=True, parents=True)
EVAL_DIR.mkdir(exist_ok=True, parents=True)
TRAIN_DIR.mkdir(exist_ok=True, parents=True)
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
ptExpSU3.trainer.print_grads_and_weights()
logger.info(ptExpSU3.config)
#console.print(ptExpSU3.config)
```
## HMC
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
xhmc, history_hmc = evaluate(
nsteps=50,
exp=ptExpSU3,
beta=6.0,
x=state.x,
eps=0.1,
nleapfrog=8,
job_type='hmc',
nlog=1,
nprint=50,
grab=True
)
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
dataset_hmc = history_hmc.get_dataset()
_ = history_hmc.plot_all(title='HMC')
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
xhmc = ptExpSU3.trainer.dynamics.unflatten(xhmc)
print(f"checkSU(x_eval): {g.checkSU(xhmc)}")
print(f"checkSU(x_eval): {g.checkSU(g.projectSU(xhmc))}")
```
## Training
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
import time
from l2hmc.utils.history import BaseHistory, summarize_dict
history_train = BaseHistory()
x = state.x
for step in range(100):
# log.info(f'HMC STEP: {step}')
tic = time.perf_counter()
x, metrics_ = ptExpSU3.trainer.train_step(
(x, state.beta)
)
toc = time.perf_counter()
metrics = {
'train_step': step,
'dt': toc - tic,
**metrics_,
}
if step % 5 == 0:
avgs = history_train.update(metrics)
summary = summarize_dict(avgs)
logger.info(summary)
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
dataset_train = history_train.get_dataset()
_ = history_train.plot_all(
title='Train',
num_chains=x.shape[0],
)
```
## Evaluation
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
# state = ptExpSU3.trainer.dynamics.random_state(6.0)
xeval, history_eval = evaluate(
nsteps=50,
exp=ptExpSU3,
beta=6.0,
# x=state.x,
job_type='eval',
nlog=1,
nprint=50,
grab=True,
)
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
dataset_eval = history_eval.get_dataset()
_ = history_eval.plot_all(title='Eval')
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
xeval = ptExpSU3.trainer.dynamics.unflatten(xeval)
logger.info(f"checkSU(x_eval): {g.checkSU(xeval)}")
logger.info(f"checkSU(x_eval): {g.checkSU(g.projectSU(xeval))}")
```
```{python}
#| jupyter: {outputs_hidden: false, source_hidden: false}
import matplotlib.pyplot as plt
pdiff = dataset_eval.plaqs - dataset_hmc.plaqs
pdiff
import xarray as xr
fig, ax = plt.subplots(figsize=(12, 4))
(pdiff ** 2).plot(ax=ax) #, robust=True)
ax.set_title(r"$\left|\delta U_{\mu\nu}\right|^{2}$ (HMC - Eval)")
outfile = Path(EVAL_DIR).joinpath('pdiff.svg')
#%xmode fig.savefig(outfile.as_posix(), dpi=400, bbox_inches='tight')
plt.show()
```