forked from ChangwenXu98/TransPolymer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeploy_sagemaker.py
More file actions
344 lines (279 loc) · 11.4 KB
/
deploy_sagemaker.py
File metadata and controls
344 lines (279 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import os
import boto3
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
import json
import time
from datetime import datetime
class TransPolymerSageMakerDeployer:
"""
Deploy TransPolymer model to Amazon SageMaker
"""
def __init__(self, region_name='us-east-1', role_arn=None):
"""
Initialize the deployer
Args:
region_name: AWS region name
role_arn: IAM role ARN for SageMaker (if None, will try to get default)
"""
self.region_name = region_name
self.session = sagemaker.Session(boto_session=boto3.Session(region_name=region_name))
# Get SageMaker execution role
if role_arn:
self.role = role_arn
else:
self.role = sagemaker.get_execution_role()
self.bucket = self.session.default_bucket()
print(f"Using S3 bucket: {self.bucket}")
print(f"Using IAM role: {self.role}")
def prepare_model_artifacts(self, model_dir="model_artifacts"):
"""
Prepare model artifacts for deployment
Args:
model_dir: Directory to store model artifacts
"""
print("Preparing model artifacts...")
# Create model directory
os.makedirs(model_dir, exist_ok=True)
# Copy necessary files to model directory
import shutil
# Copy inference script
shutil.copy("inference.py", os.path.join(model_dir, "code", "inference.py"))
# Copy custom tokenizer
shutil.copy("PolymerSmilesTokenization.py", os.path.join(model_dir, "code", "PolymerSmilesTokenization.py"))
# Copy dataset module (needed for imports)
if os.path.exists("dataset.py"):
shutil.copy("dataset.py", os.path.join(model_dir, "code", "dataset.py"))
# Copy model files if they exist
if os.path.exists("ckpt/pretrain.pt/config.json"):
shutil.copy("ckpt/pretrain.pt/config.json", os.path.join(model_dir, "config.json"))
if os.path.exists("ckpt/pretrain.pt/pytorch_model.bin"):
shutil.copy("ckpt/pretrain.pt/pytorch_model.bin", os.path.join(model_dir, "pytorch_model.bin"))
# Copy requirements
shutil.copy("requirements.txt", os.path.join(model_dir, "code", "requirements.txt"))
# Create model.tar.gz
import tarfile
model_path = f"{model_dir}.tar.gz"
with tarfile.open(model_path, "w:gz") as tar:
tar.add(model_dir, arcname=".")
print(f"Model artifacts prepared: {model_path}")
return model_path
def upload_model_to_s3(self, model_path, s3_key_prefix="transpolymer-model"):
"""
Upload model artifacts to S3
Args:
model_path: Path to model.tar.gz file
s3_key_prefix: S3 key prefix
"""
print("Uploading model to S3...")
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
s3_key = f"{s3_key_prefix}/{timestamp}/model.tar.gz"
# Upload to S3
s3_uri = self.session.upload_data(
path=model_path,
bucket=self.bucket,
key_prefix=f"{s3_key_prefix}/{timestamp}"
)
print(f"Model uploaded to: {s3_uri}")
return s3_uri
def create_model(self, model_s3_uri, model_name=None):
"""
Create SageMaker model
Args:
model_s3_uri: S3 URI of model artifacts
model_name: Name for the model (if None, will generate)
"""
if model_name is None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
model_name = f"transpolymer-model-{timestamp}"
print(f"Creating SageMaker model: {model_name}")
# Create PyTorch model
model = PyTorchModel(
model_data=model_s3_uri,
role=self.role,
entry_point="inference.py",
framework_version="1.12.0",
py_version="py38",
name=model_name,
sagemaker_session=self.session
)
self.model = model
self.model_name = model_name
print(f"Model created successfully: {model_name}")
return model
def deploy_endpoint(self, instance_type="ml.m5.large", initial_instance_count=1,
endpoint_name=None, auto_scaling=True):
"""
Deploy model to SageMaker endpoint
Args:
instance_type: EC2 instance type for endpoint
initial_instance_count: Initial number of instances
endpoint_name: Name for the endpoint (if None, will generate)
auto_scaling: Whether to enable auto-scaling
"""
if endpoint_name is None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
endpoint_name = f"transpolymer-endpoint-{timestamp}"
print(f"Deploying endpoint: {endpoint_name}")
print(f"Instance type: {instance_type}")
print(f"Initial instance count: {initial_instance_count}")
# Deploy the model
predictor = self.model.deploy(
initial_instance_count=initial_instance_count,
instance_type=instance_type,
endpoint_name=endpoint_name,
serializer=JSONSerializer(),
deserializer=JSONDeserializer()
)
self.predictor = predictor
self.endpoint_name = endpoint_name
# Configure auto-scaling if requested
if auto_scaling:
self._configure_auto_scaling()
print(f"Endpoint deployed successfully: {endpoint_name}")
return predictor
def _configure_auto_scaling(self, min_capacity=1, max_capacity=5, target_value=70.0):
"""
Configure auto-scaling for the endpoint
"""
print("Configuring auto-scaling...")
autoscaling_client = boto3.client('application-autoscaling', region_name=self.region_name)
# Register scalable target
resource_id = f"endpoint/{self.endpoint_name}/variant/AllTraffic"
try:
autoscaling_client.register_scalable_target(
ServiceNamespace='sagemaker',
ResourceId=resource_id,
ScalableDimension='sagemaker:variant:DesiredInstanceCount',
MinCapacity=min_capacity,
MaxCapacity=max_capacity
)
# Create scaling policy
autoscaling_client.put_scaling_policy(
PolicyName=f'{self.endpoint_name}-scaling-policy',
ServiceNamespace='sagemaker',
ResourceId=resource_id,
ScalableDimension='sagemaker:variant:DesiredInstanceCount',
PolicyType='TargetTrackingScaling',
TargetTrackingScalingPolicyConfiguration={
'TargetValue': target_value,
'PredefinedMetricSpecification': {
'PredefinedMetricType': 'SageMakerVariantInvocationsPerInstance'
},
'ScaleOutCooldown': 300,
'ScaleInCooldown': 300
}
)
print("Auto-scaling configured successfully")
except Exception as e:
print(f"Warning: Could not configure auto-scaling: {e}")
def test_endpoint(self, test_data=None):
"""
Test the deployed endpoint
Args:
test_data: Test data for prediction (if None, will use default)
"""
if test_data is None:
# Default test data - polymer SMILES examples
test_data = [
{
"smiles": "CC(C)(C)OC(=O)NC1=CC=CC=C1",
"property": "conductivity",
"model_type": "PE_I"
},
{
"smiles": "c1ccc2c(c1)oc1ccccc12",
"property": "band_gap",
"model_type": "Egc"
}
]
print("Testing endpoint...")
print(f"Test data: {test_data}")
try:
# Make prediction
result = self.predictor.predict(test_data)
print("Prediction successful!")
print(f"Results: {json.dumps(result, indent=2)}")
return result
except Exception as e:
print(f"Error during prediction: {e}")
raise
def cleanup_endpoint(self):
"""
Delete the endpoint to avoid charges
"""
if hasattr(self, 'predictor'):
print(f"Deleting endpoint: {self.endpoint_name}")
self.predictor.delete_endpoint()
print("Endpoint deleted successfully")
def get_endpoint_info(self):
"""
Get information about the deployed endpoint
"""
if hasattr(self, 'endpoint_name'):
return {
"endpoint_name": self.endpoint_name,
"model_name": self.model_name,
"region": self.region_name,
"endpoint_url": f"https://runtime.sagemaker.{self.region_name}.amazonaws.com/endpoints/{self.endpoint_name}/invocations"
}
else:
return None
def main():
"""
Main deployment function
"""
print("🚀 Starting TransPolymer SageMaker Deployment")
# Initialize deployer
deployer = TransPolymerSageMakerDeployer()
try:
# Step 1: Prepare model artifacts
model_path = deployer.prepare_model_artifacts()
# Step 2: Upload to S3
s3_uri = deployer.upload_model_to_s3(model_path)
# Step 3: Create SageMaker model
model = deployer.create_model(s3_uri)
# Step 4: Deploy endpoint
predictor = deployer.deploy_endpoint(
instance_type="ml.m5.large", # Change to ml.g4dn.xlarge for GPU
initial_instance_count=1
)
# Step 5: Test endpoint
deployer.test_endpoint()
# Step 6: Print endpoint information
endpoint_info = deployer.get_endpoint_info()
print("\n✅ Deployment completed successfully!")
print(f"Endpoint info: {json.dumps(endpoint_info, indent=2)}")
print("\n📋 Usage example:")
print(f"""
import boto3
import json
runtime = boto3.client('sagemaker-runtime', region_name='{deployer.region_name}')
payload = {{
"smiles": "CC(C)(C)OC(=O)NC1=CC=CC=C1",
"property": "conductivity",
"model_type": "PE_I"
}}
response = runtime.invoke_endpoint(
EndpointName='{deployer.endpoint_name}',
ContentType='application/json',
Body=json.dumps(payload)
)
result = json.loads(response['Body'].read().decode())
print(result)
""")
# Ask user if they want to keep the endpoint running
keep_running = input("\nDo you want to keep the endpoint running? (y/n): ")
if keep_running.lower() != 'y':
deployer.cleanup_endpoint()
except Exception as e:
print(f"❌ Deployment failed: {e}")
# Cleanup on failure
if hasattr(deployer, 'predictor'):
deployer.cleanup_endpoint()
raise
if __name__ == "__main__":
main()