diff --git a/README.md b/README.md index 6cf1c4a..3f3cc75 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,14 @@ -# SSM and S3 Server +# AWS Server -This project provides an HTTP server using Go, which interacts with AWS Systems Manager (SSM) to fetch/put parameters and with S3 to fetch/put files. This project is useful for running a sidecar, to use original images without a hassle. -The server exposes two main endpoints: `/ssm` and `/s3`. +This project provides an HTTP server using Go, which interacts with AWS Services. This project is useful for running a sidecar, to use original images without a hassle. ## Features - Fetch decrypted parameters from AWS SSM. - Fetch and serve files from AWS S3. - Upload files to AWS S3. +- Fetch ECR authorization token. +- Fetch caller identity from AWS STS. - Basic CI/CD pipeline using GitHub Actions for automatic builds and tests. ## Requirements @@ -115,6 +116,30 @@ Upload a file to an S3 bucket. curl -X POST -F 'file=@/path/to/your/file' "http://localhost:3000/s3?bucket=example-bucket&key=example-key" ``` +### Get ECR Login + +Fetch an authorization token for ECR. + +- **URL:** `/ecr/login` +- **Method:** `GET` +- **Example:** + + ```sh + curl "http://localhost:3000/ecr/login" | docker login --username AWS --password-stdin aws-account-id.dkr.ecr.eu-central-1.amazonaws.com + ``` + +### Get Caller Identity + +Fetch the caller identity from AWS STS. + +- **URL:** `/sts` +- **Method:** `GET` +- **Example:** + + ```sh + curl "http://localhost:3000/sts" + ``` + ## Running Tests To run the tests: diff --git a/cmd/server/main.go b/cmd/server/main.go index a50bc20..073dc2f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,207 +1,45 @@ package main import ( - "context" - "io" - "io/ioutil" - "log" - "net/http" - "os" - "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "log" + "net/http" + "os" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/leneffets/ssmserver/pkg/ecr" + "github.com/leneffets/ssmserver/pkg/s3" + "github.com/leneffets/ssmserver/pkg/ssm" + "github.com/leneffets/ssmserver/pkg/sts" ) -// Funktion, um SSM-Parameter abzurufen -func GetParameter(ctx context.Context, svc ssmiface.SSMAPI, name *string) (*ssm.GetParameterOutput, error) { - results, err := svc.GetParameterWithContext(ctx, &ssm.GetParameterInput{ - Name: name, - WithDecryption: aws.Bool(true), - }) - return results, err -} - -// Funktion, um ein Objekt aus S3 zu holen -func GetFromS3(ctx context.Context, sess *session.Session, bucket, key string) (io.ReadCloser, error) { - svc := s3.New(sess) - output, err := svc.GetObjectWithContext(ctx, &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if err != nil { - return nil, err - } - return output.Body, nil -} - -// Funktion, um ein Objekt in S3 zu legen -func PutToS3(ctx context.Context, sess *session.Session, bucket, key string, body io.ReadSeeker) error { - svc := s3.New(sess) - _, err := svc.PutObjectWithContext(ctx, &s3.PutObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - Body: body, - }) - return err -} - -// Funktion, um einen Parameter in den Parameter Store zu legen -func PutParameter(ctx context.Context, svc ssmiface.SSMAPI, name *string, value *string, typeStr *string) (*ssm.PutParameterOutput, error) { - results, err := svc.PutParameterWithContext(ctx, &ssm.PutParameterInput{ - Name: name, - Value: value, - Type: aws.String(*typeStr), - }) - return results, err -} - func main() { - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - svc := ssm.New(sess) - - creds, err := sess.Config.Credentials.Get() - if err != nil { - log.Fatalf("Failed to get credentials: %v", err) - } - log.Printf("Using credentials: %s/%s\n", creds.AccessKeyID, creds.SecretAccessKey) - - http.HandleFunc("/ssm", func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet { - // Validate input parameter - id := r.URL.Query().Get("name") - if id == "" { - http.Error(w, "Parameter 'name' is required", http.StatusBadRequest) - return - } - - // Set context with timeout - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Fetch parameter - results, err := GetParameter(ctx, svc, &id) - if err != nil { - log.Printf("Error fetching parameter: %v", err) - http.Error(w, "Error fetching parameter", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte(*results.Parameter.Value)) - } else if r.Method == http.MethodPost { - // Parse the form data - if err := r.ParseForm(); err != nil { - http.Error(w, "Invalid form data", http.StatusBadRequest) - log.Printf("Invalid form data: %v", err) - return - } - - // Validate input parameters - name := r.FormValue("name") - value := r.FormValue("value") - typeStr := r.FormValue("type") - - if name == "" || value == "" || typeStr == "" { - http.Error(w, "Parameters 'name', 'value', and 'type' are required", http.StatusBadRequest) - return - } - - // Set context with timeout - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Put parameter - _, err := PutParameter(ctx, svc, &name, &value, &typeStr) - if err != nil { - http.Error(w, "Error putting parameter", http.StatusInternalServerError) - log.Printf("Error putting parameter: %v", err) - return - } - - log.Printf("Parameter %s uploaded successfully", name) - w.WriteHeader(http.StatusOK) - } else { - http.Error(w, "Invalid method", http.StatusMethodNotAllowed) - } - }) - - http.HandleFunc("/s3", func(w http.ResponseWriter, r *http.Request) { - bucket := r.URL.Query().Get("bucket") - key := r.URL.Query().Get("key") - if bucket == "" || key == "" { - http.Error(w, "Parameters 'bucket' and 'key' are required", http.StatusBadRequest) - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if r.Method == http.MethodGet { - body, err := GetFromS3(ctx, sess, bucket, key) - if err != nil { - http.Error(w, "Error fetching file from S3", http.StatusInternalServerError) - log.Printf("Error fetching file from S3: %v", err) - return - } - defer body.Close() - - w.Header().Set("Content-Type", "application/octet-stream") - if _, err := io.Copy(w, body); err != nil { - http.Error(w, "Error sending file", http.StatusInternalServerError) - log.Printf("Error sending file: %v", err) - return - } - } else if r.Method == http.MethodPost { - file, _, err := r.FormFile("file") - if err != nil { - http.Error(w, "Error reading uploaded file", http.StatusBadRequest) - log.Printf("Error reading uploaded file: %v", err) - return - } - defer file.Close() - - tempFile, err := ioutil.TempFile("", "upload-*.tmp") - if err != nil { - http.Error(w, "Error creating temporary file", http.StatusInternalServerError) - log.Printf("Error creating temporary file: %v", err) - return - } - defer os.Remove(tempFile.Name()) - - if _, err := io.Copy(tempFile, file); err != nil { - http.Error(w, "Error saving file", http.StatusInternalServerError) - log.Printf("Error saving file: %v", err) - return - } - - tempFile.Seek(0, 0) - - if err := PutToS3(ctx, sess, bucket, key, tempFile); err != nil { - http.Error(w, "Error uploading file to S3", http.StatusInternalServerError) - log.Printf("Error uploading file to S3: %v", err) - return - } - - log.Printf("File uploaded successfully to bucket %s with key %s", bucket, key) - w.WriteHeader(http.StatusOK) - } else { - http.Error(w, "Invalid method", http.StatusMethodNotAllowed) - } - }) - - port := os.Getenv("PORT") - if port == "" { - port = "3000" - } - - log.Printf("Server running on port %s", port) - if err := http.ListenAndServe(":"+port, nil); err != nil { - log.Fatalf("Error starting server: %v", err) - } + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + http.HandleFunc("/ssm", func(w http.ResponseWriter, r *http.Request) { + ssm.HandleSSM(w, r, sess) + }) + + http.HandleFunc("/s3", func(w http.ResponseWriter, r *http.Request) { + s3.HandleS3(w, r, sess) + }) + + http.HandleFunc("/ecr/login", func(w http.ResponseWriter, r *http.Request) { + ecr.HandleECRLogin(w, r, sess) + }) + + http.HandleFunc("/sts", func(w http.ResponseWriter, r *http.Request) { + sts.HandleSTS(w, r, sess) + }) + + port := os.Getenv("PORT") + if port == "" { + port = "3000" + } + + log.Printf("Server running on port %s", port) + if err := http.ListenAndServe(":"+port, nil); err != nil { + log.Fatalf("Error starting server: %v", err) + } } diff --git a/go.mod b/go.mod index cb80184..8561f16 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ -module github.com/devstek/ssmserver +module github.com/leneffets/ssmserver go 1.22.7 require github.com/aws/aws-sdk-go v1.55.5 -require github.com/jmespath/go-jmespath v0.4.0 // indirect \ No newline at end of file +require github.com/jmespath/go-jmespath v0.4.0 // indirect diff --git a/go.sum b/go.sum index de3af83..d323266 100644 --- a/go.sum +++ b/go.sum @@ -11,4 +11,4 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= \ No newline at end of file +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main_test.go b/main_test.go index 1514a84..0ab0fa8 100644 --- a/main_test.go +++ b/main_test.go @@ -3,32 +3,48 @@ package main import ( "bytes" "context" + "encoding/json" "io" "io/ioutil" "net/http" "net/http/httptest" + "net/url" + "os" + "strings" "testing" "time" + "reflect" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/ecr" + "github.com/aws/aws-sdk-go/service/ecr/ecriface" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" - "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + ssmpkg "github.com/leneffets/ssmserver/pkg/ssm" ) // MockSSMAPI for testing type MockSSMAPI struct { ssmiface.SSMAPI - Response ssm.GetParameterOutput - Err error + Response ssm.GetParameterOutput + PutResponse ssm.PutParameterOutput + Err error } func (m *MockSSMAPI) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, opts ...request.Option) (*ssm.GetParameterOutput, error) { return &m.Response, m.Err } +func (m *MockSSMAPI) PutParameterWithContext(ctx context.Context, input *ssm.PutParameterInput, opts ...request.Option) (*ssm.PutParameterOutput, error) { + return &m.PutResponse, m.Err +} + // MockS3API for testing type MockS3API struct { s3iface.S3API @@ -36,29 +52,34 @@ type MockS3API struct { Err error } -func (m *MockS3API) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { +func (m *MockS3API) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) { return &m.Response, m.Err } -// GetParameter function for testing -func GetParameter(ctx context.Context, svc ssmiface.SSMAPI, name *string) (*ssm.GetParameterOutput, error) { - results, err := svc.GetParameterWithContext(ctx, &ssm.GetParameterInput{ - Name: aws.String(*name), - WithDecryption: aws.Bool(true), - }) - return results, err +func (m *MockS3API) PutObjectWithContext(ctx context.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) { + return &s3.PutObjectOutput{}, m.Err } -// GetFromS3 function for testing -func GetFromS3(sess s3iface.S3API, bucket, key string) (io.ReadCloser, error) { - output, err := sess.GetObject(&s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if err != nil { - return nil, err - } - return output.Body, nil +// MockECRAPI for testing +type MockECRAPI struct { + ecriface.ECRAPI + Response ecr.GetAuthorizationTokenOutput + Err error +} + +func (m *MockECRAPI) GetAuthorizationTokenWithContext(ctx context.Context, input *ecr.GetAuthorizationTokenInput, opts ...request.Option) (*ecr.GetAuthorizationTokenOutput, error) { + return &m.Response, m.Err +} + +// MockSTSAPI for testing +type MockSTSAPI struct { + stsiface.STSAPI + Response sts.GetCallerIdentityOutput + Err error +} + +func (m *MockSTSAPI) GetCallerIdentityWithContext(ctx context.Context, input *sts.GetCallerIdentityInput, opts ...request.Option) (*sts.GetCallerIdentityOutput, error) { + return &m.Response, m.Err } func TestGetParameter(t *testing.T) { @@ -86,7 +107,7 @@ func TestGetParameter(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - results, err := GetParameter(ctx, mockSvc, &id) + results, err := ssmpkg.GetParameter(ctx, mockSvc, &id) if err != nil { http.Error(w, "Error fetching parameter", http.StatusInternalServerError) return @@ -108,6 +129,32 @@ func TestGetParameter(t *testing.T) { } } +func TestPutParameter(t *testing.T) { + mockSvc := &MockSSMAPI{} + + form := url.Values{} + form.Add("name", "mock_name") + form.Add("value", "mock_value") + form.Add("type", "String") + + req, err := http.NewRequest("POST", "/ssm", strings.NewReader(form.Encode())) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ssmpkg.HandlePostSSM(w, r, mockSvc) + }) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) + } +} + func TestGetFromS3(t *testing.T) { mockS3 := &MockS3API{ Response: s3.GetObjectOutput{ @@ -129,15 +176,21 @@ func TestGetFromS3(t *testing.T) { return } - body, err := GetFromS3(mockS3, bucket, key) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + body, err := mockS3.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) if err != nil { http.Error(w, "Error fetching file from S3", http.StatusInternalServerError) return } - defer body.Close() + defer body.Body.Close() w.Header().Set("Content-Type", "application/octet-stream") - if _, err := io.Copy(w, body); err != nil { + if _, err := io.Copy(w, body.Body); err != nil { http.Error(w, "Error sending file", http.StatusInternalServerError) return } @@ -154,3 +207,164 @@ func TestGetFromS3(t *testing.T) { t.Errorf("Handler returned unexpected body: got %v want %v", rr.Body.String(), expected) } } + +func TestPutToS3(t *testing.T) { + mockS3 := &MockS3API{} + + fileContent := "mock_file_content" + file := ioutil.NopCloser(bytes.NewReader([]byte(fileContent))) + + req, err := http.NewRequest("POST", "/s3?bucket=mock_bucket&key=mock_key", file) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bucket := r.URL.Query().Get("bucket") + key := r.URL.Query().Get("key") + if bucket == "" || key == "" { + http.Error(w, "Parameters 'bucket' and 'key' are required", http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Convert io.ReadCloser to io.ReadSeeker + readSeeker := bytes.NewReader([]byte(fileContent)) + + _, err = mockS3.PutObjectWithContext(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: readSeeker, + }) + if err != nil { + http.Error(w, "Error uploading file to S3", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + }) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) + } +} + +func TestGetECRLogin(t *testing.T) { + mockECR := &MockECRAPI{ + Response: ecr.GetAuthorizationTokenOutput{ + AuthorizationData: []*ecr.AuthorizationData{ + { + AuthorizationToken: aws.String("mock_token"), + ProxyEndpoint: aws.String("https://mock_endpoint"), + }, + }, + }, + } + + req, err := http.NewRequest("GET", "/ecr/login", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := mockECR.GetAuthorizationTokenWithContext(ctx, &ecr.GetAuthorizationTokenInput{}) + if err != nil { + http.Error(w, "Error fetching ECR authorization token", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) + }) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + + expected := ecr.GetAuthorizationTokenOutput{ + AuthorizationData: []*ecr.AuthorizationData{ + { + AuthorizationToken: aws.String("mock_token"), + ProxyEndpoint: aws.String("https://mock_endpoint"), + ExpiresAt: nil, + }, + }, + } + + var actual ecr.GetAuthorizationTokenOutput + if err := json.Unmarshal(rr.Body.Bytes(), &actual); err != nil { + t.Fatalf("Failed to unmarshal response body: %v", err) + } + + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Handler returned unexpected body: got %v want %v", actual, expected) + } +} + +func TestGetCallerIdentity(t *testing.T) { + mockSTS := &MockSTSAPI{ + Response: sts.GetCallerIdentityOutput{ + Account: aws.String("mock_account"), + Arn: aws.String("mock_arn"), + UserId: aws.String("mock_user_id"), + }, + } + + req, err := http.NewRequest("GET", "/sts", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := mockSTS.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + http.Error(w, "Error fetching caller identity", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) + }) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + + expected := sts.GetCallerIdentityOutput{ + Account: aws.String("mock_account"), + Arn: aws.String("mock_arn"), + UserId: aws.String("mock_user_id"), + } + + var actual sts.GetCallerIdentityOutput + if err := json.Unmarshal(rr.Body.Bytes(), &actual); err != nil { + t.Fatalf("Failed to unmarshal response body: %v", err) + } + + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Handler returned unexpected body: got %v want %v", actual, expected) + } +} + +func TestMain(m *testing.M) { + os.Setenv("PORT", "3000") + code := m.Run() + os.Exit(code) +} diff --git a/pkg/ecr/ecr.go b/pkg/ecr/ecr.go new file mode 100644 index 0000000..c50ee98 --- /dev/null +++ b/pkg/ecr/ecr.go @@ -0,0 +1,63 @@ +package ecr + +import ( + "context" + "encoding/base64" + "log" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ecr" +) + +func GetECRCredentials(ctx context.Context, sess *session.Session) (*ecr.GetAuthorizationTokenOutput, error) { + svc := ecr.New(sess) + return svc.GetAuthorizationTokenWithContext(ctx, &ecr.GetAuthorizationTokenInput{}) +} + +func HandleECRLogin(w http.ResponseWriter, r *http.Request, sess *session.Session) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if r.Method != http.MethodGet { + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + return + } + + results, err := GetECRCredentials(ctx, sess) + if err != nil { + log.Printf("Error fetching ECR credentials: %v", err) + http.Error(w, "Error fetching ECR credentials", http.StatusInternalServerError) + return + } + + if len(results.AuthorizationData) == 0 { + http.Error(w, "No authorization data found", http.StatusInternalServerError) + return + } + + authData := results.AuthorizationData[0] + decodedToken, err := base64.StdEncoding.DecodeString(*authData.AuthorizationToken) + if err != nil { + log.Printf("Error decoding authorization token: %v", err) + http.Error(w, "Error decoding authorization token", http.StatusInternalServerError) + return + } + + tokenParts := strings.SplitN(string(decodedToken), ":", 2) + if len(tokenParts) != 2 { + http.Error(w, "Invalid authorization token format", http.StatusInternalServerError) + return + } + + password := tokenParts[1] + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(password)); err != nil { + http.Error(w, "Error writing response", http.StatusInternalServerError) + log.Printf("Error writing response: %v", err) + } +} diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go deleted file mode 100644 index f726b26..0000000 --- a/pkg/mocks/mocks.go +++ /dev/null @@ -1 +0,0 @@ -package mocks diff --git a/pkg/s3/s3.go b/pkg/s3/s3.go new file mode 100644 index 0000000..58a0cad --- /dev/null +++ b/pkg/s3/s3.go @@ -0,0 +1,108 @@ +package s3 + +import ( + "context" + "io" + "log" + "net/http" + "os" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" +) + +func GetFromS3(ctx context.Context, sess *session.Session, bucket, key string) (io.ReadCloser, error) { + svc := s3.New(sess) + output, err := svc.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, err + } + return output.Body, nil +} + +func PutToS3(ctx context.Context, sess *session.Session, bucket, key string, body io.ReadSeeker) error { + svc := s3.New(sess) + _, err := svc.PutObjectWithContext(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: body, + }) + return err +} + +func HandleS3(w http.ResponseWriter, r *http.Request, sess *session.Session) { + bucket := r.URL.Query().Get("bucket") + key := r.URL.Query().Get("key") + if bucket == "" || key == "" { + http.Error(w, "Parameters 'bucket' and 'key' are required", http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if r.Method == http.MethodGet { + handleGetS3(w, r, sess, bucket, key, ctx) + } else if r.Method == http.MethodPost { + handlePostS3(w, r, sess, bucket, key, ctx) + } else { + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + } +} + +func handleGetS3(w http.ResponseWriter, r *http.Request, sess *session.Session, bucket, key string, ctx context.Context) { + body, err := GetFromS3(ctx, sess, bucket, key) + if err != nil { + http.Error(w, "Error fetching file from S3", http.StatusInternalServerError) + log.Printf("Error fetching file from S3: %v", err) + return + } + defer body.Close() + + w.Header().Set("Content-Type", "application/octet-stream") + if _, err := io.Copy(w, body); err != nil { + http.Error(w, "Error sending file", http.StatusInternalServerError) + log.Printf("Error sending file: %v", err) + return + } +} + +func handlePostS3(w http.ResponseWriter, r *http.Request, sess *session.Session, bucket, key string, ctx context.Context) { + file, _, err := r.FormFile("file") + if err != nil { + http.Error(w, "Error reading uploaded file", http.StatusBadRequest) + log.Printf("Error reading uploaded file: %v", err) + return + } + defer file.Close() + + tempFile, err := os.CreateTemp("", "upload-*.tmp") + if err != nil { + http.Error(w, "Error creating temporary file", http.StatusInternalServerError) + log.Printf("Error creating temporary file: %v", err) + return + } + defer os.Remove(tempFile.Name()) + + if _, err := io.Copy(tempFile, file); err != nil { + http.Error(w, "Error saving file", http.StatusInternalServerError) + log.Printf("Error saving file: %v", err) + return + } + + tempFile.Seek(0, 0) + + if err := PutToS3(ctx, sess, bucket, key, tempFile); err != nil { + http.Error(w, "Error uploading file to S3", http.StatusInternalServerError) + log.Printf("Error uploading file to S3: %v", err) + return + } + + log.Printf("File uploaded successfully to bucket %s with key %s", bucket, key) + w.WriteHeader(http.StatusOK) +} diff --git a/pkg/ssm/ssm.go b/pkg/ssm/ssm.go new file mode 100644 index 0000000..2a9626c --- /dev/null +++ b/pkg/ssm/ssm.go @@ -0,0 +1,93 @@ +package ssm + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" +) + +func GetParameter(ctx context.Context, svc ssmiface.SSMAPI, name *string) (*ssm.GetParameterOutput, error) { + results, err := svc.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + Name: name, + WithDecryption: aws.Bool(true), + }) + return results, err +} + +func PutParameter(ctx context.Context, svc ssmiface.SSMAPI, name *string, value *string, typeStr *string) (*ssm.PutParameterOutput, error) { + results, err := svc.PutParameterWithContext(ctx, &ssm.PutParameterInput{ + Name: name, + Value: value, + Type: aws.String(*typeStr), + }) + return results, err +} + +func HandleSSM(w http.ResponseWriter, r *http.Request, sess *session.Session) { + svc := ssm.New(sess) + + if r.Method == http.MethodGet { + handleGetSSM(w, r, svc) + } else if r.Method == http.MethodPost { + HandlePostSSM(w, r, svc) + } else { + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + } +} + +func handleGetSSM(w http.ResponseWriter, r *http.Request, svc ssmiface.SSMAPI) { + id := r.URL.Query().Get("name") + if id == "" { + http.Error(w, "Parameter 'name' is required", http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + results, err := GetParameter(ctx, svc, &id) + if err != nil { + log.Printf("Error fetching parameter: %v", err) + http.Error(w, "Error fetching parameter", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte(*results.Parameter.Value)) +} + +func HandlePostSSM(w http.ResponseWriter, r *http.Request, svc ssmiface.SSMAPI) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid form data", http.StatusBadRequest) + log.Printf("Invalid form data: %v", err) + return + } + + name := r.FormValue("name") + value := r.FormValue("value") + typeStr := r.FormValue("type") + + if name == "" || value == "" || typeStr == "" { + http.Error(w, "Parameters 'name', 'value', and 'type' are required", http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := PutParameter(ctx, svc, &name, &value, &typeStr) + if err != nil { + http.Error(w, "Error putting parameter", http.StatusInternalServerError) + log.Printf("Error putting parameter: %v", err) + return + } + + log.Printf("Parameter %s uploaded successfully", name) + w.WriteHeader(http.StatusOK) +} diff --git a/pkg/sts/sts.go b/pkg/sts/sts.go new file mode 100644 index 0000000..013b6a2 --- /dev/null +++ b/pkg/sts/sts.go @@ -0,0 +1,41 @@ +package sts + +import ( + "context" + "encoding/json" + "log" + "net/http" + "time" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" +) + +func GetCallerIdentity(ctx context.Context, sess *session.Session) (*sts.GetCallerIdentityOutput, error) { + svc := sts.New(sess) + return svc.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) +} + +func HandleSTS(w http.ResponseWriter, r *http.Request, sess *session.Session) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if r.Method != http.MethodGet { + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + return + } + + results, err := GetCallerIdentity(ctx, sess) + if err != nil { + log.Printf("Error fetching caller identity: %v", err) + http.Error(w, "Error fetching caller identity", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(results); err != nil { + http.Error(w, "Error encoding response", http.StatusInternalServerError) + log.Printf("Error encoding response: %v", err) + } +}