Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to using golang.org/x/sys/unix over syscall #74

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
"regexp"
"strconv"
"strings"
"syscall"

"github.com/spf13/viper"
"golang.org/x/sys/unix"
"gopkg.in/Graylog2/go-gelf.v2/gelf"
)

Expand Down Expand Up @@ -249,7 +249,7 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
// Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation

sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGUSR1)
signal.Notify(sigc, unix.SIGUSR1)

for range sigc {
newWriter, err := createFileOutput(config)
Expand Down
10 changes: 5 additions & 5 deletions audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"os/user"
"path"
"strconv"
"syscall"
"testing"
"time"

"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"golang.org/x/sys/unix"
"gopkg.in/Graylog2/go-gelf.v2/gelf"
)

Expand Down Expand Up @@ -413,7 +413,7 @@ func Test_createOutput(t *testing.T) {
os.Rename(path.Join(os.TempDir(), "go-audit.test.log"), path.Join(os.TempDir(), "go-audit.test.log.rotated"))
_, err = os.Stat(path.Join(os.TempDir(), "go-audit.test.log"))
assert.True(t, os.IsNotExist(err))
syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
unix.Kill(unix.Getpid(), unix.SIGUSR1)
time.Sleep(100 * time.Millisecond)
_, err = os.Stat(path.Join(os.TempDir(), "go-audit.test.log"))
assert.Nil(t, err)
Expand Down Expand Up @@ -565,15 +565,15 @@ func Benchmark_MultiPacketMessage(b *testing.B) {
for i := 0; i < b.N; i++ {
for n := 0; n < len(data); n++ {
nlen := len(data[n])
msg := &syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
msg := &NetlinkMessage{
Header: NetlinkPacket{
Len: Endianness.Uint32(data[n][0:4]),
Type: Endianness.Uint16(data[n][4:6]),
Flags: Endianness.Uint16(data[n][6:8]),
Seq: Endianness.Uint32(data[n][8:12]),
Pid: Endianness.Uint32(data[n][12:16]),
},
Data: data[n][syscall.SizeofNlMsghdr:nlen],
Data: data[n][unix.SizeofNlMsghdr:nlen],
}
marshaller.Consume(msg)
}
Expand Down
45 changes: 26 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sync/atomic"
"syscall"
"time"
"fmt"

"golang.org/x/sys/unix"
)

// Endianness is an alias for what we assume is the current machine endianness
Expand All @@ -33,42 +34,48 @@ type AuditStatusPayload struct {
}

// NetlinkPacket is an alias to give the header a similar name here
type NetlinkPacket syscall.NlMsghdr
type NetlinkPacket unix.NlMsghdr

// NetlinkMessage is copied from syscall.NetlinkMessage as x/sys/unix does not have it
type NetlinkMessage struct {
Header NetlinkPacket
Data []byte
}

type NetlinkClient struct {
fd int
address syscall.Sockaddr
address unix.Sockaddr
seq uint32
buf []byte
}

// NewNetlinkClient creates a new NetLinkClient and optionally tries to modify the netlink recv buffer
func NewNetlinkClient(recvSize int) (*NetlinkClient, error) {
fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_AUDIT)
fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_AUDIT)
if err != nil {
return nil, fmt.Errorf("Could not create a socket: %s", err)
}

n := &NetlinkClient{
fd: fd,
address: &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: 0, Pid: 0},
address: &unix.SockaddrNetlink{Family: unix.AF_NETLINK, Groups: 0, Pid: 0},
buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH),
}

if err = syscall.Bind(fd, n.address); err != nil {
syscall.Close(fd)
if err = unix.Bind(fd, n.address); err != nil {
unix.Close(fd)
return nil, fmt.Errorf("Could not bind to netlink socket: %s", err)
}

// Set the buffer size if we were asked
if recvSize > 0 {
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, recvSize); err != nil {
el.Println("Failed to set receive buffer size")
}
}

// Print the current receive buffer size
if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil {
if v, err := unix.GetsockoptInt(n.fd, unix.SOL_SOCKET, unix.SO_RCVBUF); err == nil {
l.Println("Socket receive buffer size:", v)
}

Expand Down Expand Up @@ -102,16 +109,16 @@ func (n *NetlinkClient) Send(np *NetlinkPacket, a *AuditStatusPayload) error {
}
}

if err := syscall.Sendto(n.fd, buf.Bytes(), 0, n.address); err != nil {
if err := unix.Sendto(n.fd, buf.Bytes(), 0, n.address); err != nil {
return err
}

return nil
}

