-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_sagemaker.py
More file actions
35 lines (29 loc) · 963 Bytes
/
train_sagemaker.py
File metadata and controls
35 lines (29 loc) · 963 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from sagemaker.pytorch import PyTorch
from sagemaker.debugger import TensorBoardOutputConfig
def start_training():
tensorboard_config = TensorBoardOutputConfig(
s3_output_path="s3://your-bucket-name/tensorboard",
container_local_output_path="/opt/ml/output/tensorboard"
)
estimator = PyTorch(
entry_point="train.py",
source_dir="training",
role="your-execution-role",
framework_version="2.5.1",
py_version="py311",
instance_count=1,
instance_type="ml.g5.xlarge",
hyperparameters={
"batch-size": 32,
"epochs": 25
},
tensorboard_config=tensorboard_config
)
# Start training
estimator.fit({
"training": "s3://your-bucket-name/dataset/train",
"validation": "s3://your-bucket-name/dataset/dev",
"test": "s3://your-bucket-name/dataset/test"
})
if __name__ == "__main__":
start_training()