Skip to content

Commit f22e4a7

Browse files
Merge pull request #8 from PoCInnovation/feat/prototype-v1
[ADD]: Symetric BCE + binary class
2 parents 4c0e408 + b44469b commit f22e4a7

4 files changed

Lines changed: 338 additions & 150 deletions

File tree

prototype_v1/classifier.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""
2-
Classifier avec Cross-Entropy
2+
Classifier Binaire avec Symmetric BCE Loss
33
4-
Transforme les embeddings contextualisés en prédictions par nœud.
4+
Transforme les embeddings contextualisés en prédictions binaires par nœud.
5+
La Symmetric BCE Loss gère automatiquement la symétrie des solutions (MaxCut).
56
67
Input: [batch, n_nodes, hidden_dim]
7-
Output: [batch, n_nodes, num_classes] (logits)
8+
Output: [batch, n_nodes] (probabilités entre 0 et 1)
89
"""
910

1011
import torch
@@ -14,41 +15,38 @@
1415

1516
class Classifier(nn.Module):
1617
"""
17-
Classifier multi-classe pour les problèmes d'optimisation.
18+
Classifier binaire pour les problèmes d'optimisation sur graphes.
1819
19-
Supporte:
20-
- MaxCut, Vertex Cover, Independent Set: 2 classes
21-
- Graph Coloring: k classes
20+
Supporte: MaxCut, Vertex Cover, Independent Set (tous binaires).
21+
22+
Loss: Symmetric BCE
23+
loss = min(BCE(pred, target), BCE(pred, 1-target))
24+
→ Gère automatiquement la symétrie des solutions
2225
"""
2326

24-
def __init__(self, hidden_dim=256, max_classes=10, dropout=0.1):
27+
def __init__(self, hidden_dim=256, dropout=0.1):
2528
super().__init__()
2629

27-
self.hidden_dim = hidden_dim
28-
self.max_classes = max_classes
29-
3030
self.layers = nn.Sequential(
3131
nn.LayerNorm(hidden_dim),
3232
nn.Linear(hidden_dim, hidden_dim // 2),
3333
nn.GELU(),
3434
nn.Dropout(dropout),
35-
nn.Linear(hidden_dim // 2, max_classes)
35+
nn.Linear(hidden_dim // 2, 1) # 1 seule sortie → binaire
3636
)
3737

38-
def forward(self, x, num_classes=2):
38+
def forward(self, x):
3939
"""
4040
Args:
4141
x: [batch, n_nodes, hidden_dim] - embeddings contextualisés
42-
num_classes: int - nombre de classes
4342
4443
Returns:
45-
logits: [batch, n_nodes, num_classes]
46-
probs: [batch, n_nodes, num_classes]
47-
predictions: [batch, n_nodes]
44+
probs: [batch, n_nodes] - probabilités entre 0 et 1
45+
predictions: [batch, n_nodes] - 0 ou 1
4846
"""
49-
logits = self.layers(x)[:, :, :num_classes]
50-
probs = F.softmax(logits, dim=-1)
51-
predictions = torch.argmax(logits, dim=-1)
47+
logits = self.layers(x).squeeze(-1) # [batch, n_nodes]
48+
probs = torch.sigmoid(logits) # Sigmoid → [0, 1]
49+
predictions = (probs > 0.5).long() # Seuil → 0 ou 1
5250

5351
return {
5452
'logits': logits,
@@ -58,47 +56,105 @@ def forward(self, x, num_classes=2):
5856

5957
def compute_loss(self, logits, targets, mask=None):
6058
"""
61-
Cross-Entropy Loss.
59+
Symmetric BCE Loss.
60+
61+
Calcule la BCE dans les deux sens (target et 1-target)
62+
et garde le minimum → gère la symétrie.
6263
6364
Args:
64-
logits: [batch, n_nodes, num_classes]
65-
targets: [batch, n_nodes] - classes {0, 1, ..., k-1}
66-
mask: [batch, n_nodes] - optionnel
65+
logits: [batch, n_nodes] - sorties brutes (avant sigmoid)
66+
targets: [batch, n_nodes] - valeurs 0 ou 1
67+
mask: [batch, n_nodes] - optionnel (pour graphes de tailles différentes)
6768
6869
Returns:
6970
loss: scalar
7071
"""
71-
b, n, c = logits.shape
72-
logits_flat = logits.reshape(-1, c)
73-
targets_flat = targets.reshape(-1).long()
72+
targets = targets.float()
7473

7574
if mask is not None:
76-
mask_flat = mask.reshape(-1).float()
77-
loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
78-
return (loss * mask_flat).sum() / mask_flat.sum().clamp(min=1)
75+
mask = mask.float()
76+
77+
# Loss directe : pred vs target
78+
loss_direct = F.binary_cross_entropy_with_logits(
79+
logits, targets, reduction='none'
80+
)
81+
loss_direct = (loss_direct * mask).sum() / mask.sum().clamp(min=1)
82+
83+
# Loss inversée : pred vs (1 - target)
84+
loss_inverse = F.binary_cross_entropy_with_logits(
85+
logits, 1.0 - targets, reduction='none'
86+
)
87+
loss_inverse = (loss_inverse * mask).sum() / mask.sum().clamp(min=1)
88+
else:
89+
# Loss directe : pred vs target
90+
loss_direct = F.binary_cross_entropy_with_logits(logits, targets)
7991

80-
return F.cross_entropy(logits_flat, targets_flat)
92+
# Loss inversée : pred vs (1 - target)
93+
loss_inverse = F.binary_cross_entropy_with_logits(logits, 1.0 - targets)
94+
95+
# Symmetric : on prend le minimum des deux
96+
loss = torch.min(loss_direct, loss_inverse)
97+
98+
return loss
99+
100+
def compute_similarity(self, predictions, targets):
101+
"""
102+
Calcule le pourcentage de ressemblance (en tenant compte de la symétrie).
103+
104+
Args:
105+
predictions: [batch, n_nodes] - 0 ou 1
106+
targets: [batch, n_nodes] - 0 ou 1
107+
108+
Returns:
109+
similarity: float entre 0 et 1 (1 = parfait)
110+
"""
111+
predictions = predictions.float()
112+
targets = targets.float()
113+
114+
# Ressemblance directe
115+
match_direct = (predictions == targets).float().mean()
116+
117+
# Ressemblance inversée
118+
match_inverse = (predictions == (1.0 - targets)).float().mean()
119+
120+
# Meilleure des deux
121+
similarity = torch.max(match_direct, match_inverse)
122+
123+
return similarity.item()
81124

82125

83126
if __name__ == "__main__":
84-
print("=== Test Classifier ===")
127+
print("=== Test Classifier (Binaire + Symmetric BCE) ===\n")
85128

86129
x = torch.randn(4, 6, 256) # [batch, n_nodes, hidden_dim]
87-
classifier = Classifier(hidden_dim=256, max_classes=10)
130+
classifier = Classifier(hidden_dim=256)
88131

89-
# Test 2 classes (MaxCut)
90-
output = classifier(x, num_classes=2)
91-
print(f"Logits (2 classes): {output['logits'].shape}")
132+
# Forward
133+
output = classifier(x)
134+
print(f"Logits: {output['logits'].shape}")
135+
print(f"Probs: {output['probs'].shape}")
136+
print(f"Predictions: {output['predictions'].shape}")
137+
print(f"Exemple probs: {output['probs'][0].tolist()}")
138+
print(f"Exemple preds: {output['predictions'][0].tolist()}")
92139

93-
# Test 5 classes (Graph Coloring)
94-
output = classifier(x, num_classes=5)
95-
print(f"Logits (5 classes): {output['logits'].shape}")
140+
# Test Symmetric Loss
141+
targets = torch.tensor([[1, 0, 1, 0, 1, 0]] * 4).float()
96142

97-
# Test loss
98-
targets = torch.randint(0, 2, (4, 6))
99-
output = classifier(x, num_classes=2)
100143
loss = classifier.compute_loss(output['logits'], targets)
101-
print(f"Loss: {loss.item():.4f}")
102-
103-
print(f"Params: {sum(p.numel() for p in classifier.parameters()):,}")
144+
print(f"\nSymmetric BCE Loss: {loss.item():.4f}")
145+
146+
# Test symétrie : target et 1-target doivent donner la même loss
147+
loss_normal = classifier.compute_loss(output['logits'], targets)
148+
loss_inverted = classifier.compute_loss(output['logits'], 1 - targets)
149+
print(f"Loss (target normal): {loss_normal.item():.4f}")
150+
print(f"Loss (target inversé): {loss_inverted.item():.4f}")
151+
print(f"Égales ? {'✅ OUI' if abs(loss_normal.item() - loss_inverted.item()) < 1e-6 else '❌ NON'}")
152+
153+
# Test similarité
154+
pred = torch.tensor([[0, 1, 0, 1, 0, 1]])
155+
target = torch.tensor([[1, 0, 1, 0, 1, 0]])
156+
sim = classifier.compute_similarity(pred, target)
157+
print(f"\nSimilarité [0,1,0,1,0,1] vs [1,0,1,0,1,0]: {sim:.0%}")
158+
159+
print(f"\nParams: {sum(p.numel() for p in classifier.parameters()):,}")
104160
print("✅ OK")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
Format QAOA (Qiskit)
2+
3+
result = solver.solve(problem)
4+
print(result.x) # numpy array: [0, 1, 0, 1]
5+
print(type(result.x)) # <class 'numpy.ndarray'>
6+
Format Notre Modèle
7+
8+
output = model(x, edge_index, problem_id=0)
9+
print(output['predictions']) # tensor([[0, 1, 0, 1]])
10+
print(type(output['predictions'])) # <class 'torch.Tensor'>
11+
12+
13+
Sont-ils comparables ?
14+
QAOA Notre Modèle
15+
Type numpy.ndarray torch.Tensor
16+
Shape [n_nodes] [batch, n_nodes]
17+
Valeurs {0, 1} {0, 1}
18+
19+
Signification Nœud i dans set 0 ou 1 Nœud i dans set 0 ou 1
20+
OUI mais il faut convertir :
21+
22+
23+
# QAOA → Tensor pour comparaison
24+
qaoa_target = torch.tensor(result.x) # [0, 1, 0, 1]
25+
26+
# Notre modèle → squeeze pour enlever batch dim
27+
model_pred = output['predictions'].squeeze(0) # [0, 1, 0, 1]
28+
29+
# Maintenant comparable !
30+
⚠️ ATTENTION : Symétrie du problème !
31+
Pour MaxCut, il y a une subtilité :
32+
33+
34+
Solution [0, 1, 0, 1] = Set A: {0, 2}, Set B: {1, 3}
35+
Solution [1, 0, 1, 0] = Set A: {1, 3}, Set B: {0, 2}
36+
37+
CE SONT LES MÊMES SOLUTIONS ! (juste inversées)
38+
Donc si :
39+
40+
41+
QAOA dit: [0, 1, 0, 1]
42+
Modèle dit: [1, 0, 1, 0]
43+
44+
→ C'est CORRECT ! Même partition, juste labels inversés.
45+
Solutions pour gérer la symétrie
46+
Option 1 : Normaliser (forcer à commencer par 0)
47+
48+
def normalize_bitstring(bits):
49+
if bits[0] == 1:
50+
return 1 - bits # Inverser
51+
return bits
52+
53+
target = normalize_bitstring(qaoa_result)
54+
pred = normalize_bitstring(model_pred)

prototype_v1/model.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Graph → GNN Encoder → E_local, E_global
66
problem_id → Lookup Table → E_prob
77
Concat [E_global || E_local || E_prob] → Transformer → embeddings contextualisés
8-
Classifier → logitsCross-Entropy Loss → Backpropagation
8+
Classifier binaire → probsSymmetric BCE Loss → Backpropagation
99
"""
1010

