Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ import (
"reflect"
"testing"

"cloud.google.com/go/longrunning/autogen/longrunningpb"
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/connectionstate"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
)

func TestBeginTx(t *testing.T) {
Expand Down Expand Up @@ -310,6 +314,78 @@ func TestIsolationLevelAutoCommit(t *testing.T) {
}
}

func TestCreateDatabase(t *testing.T) {
t.Parallel()

ctx := context.Background()
db, server, teardown := setupTestDBConnection(t)
defer teardown()

var expectedResponse = &databasepb.Database{}
anyMsg, _ := anypb.New(expectedResponse)
server.TestDatabaseAdmin.SetResps([]proto.Message{
&longrunningpb.Operation{
Done: true,
Result: &longrunningpb.Operation_Response{Response: anyMsg},
Name: "test-operation",
},
})

conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer silentClose(conn)

if _, err = conn.ExecContext(ctx, "create database `foo`"); err != nil {
t.Fatalf("failed to execute CREATE DATABASE: %v", err)
}

requests := server.TestDatabaseAdmin.Reqs()
if g, w := len(requests), 1; g != w {
t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w)
}
if req, ok := requests[0].(*databasepb.CreateDatabaseRequest); ok {
if g, w := req.Parent, "projects/p/instances/i"; g != w {
t.Fatalf("parent mismatch\n Got: %v\nWant: %v", g, w)
}
} else {
t.Fatalf("request type mismatch, got %v", requests[0])
}
}

func TestDropDatabase(t *testing.T) {
t.Parallel()

ctx := context.Background()
db, server, teardown := setupTestDBConnection(t)
defer teardown()

server.TestDatabaseAdmin.SetResps([]proto.Message{&emptypb.Empty{}})

conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer silentClose(conn)

if _, err = conn.ExecContext(ctx, "drop database foo"); err != nil {
t.Fatalf("failed to execute DROP DATABASE: %v", err)
}

requests := server.TestDatabaseAdmin.Reqs()
if g, w := len(requests), 1; g != w {
t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w)
}
if req, ok := requests[0].(*databasepb.DropDatabaseRequest); ok {
if g, w := req.Database, "projects/p/instances/i/databases/foo"; g != w {
t.Fatalf("database name mismatch\n Got: %v\nWant: %v", g, w)
}
} else {
t.Fatalf("request type mismatch, got %v", requests[0])
}
}

