-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_device_detection.py
More file actions
executable file
·117 lines (96 loc) · 3.82 KB
/
test_device_detection.py
File metadata and controls
executable file
·117 lines (96 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
"""
Test script for automatic device detection.
This script tests the device detection logic without running full training.
"""
import torch
def test_device_detection():
"""Test automatic device detection with detailed output."""
print("="*60)
print("Device Detection Test")
print("="*60)
# Test CUDA
print("\n1. Testing CUDA (NVIDIA GPU):")
if torch.cuda.is_available():
print(" ✅ CUDA is available")
try:
device = torch.device("cuda")
test_tensor = torch.zeros(1).to(device)
# Get GPU info
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
cuda_version = torch.version.cuda
print(f" GPU: {gpu_name}")
print(f" Memory: {gpu_memory:.1f} GB")
print(f" CUDA Version: {cuda_version}")
print(f" ✅ CUDA initialization successful")
selected_device = "cuda"
except Exception as e:
print(f" ❌ CUDA initialization failed: {e}")
selected_device = None
else:
print(" ❌ CUDA not available")
selected_device = None
# Test MPS
if selected_device is None:
print("\n2. Testing MPS (Apple Silicon GPU):")
if torch.backends.mps.is_available():
print(" ✅ MPS is available")
try:
device = torch.device("mps")
test_tensor = torch.zeros(1).to(device)
print(f" ✅ MPS initialization successful")
selected_device = "mps"
except Exception as e:
print(f" ❌ MPS initialization failed: {e}")
selected_device = None
else:
print(" ❌ MPS not available")
# Fallback to CPU
if selected_device is None:
print("\n3. Falling back to CPU:")
device = torch.device("cpu")
print(" ✅ CPU will be used")
selected_device = "cpu"
# Final result
print("\n" + "="*60)
print("Selected Device:", selected_device.upper())
print("="*60)
# Performance note
if selected_device == "cpu":
print("\n⚠️ Note: Training on CPU will be much slower than GPU.")
print(" Consider using a machine with CUDA or MPS support for faster training.")
elif selected_device == "cuda":
print("\n✅ CUDA GPU detected - training will be fast!")
elif selected_device == "mps":
print("\n✅ Apple Silicon GPU detected - training will be accelerated!")
return selected_device
def test_simple_operation():
"""Test a simple tensor operation on the detected device."""
print("\n" + "="*60)
print("Testing Simple Operations")
print("="*60)
device = test_device_detection()
device_obj = torch.device(device)
try:
print(f"\nTesting matrix multiplication on {device}...")
x = torch.randn(100, 100).to(device_obj)
y = torch.randn(100, 100).to(device_obj)
z = torch.matmul(x, y)
print(f"✅ Matrix multiplication successful")
print(f" Result shape: {z.shape}")
print(f"\nTesting convolution on {device}...")
conv = torch.nn.Conv2d(3, 64, 3).to(device_obj)
input_tensor = torch.randn(1, 3, 32, 32).to(device_obj)
output = conv(input_tensor)
print(f"✅ Convolution successful")
print(f" Output shape: {output.shape}")
print("\n" + "="*60)
print("✅ All operations completed successfully!")
print("="*60)
except Exception as e:
print(f"\n❌ Operation failed: {e}")
print("\nThis might indicate a problem with your PyTorch installation.")
print("Consider reinstalling PyTorch for your platform.")
if __name__ == "__main__":
test_simple_operation()