1111
import torch
@@ -20,11 +20,12 @@ class QuantumGraphModel(nn.Module):
2020
"""
2121
Modèle complet pour résoudre des problèmes d'optimisation sur graphes.
2222
23-
Supporte:
23+
Supporte (tous binaires):
2424
- MaxCut (2 classes)
2525
- Vertex Cover (2 classes)
2626
- Independent Set (2 classes)
27-
- Graph Coloring (k classes)
27+
28+
Loss: Symmetric BCE (gère la symétrie des solutions)
2829
"""
2930

3031
def __init__(
@@ -36,7 +37,6 @@ def __init__(
3637
transformer_layers=4,
3738
num_heads=8,
3839
num_problems=10,
39-
max_classes=10,
4040
dropout=0.1
4141
):
4242
super().__init__()
@@ -67,24 +67,22 @@ def __init__(
6767
dropout=dropout
6868
)
6969

70-
# 4. Classifier
70+
# 4. Classifier (binaire)
7171
self.classifier = Classifier(
7272
hidden_dim=hidden_dim,
73-
max_classes=max_classes,
7473
dropout=dropout
7574
)
7675

77-
def forward(self, x, edge_index, problem_id, batch=None, num_classes=2):
76+
def forward(self, x, edge_index, problem_id, batch=None):
7877
"""
7978
Args:
8079
x: [n_nodes, node_input_dim]
8180
edge_index: [2, n_edges]
8281
problem_id: int ou [batch_size]
8382
batch: [n_nodes] (optionnel)
84-
num_classes: int
8583
8684
Returns:
87-
dict avec logits, probs, predictions
85+
dict avec logits, probs, predictions (tous [batch, n_nodes])
8886
"""
8987
# 1. GNN Encoder → E_local, E_global
9088
e_local, e_global = self.encoder(x, edge_index, batch)
@@ -105,8 +103,8 @@ def forward(self, x, edge_index, problem_id, batch=None, num_classes=2):
105103
# 3. Transformer → embeddings contextualisés
106104
contextualized = self.transformer(e_local, e_global, e_prob)
107105

108-
# 4. Classifier → logits, probs, predictions
109-
output = self.classifier(contextualized, num_classes=num_classes)
106+
# 4. Classifier binaire → probs, predictions
107+
output = self.classifier(contextualized)
110108

111109
return output
112110

@@ -126,14 +124,19 @@ def _batch_node_embeddings(self, e_local, batch, batch_size):
126124
return out
127125

128126
def compute_loss(self, logits, targets, mask=None):
129-
"""Cross-Entropy Loss"""
127+
"""Symmetric BCE Loss"""
130128
return self.classifier.compute_loss(logits, targets, mask)
131129

132-
def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None, num_classes=2):
130+
def compute_similarity(self, predictions, targets):
131+
"""Pourcentage de ressemblance (avec symétrie)"""
132+
return self.classifier.compute_similarity(predictions, targets)
133+
134+
def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None):
133135
"""Forward + Loss en une seule passe"""
134-
output = self.forward(x, edge_index, problem_id, batch, num_classes)
136+
output = self.forward(x, edge_index, problem_id, batch)
135137
loss = self.compute_loss(output['logits'], targets)
136-
return output, loss
138+
similarity = self.compute_similarity(output['predictions'], targets)
139+
return output, loss, similarity
137140