// Receive will receive a packet from a netlink socket
func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) {
nlen, _, err := syscall.Recvfrom(n.fd, n.buf, 0)
func (n *NetlinkClient) Receive() (*NetlinkMessage, error) {
nlen, _, err := unix.Recvfrom(n.fd, n.buf, 0)
if err != nil {
return nil, err
}
Expand All @@ -120,15 +127,15 @@ func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) {
return nil, errors.New("Got a 0 length packet")
}

msg := &syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
msg := &NetlinkMessage{
Header: NetlinkPacket{
Len: Endianness.Uint32(n.buf[0:4]),
Type: Endianness.Uint16(n.buf[4:6]),
Flags: Endianness.Uint16(n.buf[6:8]),
Seq: Endianness.Uint32(n.buf[8:12]),
Pid: Endianness.Uint32(n.buf[12:16]),
},
Data: n.buf[syscall.SizeofNlMsghdr:nlen],
Data: n.buf[unix.SizeofNlMsghdr:nlen],
}

return msg, nil
Expand All @@ -139,14 +146,14 @@ func (n *NetlinkClient) KeepConnection() {
payload := &AuditStatusPayload{
Mask: 4,
Enabled: 1,
Pid: uint32(syscall.Getpid()),
Pid: uint32(unix.Getpid()),
//TODO: Failure: http://lxr.free-electrons.com/source/include/uapi/linux/audit.h#L338
}

packet := &NetlinkPacket{
Type: uint16(1001),
Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK,
Pid: uint32(syscall.Getpid()),
Flags: unix.NLM_F_REQUEST | unix.NLM_F_ACK,
Pid: uint32(unix.Getpid()),
}

err := n.Send(packet, payload)
Expand Down
29 changes: 15 additions & 14 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package main
import (
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"os"
"syscall"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/sys/unix"
)

func TestNetlinkClient_KeepConnection(t *testing.T) {
n := makeNelinkClient(t)
defer syscall.Close(n.fd)
defer unix.Close(n.fd)

n.KeepConnection()
msg, err := n.Receive()
Expand All @@ -31,19 +32,19 @@ func TestNetlinkClient_KeepConnection(t *testing.T) {
// Make sure we get errors printed
lb, elb := hookLogger()
defer resetLogger()
syscall.Close(n.fd)
unix.Close(n.fd)
n.KeepConnection()
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
assert.Equal(t, "Error occurred while trying to keep the connection: bad file descriptor\n", elb.String(), "Figured we would have an error")
}

func TestNetlinkClient_SendReceive(t *testing.T) {
var err error
var msg *syscall.NetlinkMessage
var msg *NetlinkMessage

// Build our client
n := makeNelinkClient(t)
defer syscall.Close(n.fd)
defer unix.Close(n.fd)

// Make sure we can encode/decode properly
payload := &AuditStatusPayload{
Expand All @@ -54,7 +55,7 @@ func TestNetlinkClient_SendReceive(t *testing.T) {

packet := &NetlinkPacket{
Type: uint16(1001),
Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK,
Flags: unix.NLM_F_REQUEST | unix.NLM_F_ACK,
Pid: uint32(1006),
}

Expand All @@ -72,12 +73,12 @@ func TestNetlinkClient_SendReceive(t *testing.T) {
assert.Equal(t, uint32(2), msg.Header.Seq, "Header.Seq did not increment")

// Make sure 0 length packets result in an error
syscall.Sendto(n.fd, []byte{}, 0, n.address)
unix.Sendto(n.fd, []byte{}, 0, n.address)
_, err = n.Receive()
assert.Equal(t, "Got a 0 length packet", err.Error(), "Error was incorrect")

// Make sure we get errors from sendto back
syscall.Close(n.fd)
unix.Close(n.fd)
err = n.Send(packet, payload)
assert.Equal(t, "bad file descriptor", err.Error(), "Error was incorrect")

Expand Down Expand Up @@ -110,27 +111,27 @@ func TestNewNetlinkClient(t *testing.T) {
// Helper to make a client listening on a unix socket
func makeNelinkClient(t *testing.T) *NetlinkClient {
os.Remove("go-audit.test.sock")
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_RAW, 0)
fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_RAW, 0)
if err != nil {
t.Fatal("Could not create a socket:", err)
}

n := &NetlinkClient{
fd: fd,
address: &syscall.SockaddrUnix{Name: "go-audit.test.sock"},
address: &unix.SockaddrUnix{Name: "go-audit.test.sock"},
buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH),
}

if err = syscall.Bind(fd, n.address); err != nil {
syscall.Close(fd)
if err = unix.Bind(fd, n.address); err != nil {
unix.Close(fd)
t.Fatal("Could not bind to netlink socket:", err)
}

return n
}

// Helper to send and then receive a message with the netlink client
func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *syscall.NetlinkMessage {
func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *NetlinkMessage {
err := n.Send(packet, payload)
if err != nil {
t.Fatal("Failed to send:", err)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ require (
github.com/spf13/viper v0.0.0-20170217163817-7538d73b4eb9
github.com/stretchr/testify v1.2.2
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
google.golang.org/grpc v1.25.1 // indirect
gopkg.in/Graylog2/go-gelf.v2 v2.0.0-20180326133423-4dbb9d721348
Expand Down
3 changes: 1 addition & 2 deletions marshaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"os"
"regexp"
"syscall"
"time"
)

Expand Down Expand Up @@ -64,7 +63,7 @@ func NewAuditMarshaller(w *AuditWriter, eventMin uint16, eventMax uint16, trackM
}

// Ingests a netlink message and likely prepares it to be logged
func (a *AuditMarshaller) Consume(nlMsg *syscall.NetlinkMessage) {
func (a *AuditMarshaller) Consume(nlMsg *NetlinkMessage) {
aMsg := NewAuditMessage(nlMsg)

if aMsg.Seq == 0 {
Expand Down
Loading