Переглянути джерело

Update CBOR serialization to make sure it follows expected format

Version || CBOR(Envelop) || CBOR(Payload)
Marcelo Fornet 2 роки тому
батько
коміт
8c68bbd48c
7 змінених файлів з 146 додано та 140 видалено
  1. 5 5
      bus.go
  2. 4 4
      event.go
  3. 2 2
      event_test.go
  4. 1 1
      go.mod
  5. 8 29
      message.go
  6. 103 82
      raw_event.go
  7. 23 17
      raw_event_test.go

+ 5 - 5
bus.go

@@ -10,10 +10,10 @@ type Bus struct {
 	NATS *nats.EncodedConn
 }
 
-type SubscriberHandler func(event *Event)
+type SubscriberHandler func(event *Event[any])
 
-func Connect(url string) (*Bus, error) {
-	nats.RegisterEncoder("v1", &V1MessageEncoder{})
+func Connect[T any](url string) (*Bus, error) {
+	nats.RegisterEncoder("v1", &V1MessageCodec[T]{})
 	conn, err := nats.Connect(url)
 	if err != nil {
 		return nil, fmt.Errorf("failed to connect to the broker: %w", err)
@@ -61,7 +61,7 @@ func (bus *Bus) Publish(stream string, request interface{}) error {
 	return bus.PublishRawEvent(stream, rawEvent)
 }
 
-func (bus *Bus) PublishRawEvent(stream string, rawEvent *RawEvent) error {
+func (bus *Bus) PublishRawEvent(stream string, rawEvent *RawEvent[any]) error {
 	if err := bus.NATS.Publish(stream, rawEvent); err != nil {
 		return fmt.Errorf("failed to publish message: %w", err)
 	}
@@ -69,7 +69,7 @@ func (bus *Bus) PublishRawEvent(stream string, rawEvent *RawEvent) error {
 }
 
 func (bus *Bus) Subscribe(stream string, handler SubscriberHandler) (*nats.Subscription, error) {
-	return bus.NATS.Subscribe(stream, func(rawEvent RawEvent) {
+	return bus.NATS.Subscribe(stream, func(rawEvent RawEvent[any]) {
 		event, err := rawEvent.Check()
 		if err == nil { // ignore invalid messages
 			handler(event)

+ 4 - 4
event.go

@@ -8,19 +8,19 @@ import (
 	"github.com/segmentio/ksuid"
 )
 
-type Event struct {
+type Event[T any] struct {
 	Type         EventType   `json:"type,string"`
 	SequentialID uint64      `json:"sequential_id"`
 	Timestamp    time.Time   `json:"timestamp"`
 	UniqueID     ksuid.KSUID `json:"unique_id"`
-	Payload      interface{} `json:"payload"`
+	Payload      T           `json:"payload"`
 }
 
-func (event Event) Equal(other Event) bool {
+func (event Event[T]) Equal(other Event[T]) bool {
 	return reflect.DeepEqual(event, other)
 }
 
-func (event Event) JSON() string {
+func (event Event[T]) JSON() string {
 	output, err := json.Marshal(event)
 	if err != nil {
 		panic(err)

+ 2 - 2
event_test.go

@@ -6,7 +6,7 @@ import (
 )
 
 const (
-	expectedEventJSON = `{"type":"1","sequential_id":2,"timestamp":"2009-01-03T20:15:08.004+02:00","unique_id":"Z5ZWs1U9RoGQVugoz5SIrkmkj9T","payload":{}}`
+	expectedEventJSON = `{"type":"1","sequential_id":2,"timestamp":"2009-01-03T19:15:08.004+01:00","unique_id":"Z5ZWs1U9RoGQVugoz5SIrkmkj9T","payload":"Testing"}`
 )
 
 func TestEncodeEventJSON(t *testing.T) {
@@ -16,6 +16,6 @@ func TestEncodeEventJSON(t *testing.T) {
 	}
 	actual := string(output)
 	if actual != expectedEventJSON {
-		t.Errorf("expected %v, but got %v", expectedRawEventJSON, actual)
+		t.Errorf("expected %v, but got %v", expectedEventJSON, actual)
 	}
 }

+ 1 - 1
go.mod

@@ -1,6 +1,6 @@
 module github.com/aurora-is-near/borealis.go
 
-go 1.17
+go 1.18
 
 replace github.com/aurora-is-near/aurora-events/go v0.0.0 => ./events/go
 

+ 8 - 29
message.go

@@ -1,38 +1,17 @@
 package borealis
 
-import (
-	"fmt"
+type V1MessageCodec[T any] struct{}
 
-	"github.com/fxamacker/cbor/v2"
-)
-
-const (
-	V1MessageVersion = 1
-	V1MessageMinLen  = 23
-)
-
-type V1MessageEncoder struct{}
-
-func (encoder *V1MessageEncoder) Encode(stream string, message interface{}) ([]byte, error) {
-	output, err := cbor.Marshal(message)
+func (encoder *V1MessageCodec[T]) Encode(stream string, message any) ([]byte, error) {
+	event := message.(RawEvent[T])
+	output, err := event.EncodeCBOR()
 	if err != nil {
 		return nil, err
 	}
-	return append([]byte{V1MessageVersion}, output...), nil
+	return output, nil
 }
 
-func (encoder *V1MessageEncoder) Decode(stream string, input []byte, messagePtr interface{}) error {
-	if len(input) < V1MessageMinLen {
-		return fmt.Errorf("expected message length %d+, but got %d", V1MessageMinLen, len(input))
-	}
-	if input[0] != V1MessageVersion {
-		return fmt.Errorf("expected message version %d, but got %d", V1MessageVersion, input[0])
-	}
-	switch message := messagePtr.(type) {
-	case *[]byte:
-		*message = input[1:]
-		return nil
-	default:
-		return cbor.Unmarshal(input[1:], message)
-	}
+func (encoder *V1MessageCodec[T]) Decode(stream string, input []byte, messagePtr any) error {
+	event := messagePtr.(RawEvent[T])
+	return event.DecodeCBOR(input)
 }

+ 103 - 82
raw_event.go

@@ -13,50 +13,67 @@ import (
 	"github.com/segmentio/ksuid"
 )
 
-const btcEpoch = 1231006505 // Bitcoin genesis, 2009-01-03T18:15:05Z
-
-type RawEvent struct {
-	_            struct{} `cbor:",toarray" json:"-"`
-	Type         uint16
-	SequentialID uint64
-	TimestampS   uint32
-	TimestampMS  uint16
-	UniqueID     []byte
-	Payload      interface{}
+const (
+	btcEpoch              = 1231006505 // Bitcoin genesis, 2009-01-03T18:15:05Z
+	V1MessageVersion      = 1          // Current version for messages on the bus
+	DefaultMessageVersion = V1MessageVersion
+)
+
+type Envelop struct {
+	_            struct{} `json:"-"`
+	Type         uint16   `cbor:"event_type"`
+	SequentialID uint64   `cbor:"sequential_id"`
+	TimestampS   uint32   `cbor:"timestamp_s"`
+	TimestampMS  uint16   `cbor:"timestamp_ms"`
+	UniqueID     [16]byte `cbor:"unique_id"`
+}
+
+type RawEvent[T any] struct {
+	_       struct{} `json:"-"`
+	Version uint8
+	Envelop Envelop
+	Payload T
 }
 
-func NewRawEvent(type_ uint16, payload interface{}) (*RawEvent, error) {
+func NewRawEvent[T any](type_ uint16, payload T) (*RawEvent[T], error) {
 	now := time.Now()
 	ksuid, err := ksuid.NewRandomWithTime(now)
 	if err != nil {
 		return nil, err
 	}
-	return &RawEvent{
-		Type:         type_,
-		SequentialID: 0,
-		TimestampS:   uint32(now.Unix() - btcEpoch),
-		TimestampMS:  uint16(now.UnixMilli() % 1000),
-		UniqueID:     ksuid[4:],
-		Payload:      payload,
+
+	unique_id := [16]byte{}
+	copy(unique_id[:], ksuid[4:])
+
+	return &RawEvent[T]{
+		Version: DefaultMessageVersion,
+		Envelop: Envelop{
+			Type:         type_,
+			SequentialID: 0,
+			TimestampS:   uint32(now.Unix() - btcEpoch),
+			TimestampMS:  uint16(now.UnixMilli() % 1000),
+			UniqueID:     unique_id,
+		},
+		Payload: payload,
 	}, nil
 }
 
-func (rawEvent RawEvent) Check() (*Event, error) {
-	timestampS, ok := overflow.Add64(btcEpoch, int64(rawEvent.TimestampS))
-	if !ok || rawEvent.TimestampMS > 999 {
+func (rawEvent RawEvent[T]) Check() (*Event[T], error) {
+	timestampS, ok := overflow.Add64(btcEpoch, int64(rawEvent.Envelop.TimestampS))
+	if !ok || rawEvent.Envelop.TimestampMS > 999 {
 		return nil, fmt.Errorf("timestamp overflow")
 	}
-	timestampNS := int64(rawEvent.TimestampMS) * 1000000
+	timestampNS := int64(rawEvent.Envelop.TimestampMS) * 1000000
 	timestamp := time.Unix(timestampS, timestampNS)
 
-	ksuid, err := ksuid.FromParts(timestamp, rawEvent.UniqueID)
+	ksuid, err := ksuid.FromParts(timestamp, rawEvent.Envelop.UniqueID[:])
 	if err != nil {
 		return nil, err
 	}
 
-	event := Event{
-		Type:         EventType(rawEvent.Type),
-		SequentialID: rawEvent.SequentialID,
+	event := Event[T]{
+		Type:         EventType(rawEvent.Envelop.Type),
+		SequentialID: rawEvent.Envelop.SequentialID,
 		Timestamp:    timestamp,
 		UniqueID:     ksuid,
 		Payload:      rawEvent.Payload,
@@ -64,18 +81,18 @@ func (rawEvent RawEvent) Check() (*Event, error) {
 	return &event, nil
 }
 
-func (event RawEvent) Equal(other RawEvent) bool {
-	if event.Type != other.Type ||
-		event.SequentialID != other.SequentialID ||
-		event.TimestampS != other.TimestampS ||
-		event.TimestampMS != other.TimestampMS ||
-		!bytes.Equal(event.UniqueID, other.UniqueID) {
+func (event RawEvent[T]) Equal(other RawEvent[T]) bool {
+	if event.Envelop.Type != other.Envelop.Type ||
+		event.Envelop.SequentialID != other.Envelop.SequentialID ||
+		event.Envelop.TimestampS != other.Envelop.TimestampS ||
+		event.Envelop.TimestampMS != other.Envelop.TimestampMS ||
+		!bytes.Equal(event.Envelop.UniqueID[:], other.Envelop.UniqueID[:]) {
 		return false
 	}
 	return reflect.DeepEqual(event.Payload, other.Payload)
 }
 
-func (event RawEvent) JSON() string {
+func (event RawEvent[T]) JSON() string {
 	output, err := json.Marshal(event)
 	if err != nil {
 		panic(err)
@@ -83,48 +100,41 @@ func (event RawEvent) JSON() string {
 	return string(output)
 }
 
-func (event RawEvent) MarshalJSON() ([]byte, error) {
+func (event RawEvent[T]) MarshalJSON() ([]byte, error) {
 	return json.Marshal([]interface{}{
-		event.Type,
-		event.SequentialID,
-		event.TimestampS,
-		event.TimestampMS,
-		event.UniqueID, // automatically Base64-encoded
-		event.Payload,
-	})
-}
-
-func (event RawEvent) MarshalCBOR() ([]byte, error) {
-	return cbor.Marshal([]interface{}{
-		event.Type,
-		event.SequentialID,
-		event.TimestampS,
-		event.TimestampMS,
-		event.UniqueID,
+		event.Version,
+		event.Envelop.Type,
+		event.Envelop.SequentialID,
+		event.Envelop.TimestampS,
+		event.Envelop.TimestampMS,
+		event.Envelop.UniqueID[:], // automatically Base64-encoded
 		event.Payload,
 	})
 }
 
-func (event *RawEvent) UnmarshalJSON(input []byte) error {
+func (event *RawEvent[T]) UnmarshalJSON(input []byte) error {
 	var err error
 	array := []interface{}{}
 	if err = json.Unmarshal(input, &array); err != nil {
 		return err
 	}
-	if len(array) != 6 {
+	if len(array) != 7 {
 		return fmt.Errorf("event must be an array of length %d, but got %d", 6, len(array))
 	}
 
-	event.Type = uint16(array[0].(float64))
-	event.SequentialID = uint64(array[1].(float64))
-	event.TimestampS = uint32(array[2].(float64))
-	event.TimestampMS = uint16(array[3].(float64))
-	event.UniqueID, err = base64.StdEncoding.DecodeString(array[4].(string))
+	event.Version = uint8(array[0].(float64))
+	event.Envelop.Type = uint16(array[1].(float64))
+	event.Envelop.SequentialID = uint64(array[2].(float64))
+	event.Envelop.TimestampS = uint32(array[3].(float64))
+	event.Envelop.TimestampMS = uint16(array[4].(float64))
+
+	unique_id, err := base64.StdEncoding.DecodeString(array[5].(string))
 	if err != nil {
 		return err
 	}
+	copy(event.Envelop.UniqueID[:], unique_id)
 
-	payload, ok := array[5].(map[string]interface{})
+	payload, ok := array[6].(T)
 	if !ok {
 		return fmt.Errorf("event payload must be a map")
 	}
@@ -133,35 +143,46 @@ func (event *RawEvent) UnmarshalJSON(input []byte) error {
 	return nil
 }
 
-func (event *RawEvent) UnmarshalCBOR(input []byte) error {
-	var err error
-	array := []interface{}{}
-	if err = cbor.Unmarshal(input, &array); err != nil {
-		return err
+func (event RawEvent[T]) EncodeCBOR() ([]byte, error) {
+	encoded_envelop, err := cbor.Marshal(event.Envelop)
+	if err != nil {
+		return nil, err
 	}
-	if len(array) != 6 {
-		return fmt.Errorf("event must be an array of length %d, but got %d", 6, len(array))
+
+	encoded_payload, err := cbor.Marshal(event.Payload)
+	if err != nil {
+		return nil, err
 	}
 
-	event.Type = uint16(array[0].(uint64))
-	event.SequentialID = uint64(array[1].(uint64))
-	event.TimestampS = uint32(array[2].(uint64))
-	event.TimestampMS = uint16(array[3].(uint64))
-	event.UniqueID = array[4].([]byte)
+	partial := append([]byte{V1MessageVersion}, encoded_envelop...)
+	return append(partial, encoded_payload...), nil
+}
+
+func (event *RawEvent[T]) DecodeCBOR(input []byte) error {
+	var err error
 
-	payloadIn, ok := array[5].(map[interface{}]interface{})
-	if !ok {
-		return fmt.Errorf("event payload must be a map")
+	if len(input) < 1 {
+		return fmt.Errorf("Empty message. Missing Version number")
 	}
-	payloadOut := make(map[string]interface{})
-	for k, v := range payloadIn {
-		var key string
-		if key, ok = k.(string); !ok {
-			return fmt.Errorf("event payload must be a map with string keys, but got the key %v", k)
+
+	version_message := input[0]
+
+	switch version_message {
+	case V1MessageVersion:
+		decoder := cbor.NewDecoder(bytes.NewReader(input[1:]))
+
+		err = decoder.Decode(&event.Envelop)
+		if err != nil {
+			return err
 		}
-		payloadOut[key] = v
-	}
-	event.Payload = payloadOut
 
-	return nil
+		err = decoder.Decode(&event.Payload)
+		if err != nil {
+			return err
+		}
+
+		return nil
+	default:
+		return fmt.Errorf("Unsupported version message %d", version_message)
+	}
 }

+ 23 - 17
raw_event_test.go

@@ -6,29 +6,35 @@ import (
 	"testing"
 	"time"
 
-	"github.com/fxamacker/cbor/v2"
 	"github.com/segmentio/ksuid"
 )
 
-var expectedEvent = Event{
+var expectedEvent = Event[string]{
 	Type:         1,
 	SequentialID: 2,
-	Timestamp:    time.UnixMicro(1231006508004000), // 2009-01-03T18:15:08.004Z
+	Timestamp:    time.UnixMicro(1231006508004000),
 	UniqueID:     ksuid.KSUID{245, 237, 93, 44, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-	Payload:      map[string]interface{}{},
+	Payload:      "Testing",
 }
 
-var expectedRawEvent = RawEvent{
-	Type:         1,
-	SequentialID: 2,
-	TimestampS:   3,
-	TimestampMS:  4,
-	UniqueID:     []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF},
-	Payload:      map[string]interface{}{},
+var expectedRawEvent = RawEvent[string]{
+	Version: V1MessageVersion,
+	Envelop: Envelop{
+		Type:         1,
+		SequentialID: 2,
+		TimestampS:   3,
+		TimestampMS:  4,
+		UniqueID:     [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+	},
+	Payload: "Testing",
 }
 
-var expectedRawEventJSON = "[1,2,3,4,\"AAECAwQFBgcICQoLDA0ODw==\",{}]"
-var expectedRawEventCBOR = []byte{0x86, 1, 2, 3, 4, 0x50, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xa0}
+var expectedRawEventJSON = "[1,1,2,3,4,\"AAECAwQFBgcICQoLDA0ODw==\",\"Testing\"]"
+
+var expectedRawEventCBOR = []byte{1, 165, 106, 101, 118, 101, 110, 116, 95, 116, 121, 112, 101, 1, 109, 115, 101, 113,
+	117, 101, 110, 116, 105, 97, 108, 95, 105, 100, 2, 107, 116, 105, 109, 101, 115, 116, 97, 109, 112, 95, 115, 3,
+	108, 116, 105, 109, 101, 115, 116, 97, 109, 112, 95, 109, 115, 4, 105, 117, 110, 105, 113, 117, 101, 95, 105, 100,
+	80, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 103, 84, 101, 115, 116, 105, 110, 103}
 
 func TestNewRawEvent(t *testing.T) {
 	_, err := NewRawEvent(0x1234, make(map[string]interface{}))
@@ -59,7 +65,7 @@ func TestEncodeRawEventJSON(t *testing.T) {
 }
 
 func TestEncodeRawEventCBOR(t *testing.T) {
-	actual, err := cbor.Marshal(expectedRawEvent)
+	actual, err := expectedRawEvent.EncodeCBOR()
 	if err != nil {
 		t.Error(err)
 	}
@@ -69,7 +75,7 @@ func TestEncodeRawEventCBOR(t *testing.T) {
 }
 
 func TestDecodeRawEventJSON(t *testing.T) {
-	var actual RawEvent
+	var actual RawEvent[string]
 	if err := json.Unmarshal([]byte(expectedRawEventJSON), &actual); err != nil {
 		t.Error(err)
 	}
@@ -78,8 +84,8 @@ func TestDecodeRawEventJSON(t *testing.T) {
 	}
 }
 func TestDecodeRawEventCBOR(t *testing.T) {
-	var actual RawEvent
-	if err := cbor.Unmarshal([]byte(expectedRawEventCBOR), &actual); err != nil {
+	var actual RawEvent[string]
+	if err := actual.DecodeCBOR(expectedRawEventCBOR); err != nil {
 		t.Error(err)
 	}
 	if !expectedRawEvent.Equal(actual) {