From 124aa1d3e6acec38ea2bad78314d54b435303113 Mon Sep 17 00:00:00 2001 From: Pierre Baud Date: Mon, 2 Feb 2026 16:15:54 +0100 Subject: [PATCH] [ADD]: protype with Classifier, encoder, model, problem_embedding, transformers + testscripte --- prototype_v1/classifier.py | 104 +++++++++ prototype_v1/encoder.py | 148 ++++++++++++ prototype_v1/model.py | 182 +++++++++++++++ prototype_v1/problem_embedding.py | 60 +++++ prototype_v1/test_model.py | 367 ++++++++++++++++++++++++++++++ prototype_v1/transformer.py | 103 +++++++++ 6 files changed, 964 insertions(+) create mode 100644 prototype_v1/classifier.py create mode 100644 prototype_v1/encoder.py create mode 100644 prototype_v1/model.py create mode 100644 prototype_v1/problem_embedding.py create mode 100644 prototype_v1/test_model.py create mode 100644 prototype_v1/transformer.py diff --git a/prototype_v1/classifier.py b/prototype_v1/classifier.py new file mode 100644 index 0000000..4b5d938 --- /dev/null +++ b/prototype_v1/classifier.py @@ -0,0 +1,104 @@ +""" +Classifier avec Cross-Entropy + +Transforme les embeddings contextualisés en prédictions par nœud. + +Input: [batch, n_nodes, hidden_dim] +Output: [batch, n_nodes, num_classes] (logits) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Classifier(nn.Module): + """ + Classifier multi-classe pour les problèmes d'optimisation. + + Supporte: + - MaxCut, Vertex Cover, Independent Set: 2 classes + - Graph Coloring: k classes + """ + + def __init__(self, hidden_dim=256, max_classes=10, dropout=0.1): + super().__init__() + + self.hidden_dim = hidden_dim + self.max_classes = max_classes + + self.layers = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, max_classes) + ) + + def forward(self, x, num_classes=2): + """ + Args: + x: [batch, n_nodes, hidden_dim] - embeddings contextualisés + num_classes: int - nombre de classes + + Returns: + logits: [batch, n_nodes, num_classes] + probs: [batch, n_nodes, num_classes] + predictions: [batch, n_nodes] + """ + logits = self.layers(x)[:, :, :num_classes] + probs = F.softmax(logits, dim=-1) + predictions = torch.argmax(logits, dim=-1) + + return { + 'logits': logits, + 'probs': probs, + 'predictions': predictions + } + + def compute_loss(self, logits, targets, mask=None): + """ + Cross-Entropy Loss. + + Args: + logits: [batch, n_nodes, num_classes] + targets: [batch, n_nodes] - classes {0, 1, ..., k-1} + mask: [batch, n_nodes] - optionnel + + Returns: + loss: scalar + """ + b, n, c = logits.shape + logits_flat = logits.reshape(-1, c) + targets_flat = targets.reshape(-1).long() + + if mask is not None: + mask_flat = mask.reshape(-1).float() + loss = F.cross_entropy(logits_flat, targets_flat, reduction='none') + return (loss * mask_flat).sum() / mask_flat.sum().clamp(min=1) + + return F.cross_entropy(logits_flat, targets_flat) + + +if __name__ == "__main__": + print("=== Test Classifier ===") + + x = torch.randn(4, 6, 256) # [batch, n_nodes, hidden_dim] + classifier = Classifier(hidden_dim=256, max_classes=10) + + # Test 2 classes (MaxCut) + output = classifier(x, num_classes=2) + print(f"Logits (2 classes): {output['logits'].shape}") + + # Test 5 classes (Graph Coloring) + output = classifier(x, num_classes=5) + print(f"Logits (5 classes): {output['logits'].shape}") + + # Test loss + targets = torch.randint(0, 2, (4, 6)) + output = classifier(x, num_classes=2) + loss = classifier.compute_loss(output['logits'], targets) + print(f"Loss: {loss.item():.4f}") + + print(f"Params: {sum(p.numel() for p in classifier.parameters()):,}") + print("✅ OK") diff --git a/prototype_v1/encoder.py b/prototype_v1/encoder.py new file mode 100644 index 0000000..184dee4 --- /dev/null +++ b/prototype_v1/encoder.py @@ -0,0 +1,148 @@ +""" +Encodeur GNN pour le prototype v1 + +Architecture simplifiée: +- GAT layers pour encoder la structure du graphe +- Produit: node_embeddings [batch, n_nodes, hidden_dim] +- Produit: graph_embedding [batch, hidden_dim] via pooling +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GATConv, global_mean_pool, global_add_pool + + +class GNNEncoder(nn.Module): + """ + Encodeur GNN basé sur GAT. + + Entrée: Graphe (x, edge_index) + Sortie: + - node_embeddings: [n_nodes, hidden_dim] (embeddings locaux) + - graph_embedding: [batch_size, hidden_dim] (embedding global via pooling) + """ + + def __init__( + self, + input_dim=7, # Dimension des features de nœuds + hidden_dim=128, # Dimension des embeddings (128 comme dans le diagramme) + num_layers=4, # Nombre de couches GAT + num_heads=4, # Nombre de têtes d'attention + dropout=0.1 + ): + super().__init__() + + self.hidden_dim = hidden_dim + self.num_layers = num_layers + + # Couches GAT + self.gat_layers = nn.ModuleList() + self.norms = nn.ModuleList() + + # Première couche: input_dim -> hidden_dim + self.gat_layers.append( + GATConv( + input_dim, + hidden_dim // num_heads, + heads=num_heads, + dropout=dropout, + concat=True + ) + ) + self.norms.append(nn.LayerNorm(hidden_dim)) + + # Couches intermédiaires: hidden_dim -> hidden_dim + for _ in range(num_layers - 1): + self.gat_layers.append( + GATConv( + hidden_dim, + hidden_dim // num_heads, + heads=num_heads, + dropout=dropout, + concat=True + ) + ) + self.norms.append(nn.LayerNorm(hidden_dim)) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x, edge_index, batch=None): + """ + Args: + x: [n_nodes, input_dim] - features des nœuds + edge_index: [2, n_edges] - arêtes + batch: [n_nodes] - assignation des nœuds aux graphes du batch + + Returns: + node_embeddings: [n_nodes, hidden_dim] + graph_embedding: [batch_size, hidden_dim] + """ + # Gérer le cas sans batch (un seul graphe) + if batch is None: + batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) + + # Passer à travers les couches GAT + for i, (gat, norm) in enumerate(zip(self.gat_layers, self.norms)): + x_new = gat(x, edge_index) + x_new = norm(x_new) + x_new = F.elu(x_new) + x_new = self.dropout(x_new) + + # Residual connection après la première couche + if i > 0: + x = x + x_new + else: + x = x_new + + node_embeddings = x # [n_nodes, hidden_dim] + + # Pooling global pour l'embedding du graphe + # Utilise mean pooling (comme dans le diagramme) + graph_embedding = global_mean_pool(node_embeddings, batch) # [batch_size, hidden_dim] + + return node_embeddings, graph_embedding + + +def test_encoder(): + """Test de l'encodeur""" + print("=== Test GNNEncoder ===\n") + + # Paramètres + n_nodes = 6 + input_dim = 7 + hidden_dim = 128 + + # Données de test + x = torch.randn(n_nodes, input_dim) + # Graphe simple: 0-1, 1-2, 2-3, 3-4, 4-5, 5-0 + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 0], + [1, 2, 3, 4, 5, 0, 0, 1, 2, 3, 4, 5] + ]) + + # Créer l'encodeur + encoder = GNNEncoder( + input_dim=input_dim, + hidden_dim=hidden_dim, + num_layers=4, + num_heads=4 + ) + + # Forward pass + node_emb, graph_emb = encoder(x, edge_index) + + print(f"Input: x shape = {x.shape}") + print(f"Output: node_embeddings shape = {node_emb.shape}") + print(f"Output: graph_embedding shape = {graph_emb.shape}") + + # Vérifications + assert node_emb.shape == (n_nodes, hidden_dim), f"Expected ({n_nodes}, {hidden_dim})" + assert graph_emb.shape == (1, hidden_dim), f"Expected (1, {hidden_dim})" + + print(f"\nParamètres: {sum(p.numel() for p in encoder.parameters()):,}") + print("\n✅ Test passé!") + + +if __name__ == "__main__": + test_encoder() diff --git a/prototype_v1/model.py b/prototype_v1/model.py new file mode 100644 index 0000000..b721db2 --- /dev/null +++ b/prototype_v1/model.py @@ -0,0 +1,182 @@ +""" +Modèle Complet - Prototype v1 + +Pipeline: + Graph → GNN Encoder → E_local, E_global + problem_id → Lookup Table → E_prob + Concat [E_global || E_local || E_prob] → Transformer → embeddings contextualisés + Classifier → logits → Cross-Entropy Loss → Backpropagation +""" + +import torch +import torch.nn as nn +from encoder import GNNEncoder +from problem_embedding import ProblemEmbeddingTable +from transformer import GraphTransformer +from classifier import Classifier + + +class QuantumGraphModel(nn.Module): + """ + Modèle complet pour résoudre des problèmes d'optimisation sur graphes. + + Supporte: + - MaxCut (2 classes) + - Vertex Cover (2 classes) + - Independent Set (2 classes) + - Graph Coloring (k classes) + """ + + def __init__( + self, + node_input_dim=7, + embedding_dim=128, + hidden_dim=256, + gnn_layers=4, + transformer_layers=4, + num_heads=8, + num_problems=10, + max_classes=10, + dropout=0.1 + ): + super().__init__() + + self.embedding_dim = embedding_dim + + # 1. GNN Encoder + self.encoder = GNNEncoder( + input_dim=node_input_dim, + hidden_dim=embedding_dim, + num_layers=gnn_layers, + num_heads=num_heads // 2, + dropout=dropout + ) + + # 2. Problem Embedding Table + self.problem_embedding = ProblemEmbeddingTable( + num_problems=num_problems, + embedding_dim=embedding_dim + ) + + # 3. Transformer + self.transformer = GraphTransformer( + input_dim=embedding_dim, + hidden_dim=hidden_dim, + num_layers=transformer_layers, + num_heads=num_heads, + dropout=dropout + ) + + # 4. Classifier + self.classifier = Classifier( + hidden_dim=hidden_dim, + max_classes=max_classes, + dropout=dropout + ) + + def forward(self, x, edge_index, problem_id, batch=None, num_classes=2): + """ + Args: + x: [n_nodes, node_input_dim] + edge_index: [2, n_edges] + problem_id: int ou [batch_size] + batch: [n_nodes] (optionnel) + num_classes: int + + Returns: + dict avec logits, probs, predictions + """ + # 1. GNN Encoder → E_local, E_global + e_local, e_global = self.encoder(x, edge_index, batch) + + batch_size = e_global.shape[0] + + # Reformater e_local pour [batch_size, n_nodes, embedding_dim] + if batch is None: + e_local = e_local.unsqueeze(0) + else: + e_local = self._batch_node_embeddings(e_local, batch, batch_size) + + # 2. Problem Embedding → E_prob + e_prob = self.problem_embedding(problem_id) + if e_prob.shape[0] == 1 and batch_size > 1: + e_prob = e_prob.expand(batch_size, -1) + + # 3. Transformer → embeddings contextualisés + contextualized = self.transformer(e_local, e_global, e_prob) + + # 4. Classifier → logits, probs, predictions + output = self.classifier(contextualized, num_classes=num_classes) + + return output + + def _batch_node_embeddings(self, e_local, batch, batch_size): + """Reformate les embeddings de nœuds en [batch, max_nodes, dim]""" + nodes_per_graph = torch.bincount(batch) + max_nodes = nodes_per_graph.max().item() + + device = e_local.device + out = torch.zeros(batch_size, max_nodes, self.embedding_dim, device=device) + + for b in range(batch_size): + mask = (batch == b) + n_nodes = mask.sum().item() + out[b, :n_nodes] = e_local[mask] + + return out + + def compute_loss(self, logits, targets, mask=None): + """Cross-Entropy Loss""" + return self.classifier.compute_loss(logits, targets, mask) + + def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None, num_classes=2): + """Forward + Loss en une seule passe""" + output = self.forward(x, edge_index, problem_id, batch, num_classes) + loss = self.compute_loss(output['logits'], targets) + return output, loss + + +if __name__ == "__main__": + print("=== Test QuantumGraphModel ===\n") + + # Graphe simple + n_nodes = 6 + x = torch.randn(n_nodes, 7) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 0], + [1, 2, 3, 4, 5, 0, 0, 1, 2, 3, 4, 5] + ]) + + # Modèle + model = QuantumGraphModel( + node_input_dim=7, + embedding_dim=128, + hidden_dim=256, + gnn_layers=4, + transformer_layers=4 + ) + + print(f"Paramètres: {sum(p.numel() for p in model.parameters()):,}") + + # Forward (MaxCut) + output = model(x, edge_index, problem_id=0, num_classes=2) + print(f"\nMaxCut (2 classes):") + print(f" Logits: {output['logits'].shape}") + print(f" Predictions: {output['predictions']}") + + # Loss + targets = torch.tensor([[1, 0, 1, 0, 1, 0]]) + loss = model.compute_loss(output['logits'], targets) + print(f" Loss: {loss.item():.4f}") + + # Backprop + loss.backward() + print(" Backprop OK") + + # Graph Coloring + output = model(x, edge_index, problem_id=3, num_classes=5) + print(f"\nGraph Coloring (5 classes):") + print(f" Logits: {output['logits'].shape}") + print(f" Predictions: {output['predictions']}") + + print("\n✅ Tous les tests passés!") diff --git a/prototype_v1/problem_embedding.py b/prototype_v1/problem_embedding.py new file mode 100644 index 0000000..9825b06 --- /dev/null +++ b/prototype_v1/problem_embedding.py @@ -0,0 +1,60 @@ +""" +Table d'Embedding pour les Problem IDs + +ID (ex: 2) → Lookup Table → Vector (1, 128) + +Table: +(0, [0.654, 0.352, ..., 0.374]) MaxCut +(1, [0.376, 0.023, ..., 0.332]) VertexCover +(2, [0.372, 0.832, ..., 0.374]) IndependentSet +(3, [0.103, 0.334, ..., 0.743]) GraphColoring +""" + +import torch +import torch.nn as nn + + +PROBLEM_REGISTRY = { + 0: "MaxCut", + 1: "VertexCover", + 2: "IndependentSet", + 3: "GraphColoring", +} + + +class ProblemEmbeddingTable(nn.Module): + """ + Table d'embedding: problem_id → vecteur dense. + Appris par backpropagation. + """ + + def __init__(self, num_problems=10, embedding_dim=128): + super().__init__() + self.num_problems = num_problems + self.embedding_dim = embedding_dim + + # Lookup table + self.embedding_table = nn.Embedding(num_problems, embedding_dim) + nn.init.normal_(self.embedding_table.weight, mean=0.0, std=0.1) + + def forward(self, problem_id): + """ + Args: + problem_id: int ou Tensor [batch_size] + Returns: + embedding: [batch_size, embedding_dim] + """ + if isinstance(problem_id, int): + problem_id = torch.tensor([problem_id]) + if not isinstance(problem_id, torch.Tensor): + problem_id = torch.tensor(problem_id) + + problem_id = problem_id.to(self.embedding_table.weight.device) + return self.embedding_table(problem_id) + + +if __name__ == "__main__": + table = ProblemEmbeddingTable(num_problems=10, embedding_dim=128) + emb = table(2) + print(f"Problem ID 2 → shape: {emb.shape}") + print(f"Params: {sum(p.numel() for p in table.parameters()):,}") diff --git a/prototype_v1/test_model.py b/prototype_v1/test_model.py new file mode 100644 index 0000000..655900f --- /dev/null +++ b/prototype_v1/test_model.py @@ -0,0 +1,367 @@ +""" +Script de Test Complet pour le Prototype v1 + +Teste: +1. Chaque composant individuellement +2. Le modèle complet +3. La backpropagation +4. Différents problèmes (binaire et multi-classe) +""" + +import torch +import sys +from pathlib import Path + +# Ajouter le dossier au path +sys.path.insert(0, str(Path(__file__).parent)) + +from encoder import GNNEncoder +from problem_embedding import ProblemEmbeddingTable +from transformer import GraphTransformer +from classifier import Classifier +from model import QuantumGraphModel + + +def create_test_graph(n_nodes=6, n_features=7): + """Crée un graphe de test simple (cycle)""" + x = torch.randn(n_nodes, n_features) + + # Graphe en cycle: 0-1-2-3-4-5-0 + sources = list(range(n_nodes)) + list(range(1, n_nodes)) + [0] + targets = list(range(1, n_nodes)) + [0] + list(range(n_nodes)) + edge_index = torch.tensor([sources, targets]) + + return x, edge_index + + +def test_encoder(): + """Test du GNN Encoder""" + print("=" * 60) + print("TEST 1: GNN Encoder") + print("=" * 60) + + x, edge_index = create_test_graph(n_nodes=6, n_features=7) + + encoder = GNNEncoder( + input_dim=7, + hidden_dim=128, + num_layers=4, + num_heads=4 + ) + + e_local, e_global = encoder(x, edge_index) + + print(f"Input: x = {x.shape}, edge_index = {edge_index.shape}") + print(f"Output: e_local = {e_local.shape}, e_global = {e_global.shape}") + + assert e_local.shape == (6, 128), f"Expected (6, 128), got {e_local.shape}" + assert e_global.shape == (1, 128), f"Expected (1, 128), got {e_global.shape}" + + print("✅ Encoder OK\n") + return True + + +def test_problem_embedding(): + """Test de la table d'embedding""" + print("=" * 60) + print("TEST 2: Problem Embedding Table") + print("=" * 60) + + table = ProblemEmbeddingTable(num_problems=10, embedding_dim=128) + + # Test single ID + e_prob = table(0) + print(f"Problem ID 0 (MaxCut): shape = {e_prob.shape}") + + # Test batch + e_prob_batch = table(torch.tensor([0, 1, 2, 3])) + print(f"Batch [0,1,2,3]: shape = {e_prob_batch.shape}") + + # Vérifier que chaque ID donne un embedding différent + e0 = table(0) + e1 = table(1) + diff = (e0 - e1).abs().mean().item() + print(f"Différence entre ID 0 et ID 1: {diff:.4f}") + + assert e_prob.shape == (1, 128) + assert e_prob_batch.shape == (4, 128) + assert diff > 0, "Les embeddings devraient être différents" + + print("✅ Problem Embedding OK\n") + return True + + +def test_transformer(): + """Test du Transformer""" + print("=" * 60) + print("TEST 3: Graph Transformer") + print("=" * 60) + + batch_size = 2 + n_nodes = 6 + input_dim = 128 + + e_local = torch.randn(batch_size, n_nodes, input_dim) + e_global = torch.randn(batch_size, input_dim) + e_prob = torch.randn(batch_size, input_dim) + + transformer = GraphTransformer( + input_dim=128, + hidden_dim=256, + num_layers=4, + num_heads=8 + ) + + output = transformer(e_local, e_global, e_prob) + + print(f"Input: e_local = {e_local.shape}") + print(f" e_global = {e_global.shape}") + print(f" e_prob = {e_prob.shape}") + print(f"Output: contextualized = {output.shape}") + + assert output.shape == (batch_size, n_nodes, 256) + + print("✅ Transformer OK\n") + return True + + +def test_classifier(): + """Test du Classifier""" + print("=" * 60) + print("TEST 4: Classifier") + print("=" * 60) + + batch_size = 2 + n_nodes = 6 + hidden_dim = 256 + + x = torch.randn(batch_size, n_nodes, hidden_dim) + classifier = Classifier(hidden_dim=256, max_classes=10) + + # Test binaire (2 classes) + output_2 = classifier(x, num_classes=2) + print(f"Binaire (2 classes):") + print(f" Logits: {output_2['logits'].shape}") + print(f" Probs: {output_2['probs'].shape}") + print(f" Predictions: {output_2['predictions'].shape}") + print(f" Exemple: {output_2['predictions'][0].tolist()}") + + # Test multi-classe (5 classes) + output_5 = classifier(x, num_classes=5) + print(f"\nMulti-classe (5 classes):") + print(f" Logits: {output_5['logits'].shape}") + print(f" Predictions: {output_5['predictions'][0].tolist()}") + + # Test loss + targets = torch.randint(0, 2, (batch_size, n_nodes)) + loss = classifier.compute_loss(output_2['logits'], targets) + print(f"\nLoss (CE): {loss.item():.4f}") + + assert output_2['logits'].shape == (batch_size, n_nodes, 2) + assert output_5['logits'].shape == (batch_size, n_nodes, 5) + + print("✅ Classifier OK\n") + return True + + +def test_full_model(): + """Test du modèle complet""" + print("=" * 60) + print("TEST 5: Modèle Complet (QuantumGraphModel)") + print("=" * 60) + + x, edge_index = create_test_graph(n_nodes=6, n_features=7) + + model = QuantumGraphModel( + node_input_dim=7, + embedding_dim=128, + hidden_dim=256, + gnn_layers=4, + transformer_layers=4 + ) + + n_params = sum(p.numel() for p in model.parameters()) + print(f"Paramètres totaux: {n_params:,}") + + # Test MaxCut (problem_id=0, 2 classes) + print(f"\n--- MaxCut (problem_id=0, 2 classes) ---") + output = model(x, edge_index, problem_id=0, num_classes=2) + print(f"Logits: {output['logits'].shape}") + print(f"Predictions: {output['predictions'].tolist()}") + + # Test Vertex Cover (problem_id=1, 2 classes) + print(f"\n--- Vertex Cover (problem_id=1, 2 classes) ---") + output = model(x, edge_index, problem_id=1, num_classes=2) + print(f"Predictions: {output['predictions'].tolist()}") + + # Test Graph Coloring (problem_id=3, 4 classes) + print(f"\n--- Graph Coloring (problem_id=3, 4 classes) ---") + output = model(x, edge_index, problem_id=3, num_classes=4) + print(f"Logits: {output['logits'].shape}") + print(f"Predictions: {output['predictions'].tolist()}") + + print("✅ Modèle Complet OK\n") + return True + + +def test_backpropagation(): + """Test de la backpropagation à travers tout le modèle""" + print("=" * 60) + print("TEST 6: Backpropagation") + print("=" * 60) + + x, edge_index = create_test_graph(n_nodes=6, n_features=7) + + model = QuantumGraphModel( + node_input_dim=7, + embedding_dim=128, + hidden_dim=256 + ) + + # Forward + output = model(x, edge_index, problem_id=0, num_classes=2) + + # Loss + targets = torch.tensor([[1, 0, 1, 0, 1, 0]]) + loss = model.compute_loss(output['logits'], targets) + print(f"Loss avant backward: {loss.item():.4f}") + + # Backward + loss.backward() + + # Vérifier que les gradients existent partout + components = { + 'GNN Encoder': model.encoder, + 'Problem Embedding': model.problem_embedding, + 'Transformer': model.transformer, + 'Classifier': model.classifier + } + + print("\nGradients par composant:") + for name, component in components.items(): + has_grad = any(p.grad is not None and p.grad.abs().sum() > 0 + for p in component.parameters() if p.requires_grad) + status = "✅" if has_grad else "❌" + print(f" {status} {name}") + + # Vérifier spécifiquement la lookup table + lookup_grad = model.problem_embedding.embedding_table.weight.grad + if lookup_grad is not None: + grad_problem_0 = lookup_grad[0].abs().sum().item() + print(f"\n Gradient lookup table (ID=0): {grad_problem_0:.6f}") + + print("✅ Backpropagation OK\n") + return True + + +def test_training_step(): + """Simule une étape d'entraînement""" + print("=" * 60) + print("TEST 7: Training Step Simulation") + print("=" * 60) + + x, edge_index = create_test_graph(n_nodes=6, n_features=7) + + model = QuantumGraphModel( + node_input_dim=7, + embedding_dim=128, + hidden_dim=256 + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + targets = torch.tensor([[1, 0, 1, 0, 1, 0]]) + + print("Simulation de 5 steps d'entraînement:") + + for step in range(5): + optimizer.zero_grad() + + output = model(x, edge_index, problem_id=0, num_classes=2) + loss = model.compute_loss(output['logits'], targets) + + loss.backward() + optimizer.step() + + accuracy = (output['predictions'] == targets).float().mean().item() + print(f" Step {step+1}: Loss = {loss.item():.4f}, Accuracy = {accuracy:.2%}") + + print("✅ Training Step OK\n") + return True + + +def test_different_graph_sizes(): + """Test avec différentes tailles de graphes""" + print("=" * 60) + print("TEST 8: Différentes Tailles de Graphes") + print("=" * 60) + + model = QuantumGraphModel( + node_input_dim=7, + embedding_dim=128, + hidden_dim=256 + ) + + for n_nodes in [4, 8, 16, 32]: + x, edge_index = create_test_graph(n_nodes=n_nodes, n_features=7) + + output = model(x, edge_index, problem_id=0, num_classes=2) + + print(f" {n_nodes} nœuds: predictions shape = {output['predictions'].shape}") + assert output['predictions'].shape == (1, n_nodes) + + print("✅ Différentes Tailles OK\n") + return True + + +def run_all_tests(): + """Lance tous les tests""" + print("\n" + "=" * 60) + print(" TESTS DU PROTOTYPE v1 - QuantumGraphModel") + print("=" * 60 + "\n") + + tests = [ + ("Encoder", test_encoder), + ("Problem Embedding", test_problem_embedding), + ("Transformer", test_transformer), + ("Classifier", test_classifier), + ("Full Model", test_full_model), + ("Backpropagation", test_backpropagation), + ("Training Step", test_training_step), + ("Different Sizes", test_different_graph_sizes), + ] + + results = [] + for name, test_fn in tests: + try: + success = test_fn() + results.append((name, success)) + except Exception as e: + print(f"❌ ERREUR dans {name}: {e}\n") + results.append((name, False)) + + # Résumé + print("=" * 60) + print(" RÉSUMÉ") + print("=" * 60) + + passed = sum(1 for _, success in results if success) + total = len(results) + + for name, success in results: + status = "✅" if success else "❌" + print(f" {status} {name}") + + print(f"\n Total: {passed}/{total} tests passés") + + if passed == total: + print("\n 🎉 TOUS LES TESTS SONT PASSÉS ! 🎉") + else: + print("\n ⚠️ Certains tests ont échoué.") + + print("=" * 60 + "\n") + + return passed == total + + +if __name__ == "__main__": + run_all_tests() diff --git a/prototype_v1/transformer.py b/prototype_v1/transformer.py new file mode 100644 index 0000000..4e300da --- /dev/null +++ b/prototype_v1/transformer.py @@ -0,0 +1,103 @@ +""" +Transformer pour prototype v1 + +Concat: [E_global || E_local || E_prob] → Transformer → embeddings contextualisés + +Le Classifier est dans un fichier séparé (classifier.py) +""" + +import torch +import torch.nn as nn +import math + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=1000, dropout=0.1): + super().__init__() + self.dropout = nn.Dropout(dropout) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe.unsqueeze(0)) + + def forward(self, x): + return self.dropout(x + self.pe[:, :x.size(1), :]) + + +class GraphTransformer(nn.Module): + """ + Pour chaque nœud: input = [E_global || E_local || E_prob] + Self-attention entre tous les nœuds. + + Output: embeddings contextualisés [batch, n_nodes, hidden_dim] + """ + + def __init__( + self, + input_dim=128, + hidden_dim=256, + num_layers=4, + num_heads=8, + dropout=0.1 + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + concat_dim = input_dim * 3 # E_global || E_local || E_prob + + self.input_projection = nn.Linear(concat_dim, hidden_dim) + self.pos_encoding = PositionalEncoding(hidden_dim, dropout=dropout) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=hidden_dim * 4, + dropout=dropout, + activation='gelu', + batch_first=True + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + def forward(self, e_local, e_global, e_prob): + """ + Args: + e_local: [batch, n_nodes, input_dim] + e_global: [batch, input_dim] + e_prob: [batch, input_dim] + + Returns: + contextualized_embeddings: [batch, n_nodes, hidden_dim] + """ + batch_size, n_nodes, _ = e_local.shape + + # Expand pour chaque nœud + e_global_exp = e_global.unsqueeze(1).expand(-1, n_nodes, -1) + e_prob_exp = e_prob.unsqueeze(1).expand(-1, n_nodes, -1) + + # Concat: [E_global || E_local || E_prob] + x = torch.cat([e_global_exp, e_local, e_prob_exp], dim=-1) + + # Projection + Positional Encoding + Transformer + x = self.input_projection(x) + x = self.pos_encoding(x) + x = self.transformer(x) + + return x # [batch, n_nodes, hidden_dim] + + +if __name__ == "__main__": + print("=== Test GraphTransformer ===") + + e_local = torch.randn(4, 6, 128) + e_global = torch.randn(4, 128) + e_prob = torch.randn(4, 128) + + transformer = GraphTransformer(input_dim=128, hidden_dim=256) + output = transformer(e_local, e_global, e_prob) + + print(f"Input: e_local {e_local.shape}, e_global {e_global.shape}, e_prob {e_prob.shape}") + print(f"Output: {output.shape}") + print(f"Params: {sum(p.numel() for p in transformer.parameters()):,}") + print("✅ OK")