138141

139142
if __name__ == "__main__":
@@ -159,24 +162,27 @@ def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None, num_
159162
print(f"Paramètres: {sum(p.numel() for p in model.parameters()):,}")
160163

161164
# Forward (MaxCut)
162-
output = model(x, edge_index, problem_id=0, num_classes=2)
163-
print(f"\nMaxCut (2 classes):")
164-
print(f" Logits: {output['logits'].shape}")
165+
output = model(x, edge_index, problem_id=0)
166+
print(f"\nMaxCut:")
167+
print(f" Probs: {output['probs'].shape}")
165168
print(f" Predictions: {output['predictions']}")
166169

167-
# Loss
170+
# Symmetric Loss
168171
targets = torch.tensor([[1, 0, 1, 0, 1, 0]])
169172
loss = model.compute_loss(output['logits'], targets)
170173
print(f" Loss: {loss.item():.4f}")
171174

175+
# Test symétrie
176+
loss_inv = model.compute_loss(output['logits'], 1 - targets)
177+
print(f" Loss inversée: {loss_inv.item():.4f}")
178+
print(f" Symétrie OK ? {'✅' if abs(loss.item() - loss_inv.item()) < 1e-6 else '❌'}")
179+
180+
# Similarité
181+
sim = model.compute_similarity(output['predictions'], targets)
182+
print(f" Similarité: {sim:.0%}")
183+
172184
# Backprop
173185
loss.backward()
174186
print(" Backprop OK")
175187

176-
# Graph Coloring
177-
output = model(x, edge_index, problem_id=3, num_classes=5)
178-
print(f"\nGraph Coloring (5 classes):")
179-
print(f" Logits: {output['logits'].shape}")
180-
print(f" Predictions: {output['predictions']}")
181-
182188
print("\n✅ Tous les tests passés!")

0 commit comments

Comments
 (0)