Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 1 addition & 24 deletions cmd/protoc-gen-fastmarshal/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ func addSprigFunctions(fm template.FuncMap) template.FuncMap {
func addProtoFunctions(fm template.FuncMap, protoFile *protogen.File, names specialNames, goPackageForFile map[string]string) template.FuncMap {
fm["protoNumberEncodeMethod"] = protoNumberEncodeMethod
fm["getExtensions"] = getExtensions(protoFile)
fm["allMessages"] = allMessages(protoFile)
fm["getAdditionalImports"] = getAdditionalImports(protoFile, goPackageForFile)
fm["getImportPrefix"] = getImportPrefix(protoFile, goPackageForFile)
fm["mapFieldGoType"] = mapFieldGoType(protoFile, goPackageForFile)
Expand Down Expand Up @@ -250,28 +249,6 @@ func getExtensions(protoFile *protogen.File) func(*protogen.Message) []*protogen
}
}

// allMessages returns a list of all top-level and nested message definitions in protoFile
func allMessages(protoFile *protogen.File) func() []*protogen.Message {
var queue, msgs []*protogen.Message
queue = append(queue, protoFile.Messages...)
for len(queue) > 0 {
m := queue[0]
queue = queue[1:]
msgs = append(msgs, m)
for _, mm := range m.Messages {
// skip "messgaes" that represent map fields
if mm.Desc.IsMapEntry() {
continue
}
queue = append(queue, mm)
}
}

return func() []*protogen.Message {
return msgs
}
}

