Skip to content

Commit

Permalink
Altered to have a struct owning the connection to NATS
Browse files Browse the repository at this point in the history
  • Loading branch information
strottos committed Dec 3, 2024
1 parent d5e2bcf commit 8970749
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 63 deletions.
17 changes: 9 additions & 8 deletions cmd/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type agentPolicy string
type agentPluginConfig map[string]string

type agentPlugin struct {
AssessmentPlanIds []*string `json:"assessmentPlanIds"`
AssessmentPlanIds []*string `json:"assessment_plan_ids"`
Source *string `json:"source"`
Policies []*agentPolicy `json:"policy"`
Config agentPluginConfig `json:"config"`
Expand Down Expand Up @@ -132,7 +132,6 @@ func mergeConfig(cmd *cobra.Command, fileConfig *viper.Viper) (*agentConfig, err

config := &agentConfig{}
err := fileConfig.Unmarshal(config)
log.Println("Merged config", "config", config.Plugins["local-ssh-security"])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -193,6 +192,7 @@ func agentRunner(cmd *cobra.Command, args []string) error {
logger: logger,
config: *config,

natsBus: event.NewNatsBus(logger),
pluginLocations: map[string]string{},
}

Expand Down Expand Up @@ -244,6 +244,7 @@ type AgentRunner struct {
mu sync.Mutex

config agentConfig
natsBus *event.NatsBus

pluginLocations map[string]string

Expand All @@ -253,12 +254,12 @@ type AgentRunner struct {
func (ar *AgentRunner) Run() error {
ar.logger.Info("Starting agent", "daemon", ar.config.Daemon, "nats_uri", ar.config.Nats.Url)

err := event.Connect(ar.config.Nats.Url)
err := ar.natsBus.Connect(ar.config.Nats.Url)
if err != nil {
log.Fatal(err)
}

defer event.Close()
defer ar.natsBus.Close()

err = ar.DownloadPlugins()
if err != nil {
Expand Down Expand Up @@ -352,7 +353,7 @@ func (ar *AgentRunner) runInstance() error {
if err != nil {
for _, assessmentPlanId := range assessmentPlanIds {
result := runner.ErrorResult(assessmentPlanId, err)
if pubErr := event.Publish(result, "job.result"); pubErr != nil {
if pubErr := event.Publish(ar.natsBus, result, "job.result"); pubErr != nil {
logger.Error("Error publishing configure result", "error", pubErr)
}
}
Expand All @@ -363,7 +364,7 @@ func (ar *AgentRunner) runInstance() error {
if err != nil {
for _, assessmentPlanId := range assessmentPlanIds {
result := runner.ErrorResult(assessmentPlanId, err)
if pubErr := event.Publish(result, "job.result"); pubErr != nil {
if pubErr := event.Publish(ar.natsBus, result, "job.result"); pubErr != nil {
logger.Error("Error publishing evaslutae result", "error", pubErr)
}
}
Expand All @@ -378,7 +379,7 @@ func (ar *AgentRunner) runInstance() error {
if err != nil {
for _, assessmentPlanId := range assessmentPlanIds {
result := runner.ErrorResult(assessmentPlanId, err)
if pubErr := event.Publish(result, "job.result"); pubErr != nil {
if pubErr := event.Publish(ar.natsBus, result, "job.result"); pubErr != nil {
logger.Error("Error publishing evaslutae result", "error", pubErr)
}
}
Expand All @@ -402,7 +403,7 @@ func (ar *AgentRunner) runInstance() error {
}

// Publish findings to nats
if pubErr := event.Publish(result, "job.result"); pubErr != nil {
if pubErr := event.Publish(ar.natsBus, result, "job.result"); pubErr != nil {
logger.Error("Error publishing result", "error", pubErr)
}
}
Expand Down
71 changes: 22 additions & 49 deletions internal/event/bus.go
Original file line number Diff line number Diff line change
@@ -1,82 +1,55 @@
package event

import (
"bytes"
"encoding/json"
"errors"
"log"
"sync"

"github.com/nats-io/nats.go"
"github.com/hashicorp/go-hclog"
"sync"
)

const NATS_RECONNECT_BUF_SIZE = 5*1024*1024

type chanHolder struct {
Ch interface{}
}
type NatsBus struct {
logger hclog.Logger

var (
conn *nats.Conn
subCh []chanHolder
mu sync.Mutex
)
}

func Connect(server string) error {
mu.Lock()
defer mu.Unlock()
func NewNatsBus(logger hclog.Logger) *NatsBus {
return &NatsBus{
logger: logger,
}
}

if conn != nil {
func (nb *NatsBus) Connect(server string) error {
nb.mu.Lock()
defer nb.mu.Unlock()

if nb.conn != nil {
return errors.New("already connected")
}

c, err := nats.Connect(server, nats.ReconnectBufSize(NATS_RECONNECT_BUF_SIZE))
if err != nil {
return err
}
conn = c
subCh = make([]chanHolder, 0)
nb.conn = c

return nil
}

func Subscribe[T any](topic string) (chan T, error) {
ch := make(chan T)
_, err := conn.Subscribe(topic, func(m *nats.Msg) {
var msg T
decoder := json.NewDecoder(bytes.NewReader(m.Data))
decoder.DisallowUnknownFields()
err := decoder.Decode(&msg)
if err != nil {
log.Printf("Error unmarshalling message: %v", err)
return
}
ch <- msg
})
if err != nil {
return nil, err
}
mu.Lock()
subCh = append(subCh, chanHolder{Ch: ch})
mu.Unlock()

return ch, nil
}

func Publish[T any](msg T, topic string) error {
// Not a method due to Golang limitations on generics there, so we just pass the bus as a parameter.
func Publish[T any](nb *NatsBus, msg T, topic string) error {
data, err := json.Marshal(msg)
if err != nil {
return err
}
log.Printf("Publishing message to %s: %s", topic, string(data))
return conn.Publish(topic, data)
nb.logger.Trace("Publishing message", "topic", topic, "data", string(data))
return nb.conn.Publish(topic, data)
}

func Close() {
conn.Close()
for _, holder := range subCh {
if ch, ok := holder.Ch.(chan interface{}); ok {
close(ch)
}
}
func (nb *NatsBus) Close() {
nb.conn.Close()
}
21 changes: 15 additions & 6 deletions internal/event/bus_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package event

import (
"encoding/json"
"fmt"
"net"
"testing"

"github.com/hashicorp/go-hclog"
natsserver "github.com/nats-io/nats-server/v2/test"
"github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
)

Expand All @@ -31,21 +34,27 @@ func TestBus(t *testing.T) {
s := natsserver.RunServer(&options)
defer s.Shutdown()

err = Connect(fmt.Sprintf("nats://localhost:%d", port))
nb := NewNatsBus(hclog.Default())

err = nb.Connect(fmt.Sprintf("nats://localhost:%d", port))
assert.NoError(t, err)

topic := "test"
msg := Message{Text: "Hello World"}

ch, err := Subscribe[Message](topic)
assert.NoError(t, err)
assert.NotNil(t, ch)
ch := make(chan Message)

_, err = nb.conn.Subscribe(topic, func(m *nats.Msg) {
var msg Message
json.Unmarshal(m.Data, &msg)
ch <- msg
})

err = Publish(msg, topic)
err = Publish(nb, msg, topic)
assert.NoError(t, err)

received := <-ch
assert.Equal(t, msg.Text, received.Text)

Close()
nb.Close()
}

0 comments on commit 8970749

Please sign in to comment.