diff --git a/cmd/proxygenerator/main.go b/cmd/proxygenerator/main.go index 0a7a593b..162bb92a 100644 --- a/cmd/proxygenerator/main.go +++ b/cmd/proxygenerator/main.go @@ -31,7 +31,12 @@ func main() { log.Print(interceptorErr) } - if serviceErr != nil || interceptorErr != nil { + requestHeaderErr := generateRequestHeader(cfg) + if requestHeaderErr != nil { + log.Print(requestHeaderErr) + } + + if serviceErr != nil || interceptorErr != nil || requestHeaderErr != nil { os.Exit(1) } } diff --git a/cmd/proxygenerator/request_header.go b/cmd/proxygenerator/request_header.go new file mode 100644 index 00000000..71bd9554 --- /dev/null +++ b/cmd/proxygenerator/request_header.go @@ -0,0 +1,317 @@ +package main + +import ( + "bytes" + "fmt" + "go/format" + "os" + "regexp" + "sort" + "strings" + "text/template" + + "golang.org/x/tools/imports" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + + protometa "go.temporal.io/api/protometa/v1" +) + +const requestHeaderFile = "../../proxy/request_header.go" + +const requestHeaderTemplateText = ` +// Code generated by proxygenerator; DO NOT EDIT. + +package proxy + +import ( + "context" + "errors" + "fmt" + + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + {{range $path, $alias := .Imports}} + {{$alias}} "{{$path}}" + {{end}} +) + +// ExtractHeadersOptions contains options for extracting Temporal request headers. +type ExtractHeadersOptions struct { + // Request is the proto message to extract headers from. Required. + Request proto.Message + + // ExistingMetadata contains existing metadata to check for duplicates. + // If provided, headers that already exist will not be added again. + // If nil, no duplicate checking is performed. + ExistingMetadata metadata.MD +} + +// ExtractTemporalRequestHeaders extracts field values from request messages and returns +// them as a slice of key-value pairs suitable for use with metadata.AppendToOutgoingContext. +// Returns nil if no headers should be set. +func ExtractTemporalRequestHeaders(ctx context.Context, opts ExtractHeadersOptions) ([]string, error) { + if opts.Request == nil { + return nil, errors.New("request cannot be nil") + } + + var headers []string + + // Set namespace header if present and not already exists + if len(opts.ExistingMetadata.Get("temporal-namespace")) == 0 { + if nsReq, ok := opts.Request.(interface{ GetNamespace() string }); ok { + if ns := nsReq.GetNamespace(); ns != "" { + headers = append(headers, "temporal-namespace", ns) + } + } + } + + // Set headers from proto annotations{{if .Methods}} + switch r := opts.Request.(type) { +{{range .Methods}} case *{{.PackageAlias}}.{{.RequestType}}: +{{range .Headers}}{{.Code}}{{end}} +{{end}} }{{end}} + + return headers, nil +} +` + +var requestHeaderTemplate = template.Must(template.New("request_header").Parse(requestHeaderTemplateText)) + +type requestHeaderTemplateInput struct { + Methods []methodHeaderInfo + Imports map[string]string // map[importPath]alias +} + +type methodHeaderInfo struct { + PackageAlias string + RequestType string + Headers []headerInfo +} + +type headerInfo struct { + Code string +} + +func generateRequestHeader(cfg config) error { + data, err := os.ReadFile(cfg.descriptorPath) + if err != nil { + return fmt.Errorf("reading descriptor set: %w", err) + } + + var fdSet descriptorpb.FileDescriptorSet + if err := proto.Unmarshal(data, &fdSet); err != nil { + return fmt.Errorf("unmarshalling descriptor set: %w", err) + } + + files, err := protodesc.NewFiles(&fdSet) + if err != nil { + return fmt.Errorf("creating file registry: %w", err) + } + + var allMethods []methodHeaderInfo + importsMap := make(map[string]string) // map[importPath]alias + var rangeErr error + + files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + services := fd.Services() + for i := 0; i < services.Len(); i++ { + service := services.Get(i) + methods, importPath, alias, err := extractMethodHeaders(service) + if err != nil { + rangeErr = err + return false + } + + if len(methods) > 0 && importPath != "" { + importsMap[importPath] = alias + } + + allMethods = append(allMethods, methods...) + } + return true + }) + + if rangeErr != nil { + return rangeErr + } + + // Sort methods alphabetically by RequestType for consistent output + sort.Slice(allMethods, func(i, j int) bool { + return allMethods[i].RequestType < allMethods[j].RequestType + }) + + buf := &bytes.Buffer{} + err = requestHeaderTemplate.Execute(buf, requestHeaderTemplateInput{ + Methods: allMethods, + Imports: importsMap, + }) + if err != nil { + return fmt.Errorf("executing template: %w", err) + } + + src, err := imports.Process(requestHeaderFile, buf.Bytes(), nil) + if err != nil { + return fmt.Errorf("failed to process imports: %w", err) + } + + src, err = format.Source(src) + if err != nil { + return fmt.Errorf("failed to format: %w", err) + } + + if cfg.verifyOnly { + currentSrc, err := os.ReadFile(requestHeaderFile) + if err != nil { + return fmt.Errorf("failed to read existing file: %w", err) + } + + if !bytes.Equal(src, currentSrc) { + return fmt.Errorf("generated file does not match existing file: %s", requestHeaderFile) + } + + return nil + } + + if err := os.WriteFile(requestHeaderFile, src, 0666); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} + +func extractMethodHeaders(service protoreflect.ServiceDescriptor) ([]methodHeaderInfo, string, string, error) { + var methods []methodHeaderInfo + + // Get the Go package info from the file descriptor + fileDesc := service.ParentFile() + goPackageOption := fileDesc.Options().(*descriptorpb.FileOptions).GetGoPackage() + + // Parse go_package option: "go.temporal.io/api/workflowservice/v1;workflowservice" + // Format is "import/path;packagename" or just "import/path" + parts := strings.Split(goPackageOption, ";") + importPath := parts[0] + var packageAlias string + if len(parts) > 1 { + packageAlias = parts[1] + } else { + // Use last part of import path as alias + pathParts := strings.Split(importPath, "/") + packageAlias = pathParts[len(pathParts)-1] + } + + for i := 0; i < service.Methods().Len(); i++ { + method := service.Methods().Get(i) + opts := method.Options() + if opts == nil { + continue + } + + if !proto.HasExtension(opts, protometa.E_RequestHeader) { + continue + } + + ext := proto.GetExtension(opts, protometa.E_RequestHeader) + annotations, ok := ext.([]*protometa.RequestHeaderAnnotation) + if !ok || len(annotations) == 0 { + continue + } + + requestTypeName := string(method.Input().Name()) + requestMsgDesc := method.Input() + var headerInfos []headerInfo + + for _, annotation := range annotations { + code, err := generateHeaderCode(annotation.GetHeader(), annotation.GetValue(), "r", requestMsgDesc) + if err != nil { + return nil, "", "", fmt.Errorf("failed to generate header code for %s.%s: %w", service.Name(), method.Name(), err) + } + headerInfos = append(headerInfos, headerInfo{Code: code}) + } + methods = append(methods, methodHeaderInfo{ + PackageAlias: packageAlias, + RequestType: requestTypeName, + Headers: headerInfos, + }) + } + + return methods, importPath, packageAlias, nil +} + +func generateHeaderCode(headerName, valueTemplate, reqVar string, msgDesc protoreflect.MessageDescriptor) (string, error) { + fieldPaths := parseValueTemplate(valueTemplate) + + if len(fieldPaths) == 0 { + return fmt.Sprintf("\t\tif %q != \"\" && len(opts.ExistingMetadata.Get(%q)) == 0 {\n\t\t\theaders = append(headers, %q, %q)\n\t\t}", valueTemplate, headerName, headerName, valueTemplate), nil + } + + if len(fieldPaths) > 1 { + return "", fmt.Errorf("only one field interpolation is supported, found %d", len(fieldPaths)) + } + + fieldPath := fieldPaths[0] + accessor, err := generateFieldAccessor(fieldPath, reqVar, msgDesc) + if err != nil { + return "", fmt.Errorf("failed to generate accessor for %s: %w", fieldPath, err) + } + + finalValue := strings.Replace(valueTemplate, "{"+fieldPath+"}", "%s", 1) + + // Generate code that checks if the field value is non-empty and header doesn't exist before formatting and appending + if finalValue == "%s" { + // No prefix/suffix around the field value, so use it directly + return fmt.Sprintf("\t\tif val := %s; val != \"\" && len(opts.ExistingMetadata.Get(%q)) == 0 {\n\t\t\theaders = append(headers, %q, val)\n\t\t}", + accessor, headerName, headerName), nil + } + return fmt.Sprintf("\t\tif val := %s; val != \"\" && len(opts.ExistingMetadata.Get(%q)) == 0 {\n\t\t\theaders = append(headers, %q, fmt.Sprintf(%q, val))\n\t\t}", + accessor, headerName, headerName, finalValue), nil +} + +func generateFieldAccessor(fieldPath, varName string, msgDesc protoreflect.MessageDescriptor) (string, error) { + parts := strings.Split(fieldPath, ".") + accessor := varName + currentMsg := msgDesc + + for _, part := range parts { + field := currentMsg.Fields().ByName(protoreflect.Name(part)) + if field == nil { + return "", fmt.Errorf("field %s not found in message %s", part, currentMsg.FullName()) + } + + goName := protoFieldToGoName(part) + accessor = fmt.Sprintf("%s.Get%s()", accessor, goName) + + if field.Kind() == protoreflect.MessageKind && field.Message() != nil { + currentMsg = field.Message() + } + } + + return accessor, nil +} + +func parseValueTemplate(valueTemplate string) []string { + re := regexp.MustCompile(`\{([^}]+)\}`) + matches := re.FindAllStringSubmatch(valueTemplate, -1) + + fieldPaths := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) > 1 { + fieldPaths = append(fieldPaths, match[1]) + } + } + + return fieldPaths +} + +func protoFieldToGoName(protoName string) string { + parts := strings.Split(protoName, "_") + for i := range parts { + if len(parts[i]) > 0 { + parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] + } + } + return strings.Join(parts, "") +} + diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 6048b980..5cd474af 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -24,7 +24,7 @@ import ( "go.temporal.io/api/sdk/v1" "go.temporal.io/api/update/v1" "go.temporal.io/api/workflow/v1" - "go.temporal.io/api/workflowservice/v1" + workflowservice "go.temporal.io/api/workflowservice/v1" "google.golang.org/grpc" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" diff --git a/proxy/request_header.go b/proxy/request_header.go new file mode 100644 index 00000000..b9a5f00b --- /dev/null +++ b/proxy/request_header.go @@ -0,0 +1,275 @@ +// Code generated by proxygenerator; DO NOT EDIT. + +package proxy + +import ( + "context" + "errors" + "fmt" + + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + + workflowservice "go.temporal.io/api/workflowservice/v1" +) + +// ExtractHeadersOptions contains options for extracting Temporal request headers. +type ExtractHeadersOptions struct { + // Request is the proto message to extract headers from. Required. + Request proto.Message + + // ExistingMetadata contains existing metadata to check for duplicates. + // If provided, headers that already exist will not be added again. + // If nil, no duplicate checking is performed. + ExistingMetadata metadata.MD +} + +// ExtractTemporalRequestHeaders extracts field values from request messages and returns +// them as a slice of key-value pairs suitable for use with metadata.AppendToOutgoingContext. +// Returns nil if no headers should be set. +func ExtractTemporalRequestHeaders(ctx context.Context, opts ExtractHeadersOptions) ([]string, error) { + if opts.Request == nil { + return nil, errors.New("request cannot be nil") + } + + var headers []string + + // Set namespace header if present and not already exists + if len(opts.ExistingMetadata.Get("temporal-namespace")) == 0 { + if nsReq, ok := opts.Request.(interface{ GetNamespace() string }); ok { + if ns := nsReq.GetNamespace(); ns != "" { + headers = append(headers, "temporal-namespace", ns) + } + } + } + + // Set headers from proto annotations + switch r := opts.Request.(type) { + case *workflowservice.CreateScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.DeleteScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.DeleteWorkerDeploymentRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DeleteWorkerDeploymentVersionRequest: + if val := r.GetDeploymentVersion().GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DeleteWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.DescribeBatchOperationRequest: + if val := r.GetJobId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("batch:%s", val)) + } + case *workflowservice.DescribeScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.DescribeTaskQueueRequest: + if val := r.GetTaskQueue().GetName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("taskqueue:%s", val)) + } + case *workflowservice.DescribeWorkerDeploymentRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DescribeWorkerDeploymentVersionRequest: + if val := r.GetDeploymentVersion().GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DescribeWorkerRequest: + if val := r.GetWorkerInstanceKey(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("worker:%s", val)) + } + case *workflowservice.DescribeWorkflowExecutionRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ExecuteMultiOperationRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.FetchWorkerConfigRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.GetWorkflowExecutionHistoryRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.GetWorkflowExecutionHistoryReverseRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ListScheduleMatchingTimesRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.ListTaskQueuePartitionsRequest: + if val := r.GetTaskQueue().GetName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("taskqueue:%s", val)) + } + case *workflowservice.PatchScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.PauseActivityRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.PauseWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.PollWorkflowExecutionUpdateRequest: + if val := r.GetUpdateRef().GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.QueryWorkflowRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.RecordActivityTaskHeartbeatByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RecordActivityTaskHeartbeatRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RecordWorkerHeartbeatRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RequestCancelWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ResetActivityRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ResetStickyTaskQueueRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ResetWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.RespondActivityTaskCanceledByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondActivityTaskCanceledRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondActivityTaskCompletedByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondActivityTaskCompletedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondActivityTaskFailedByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondActivityTaskFailedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondWorkflowTaskCompletedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.RespondWorkflowTaskFailedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.SetWorkerDeploymentCurrentVersionRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.SetWorkerDeploymentManagerRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.SetWorkerDeploymentRampingVersionRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.SignalWithStartWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.SignalWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.StartBatchOperationRequest: + if val := r.GetJobId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("batch:%s", val)) + } + case *workflowservice.StartWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.StopBatchOperationRequest: + if val := r.GetJobId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("batch:%s", val)) + } + case *workflowservice.TerminateWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UnpauseActivityRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UnpauseWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UpdateActivityOptionsRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UpdateScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.UpdateTaskQueueConfigRequest: + if val := r.GetTaskQueue(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("taskqueue:%s", val)) + } + case *workflowservice.UpdateWorkerConfigRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", val) + } + case *workflowservice.UpdateWorkerDeploymentVersionMetadataRequest: + if val := r.GetDeploymentVersion().GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.UpdateWorkflowExecutionOptionsRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UpdateWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + } + + return headers, nil +} diff --git a/proxy/request_header_test.go b/proxy/request_header_test.go new file mode 100644 index 00000000..b151fac3 --- /dev/null +++ b/proxy/request_header_test.go @@ -0,0 +1,208 @@ +package proxy + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "go.temporal.io/api/common/v1" + "go.temporal.io/api/deployment/v1" + "go.temporal.io/api/taskqueue/v1" + "go.temporal.io/api/workflowservice/v1" +) + +// findHeader searches for a header key in the headers slice and returns its value +func findHeader(headers []string, key string) (string, bool) { + for i := 0; i < len(headers); i += 2 { + if i+1 < len(headers) && headers[i] == key { + return headers[i+1], true + } + } + return "", false +} + +func TestExtractTemporalRequestHeaders_NamespaceAlwaysIncluded(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "test-workflow", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + // Namespace should always be included in the headers + nsVal, found := findHeader(headers, "temporal-namespace") + require.True(t, found, "Expected temporal-namespace header, but not found") + require.Equal(t, "test-namespace", nsVal) +} + +func TestExtractTemporalRequestHeaders_EmptyWorkflowId(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + // Should still set namespace + nsVal, found := findHeader(headers, "temporal-namespace") + require.True(t, found, "Expected temporal-namespace header even with empty workflow_id") + require.Equal(t, "test-namespace", nsVal) +} + +func TestExtractTemporalRequestHeaders_SkipExistingHeaders(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "test-workflow", + } + + existingMD := metadata.MD{} + existingMD.Set("temporal-namespace", "existing-namespace") + existingMD.Set("temporal-resource-id", "existing-resource-id") + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + ExistingMetadata: existingMD, + }) + require.NoError(t, err) + + // Should not add any headers since they already exist + require.Empty(t, headers, "Expected no headers to be added when they already exist") +} + +func TestExtractTemporalRequestHeaders_NilRequest(t *testing.T) { + _, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: nil, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "request cannot be nil") +} + +func TestExtractTemporalRequestHeaders_ResourceIdWithPrefix(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "my-workflow-123", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + resVal, found := findHeader(headers, "temporal-resource-id") + require.True(t, found, "Expected temporal-resource-id header") + require.Equal(t, "workflow:my-workflow-123", resVal) +} + +func TestExtractTemporalRequestHeaders_NestedFieldAccess(t *testing.T) { + req := &workflowservice.DeleteWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowExecution: &common.WorkflowExecution{ + WorkflowId: "nested-workflow", + }, + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + resVal, found := findHeader(headers, "temporal-resource-id") + require.True(t, found, "Expected temporal-resource-id header for nested field") + require.Equal(t, "workflow:nested-workflow", resVal) +} + +func TestExtractTemporalRequestHeaders_NestedFieldNilParent(t *testing.T) { + // WorkflowExecution is nil, so GetWorkflowId() should return "" + req := &workflowservice.DeleteWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowExecution: nil, + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + _, found := findHeader(headers, "temporal-resource-id") + require.False(t, found, "Expected no temporal-resource-id header when nested field is nil") + + // Namespace should still be set + nsVal, found := findHeader(headers, "temporal-namespace") + require.True(t, found) + require.Equal(t, "test-namespace", nsVal) +} + +func TestExtractTemporalRequestHeaders_DeploymentNestedField(t *testing.T) { + req := &workflowservice.DeleteWorkerDeploymentVersionRequest{ + Namespace: "test-namespace", + DeploymentVersion: &deployment.WorkerDeploymentVersion{ + DeploymentName: "my-deployment", + }, + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + resVal, found := findHeader(headers, "temporal-resource-id") + require.True(t, found, "Expected temporal-resource-id header for deployment") + require.Equal(t, "deployment:my-deployment", resVal) +} + +func TestExtractTemporalRequestHeaders_TaskQueueNestedField(t *testing.T) { + req := &workflowservice.DescribeTaskQueueRequest{ + Namespace: "test-namespace", + TaskQueue: &taskqueue.TaskQueue{ + Name: "my-task-queue", + }, + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + resVal, found := findHeader(headers, "temporal-resource-id") + require.True(t, found, "Expected temporal-resource-id header for task queue") + require.Equal(t, "taskqueue:my-task-queue", resVal) +} + +func TestExtractTemporalRequestHeaders_ScheduleResourceId(t *testing.T) { + req := &workflowservice.CreateScheduleRequest{ + Namespace: "test-namespace", + ScheduleId: "my-schedule", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + resVal, found := findHeader(headers, "temporal-resource-id") + require.True(t, found, "Expected temporal-resource-id header for schedule") + require.Equal(t, "schedule:my-schedule", resVal) +} + +func TestExtractTemporalRequestHeaders_EmptyNamespace(t *testing.T) { + req := &workflowservice.RecordActivityTaskHeartbeatRequest{ + TaskToken: []byte("token"), + Namespace: "", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + _, found := findHeader(headers, "temporal-namespace") + require.False(t, found, "Expected no temporal-namespace header when namespace is empty") +}