Skip to content

Commit b8a63d3

Browse files
authored
Merge pull request #231 from bonitoo-io/pr-152-again
Add Column Metadata
2 parents 7c0bc43 + 3b533ba commit b8a63d3

File tree

10 files changed

+422
-3
lines changed

10 files changed

+422
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ It only asserts that argument is of `time.Time` type.
222222

223223
## Change Log
224224

225+
- **2019-04-06** - added functionality to mock a sql MetaData request
225226
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
226227
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
227228
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.

column.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package sqlmock
2+
3+
import "reflect"
4+
5+
// Column is a mocked column Metadata for rows.ColumnTypes()
6+
type Column struct {
7+
name string
8+
dbType string
9+
nullable bool
10+
nullableOk bool
11+
length int64
12+
lengthOk bool
13+
precision int64
14+
scale int64
15+
psOk bool
16+
scanType reflect.Type
17+
}
18+
19+
func (c *Column) Name() string {
20+
return c.name
21+
}
22+
23+
func (c *Column) DbType() string {
24+
return c.dbType
25+
}
26+
27+
func (c *Column) IsNullable() (bool, bool) {
28+
return c.nullable, c.nullableOk
29+
}
30+
31+
func (c *Column) Length() (int64, bool) {
32+
return c.length, c.lengthOk
33+
}
34+
35+
func (c *Column) PrecisionScale() (int64, int64, bool) {
36+
return c.precision, c.scale, c.psOk
37+
}
38+
39+
func (c *Column) ScanType() reflect.Type {
40+
return c.scanType
41+
}
42+
43+
// NewColumn returns a Column with specified name
44+
func NewColumn(name string) *Column {
45+
return &Column{
46+
name: name,
47+
}
48+
}
49+
50+
// Nullable returns the column with nullable metadata set
51+
func (c *Column) Nullable(nullable bool) *Column {
52+
c.nullable = nullable
53+
c.nullableOk = true
54+
return c
55+
}
56+
57+
// OfType returns the column with type metadata set
58+
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
59+
c.dbType = dbType
60+
c.scanType = reflect.TypeOf(sampleValue)
61+
return c
62+
}
63+
64+
// WithLength returns the column with length metadata set.
65+
func (c *Column) WithLength(length int64) *Column {
66+
c.length = length
67+
c.lengthOk = true
68+
return c
69+
}
70+
71+
// WithPrecisionAndScale returns the column with precision and scale metadata set.
72+
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
73+
c.precision = precision
74+
c.scale = scale
75+
c.psOk = true
76+
return c
77+
}

column_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package sqlmock
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
"time"
7+
)
8+
9+
func TestColumn(t *testing.T) {
10+
now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
11+
column1 := NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
12+
column2 := NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
13+
column3 := NewColumn("when").OfType("TIMESTAMP", now)
14+
15+
if column1.ScanType().Kind() != reflect.String {
16+
t.Errorf("string scanType mismatch: %v", column1.ScanType())
17+
}
18+
if column2.ScanType().Kind() != reflect.Float64 {
19+
t.Errorf("float scanType mismatch: %v", column2.ScanType())
20+
}
21+
if column3.ScanType() != reflect.TypeOf(time.Time{}) {
22+
t.Errorf("time scanType mismatch: %v", column3.ScanType())
23+
}
24+
25+
nullable, ok := column1.IsNullable()
26+
if !nullable || !ok {
27+
t.Errorf("'test' column should be nullable")
28+
}
29+
nullable, ok = column2.IsNullable()
30+
if nullable || !ok {
31+
t.Errorf("'number' column should not be nullable")
32+
}
33+
nullable, ok = column3.IsNullable()
34+
if ok {
35+
t.Errorf("'when' column nullability should be unknown")
36+
}
37+
38+
length, ok := column1.Length()
39+
if length != 100 || !ok {
40+
t.Errorf("'test' column wrong length")
41+
}
42+
length, ok = column2.Length()
43+
if ok {
44+
t.Errorf("'number' column is not of variable length type")
45+
}
46+
length, ok = column3.Length()
47+
if ok {
48+
t.Errorf("'when' column is not of variable length type")
49+
}
50+
51+
_, _, ok = column1.PrecisionScale()
52+
if ok {
53+
t.Errorf("'test' column not applicable")
54+
}
55+
precision, scale, ok := column2.PrecisionScale()
56+
if precision != 10 || scale != 4 || !ok {
57+
t.Errorf("'number' column not applicable")
58+
}
59+
_, _, ok = column3.PrecisionScale()
60+
if ok {
61+
t.Errorf("'when' column not applicable")
62+
}
63+
}

expectations_go18.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,19 @@ import (
1212
// WillReturnRows specifies the set of resulting rows that will be returned
1313
// by the triggered query
1414
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
15+
defs := 0
1516
sets := make([]*Rows, len(rows))
1617
for i, r := range rows {
1718
sets[i] = r
19+
if r.def != nil {
20+
defs++
21+
}
22+
}
23+
if defs > 0 && defs == len(sets) {
24+
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
25+
} else {
26+
e.rows = &rowSets{sets: sets, ex: e}
1827
}
19-
e.rows = &rowSets{sets: sets, ex: e}
2028
return e
2129
}
2230

rows.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ func (rs *rowSets) invalidateRaw() {
120120
type Rows struct {
121121
converter driver.ValueConverter
122122
cols []string
123+
def []*Column
123124
rows [][]driver.Value
124125
pos int
125126
nextErr map[int]error

rows_go18.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
package sqlmock
44

5-
import "io"
5+
import (
6+
"database/sql/driver"
7+
"io"
8+
"reflect"
9+
)
610

711
// Implement the "RowsNextResultSet" interface
812
func (rs *rowSets) HasNextResultSet() bool {
@@ -18,3 +22,53 @@ func (rs *rowSets) NextResultSet() error {
1822
rs.pos++
1923
return nil
2024
}
25+
26+
// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition
27+
type rowSetsWithDefinition struct {
28+
*rowSets
29+
}
30+
31+
// Implement the "RowsColumnTypeDatabaseTypeName" interface
32+
func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string {
33+
return rs.getDefinition(index).DbType()
34+
}
35+
36+
// Implement the "RowsColumnTypeLength" interface
37+
func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) {
38+
return rs.getDefinition(index).Length()
39+
}
40+
41+
// Implement the "RowsColumnTypeNullable" interface
42+
func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) {
43+
return rs.getDefinition(index).IsNullable()
44+
}
45+
46+
// Implement the "RowsColumnTypePrecisionScale" interface
47+
func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
48+
return rs.getDefinition(index).PrecisionScale()
49+
}
50+
51+
// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType
52+
func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type {
53+
return rs.getDefinition(index).ScanType()
54+
}
55+
56+
// return column definition from current set metadata
57+
func (rs *rowSetsWithDefinition) getDefinition(index int) *Column {
58+
return rs.sets[rs.pos].def[index]
59+
}
60+
61+
// NewRowsWithColumnDefinition return rows with columns metadata
62+
func NewRowsWithColumnDefinition(columns ...*Column) *Rows {
63+
cols := make([]string, len(columns))
64+
for i, column := range columns {
65+
cols[i] = column.Name()
66+
}
67+
68+
return &Rows{
69+
cols: cols,
70+
def: columns,
71+
nextErr: make(map[int]error),
72+
converter: driver.DefaultParameterConverter,
73+
}
74+
}

0 commit comments

Comments
 (0)