Skip to content

Commit 2226a18

Browse files
authored
Append @procname to start of procedures
2 parents cec1702 + 924f432 commit 2226a18

3 files changed

Lines changed: 67 additions & 7 deletions

File tree

deployable_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package sqlcode
22

33
import (
4-
"github.com/stretchr/testify/assert"
5-
"github.com/stretchr/testify/require"
64
"testing"
75
"testing/fstest"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
89
)
910

1011
func TestDeployable(t *testing.T) {

sqlparser/parser.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"strings"
1515
)
1616

17+
var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n"
18+
1719
func CopyToken(s *Scanner, target *[]Unparsed) {
1820
*target = append(*target, CreateUnparsed(s))
1921
}
@@ -415,6 +417,8 @@ func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Creat
415417
// point we copy the rest until the batch ends; *but* track dependencies
416418
// + some other details mentioned below
417419

420+
firstAs := true
421+
418422
tailloop:
419423
for {
420424
tt := s.TokenType()
@@ -466,6 +470,22 @@ tailloop:
466470
if !found {
467471
result.DependsOn = append(result.DependsOn, dep)
468472
}
473+
case tt == ReservedWordToken && s.Token() == "as":
474+
CopyToken(s, &result.Body)
475+
NextTokenCopyingWhitespace(s, &result.Body)
476+
if firstAs {
477+
// Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name
478+
// from inside the procedure (for example, when logging)
479+
if result.CreateType == "procedure" {
480+
procNameToken := Unparsed{
481+
Type: OtherToken,
482+
RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")),
483+
}
484+
result.Body = append(result.Body, procNameToken)
485+
}
486+
firstAs = false
487+
}
488+
469489
default:
470490
CopyToken(s, &result.Body)
471491
NextTokenCopyingWhitespace(s, &result.Body)

sqlparser/parser_test.go

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package sqlparser
22

33
import (
4-
"github.com/stretchr/testify/assert"
5-
"github.com/stretchr/testify/require"
4+
"fmt"
65
"strings"
76
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
810
)
911

1012
func TestParserSmokeTest(t *testing.T) {
@@ -44,16 +46,16 @@ end;
4446

4547
assert.Equal(t, "[TestFunc]", c.QuotedName.Value)
4648
assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings())
47-
assert.Equal(t, `-- preceding comment 1
49+
assert.Equal(t, fmt.Sprintf(`-- preceding comment 1
4850
/* preceding comment 2
4951
50-
asdfasdf */create procedure [code].TestFunc as begin
52+
asdfasdf */create procedure [code].TestFunc as %sbegin
5153
refers to [code].OtherFunc [code].HelloFunc;
5254
create table x ( int x not null ); -- should be ok
5355
end;
5456
5557
/* trailing comment */
56-
`, c.String())
58+
`, fmt.Sprintf(templateRoutineName, "TestFunc")), c.String())
5759

5860
assert.Equal(t,
5961
[]Error{
@@ -271,6 +273,43 @@ create procedure [code].FirstProc as table (x int)
271273
assert.Equal(t, emsg, doc.Errors[0].Message)
272274
}
273275

276+
func TestCreateProcsAndCheckForRoutineName(t *testing.T) {
277+
testcases := []struct {
278+
name string
279+
doc Document
280+
expectedProcName string
281+
expectedIndex int
282+
}{
283+
{
284+
name: "Test simple proc",
285+
expectedProcName: "FirstProc",
286+
doc: ParseString("test.sql", `
287+
create procedure [code].FirstProc as
288+
begin
289+
end
290+
`),
291+
expectedIndex: 10,
292+
},
293+
{
294+
name: "Test proc with args",
295+
expectedProcName: "transform:safeguarding.Calculation/HEAD",
296+
doc: ParseString("test.sql", `
297+
create procedure [code].[transform:safeguarding.Calculation/HEAD](@now datetime2,
298+
@count bigint output) as
299+
`),
300+
expectedIndex: 22,
301+
},
302+
}
303+
for _, tc := range testcases {
304+
require.Equal(t, 0, len(tc.doc.Errors))
305+
assert.Len(t, tc.doc.Creates, 1)
306+
assert.Greater(t, len(tc.doc.Creates[0].Body), tc.expectedIndex)
307+
assert.Equal(t,
308+
fmt.Sprintf(templateRoutineName, tc.expectedProcName),
309+
tc.doc.Creates[0].Body[tc.expectedIndex].RawValue,
310+
)
311+
}
312+
}
274313

275314
func TestGoWithoutNewline(t *testing.T) {
276315
doc := ParseString("test.sql", `

0 commit comments

Comments
 (0)