11"""
2- Classifier avec Cross-Entropy
2+ Classifier Binaire avec Symmetric BCE Loss
33
4- Transforme les embeddings contextualisés en prédictions par nœud.
4+ Transforme les embeddings contextualisés en prédictions binaires par nœud.
5+ La Symmetric BCE Loss gère automatiquement la symétrie des solutions (MaxCut).
56
67Input: [batch, n_nodes, hidden_dim]
7- Output: [batch, n_nodes, num_classes ] (logits )
8+ Output: [batch, n_nodes] (probabilités entre 0 et 1 )
89"""
910
1011import torch
1415
1516class Classifier (nn .Module ):
1617 """
17- Classifier multi-classe pour les problèmes d'optimisation.
18+ Classifier binaire pour les problèmes d'optimisation sur graphes .
1819
19- Supporte:
20- - MaxCut, Vertex Cover, Independent Set: 2 classes
21- - Graph Coloring: k classes
20+ Supporte: MaxCut, Vertex Cover, Independent Set (tous binaires).
21+
22+ Loss: Symmetric BCE
23+ loss = min(BCE(pred, target), BCE(pred, 1-target))
24+ → Gère automatiquement la symétrie des solutions
2225 """
2326
24- def __init__ (self , hidden_dim = 256 , max_classes = 10 , dropout = 0.1 ):
27+ def __init__ (self , hidden_dim = 256 , dropout = 0.1 ):
2528 super ().__init__ ()
2629
27- self .hidden_dim = hidden_dim
28- self .max_classes = max_classes
29-
3030 self .layers = nn .Sequential (
3131 nn .LayerNorm (hidden_dim ),
3232 nn .Linear (hidden_dim , hidden_dim // 2 ),
3333 nn .GELU (),
3434 nn .Dropout (dropout ),
35- nn .Linear (hidden_dim // 2 , max_classes )
35+ nn .Linear (hidden_dim // 2 , 1 ) # 1 seule sortie → binaire
3636 )
3737
38- def forward (self , x , num_classes = 2 ):
38+ def forward (self , x ):
3939 """
4040 Args:
4141 x: [batch, n_nodes, hidden_dim] - embeddings contextualisés
42- num_classes: int - nombre de classes
4342
4443 Returns:
45- logits: [batch, n_nodes, num_classes]
46- probs: [batch, n_nodes, num_classes]
47- predictions: [batch, n_nodes]
44+ probs: [batch, n_nodes] - probabilités entre 0 et 1
45+ predictions: [batch, n_nodes] - 0 ou 1
4846 """
49- logits = self .layers (x )[:, :, : num_classes ]
50- probs = F . softmax (logits , dim = - 1 )
51- predictions = torch . argmax ( logits , dim = - 1 )
47+ logits = self .layers (x ). squeeze ( - 1 ) # [batch, n_nodes ]
48+ probs = torch . sigmoid (logits ) # Sigmoid → [0, 1]
49+ predictions = ( probs > 0.5 ). long () # Seuil → 0 ou 1
5250
5351 return {
5452 'logits' : logits ,
@@ -58,47 +56,105 @@ def forward(self, x, num_classes=2):
5856
5957 def compute_loss (self , logits , targets , mask = None ):
6058 """
61- Cross-Entropy Loss.
59+ Symmetric BCE Loss.
60+
61+ Calcule la BCE dans les deux sens (target et 1-target)
62+ et garde le minimum → gère la symétrie.
6263
6364 Args:
64- logits: [batch, n_nodes, num_classes]
65- targets: [batch, n_nodes] - classes {0, 1, ..., k-1}
66- mask: [batch, n_nodes] - optionnel
65+ logits: [batch, n_nodes] - sorties brutes (avant sigmoid)
66+ targets: [batch, n_nodes] - valeurs 0 ou 1
67+ mask: [batch, n_nodes] - optionnel (pour graphes de tailles différentes)
6768
6869 Returns:
6970 loss: scalar
7071 """
71- b , n , c = logits .shape
72- logits_flat = logits .reshape (- 1 , c )
73- targets_flat = targets .reshape (- 1 ).long ()
72+ targets = targets .float ()
7473
7574 if mask is not None :
76- mask_flat = mask .reshape (- 1 ).float ()
77- loss = F .cross_entropy (logits_flat , targets_flat , reduction = 'none' )
78- return (loss * mask_flat ).sum () / mask_flat .sum ().clamp (min = 1 )
75+ mask = mask .float ()
76+
77+ # Loss directe : pred vs target
78+ loss_direct = F .binary_cross_entropy_with_logits (
79+ logits , targets , reduction = 'none'
80+ )
81+ loss_direct = (loss_direct * mask ).sum () / mask .sum ().clamp (min = 1 )
82+
83+ # Loss inversée : pred vs (1 - target)
84+ loss_inverse = F .binary_cross_entropy_with_logits (
85+ logits , 1.0 - targets , reduction = 'none'
86+ )
87+ loss_inverse = (loss_inverse * mask ).sum () / mask .sum ().clamp (min = 1 )
88+ else :
89+ # Loss directe : pred vs target
90+ loss_direct = F .binary_cross_entropy_with_logits (logits , targets )
7991
80- return F .cross_entropy (logits_flat , targets_flat )
92+ # Loss inversée : pred vs (1 - target)
93+ loss_inverse = F .binary_cross_entropy_with_logits (logits , 1.0 - targets )
94+
95+ # Symmetric : on prend le minimum des deux
96+ loss = torch .min (loss_direct , loss_inverse )
97+
98+ return loss
99+
100+ def compute_similarity (self , predictions , targets ):
101+ """
102+ Calcule le pourcentage de ressemblance (en tenant compte de la symétrie).
103+
104+ Args:
105+ predictions: [batch, n_nodes] - 0 ou 1
106+ targets: [batch, n_nodes] - 0 ou 1
107+
108+ Returns:
109+ similarity: float entre 0 et 1 (1 = parfait)
110+ """
111+ predictions = predictions .float ()
112+ targets = targets .float ()
113+
114+ # Ressemblance directe
115+ match_direct = (predictions == targets ).float ().mean ()
116+
117+ # Ressemblance inversée
118+ match_inverse = (predictions == (1.0 - targets )).float ().mean ()
119+
120+ # Meilleure des deux
121+ similarity = torch .max (match_direct , match_inverse )
122+
123+ return similarity .item ()
81124
82125
83126if __name__ == "__main__" :
84- print ("=== Test Classifier ===" )
127+ print ("=== Test Classifier (Binaire + Symmetric BCE) ===\n " )
85128
86129 x = torch .randn (4 , 6 , 256 ) # [batch, n_nodes, hidden_dim]
87- classifier = Classifier (hidden_dim = 256 , max_classes = 10 )
130+ classifier = Classifier (hidden_dim = 256 )
88131
89- # Test 2 classes (MaxCut)
90- output = classifier (x , num_classes = 2 )
91- print (f"Logits (2 classes): { output ['logits' ].shape } " )
132+ # Forward
133+ output = classifier (x )
134+ print (f"Logits: { output ['logits' ].shape } " )
135+ print (f"Probs: { output ['probs' ].shape } " )
136+ print (f"Predictions: { output ['predictions' ].shape } " )
137+ print (f"Exemple probs: { output ['probs' ][0 ].tolist ()} " )
138+ print (f"Exemple preds: { output ['predictions' ][0 ].tolist ()} " )
92139
93- # Test 5 classes (Graph Coloring)
94- output = classifier (x , num_classes = 5 )
95- print (f"Logits (5 classes): { output ['logits' ].shape } " )
140+ # Test Symmetric Loss
141+ targets = torch .tensor ([[1 , 0 , 1 , 0 , 1 , 0 ]] * 4 ).float ()
96142
97- # Test loss
98- targets = torch .randint (0 , 2 , (4 , 6 ))
99- output = classifier (x , num_classes = 2 )
100143 loss = classifier .compute_loss (output ['logits' ], targets )
101- print (f"Loss: { loss .item ():.4f} " )
102-
103- print (f"Params: { sum (p .numel () for p in classifier .parameters ()):,} " )
144+ print (f"\n Symmetric BCE Loss: { loss .item ():.4f} " )
145+
146+ # Test symétrie : target et 1-target doivent donner la même loss
147+ loss_normal = classifier .compute_loss (output ['logits' ], targets )
148+ loss_inverted = classifier .compute_loss (output ['logits' ], 1 - targets )
149+ print (f"Loss (target normal): { loss_normal .item ():.4f} " )
150+ print (f"Loss (target inversé): { loss_inverted .item ():.4f} " )
151+ print (f"Égales ? { '✅ OUI' if abs (loss_normal .item () - loss_inverted .item ()) < 1e-6 else '❌ NON' } " )
152+
153+ # Test similarité
154+ pred = torch .tensor ([[0 , 1 , 0 , 1 , 0 , 1 ]])
155+ target = torch .tensor ([[1 , 0 , 1 , 0 , 1 , 0 ]])
156+ sim = classifier .compute_similarity (pred , target )
157+ print (f"\n Similarité [0,1,0,1,0,1] vs [1,0,1,0,1,0]: { sim :.0%} " )
158+
159+ print (f"\n Params: { sum (p .numel () for p in classifier .parameters ()):,} " )
104160 print ("✅ OK" )
0 commit comments