package pq import ( "math" "reflect" "testing" "github.com/lib/pq/oid" ) func TestDataTypeName(t *testing.T) { tts := []struct { typ oid.Oid name string }{ {oid.T_int8, "INT8"}, {oid.T_int4, "INT4"}, {oid.T_int2, "INT2"}, {oid.T_varchar, "VARCHAR"}, {oid.T_text, "TEXT"}, {oid.T_bool, "BOOL"}, {oid.T_numeric, "NUMERIC"}, {oid.T_date, "DATE"}, {oid.T_time, "TIME"}, {oid.T_timetz, "TIMETZ"}, {oid.T_timestamp, "TIMESTAMP"}, {oid.T_timestamptz, "TIMESTAMPTZ"}, {oid.T_bytea, "BYTEA"}, } for i, tt := range tts { dt := fieldDesc{OID: tt.typ} if name := dt.Name(); name != tt.name { t.Errorf("(%d) got: %s want: %s", i, name, tt.name) } } } func TestDataType(t *testing.T) { tts := []struct { typ oid.Oid kind reflect.Kind }{ {oid.T_int8, reflect.Int64}, {oid.T_int4, reflect.Int32}, {oid.T_int2, reflect.Int16}, {oid.T_varchar, reflect.String}, {oid.T_text, reflect.String}, {oid.T_bool, reflect.Bool}, {oid.T_date, reflect.Struct}, {oid.T_time, reflect.Struct}, {oid.T_timetz, reflect.Struct}, {oid.T_timestamp, reflect.Struct}, {oid.T_timestamptz, reflect.Struct}, {oid.T_bytea, reflect.Slice}, } for i, tt := range tts { dt := fieldDesc{OID: tt.typ} if kind := dt.Type().Kind(); kind != tt.kind { t.Errorf("(%d) got: %s want: %s", i, kind, tt.kind) } } } func TestDataTypeLength(t *testing.T) { tts := []struct { typ oid.Oid len int mod int length int64 ok bool }{ {oid.T_int4, 0, -1, 0, false}, {oid.T_varchar, 65535, 9, 5, true}, {oid.T_text, 65535, -1, math.MaxInt64, true}, {oid.T_bytea, 65535, -1, math.MaxInt64, true}, } for i, tt := range tts { dt := fieldDesc{OID: tt.typ, Len: tt.len, Mod: tt.mod} if l, k := dt.Length(); k != tt.ok || l != tt.length { t.Errorf("(%d) got: %d, %t want: %d, %t", i, l, k, tt.length, tt.ok) } } } func TestDataTypePrecisionScale(t *testing.T) { tts := []struct { typ oid.Oid mod int precision, scale int64 ok bool }{ {oid.T_int4, -1, 0, 0, false}, {oid.T_numeric, 589830, 9, 2, true}, {oid.T_text, -1, 0, 0, false}, } for i, tt := range tts { dt := fieldDesc{OID: tt.typ, Mod: tt.mod} p, s, k := dt.PrecisionScale() if k != tt.ok { t.Errorf("(%d) got: %t want: %t", i, k, tt.ok) } if p != tt.precision { t.Errorf("(%d) wrong precision got: %d want: %d", i, p, tt.precision) } if s != tt.scale { t.Errorf("(%d) wrong scale got: %d want: %d", i, s, tt.scale) } } } func TestRowsColumnTypes(t *testing.T) { columnTypesTests := []struct { Name string TypeName string Length struct { Len int64 OK bool } DecimalSize struct { Precision int64 Scale int64 OK bool } ScanType reflect.Type }{ { Name: "a", TypeName: "INT4", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(int32(0)), }, { Name: "bar", TypeName: "TEXT", Length: struct { Len int64 OK bool }{ Len: math.MaxInt64, OK: true, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(""), }, } db := openTestConn(t) defer db.Close() rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") if err != nil { t.Fatal(err) } columns, err := rows.ColumnTypes() if err != nil { t.Fatal(err) } if len(columns) != 3 { t.Errorf("expected 3 columns found %d", len(columns)) } for i, tt := range columnTypesTests { c := columns[i] if c.Name() != tt.Name { t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) } if c.DatabaseTypeName() != tt.TypeName { t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) } l, ok := c.Length() if l != tt.Length.Len { t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) } if ok != tt.Length.OK { t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) } p, s, ok := c.DecimalSize() if p != tt.DecimalSize.Precision { t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) } if s != tt.DecimalSize.Scale { t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) } if ok != tt.DecimalSize.OK { t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) } if c.ScanType() != tt.ScanType { t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) } } }