func TestDDLUsingQueryContext(t *testing.T) {
t.Parallel()

Expand Down
57 changes: 49 additions & 8 deletions statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ var updateStatements = map[string]bool{"UPDATE": true}
var deleteStatements = map[string]bool{"DELETE": true}
var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements))
var clientSideKeywords = map[string]bool{
"SHOW": true,
"SET": true,
"RESET": true,
"START": true,
"RUN": true,
"ABORT": true,
}
"SHOW": true,
"SET": true,
"RESET": true,
"START": true,
"RUN": true,
"ABORT": true,
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
"DROP": true, // DROP DATABASE is handled as a client-side statement
}
var createStatements = map[string]bool{"CREATE": true}
var dropStatements = map[string]bool{"DROP": true}
var showStatements = map[string]bool{"SHOW": true}
var setStatements = map[string]bool{"SET": true}
var resetStatements = map[string]bool{"RESET": true}
Expand Down Expand Up @@ -201,18 +205,36 @@ func (i *identifier) String() string {
// eatIdentifier reads the identifier at the current parser position, updates the parser position,
// and returns the identifier.
func (p *simpleParser) eatIdentifier() (identifier, error) {
// TODO: Add support for quoted identifiers.
p.skipWhitespacesAndComments()
if p.pos >= len(p.sql) {
return identifier{}, status.Errorf(codes.InvalidArgument, "no identifier found at position %d", p.pos)
}

startPos := p.pos
first := true
result := identifier{parts: make([]string, 0, 1)}
appendLastPart := true
for p.pos < len(p.sql) {
if first {
first = false
// Check if this is a quoted identifier.
if p.sql[p.pos] == p.statementParser.identifierQuoteToken() {
pos, quoteLen, err := p.statementParser.skipQuoted(p.sql, p.pos, p.sql[p.pos])
if err != nil {
return identifier{}, err
}
p.pos = pos
result.parts = append(result.parts, string(p.sql[startPos+quoteLen:pos-quoteLen]))
if p.eatToken('.') {
p.skipWhitespacesAndComments()
startPos = p.pos
first = true
continue
} else {
appendLastPart = false
break
}
}
if !p.isValidFirstIdentifierChar() {
return identifier{}, status.Errorf(codes.InvalidArgument, "invalid first identifier character found at position %d: %s", p.pos, p.sql[p.pos:p.pos+1])
}
Expand Down Expand Up @@ -446,6 +468,13 @@ func (p *statementParser) supportsNestedComments() bool {
return p.dialect == databasepb.DatabaseDialect_POSTGRESQL
}

func (p *statementParser) identifierQuoteToken() byte {
if p.dialect == databasepb.DatabaseDialect_POSTGRESQL {
return '"'
}
return '`'
}

func (p *statementParser) supportsBacktickQuotes() bool {
return p.dialect != databasepb.DatabaseDialect_POSTGRESQL
}
Expand Down Expand Up @@ -653,6 +682,10 @@ func (p *statementParser) skipMultiLineComment(sql []byte, pos int) int {
return pos
}

// skipQuoted skips a quoted string at the given position in the sql string and
// returns the new position, the quote length, or an error if the quoted string
// could not be read.
// The quote length is either 1 for normal quoted strings, and 3 for triple-quoted string.
func (p *statementParser) skipQuoted(sql []byte, pos int, quote byte) (int, int, error) {
isTripleQuoted := p.supportsTripleQuotedLiterals() && len(sql) > pos+2 && sql[pos+1] == quote && sql[pos+2] == quote
if isTripleQuoted && (isMultibyte(sql[pos+1]) || isMultibyte(sql[pos+2])) {
Expand Down Expand Up @@ -951,6 +984,14 @@ func (p *statementParser) isQuery(query string) bool {
return info.statementType == statementTypeQuery
}

func isCreateKeyword(keyword string) bool {
return isStatementKeyword(keyword, createStatements)
}

func isDropKeyword(keyword string) bool {
return isStatementKeyword(keyword, dropStatements)
}

func isQueryKeyword(keyword string) bool {
return isStatementKeyword(keyword, selectStatements)
}
Expand Down
34 changes: 34 additions & 0 deletions statement_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2494,15 +2494,28 @@ func TestEatIdentifier(t *testing.T) {
input: "my_property",
want: identifier{parts: []string{"my_property"}},
},
{
input: "`my_property`",
want: identifier{parts: []string{"my_property"}},
},
{
input: "my_extension.my_property",
want: identifier{parts: []string{"my_extension", "my_property"}},
},
{
input: "`my_extension`.`my_property`",
want: identifier{parts: []string{"my_extension", "my_property"}},
},
{
// spaces are allowed
input: " \n my_extension . \t my_property ",
want: identifier{parts: []string{"my_extension", "my_property"}},
},
{
// spaces are allowed
input: " \n `my_extension` . \t `my_property` ",
want: identifier{parts: []string{"my_extension", "my_property"}},
},
{
// comments are treated the same as spaces and are allowed
input: " /* comment */ \n my_extension -- yet another comment\n. \t -- Also a comment \nmy_property ",
Expand All @@ -2512,6 +2525,14 @@ func TestEatIdentifier(t *testing.T) {
input: "p1.p2.p3.p4",
want: identifier{parts: []string{"p1", "p2", "p3", "p4"}},
},
{
input: "`p1`.`p2`.`p3`.`p4`",
want: identifier{parts: []string{"p1", "p2", "p3", "p4"}},
},
{
input: "`p1`.p2.`p3`.p4",
want: identifier{parts: []string{"p1", "p2", "p3", "p4"}},
},
{
input: "a.b.c",
want: identifier{parts: []string{"a", "b", "c"}},
Expand All @@ -2520,6 +2541,11 @@ func TestEatIdentifier(t *testing.T) {
input: "1a",
wantErr: true,
},
{
// Double-quotes are not valid around identifiers in GoogleSQL.
input: `"1a""`,
wantErr: true,
},
{
input: "my_extension.",
wantErr: true,
Expand All @@ -2537,6 +2563,14 @@ func TestEatIdentifier(t *testing.T) {
input: "a . 1a",
wantErr: true,
},
{
input: "`p1 /* looks like a comment */ `.`p2`",
want: identifier{parts: []string{"p1 /* looks like a comment */ ", "p2"}},
},
{
input: "```p1 -- looks like a comment\n ```.`p2`",
want: identifier{parts: []string{"p1 -- looks like a comment\n ", "p2"}},
},
}
for _, test := range tests {
sp := &simpleParser{sql: []byte(test.input), statementParser: parser}
Expand Down
Loading
Loading