|
| 1 | +import numpy as np |
| 2 | +from collections import Counter |
| 3 | +from sklearn.datasets import fetch_openml |
| 4 | +from skimage.transform import resize |
| 5 | +import warnings |
| 6 | +warnings.filterwarnings("ignore") |
| 7 | + |
| 8 | +# --- STEP 1: Load and preprocess MNIST zeros (4x4 binarized) --- |
| 9 | + |
| 10 | +print("Downloading and preprocessing MNIST...") |
| 11 | +mnist = fetch_openml("mnist_784", version=1, as_frame=False) |
| 12 | +X, y = mnist["data"], mnist["target"] |
| 13 | +X_zeros = X[y == '0'] / 255.0 # Normalize |
| 14 | +X_zeros = X_zeros[:200] # For speed |
| 15 | + |
| 16 | +def downsample_binarize(img, size=4): |
| 17 | + img = img.reshape(28, 28) |
| 18 | + small = resize(img, (size, size), order=0, anti_aliasing=False, preserve_range=True) |
| 19 | + binary = (small > 0.5).astype(int) |
| 20 | + return ''.join(map(str, binary.flatten())) |
| 21 | + |
| 22 | +samples_bin = [downsample_binarize(img) for img in X_zeros] |
| 23 | +data_dist = Counter(samples_bin) |
| 24 | +total = sum(data_dist.values()) |
| 25 | +data_dist = {k: v / total for k, v in data_dist.items()} |
| 26 | + |
| 27 | +# --- STEP 2: Quantum Circuit Utils --- |
| 28 | + |
| 29 | +# R_y rotation |
| 30 | +def Ry(theta): |
| 31 | + return np.array([ |
| 32 | + [np.cos(theta/2), -np.sin(theta/2)], |
| 33 | + [np.sin(theta/2), np.cos(theta/2)] |
| 34 | + ]) |
| 35 | + |
| 36 | +# CNOT gate for any 2 qubits |
| 37 | +def CNOT(n, control, target): |
| 38 | + dim = 2**n |
| 39 | + op = np.zeros((dim, dim), dtype=complex) |
| 40 | + for i in range(dim): |
| 41 | + bits = list(np.binary_repr(i, width=n)) |
| 42 | + if bits[control] == '1': |
| 43 | + bits[target] = '1' if bits[target] == '0' else '0' |
| 44 | + j = int(''.join(bits), 2) |
| 45 | + op[i, j] = 1 |
| 46 | + return op |
| 47 | + |
| 48 | +# Build the quantum state from params |
| 49 | +def variational_state(params): |
| 50 | + n = len(params) |
| 51 | + state = np.zeros(2**n, dtype=complex) |
| 52 | + state[0] = 1 |
| 53 | + |
| 54 | + # Apply Ry rotations |
| 55 | + U = 1 |
| 56 | + for theta in params: |
| 57 | + U = np.kron(U, Ry(theta)) |
| 58 | + state = U @ state |
| 59 | + |
| 60 | + # Apply entangling CNOTs: linear chain |
| 61 | + for i in range(n - 1): |
| 62 | + state = CNOT(n, i, i + 1) @ state |
| 63 | + |
| 64 | + return state |
| 65 | + |
| 66 | +# Sample bitstrings from state |
| 67 | +def sample_state(psi, num_samples=1000): |
| 68 | + probs = np.abs(psi)**2 |
| 69 | + states = [format(i, f'0{int(np.log2(len(psi)))}b') for i in range(len(psi))] |
| 70 | + return np.random.choice(states, size=num_samples, p=probs) |
| 71 | + |
| 72 | +# Get distribution from samples |
| 73 | +def get_prob_dist(samples): |
| 74 | + counts = Counter(samples) |
| 75 | + total = sum(counts.values()) |
| 76 | + return {x: c / total for x, c in counts.items()} |
| 77 | + |
| 78 | +# KL divergence: D_KL(p || q) |
| 79 | +def kl_divergence(p, q, eps=1e-10): |
| 80 | + kl = 0.0 |
| 81 | + for x in p: |
| 82 | + px = p[x] |
| 83 | + qx = q.get(x, eps) |
| 84 | + kl += px * np.log(px / (qx + eps)) |
| 85 | + return kl |
| 86 | + |
| 87 | +# Parameter-shift gradients |
| 88 | +def parameter_shift_grad(params, data_dist, shift=np.pi/2, num_samples=500): |
| 89 | + grads = np.zeros_like(params) |
| 90 | + for i in range(len(params)): |
| 91 | + plus = params.copy() |
| 92 | + minus = params.copy() |
| 93 | + plus[i] += shift |
| 94 | + minus[i] -= shift |
| 95 | + |
| 96 | + psi_plus = variational_state(plus) |
| 97 | + psi_minus = variational_state(minus) |
| 98 | + dist_plus = get_prob_dist(sample_state(psi_plus, num_samples)) |
| 99 | + dist_minus = get_prob_dist(sample_state(psi_minus, num_samples)) |
| 100 | + |
| 101 | + kl_plus = kl_divergence(data_dist, dist_plus) |
| 102 | + kl_minus = kl_divergence(data_dist, dist_minus) |
| 103 | + grads[i] = 0.5 * (kl_plus - kl_minus) |
| 104 | + return grads |
| 105 | + |
| 106 | +# --- STEP 3: Training VQBM on MNIST patterns --- |
| 107 | + |
| 108 | +n_qubits = 4 |
| 109 | +params = np.random.uniform(0, 2*np.pi, size=n_qubits) |
| 110 | +lr = 0.2 |
| 111 | + |
| 112 | +print("\nTraining VQBM...\n") |
| 113 | +for step in range(100): |
| 114 | + psi = variational_state(params) |
| 115 | + model_samples = sample_state(psi, num_samples=1000) |
| 116 | + model_dist = get_prob_dist(model_samples) |
| 117 | + loss = kl_divergence(data_dist, model_dist) |
| 118 | + |
| 119 | + grads = parameter_shift_grad(params, data_dist) |
| 120 | + params -= lr * grads |
| 121 | + |
| 122 | + if step % 10 == 0: |
| 123 | + print(f"Step {step:3d}: KL Divergence = {loss:.5f}") |
| 124 | + |
| 125 | +# --- STEP 4: Results --- |
| 126 | + |
| 127 | +print("\nFinal learned distribution (top states):") |
| 128 | +final_samples = sample_state(variational_state(params), num_samples=2000) |
| 129 | +final_dist = get_prob_dist(final_samples) |
| 130 | +for k, v in sorted(final_dist.items(), key=lambda x: -x[1])[:10]: |
| 131 | + print(f"{k}: {v:.4f}") |
0 commit comments