diff --git a/bench_test.go b/bench_test.go index 77acc88..0e83096 100644 --- a/bench_test.go +++ b/bench_test.go @@ -11,6 +11,43 @@ import ( "github.com/Basekick-Labs/msgpack/v6" ) +func benchmarkEncode(b *testing.B, src interface{}) { + enc := msgpack.NewEncoder(ioutil.Discard) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + if err := enc.Encode(src); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeInt8(b *testing.B) { benchmarkEncode(b, int8(-42)) } +func BenchmarkEncodeInt16(b *testing.B) { benchmarkEncode(b, int16(-3200)) } +func BenchmarkEncodeInt32(b *testing.B) { benchmarkEncode(b, int32(-320000)) } + +func BenchmarkEncodeUint8(b *testing.B) { benchmarkEncode(b, uint8(200)) } +func BenchmarkEncodeUint16(b *testing.B) { benchmarkEncode(b, uint16(64000)) } +func BenchmarkEncodeUint32(b *testing.B) { benchmarkEncode(b, uint32(4000000000)) } + +func BenchmarkEncodeStringSlice(b *testing.B) { + benchmarkEncode(b, []string{"hello", "world", "foo", "bar"}) +} + +func BenchmarkEncodeMapStringBool(b *testing.B) { + benchmarkEncode(b, map[string]bool{"hello": true, "world": false}) +} + +// namedInt misses every case in the Encode() type switch, measuring the cost +// of traversing the full switch and falling through to the reflection path. +type namedInt int64 + +func BenchmarkEncodeFallthrough(b *testing.B) { + benchmarkEncode(b, namedInt(42)) +} + func BenchmarkDiscard(b *testing.B) { enc := msgpack.NewEncoder(ioutil.Discard) diff --git a/encode.go b/encode.go index 6420663..da4a108 100644 --- a/encode.go +++ b/encode.go @@ -284,10 +284,22 @@ func (e *Encoder) Encode(v interface{}) error { return e.EncodeBytes(v) case int: return e.EncodeInt(int64(v)) + case int8: + return e.encodeInt8Cond(v) + case int16: + return e.encodeInt16Cond(v) + case int32: + return e.encodeInt32Cond(v) case int64: return e.encodeInt64Cond(v) case uint: return e.EncodeUint(uint64(v)) + case uint8: + return e.encodeUint8Cond(v) + case uint16: + return e.encodeUint16Cond(v) + case uint32: + return e.encodeUint32Cond(v) case uint64: return e.encodeUint64Cond(v) case bool: @@ -302,6 +314,8 @@ func (e *Encoder) Encode(v interface{}) error { return e.EncodeTime(v) case map[string]string: return e.encodeMapStringString(v) + case map[string]bool: + return e.encodeMapStringBool(v) case map[string]interface{}: if e.flags&sortMapKeysFlag != 0 { return e.EncodeMapSorted(v) @@ -309,6 +323,8 @@ func (e *Encoder) Encode(v interface{}) error { return e.EncodeMap(v) case []interface{}: return e.encodeInterfaceSlice(v) + case []string: + return e.encodeStringSlice(v) } return e.EncodeValue(reflect.ValueOf(v)) } diff --git a/encode_map.go b/encode_map.go index b8e5a38..38cd868 100644 --- a/encode_map.go +++ b/encode_map.go @@ -60,30 +60,13 @@ func encodeMapStringBoolValue(e *Encoder, v reflect.Value) error { return e.EncodeNil() } - if err := e.EncodeMapLen(v.Len()); err != nil { - return err - } - var m map[string]bool if v.Type() == mapStringBoolType { m = v.Interface().(map[string]bool) } else { m = v.Convert(mapStringBoolType).Interface().(map[string]bool) } - if e.flags&sortMapKeysFlag != 0 { - return e.encodeSortedMapStringBool(m) - } - - for mk, mv := range m { - if err := e.EncodeString(mk); err != nil { - return err - } - if err := e.EncodeBool(mv); err != nil { - return err - } - } - - return nil + return e.encodeMapStringBool(m) } func encodeMapStringStringValue(e *Encoder, v reflect.Value) error { @@ -91,30 +74,13 @@ func encodeMapStringStringValue(e *Encoder, v reflect.Value) error { return e.EncodeNil() } - if err := e.EncodeMapLen(v.Len()); err != nil { - return err - } - var m map[string]string if v.Type() == mapStringStringType { m = v.Interface().(map[string]string) } else { m = v.Convert(mapStringStringType).Interface().(map[string]string) } - if e.flags&sortMapKeysFlag != 0 { - return e.encodeSortedMapStringString(m) - } - - for mk, mv := range m { - if err := e.EncodeString(mk); err != nil { - return err - } - if err := e.EncodeString(mv); err != nil { - return err - } - } - - return nil + return e.encodeMapStringString(m) } func encodeMapStringInterfaceValue(e *Encoder, v reflect.Value) error { @@ -133,6 +99,27 @@ func encodeMapStringInterfaceValue(e *Encoder, v reflect.Value) error { return e.EncodeMap(m) } +func (e *Encoder) encodeMapStringBool(m map[string]bool) error { + if m == nil { + return e.EncodeNil() + } + if err := e.EncodeMapLen(len(m)); err != nil { + return err + } + if e.flags&sortMapKeysFlag != 0 { + return e.encodeSortedMapStringBool(m) + } + for mk, mv := range m { + if err := e.EncodeString(mk); err != nil { + return err + } + if err := e.EncodeBool(mv); err != nil { + return err + } + } + return nil +} + func (e *Encoder) encodeMapStringString(m map[string]string) error { if m == nil { return e.EncodeNil() diff --git a/msgpack_test.go b/msgpack_test.go index 6de7319..3d14e8f 100644 --- a/msgpack_test.go +++ b/msgpack_test.go @@ -99,6 +99,52 @@ func (t *MsgpackTest) TestMap() { } } +func (t *MsgpackTest) TestEncodeTypeSwitchFastPaths() { + // Every type fast-pathed in Encode() must produce output byte-identical + // to the reflection path (EncodeValue). + for _, v := range []interface{}{ + int8(-42), int8(127), int8(-128), + int16(-3200), int16(32767), + int32(-320000), int32(2147483647), + uint8(0), uint8(200), uint8(255), + uint16(64000), + uint32(4000000000), + []string{"hello", "world"}, + []string{}, + []string(nil), + map[string]bool{"hello": true}, + map[string]bool{}, + map[string]bool(nil), + } { + var fast, slow bytes.Buffer + + enc := msgpack.NewEncoder(&fast) + t.Nil(enc.Encode(v)) + + enc = msgpack.NewEncoder(&slow) + t.Nil(enc.EncodeValue(reflect.ValueOf(v))) + + t.Equal(slow.Bytes(), fast.Bytes(), fmt.Sprintf("encoding %T(%v)", v, v)) + } +} + +func (t *MsgpackTest) TestEncodeMapStringBoolSorted() { + in := map[string]bool{"c": true, "a": false, "b": true} + + t.enc.SetSortMapKeys(true) + t.Nil(t.enc.Encode(in)) + t.Equal([]byte{ + 0x83, + 0xa1, 'a', 0xc2, + 0xa1, 'b', 0xc3, + 0xa1, 'c', 0xc3, + }, t.buf.Bytes()) + + var out map[string]bool + t.Nil(t.dec.Decode(&out)) + t.Equal(in, out) +} + func (t *MsgpackTest) TestStructNil() { var dst *nameStruct diff --git a/types_test.go b/types_test.go index 91faff6..3b4f1d3 100644 --- a/types_test.go +++ b/types_test.go @@ -436,6 +436,7 @@ type ( sliceByte []byte sliceString []string mapStringString map[string]string + mapStringBool map[string]bool mapStringInterface map[string]interface{} ) @@ -544,6 +545,9 @@ var ( {in: map[string]string(nil), out: new(map[string]string)}, {in: map[string]interface{}{"foo": nil}, out: new(map[string]interface{})}, {in: mapStringString{"foo": "bar"}, out: new(mapStringString)}, + {in: map[string]bool{"foo": true, "bar": false}, out: new(map[string]bool)}, + {in: map[string]bool(nil), out: new(map[string]bool)}, + {in: mapStringBool{"foo": true}, out: new(mapStringBool)}, {in: map[stringAlias]stringAlias{"foo": "bar"}, out: new(map[stringAlias]stringAlias)}, {in: mapStringInterface{"foo": "bar"}, out: new(mapStringInterface)}, {in: map[stringAlias]interfaceAlias{"foo": "bar"}, out: new(map[stringAlias]interfaceAlias)},