// getAdditionalImports returns a set of distinct imports paths required by the fields of v, which
// must be either a single protogen.Message or a slice of messages.
//
Expand Down Expand Up @@ -440,7 +417,7 @@ func msgHasRequiredField(m *protogen.Message) bool {
func hasRequiredFields(protoFile *protogen.File) func(*protogen.Message) bool {
anyMessageHasRequiredFields := false
if protoFile.Desc.Syntax() == protoreflect.Proto2 {
for _, m := range allMessages(protoFile)() {
for _, m := range allMessages(protoFile) {
anyMessageHasRequiredFields = anyMessageHasRequiredFields || msgHasRequiredField(m)
}
}
Expand Down
1 change: 1 addition & 0 deletions cmd/protoc-gen-fastmarshal/generate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
// Protobuf code generator
type generateRequest struct {
ProtoDesc *protogen.File
Messages []*protogen.Message
Mode outputMode
NameTemplate string
Funcs template.FuncMap
Expand Down
19 changes: 12 additions & 7 deletions cmd/protoc-gen-fastmarshal/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"maps"
"os"
"time"

Expand All @@ -21,6 +22,7 @@ func generateSingle(plugin *protogen.Plugin, req generateRequest) error {
Now time.Time
Pwd string
ProtoDesc *protogen.File
Messages []*protogen.Message
APIVersion string
SpecialNames specialNames
EnableUnsafeDecode bool
Expand All @@ -29,6 +31,7 @@ func generateSingle(plugin *protogen.Plugin, req generateRequest) error {
Now: time.Now().UTC(),
Pwd: func() string { p, _ := os.Getwd(); return p }(),
ProtoDesc: req.ProtoDesc,
Messages: req.Messages,
APIVersion: req.APIVersion,
SpecialNames: req.SpecialNames,
EnableUnsafeDecode: req.EnableUnsafeDecode,
Expand All @@ -38,26 +41,29 @@ func generateSingle(plugin *protogen.Plugin, req generateRequest) error {
name, content string
err error
)

goPackageForFile := make(map[string]string, len(plugin.Files))
for _, f := range plugin.Files {
goPackageForFile[f.Desc.Path()] = string(f.GoPackageName)
}
funcs := codeGenFunctions(req.ProtoDesc, req.SpecialNames, goPackageForFile)
for k, v := range req.Funcs {
funcs[k] = v
}
maps.Copy(funcs, req.Funcs)

nt, err := loadTemplateFromString(req.NameTemplate, funcs)
if err != nil {
return fmt.Errorf("unable to parse output file name template: %w", err)
}

name, err = renderTemplate(nt, args)
if err != nil {
return fmt.Errorf("unable to generate output file name from name template: %w", err)
}

ct, err := loadTemplateFromEmbedded(funcs)
if err != nil {
return fmt.Errorf("unable to load embedded content templates: %w", err)
}

content, err = renderNamedTemplate(ct, "SingleFile", args)
if err != nil {
return fmt.Errorf("unable to generate output from content template: %w", err)
Expand Down Expand Up @@ -86,9 +92,7 @@ func generatePerMessage(plugin *protogen.Plugin, req generateRequest) error {
goPackageForFile[f.Desc.Path()] = string(f.GoPackageName)
}
funcs := codeGenFunctions(req.ProtoDesc, req.SpecialNames, goPackageForFile)
for k, v := range req.Funcs {
funcs[k] = v
}
maps.Copy(funcs, req.Funcs)

var (
now = time.Now().UTC()
Expand All @@ -99,7 +103,7 @@ func generatePerMessage(plugin *protogen.Plugin, req generateRequest) error {
if err != nil {
return fmt.Errorf("unable to load embedded content templates: %w", err)
}
for _, msg := range allMessages(req.ProtoDesc)() {
for _, msg := range req.Messages {
args := genArgsPerFile{
Now: now,
Pwd: pwd,
Expand All @@ -109,6 +113,7 @@ func generatePerMessage(plugin *protogen.Plugin, req generateRequest) error {
SpecialNames: req.SpecialNames,
EnableUnsafeDecode: req.EnableUnsafeDecode,
}

content, err := renderNamedTemplate(tt, "PerMessage", args)
if err != nil {
return fmt.Errorf("error executing content template for message %s: %w", msg.Desc.FullName(), err)
Expand Down
25 changes: 25 additions & 0 deletions cmd/protoc-gen-fastmarshal/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ func doGenerate(opts *options) func(*protogen.Plugin) error {
continue
}

messages := allMessages(protoFile)
if len(messages) == 0 {
continue
}

// account for per-message output mode
// - default output template is "[protofile].pb.fm.go"
// - file-per-message template is "[protofile]_[lower(messagename)].pb.fm.go"
Expand All @@ -139,6 +144,7 @@ func doGenerate(opts *options) func(*protogen.Plugin) error {
req := generateRequest{
Mode: outputModeSingleFile,
ProtoDesc: protoFile,
Messages: messages,
NameTemplate: nameTemplate,
APIVersion: opts.apiVersion.String(),
SpecialNames: opts.specialNames,
Expand All @@ -154,3 +160,22 @@ func doGenerate(opts *options) func(*protogen.Plugin) error {
return nil
}
}

// allMessages returns a list of all top-level and nested message definitions in protoFile
func allMessages(protoFile *protogen.File) []*protogen.Message {
var queue, msgs []*protogen.Message
queue = append(queue, protoFile.Messages...)
for len(queue) > 0 {
m := queue[0]
queue = queue[1:]
msgs = append(msgs, m)
for _, mm := range m.Messages {
// skip "messages" that represent map fields
if mm.Desc.IsMapEntry() {
continue
}
queue = append(queue, mm)
}
}
return msgs
}
43 changes: 43 additions & 0 deletions cmd/protoc-gen-fastmarshal/run_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package main

import (
"flag"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/pluginpb"

"github.com/CrowdStrike/csproto"
)

func Test_doGenerate_Empty(t *testing.T) {
flags := flag.NewFlagSet("", flag.PanicOnError)
pluginOpts := protogen.Options{ParamFunc: flags.Set}

plugin, err := pluginOpts.New(&pluginpb.CodeGeneratorRequest{
FileToGenerate: []string{"empty.proto"},
ProtoFile: []*descriptorpb.FileDescriptorProto{
{
Name: csproto.String("empty.proto"),
Package: csproto.String("crowdstrike.csproto.test"),
Syntax: csproto.String("proto3"),
Options: &descriptorpb.FileOptions{
GoPackage: csproto.String("github.com/CrowdStrike/csproto/testpb"),
},
},
},
})
require.NoError(t, err)

runPlugin := doGenerate(&options{
specialNames: make(specialNames),
})

err = runPlugin(plugin)
require.NoError(t, err)

resp := plugin.Response()
require.Empty(t, resp.File)
}
2 changes: 1 addition & 1 deletion cmd/protoc-gen-fastmarshal/templates/permessage.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,4 @@ func (m *{{.Message.GoIdent.GoName}}) csprotoCheckRequiredFields() error {
return nil
}
{{ end }}
{{end}}
{{end}}
6 changes: 3 additions & 3 deletions cmd/protoc-gen-fastmarshal/templates/singlefile.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
"strings"{{end}}
"sync/atomic"
"github.com/CrowdStrike/csproto"
{{range $path, $alias := (allMessages | getAdditionalImports)}}{{ (printf "%s %s" $alias $path) | trimspace}}
{{range $path, $alias := (.Messages | getAdditionalImports)}}{{ (printf "%s %s" $alias $path) | trimspace}}
{{end}}
)

{{ range allMessages }}
{{ range .Messages }}

//------------------------------------------------------------------------------
// Custom Protobuf size/marshal/unmarshal code for {{ .GoIdent.GoName }}
Expand Down Expand Up @@ -192,4 +192,4 @@ func (m *{{.GoIdent.GoName}}) csprotoCheckRequiredFields() error {
}
{{ end }}
{{ end }}
{{end}}
{{end}}