diff --git a/.travis.yml b/.travis.yml index 31f9c88..211edd7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ go: - 1.5 - 1.6 - 1.7 + - 1.8 - tip script: go test -race -coverprofile=coverage.txt -covermode=atomic diff --git a/LICENSE b/LICENSE index 25255bb..7f8bedf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) -Copyright (c) 2013-2016, DATA-DOG team +Copyright (c) 2013-2017, DATA-DOG team All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index 80b5bc0..f3d5487 100644 --- a/README.md +++ b/README.md @@ -10,19 +10,16 @@ maintain correct **TDD** workflow. - this library is now complete and stable. (you may not find new changes for this reason) - supports concurrency and multiple connections. +- supports **go1.8** Context related feature mocking and Named sql parameters. - does not require any modifications to your source code. - the driver allows to mock any sql driver method behavior. - has strict by default expectation order matching. -- has no vendor dependencies. +- has no third party dependencies. ## Install go get gopkg.in/DATA-DOG/go-sqlmock.v1 -If you need an old version, checkout **go-sqlmock** at gopkg.in: - - go get gopkg.in/DATA-DOG/go-sqlmock.v0 - ## Documentation and Examples Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference. @@ -187,8 +184,11 @@ It only asserts that argument is of `time.Time` type. go test -race -## Changes +## Change Log +- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct + but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now + accept multiple row sets. - **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL query. It should still be validated even if Exec or Query is not executed on that prepared statement. diff --git a/expectations.go b/expectations.go index 5b6865e..415759e 100644 --- a/expectations.go +++ b/expectations.go @@ -3,10 +3,10 @@ package sqlmock import ( "database/sql/driver" "fmt" - "reflect" "regexp" "strings" "sync" + "time" ) // an expectation interface @@ -54,6 +54,7 @@ func (e *ExpectedClose) String() string { // returned by *Sqlmock.ExpectBegin. type ExpectedBegin struct { commonExpectation + delay time.Duration } // WillReturnError allows to set an error for *sql.DB.Begin action @@ -71,6 +72,13 @@ func (e *ExpectedBegin) String() string { return msg } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { + e.delay = duration + return e +} + // ExpectedCommit is used to manage *sql.Tx.Commit expectation // returned by *Sqlmock.ExpectCommit. type ExpectedCommit struct { @@ -118,7 +126,8 @@ func (e *ExpectedRollback) String() string { // Returned by *Sqlmock.ExpectQuery. type ExpectedQuery struct { queryBasedExpectation - rows driver.Rows + rows driver.Rows + delay time.Duration } // WithArgs will match given expected args to actual database query arguments. @@ -135,10 +144,10 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { return e } -// WillReturnRows specifies the set of resulting rows that will be returned -// by the triggered query -func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { - e.rows = rows +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { + e.delay = duration return e } @@ -158,12 +167,7 @@ func (e *ExpectedQuery) String() string { } if e.rows != nil { - msg += "\n - should return rows:\n" - rs, _ := e.rows.(*rows) - for i, row := range rs.rows { - msg += fmt.Sprintf(" %d - %+v\n", i, row) - } - msg = strings.TrimSpace(msg) + msg += fmt.Sprintf("\n - %s", e.rows) } if e.err != nil { @@ -178,6 +182,7 @@ func (e *ExpectedQuery) String() string { type ExpectedExec struct { queryBasedExpectation result driver.Result + delay time.Duration } // WithArgs will match given expected args to actual database exec operation arguments. @@ -194,6 +199,13 @@ func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { return e } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { + e.delay = duration + return e +} + // String returns string representation func (e *ExpectedExec) String() string { msg := "ExpectedExec => expecting Exec which:" @@ -244,6 +256,7 @@ type ExpectedPrepare struct { sqlRegex *regexp.Regexp statement driver.Stmt closeErr error + delay time.Duration } // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. @@ -258,6 +271,13 @@ func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { return e } +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { + e.delay = duration + return e +} + // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { @@ -300,7 +320,7 @@ type queryBasedExpectation struct { args []driver.Value } -func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (err error) { +func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { if !e.queryMatches(sql) { return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String()) } @@ -322,37 +342,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (e func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } - -func (e *queryBasedExpectation) argsMatches(args []driver.Value) error { - if nil == e.args { - return nil - } - if len(args) != len(e.args) { - return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) - } - for k, v := range args { - // custom argument matcher - matcher, ok := e.args[k].(Argument) - if ok { - if !matcher.Match(v) { - return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) - } - continue - } - - // convert to driver converter - darg, err := driver.DefaultParameterConverter.ConvertValue(e.args[k]) - if err != nil { - return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) - } - - if !driver.IsValue(darg) { - return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) - } - - if !reflect.DeepEqual(darg, args[k]) { - return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, args[k], args[k]) - } - } - return nil -} diff --git a/expectations_before_go18.go b/expectations_before_go18.go new file mode 100644 index 0000000..146f240 --- /dev/null +++ b/expectations_before_go18.go @@ -0,0 +1,52 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { + e.rows = &rowSets{sets: []*Rows{rows}} + return e +} + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + // convert to driver converter + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/expectations_go18.go b/expectations_go18.go new file mode 100644 index 0000000..2b4b44e --- /dev/null +++ b/expectations_go18.go @@ -0,0 +1,66 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { + sets := make([]*Rows, len(rows)) + for i, r := range rows { + sets[i] = r + } + e.rows = &rowSets{sets: sets} + return e +} + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + // @TODO should we assert either all args are named or ordinal? + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + if named, isNamed := dval.(sql.NamedArg); isNamed { + dval = named.Value + if v.Name != named.Name { + return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) + } + } else if k+1 != v.Ordinal { + return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal) + } + + // convert to driver converter + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/expectations_go18_test.go b/expectations_go18_test.go new file mode 100644 index 0000000..5f30d2f --- /dev/null +++ b/expectations_go18_test.go @@ -0,0 +1,64 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "testing" +) + +func TestQueryExpectationNamedArgComparison(t *testing.T) { + e := &queryBasedExpectation{} + against := []namedValue{{Value: int64(5), Name: "id"}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{ + sql.Named("id", 5), + sql.Named("s", "str"), + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "s"}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "username"}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to Name") + } + + e.args = []driver.Value{int64(5), "str"} + + against = []namedValue{ + {Value: int64(5), Ordinal: 0}, + {Value: "str", Ordinal: 1}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to wrong Ordinal position") + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } +} diff --git a/expectations_test.go b/expectations_test.go index 032f029..2e3c097 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -10,29 +10,38 @@ import ( func TestQueryExpectationArgComparison(t *testing.T) { e := &queryBasedExpectation{} - against := []driver.Value{int64(5)} + against := []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) } e.args = []driver.Value{5, "str"} - against = []driver.Value{int64(5)} + against = []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since the size is not the same") } - against = []driver.Value{int64(3), "str"} + against = []namedValue{ + {Value: int64(3), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since the first argument (int value) is different") } - against = []driver.Value{int64(5), "st"} + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "st", Ordinal: 2}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since the second argument (string value) is different") } - against = []driver.Value{int64(5), "str"} + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, but it did not: %s", err) } @@ -41,7 +50,10 @@ func TestQueryExpectationArgComparison(t *testing.T) { tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") e.args = []driver.Value{5, tm} - against = []driver.Value{int64(5), tm} + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: tm, Ordinal: 2}, + } if err := e.argsMatches(against); err != nil { t.Error("arguments should match, but it did not") } @@ -56,25 +68,33 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { var e *queryBasedExpectation e = &queryBasedExpectation{args: []driver.Value{true}} - against := []driver.Value{true} + against := []namedValue{ + {Value: true, Ordinal: 1}, + } if err := e.argsMatches(against); err != nil { t.Error("arguments should match, since arguments are the same") } e = &queryBasedExpectation{args: []driver.Value{false}} - against = []driver.Value{false} + against = []namedValue{ + {Value: false, Ordinal: 1}, + } if err := e.argsMatches(against); err != nil { t.Error("arguments should match, since argument are the same") } e = &queryBasedExpectation{args: []driver.Value{true}} - against = []driver.Value{false} + against = []namedValue{ + {Value: false, Ordinal: 1}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since argument is different") } e = &queryBasedExpectation{args: []driver.Value{false}} - against = []driver.Value{true} + against = []namedValue{ + {Value: true, Ordinal: 1}, + } if err := e.argsMatches(against); err == nil { t.Error("arguments should not match, since argument is different") } @@ -117,7 +137,7 @@ func TestBuildQuery(t *testing.T) { name = 'John' and address = 'Jakarta' - + ` mock.ExpectQuery(query) diff --git a/rows.go b/rows.go index 8b6beb6..39f9f83 100644 --- a/rows.go +++ b/rows.go @@ -3,6 +3,7 @@ package sqlmock import ( "database/sql/driver" "encoding/csv" + "fmt" "io" "strings" ) @@ -18,57 +19,22 @@ var CSVColumnParser = func(s string) []byte { return []byte(s) } -// Rows interface allows to construct rows -// which also satisfies database/sql/driver.Rows interface -type Rows interface { - // composed interface, supports sql driver.Rows - driver.Rows - - // AddRow composed from database driver.Value slice - // return the same instance to perform subsequent actions. - // Note that the number of values must match the number - // of columns - AddRow(columns ...driver.Value) Rows - - // FromCSVString build rows from csv string. - // return the same instance to perform subsequent actions. - // Note that the number of values must match the number - // of columns - FromCSVString(s string) Rows - - // RowError allows to set an error - // which will be returned when a given - // row number is read - RowError(row int, err error) Rows - - // CloseError allows to set an error - // which will be returned by rows.Close - // function. - // - // The close error will be triggered only in cases - // when rows.Next() EOF was not yet reached, that is - // a default sql library behavior - CloseError(err error) Rows +type rowSets struct { + sets []*Rows + pos int } -type rows struct { - cols []string - rows [][]driver.Value - pos int - nextErr map[int]error - closeErr error +func (rs *rowSets) Columns() []string { + return rs.sets[rs.pos].cols } -func (r *rows) Columns() []string { - return r.cols -} - -func (r *rows) Close() error { - return r.closeErr +func (rs *rowSets) Close() error { + return rs.sets[rs.pos].closeErr } // advances to next row -func (r *rows) Next(dest []driver.Value) error { +func (rs *rowSets) Next(dest []driver.Value) error { + r := rs.sets[rs.pos] r.pos++ if r.pos > len(r.rows) { return io.EOF // per interface spec @@ -81,24 +47,66 @@ func (r *rows) Next(dest []driver.Value) error { return r.nextErr[r.pos-1] } +// transforms to debuggable printable string +func (rs *rowSets) String() string { + msg := "should return rows:\n" + if len(rs.sets) == 1 { + for n, row := range rs.sets[0].rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + return strings.TrimSpace(msg) + } + for i, set := range rs.sets { + msg += fmt.Sprintf(" result set: %d\n", i) + for n, row := range set.rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + } + return strings.TrimSpace(msg) +} + +// Rows is a mocked collection of rows to +// return for Query result +type Rows struct { + cols []string + rows [][]driver.Value + pos int + nextErr map[int]error + closeErr error +} + // NewRows allows Rows to be created from a // sql driver.Value slice or from the CSV string and // to be used as sql driver.Rows -func NewRows(columns []string) Rows { - return &rows{cols: columns, nextErr: make(map[int]error)} +func NewRows(columns []string) *Rows { + return &Rows{cols: columns, nextErr: make(map[int]error)} } -func (r *rows) CloseError(err error) Rows { +// CloseError allows to set an error +// which will be returned by rows.Close +// function. +// +// The close error will be triggered only in cases +// when rows.Next() EOF was not yet reached, that is +// a default sql library behavior +func (r *Rows) CloseError(err error) *Rows { r.closeErr = err return r } -func (r *rows) RowError(row int, err error) Rows { +// RowError allows to set an error +// which will be returned when a given +// row number is read +func (r *Rows) RowError(row int, err error) *Rows { r.nextErr[row] = err return r } -func (r *rows) AddRow(values ...driver.Value) Rows { +// AddRow composed from database driver.Value slice +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) AddRow(values ...driver.Value) *Rows { if len(values) != len(r.cols) { panic("Expected number of values to match number of columns") } @@ -112,7 +120,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows { return r } -func (r *rows) FromCSVString(s string) Rows { +// FromCSVString build rows from csv string. +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) FromCSVString(s string) *Rows { res := strings.NewReader(strings.TrimSpace(s)) csvReader := csv.NewReader(res) diff --git a/rows_go18.go b/rows_go18.go new file mode 100644 index 0000000..4ecf84e --- /dev/null +++ b/rows_go18.go @@ -0,0 +1,20 @@ +// +build go1.8 + +package sqlmock + +import "io" + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) HasNextResultSet() bool { + return rs.pos+1 < len(rs.sets) +} + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) NextResultSet() error { + if !rs.HasNextResultSet() { + return io.EOF + } + + rs.pos++ + return nil +} diff --git a/rows_go18_test.go b/rows_go18_test.go new file mode 100644 index 0000000..297e7c0 --- /dev/null +++ b/rows_go18_test.go @@ -0,0 +1,92 @@ +// +build go1.8 + +package sqlmock + +import ( + "fmt" + "testing" +) + +func TestQueryMultiRows(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error")) + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users"). + WithArgs(5). + WillReturnRows(rs1, rs2) + + rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Error("expected a row to be available in first result set") + } + + var id int + var name string + + err = rows.Scan(&id, &name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if id != 5 || name != "hello world" { + t.Errorf("unexpected row values id: %v name: %v", id, name) + } + + if rows.Next() { + t.Error("was not expecting next row in first result set") + } + + if !rows.NextResultSet() { + t.Error("had to have next result set") + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "gopher" { + t.Errorf("unexpected row name: %v", name) + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "john" { + t.Errorf("unexpected row name: %v", name) + } + + if rows.Next() { + t.Error("expected next row to produce error") + } + + if rows.Err() == nil { + t.Error("expected an error, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} diff --git a/sqlmock.go b/sqlmock.go index 7ac8076..b906a3f 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -15,6 +15,7 @@ import ( "database/sql/driver" "fmt" "regexp" + "time" ) // Sqlmock interface serves to create expectations @@ -66,6 +67,11 @@ type Sqlmock interface { // By default it is set to - true. But if you use goroutines // to parallelize your query executation, that option may // be handy. + // + // This option may be turned on anytime during tests. As soon + // as it is switched to false, expectations will be matched + // in any order. Or otherwise if switched to true, any unmatched + // expectations will be expected in order MatchExpectationsInOrder(bool) } @@ -154,6 +160,16 @@ func (c *sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { + ex, err := c.begin() + if err != nil { + return nil, err + } + + time.Sleep(ex.delay) + return c, nil +} + +func (c *sqlmock) begin() (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int @@ -184,7 +200,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) { expected.triggered = true expected.Unlock() - return c, expected.err + + return expected, expected.err } func (c *sqlmock) ExpectBegin() *ExpectedBegin { @@ -194,7 +211,25 @@ func (c *sqlmock) ExpectBegin() *ExpectedBegin { } // Exec meets http://golang.org/pkg/database/sql/driver/#Execer -func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.exec(query, namedArgs) + if err != nil { + return nil, err + } + + time.Sleep(ex.delay) + return ex.result, nil +} + +func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int @@ -229,7 +264,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er } return nil, fmt.Errorf(msg, query, args) } - defer expected.Unlock() if !expected.queryMatches(query) { @@ -241,7 +275,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er } expected.triggered = true - if expected.err != nil { return nil, expected.err // mocked to return error } @@ -250,7 +283,7 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected) } - return expected.result, err + return expected, nil } func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { @@ -263,6 +296,16 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { + ex, err := c.prepare(query) + if err != nil { + return nil, err + } + + time.Sleep(ex.delay) + return &statement{c, query, ex.closeErr}, nil +} + +func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool @@ -298,7 +341,7 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { } expected.triggered = true - return &statement{c, query, expected.closeErr}, expected.err + return expected, expected.err } func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { @@ -308,8 +351,32 @@ func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { return e } +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + // Query meets http://golang.org/pkg/database/sql/driver/#Queryer -func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.query(query, namedArgs) + if err != nil { + return nil, err + } + + time.Sleep(ex.delay) + return ex.rows, nil +} + +func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int @@ -357,7 +424,6 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err } expected.triggered = true - if expected.err != nil { return nil, expected.err // mocked to return error } @@ -365,8 +431,7 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err if expected.rows == nil { return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } - - return expected.rows, err + return expected, nil } func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { diff --git a/sqlmock_go18.go b/sqlmock_go18.go new file mode 100644 index 0000000..c49429c --- /dev/null +++ b/sqlmock_go18.go @@ -0,0 +1,101 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql/driver" + "errors" + "time" +) + +var ErrCancelled = errors.New("canceling query due to user request") + +// Implement the "QueryerContext" interface +func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.query(query, namedArgs) + if err != nil { + return nil, err + } + + select { + case <-time.After(ex.delay): + return ex.rows, nil + case <-ctx.Done(): + return nil, ErrCancelled + } +} + +// Implement the "ExecerContext" interface +func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.exec(query, namedArgs) + if err != nil { + return nil, err + } + + select { + case <-time.After(ex.delay): + return ex.result, nil + case <-ctx.Done(): + return nil, ErrCancelled + } +} + +// Implement the "ConnBeginTx" interface +func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + ex, err := c.begin() + if err != nil { + return nil, err + } + + select { + case <-time.After(ex.delay): + return c, nil + case <-ctx.Done(): + return nil, ErrCancelled + } +} + +// Implement the "ConnPrepareContext" interface +func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + ex, err := c.prepare(query) + if err != nil { + return nil, err + } + + select { + case <-time.After(ex.delay): + return &statement{c, query, ex.closeErr}, nil + case <-ctx.Done(): + return nil, ErrCancelled + } +} + +// Implement the "Pinger" interface +// for now we do not have a Ping expectation +// may be something for the future +func (c *sqlmock) Ping(ctx context.Context) error { + return nil +} + +// Implement the "StmtExecContext" interface +func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return stmt.conn.ExecContext(ctx, stmt.query, args) +} + +// Implement the "StmtQueryContext" interface +func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return stmt.conn.QueryContext(ctx, stmt.query, args) +} + +// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go new file mode 100644 index 0000000..9eadcb5 --- /dev/null +++ b/sqlmock_go18_test.go @@ -0,0 +1,426 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql" + "testing" + "time" +) + +func TestContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestPreparedStatementContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("DELETE FROM users"). + ExpectExec(). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.Prepare("DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + _, err = stmt.ExecContext(ctx) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = stmt.ExecContext(ctx) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextExecWithNamedArg(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextExec(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + res, err := db.ExecContext(ctx, "DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + affected, err := res.RowsAffected() + if affected != 1 { + t.Errorf("expected affected rows 1, but got %v", affected) + } + + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestPreparedStatementContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + ExpectQuery(). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + _, err = stmt.QueryContext(ctx, 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = stmt.QueryContext(ctx, 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextQuery(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id ="). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Millisecond * 3). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5)) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if !rows.Next() { + t.Error("expected one row, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextBeginCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.BeginTx(ctx, nil) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.BeginTx(ctx, nil) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextBegin(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if tx == nil { + t.Error("expected tx, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextPrepareCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.PrepareContext(ctx, "SELECT") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.PrepareContext(ctx, "SELECT") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextPrepare(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.PrepareContext(ctx, "SELECT") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if stmt == nil { + t.Error("expected stmt, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +}