diff --git a/internal/dhcpsvc/config_test.go b/internal/dhcpsvc/config_test.go index f446bacd..dd59c35f 100644 --- a/internal/dhcpsvc/config_test.go +++ b/internal/dhcpsvc/config_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/golibs/timeutil" ) +// TODO(e.burkov): Split into several tests for each part of the configuration. func TestConfig_Validate(t *testing.T) { validIPv4Conf := &dhcpsvc.IPv4Config{ Enabled: true, @@ -121,7 +122,7 @@ func TestConfig_Validate(t *testing.T) { }, { conf: &dhcpsvc.Config{ Enabled: true, - Logger: discardLog, + Logger: testLogger, LocalDomainName: testLocalTLD, Interfaces: map[string]*dhcpsvc.InterfaceConfig{ "eth0": { @@ -136,7 +137,7 @@ func TestConfig_Validate(t *testing.T) { }, { conf: &dhcpsvc.Config{ Enabled: true, - Logger: discardLog, + Logger: testLogger, LocalDomainName: testLocalTLD, Interfaces: map[string]*dhcpsvc.InterfaceConfig{ "eth0": { @@ -151,7 +152,7 @@ func TestConfig_Validate(t *testing.T) { }, { conf: &dhcpsvc.Config{ Enabled: true, - Logger: discardLog, + Logger: testLogger, LocalDomainName: testLocalTLD, Interfaces: map[string]*dhcpsvc.InterfaceConfig{ "eth0": { @@ -167,7 +168,7 @@ func TestConfig_Validate(t *testing.T) { }, { conf: &dhcpsvc.Config{ Enabled: true, - Logger: discardLog, + Logger: testLogger, LocalDomainName: testLocalTLD, Interfaces: map[string]*dhcpsvc.InterfaceConfig{ "eth0": { diff --git a/internal/dhcpsvc/dhcpsvc_test.go b/internal/dhcpsvc/dhcpsvc_test.go index b768d8a9..198ca165 100644 --- a/internal/dhcpsvc/dhcpsvc_test.go +++ b/internal/dhcpsvc/dhcpsvc_test.go @@ -1,22 +1,44 @@ package dhcpsvc_test import ( + "cmp" + "io/fs" "net/netip" + "os" + "path" + "path/filepath" + "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" ) // testLocalTLD is a common local TLD for tests. const testLocalTLD = "local" -// testTimeout is a common timeout for tests and contexts. -const testTimeout time.Duration = 10 * time.Second +// testIfaceName is the name of the test network interface. +const testIfaceName = "iface0" -// discardLog is a logger to discard test output. -var discardLog = slogutil.NewDiscardLogger() +// testTimeout is a common timeout for tests and contexts. +const testTimeout = 10 * time.Second + +// testLeaseTTL is the lease duration used in tests. +const testLeaseTTL = 24 * time.Hour + +// testXid is a common transaction ID for DHCPv4 tests. +const testXid = 1 + +// testLogger is a common logger for tests. +var testLogger = slogutil.NewDiscardLogger() + +// testdata is a filesystem containing data for tests. +var testdata = os.DirFS("testdata") // testInterfaceConf is a common set of interface configurations for tests. var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{ @@ -57,3 +79,64 @@ var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{ }, }, } + +// disabledIPv6Config is a configuration of IPv6 part of the interfaces +// configuration that is disabled. +var disabledIPv6Config = &dhcpsvc.IPv6Config{Enabled: false} + +// fullLayersStack is the complete stack of layers expected to appear in the +// DHCP response packets. +var fullLayersStack = []gopacket.LayerType{ + layers.LayerTypeEthernet, + layers.LayerTypeIPv4, + layers.LayerTypeUDP, + layers.LayerTypeDHCPv4, +} + +// newTempDB copies the leases database file located in the testdata FS, under +// tb.Name()/leases.json, to a temporary directory and returns the path to the +// copied file. +func newTempDB(tb testing.TB) (dst string) { + tb.Helper() + + const filename = "leases.json" + + data, err := fs.ReadFile(testdata, path.Join(tb.Name(), filename)) + require.NoError(tb, err) + + dst = filepath.Join(tb.TempDir(), filename) + + err = os.WriteFile(dst, data, dhcpsvc.DatabasePerm) + require.NoError(tb, err) + + return dst +} + +// newTestDHCPServer creates a new DHCPServer for testing. It uses the default +// values of config in case it's nil or some of its fields aren't set. +func newTestDHCPServer(tb testing.TB, conf *dhcpsvc.Config) (srv *dhcpsvc.DHCPServer) { + tb.Helper() + + conf = cmp.Or(conf, &dhcpsvc.Config{ + Enabled: true, + }) + + conf.NetworkDeviceManager = cmp.Or[dhcpsvc.NetworkDeviceManager]( + conf.NetworkDeviceManager, + dhcpsvc.EmptyNetworkDeviceManager{}, + ) + conf.Logger = cmp.Or(conf.Logger, testLogger) + conf.LocalDomainName = cmp.Or(conf.LocalDomainName, testLocalTLD) + if conf.DBFilePath == "" { + conf.DBFilePath = filepath.Join(tb.TempDir(), "leases.json") + } + conf.ICMPTimeout = cmp.Or(conf.ICMPTimeout, testTimeout) + if conf.Interfaces == nil { + conf.Interfaces = testInterfaceConf + } + + srv, err := dhcpsvc.New(testutil.ContextWithTimeout(tb, testTimeout), conf) + require.NoError(tb, err) + + return srv +} diff --git a/internal/dhcpsvc/handler4.go b/internal/dhcpsvc/handler4.go index e6feb3b3..90ceac41 100644 --- a/internal/dhcpsvc/handler4.go +++ b/internal/dhcpsvc/handler4.go @@ -159,15 +159,15 @@ func (iface *dhcpInterfaceV4) handleDiscover( l.DebugContext(ctx, "different requested ip", "requested", reqIP, "lease", lease.IP) } + lease.updateExpiry(iface.clock, iface.common.leaseTTL) iface.respondOffer(ctx, req, fd, lease) return } - // TODO(e.burkov): Allocate a new lease. lease, err := iface.allocateLease(ctx, mac) if err != nil { - l.ErrorContext(ctx, "allocating a lease", "error", err) + l.ErrorContext(ctx, "allocating a lease", slogutil.KeyError, err) return } @@ -222,7 +222,15 @@ func (iface *dhcpInterfaceV4) handleSelecting( } // Commit the lease and send ACK. - iface.commitLease(ctx, lease, hostname4(req)) + lease.Hostname = hostname4(req) + err := iface.commitLease(ctx, lease) + if err != nil { + l.ErrorContext(ctx, "selecting request failed", slogutil.KeyError, err) + iface.respondNAK(ctx, req, fd) + + return + } + iface.respondACK(ctx, req, fd, lease) } @@ -276,7 +284,15 @@ func (iface *dhcpInterfaceV4) handleInitReboot( } // Commit the lease and send ACK. - iface.commitLease(ctx, lease, hostname4(req)) + lease.Hostname = hostname4(req) + err := iface.commitLease(ctx, lease) + if err != nil { + l.ErrorContext(ctx, "init-reboot request failed", slogutil.KeyError, err) + iface.respondNAK(ctx, req, fd) + + return + } + iface.respondACK(ctx, req, fd, lease) } @@ -316,7 +332,15 @@ func (iface *dhcpInterfaceV4) handleRenew( } // Commit the lease and send ACK. - iface.commitLease(ctx, lease, hostname4(req)) + lease.Hostname = hostname4(req) + err := iface.commitLease(ctx, lease) + if err != nil { + l.ErrorContext(ctx, "renew request failed", slogutil.KeyError, err) + iface.respondNAK(ctx, req, fd) + + return + } + iface.respondACK(ctx, req, fd, lease) } diff --git a/internal/dhcpsvc/handler4_test.go b/internal/dhcpsvc/handler4_test.go new file mode 100644 index 00000000..c92c6868 --- /dev/null +++ b/internal/dhcpsvc/handler4_test.go @@ -0,0 +1,292 @@ +package dhcpsvc_test + +import ( + "net" + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/testutil/faketime" + "github.com/AdguardTeam/golibs/testutil/servicetest" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testCurrentTime is the fixed time returned by [testClock] to ensure +// reproducible tests. +var testCurrentTime = time.Date(2025, 1, 1, 1, 1, 1, 0, time.UTC) + +// testClock is the test [timeutil.Clock] that always returns [testCurrentTime]. +var testClock = &faketime.Clock{ + OnNow: func() (now time.Time) { + return testCurrentTime + }, +} + +func TestDHCPServer_ServeEther4_discover(t *testing.T) { + t.Parallel() + + // NOTE: Keep in sync with testdata. + const ( + // leaseHostnameStatic is the hostname for the static lease. + leaseHostnameStatic = "static4" + + // leaseHostnameDynamic is the hostname for the dynamic lease. + leaseHostnameDynamic = "dynamic4" + + // leaseHostnameExpired is the hostname for the expired lease. + leaseHostnameExpired = "expired4" + ) + + // NOTE: Keep in sync with testdata. + var ( + // hwAddrUnknown is the MAC address for an unknown client. + hwAddrUnknown = net.HardwareAddr{0x0, 0x1, 0x2, 0x3, 0x4, 0x5} + + // hwAddrStatic is the MAC address for a known static lease. + hwAddrStatic = net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6} + + // hwAddrDynamic is the MAC address for a known dynamic lease. + hwAddrDynamic = net.HardwareAddr{0x2, 0x3, 0x4, 0x5, 0x6, 0x7} + + // hwAddrExpired is the MAC address for a known expired lease. + hwAddrExpired = net.HardwareAddr{0x3, 0x4, 0x5, 0x6, 0x7, 0x8} + ) + + // NOTE: Keep in sync with testdata. + dynamicLeaseExpiry := time.Date(2025, 1, 1, 10, 1, 1, 0, time.UTC) + dynamicLeaseTTL := dynamicLeaseExpiry.Sub(testCurrentTime) + + ipv4Conf := &dhcpsvc.IPv4Config{ + Clock: testClock, + SubnetMask: netip.MustParseAddr("255.255.255.0"), + GatewayIP: netip.MustParseAddr("192.168.0.1"), + RangeStart: netip.MustParseAddr("192.168.0.100"), + RangeEnd: netip.MustParseAddr("192.168.0.200"), + LeaseDuration: testLeaseTTL, + Enabled: true, + } + ifacesConfig := map[string]*dhcpsvc.InterfaceConfig{ + testIfaceName: {IPv4: ipv4Conf, IPv6: disabledIPv6Config}, + } + + testCases := []struct { + name string + in gopacket.Packet + wantOpts layers.DHCPOptions + }{{ + name: "new", + in: newDHCPDISCOVER(t, hwAddrUnknown), + wantOpts: layers.DHCPOptions{ + newOptMessageType(t, layers.DHCPMsgTypeOffer), + newOptServerID(t, ipv4Conf.GatewayIP), + newOptLeaseTime(t, testLeaseTTL), + }, + }, { + name: "existing_static", + in: newDHCPDISCOVER(t, hwAddrStatic), + wantOpts: layers.DHCPOptions{ + newOptMessageType(t, layers.DHCPMsgTypeOffer), + newOptServerID(t, ipv4Conf.GatewayIP), + newOptLeaseTime(t, testLeaseTTL), + newOptHostname(t, leaseHostnameStatic), + }, + }, { + name: "existing_dynamic", + in: newDHCPDISCOVER(t, hwAddrDynamic), + wantOpts: layers.DHCPOptions{ + newOptMessageType(t, layers.DHCPMsgTypeOffer), + newOptServerID(t, ipv4Conf.GatewayIP), + newOptLeaseTime(t, dynamicLeaseTTL), + newOptHostname(t, leaseHostnameDynamic), + }, + }, { + name: "existing_dynamic_expired", + in: newDHCPDISCOVER(t, hwAddrExpired), + wantOpts: layers.DHCPOptions{ + newOptMessageType(t, layers.DHCPMsgTypeOffer), + newOptServerID(t, ipv4Conf.GatewayIP), + newOptLeaseTime(t, testLeaseTTL), + newOptHostname(t, leaseHostnameExpired), + }, + }} + + for _, tc := range testCases { + req := testutil.RequireTypeAssert[*layers.DHCPv4](t, tc.in.Layer(layers.LayerTypeDHCPv4)) + + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName) + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + Interfaces: ifacesConfig, + NetworkDeviceManager: ndMgr, + DBFilePath: newTempDB(t), + Enabled: true, + }) + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + servicetest.RequireRun(t, srv, testTimeout) + + testutil.RequireSend(t, inCh, tc.in, testTimeout) + + respData, ok := testutil.RequireReceive(t, outCh, testTimeout) + require.True(t, ok) + + assertValidOffer(t, req, respData, tc.wantOpts) + }) + } +} + +func TestDHCPServer_ServeEther4_discoverExpired(t *testing.T) { + t.Parallel() + + // hwAddrUnknown is the MAC address for an unknown client, not related to + // any existing lease. + // + // NOTE: Keep in sync with testdata. + hwAddrUnknown := net.HardwareAddr{0x0, 0x1, 0x2, 0x3, 0x4, 0x5} + + pkt := newDHCPDISCOVER(t, hwAddrUnknown) + req := testutil.RequireTypeAssert[*layers.DHCPv4](t, pkt.Layer(layers.LayerTypeDHCPv4)) + + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName) + + ipv4Conf := &dhcpsvc.IPv4Config{ + Clock: testClock, + SubnetMask: netip.MustParseAddr("255.255.255.0"), + GatewayIP: netip.MustParseAddr("192.168.0.1"), + RangeStart: netip.MustParseAddr("192.168.0.100"), + RangeEnd: netip.MustParseAddr("192.168.0.100"), + LeaseDuration: testLeaseTTL, + Enabled: true, + } + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + Interfaces: map[string]*dhcpsvc.InterfaceConfig{ + testIfaceName: {IPv4: ipv4Conf, IPv6: disabledIPv6Config}, + }, + NetworkDeviceManager: ndMgr, + DBFilePath: newTempDB(t), + Enabled: true, + }) + servicetest.RequireRun(t, srv, testTimeout) + + testutil.RequireSend(t, inCh, pkt, testTimeout) + + respData, ok := testutil.RequireReceive(t, outCh, testTimeout) + require.True(t, ok) + + assertValidOffer(t, req, respData, layers.DHCPOptions{ + newOptMessageType(t, layers.DHCPMsgTypeOffer), + newOptServerID(t, ipv4Conf.GatewayIP), + newOptLeaseTime(t, testLeaseTTL), + }) +} + +// TODO(e.burkov): Add tests for DHCPREQUEST, DHCPRELEASE, DHCPDECLINE. + +// TODO(e.burkov): Add tests for wrong packets. + +// newDHCPDISCOVER creates a new DHCPDISCOVER packet for testing. +// +// TODO(e.burkov): Add parameters. +func newDHCPDISCOVER(tb testing.TB, clientHWAddr net.HardwareAddr) (pkt gopacket.Packet) { + tb.Helper() + + eth := &layers.Ethernet{ + SrcMAC: clientHWAddr, + DstMAC: net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + EthernetType: layers.EthernetTypeIPv4, + } + ip := &layers.IPv4{ + Version: 4, + TTL: dhcpsvc.IPv4DefaultTTL, + SrcIP: net.IPv4zero.To4(), + DstIP: net.IPv4bcast.To4(), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: dhcpsvc.ClientPortV4, + DstPort: dhcpsvc.ServerPortV4, + } + _ = udp.SetNetworkLayerForChecksum(ip) + + dhcp := &layers.DHCPv4{ + Operation: layers.DHCPOpRequest, + HardwareType: layers.LinkTypeEthernet, + HardwareLen: dhcpsvc.EUI48AddrLen, + Xid: testXid, + ClientHWAddr: clientHWAddr, + Options: layers.DHCPOptions{ + layers.NewDHCPOption( + layers.DHCPOptMessageType, + []byte{byte(layers.DHCPMsgTypeDiscover)}, + ), + }, + } + + return newTestPacket(tb, layers.LinkTypeEthernet, eth, ip, udp, dhcp) +} + +// newTestPacket creates a valid packet from ls using first as first layer +// decoder. +func newTestPacket( + tb testing.TB, + first gopacket.Decoder, + ls ...gopacket.SerializableLayer, +) (pkg gopacket.Packet) { + tb.Helper() + + buf := gopacket.NewSerializeBuffer() + + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + err := gopacket.SerializeLayers(buf, opts, ls...) + require.NoError(tb, err) + + return gopacket.NewPacket(buf.Bytes(), first, gopacket.Default) +} + +// requireEthernet requires data to contain an Ethernet layer and all layers +// from ls. First of ls must be of type [layers.LayerTypeEthernet]. +func requireEthernet( + tb testing.TB, + data []byte, + ls ...gopacket.DecodingLayer, +) (types []gopacket.LayerType) { + tb.Helper() + + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ls...) + + err := parser.DecodeLayers(data, &types) + require.NoError(tb, err) + + return types +} + +// assertValidOffer asserts that respData contains a complete DHCPOFFER response +// with the expected options, wrapped with all layers down to Ethernet. +func assertValidOffer( + tb testing.TB, + discover *layers.DHCPv4, + respData []byte, + wantOpts layers.DHCPOptions, +) { + tb.Helper() + + resp := &layers.DHCPv4{} + types := requireEthernet(tb, respData, &layers.Ethernet{}, &layers.IPv4{}, &layers.UDP{}, resp) + require.Equal(tb, fullLayersStack, types) + + assert.Equal(tb, layers.DHCPOpReply, resp.Operation, "operation") + assert.Equal(tb, discover.HardwareType, resp.HardwareType, "hardware type") + assert.Equal(tb, discover.HardwareLen, resp.HardwareLen, "hardware length") + assert.Equal(tb, discover.Xid, resp.Xid, "xid") + assert.Equal(tb, discover.ClientHWAddr, resp.ClientHWAddr, "client hardware address") + assert.Equal(tb, wantOpts, resp.Options, "options") +} diff --git a/internal/dhcpsvc/lease.go b/internal/dhcpsvc/lease.go index e27b064c..f0e97758 100644 --- a/internal/dhcpsvc/lease.go +++ b/internal/dhcpsvc/lease.go @@ -6,6 +6,8 @@ import ( "net/netip" "slices" "time" + + "github.com/AdguardTeam/golibs/timeutil" ) // Lease is a DHCP lease. @@ -47,13 +49,28 @@ func (l *Lease) Clone() (clone *Lease) { } } -// eui48AddrLen is the length of a valid EUI-48 hardware address. -const eui48AddrLen = 6 +// EUI48AddrLen is the length of a valid EUI-48 hardware address. +const EUI48AddrLen = 6 // blockedHardwareAddr is the hardware address used to mark a lease as blocked. -var blockedHardwareAddr = make(net.HardwareAddr, eui48AddrLen) +var blockedHardwareAddr = make(net.HardwareAddr, EUI48AddrLen) // IsBlocked returns true if the lease is blocked. func (l *Lease) IsBlocked() (blocked bool) { return bytes.Equal(l.HWAddr, blockedHardwareAddr) } + +// updateExpiry updates the lease expiry time if the current time is past the +// expiry. For static leases, this operation is a no-op. +func (l *Lease) updateExpiry(clock timeutil.Clock, ttl time.Duration) { + if l.IsStatic { + return + } + + now := clock.Now() + if now.Before(l.Expiry) { + return + } + + l.Expiry = now.Add(ttl) +} diff --git a/internal/dhcpsvc/networkdevice.go b/internal/dhcpsvc/networkdevice.go index f5a1b24d..a9008792 100644 --- a/internal/dhcpsvc/networkdevice.go +++ b/internal/dhcpsvc/networkdevice.go @@ -34,7 +34,27 @@ type NetworkDeviceManager interface { Open(ctx context.Context, conf *NetworkDeviceConfig) (dev NetworkDevice, err error) } -// NetworkDevice provides reading and writing packets to a network interface. +// EmptyNetworkDeviceManager is an empty implementation of +// [NetworkDeviceManager]. +type EmptyNetworkDeviceManager struct{} + +// type check +var _ NetworkDeviceManager = EmptyNetworkDeviceManager{} + +// Open implements the [NetworkDeviceManager] interface for +// [EmptyNetworkDeviceManager]. It always returns [EmptyNetworkDevice]. +func (EmptyNetworkDeviceManager) Open( + _ context.Context, + _ *NetworkDeviceConfig, +) (nd NetworkDevice, err error) { + return nil, nil +} + +// NetworkDevice provides an ability of reading and writing packets to a network +// interface. It used to generalize implementations for different platforms and +// to simplify testing. +// +// It's based on [pcap.Handle]. type NetworkDevice interface { gopacket.PacketDataSource @@ -45,6 +65,31 @@ type NetworkDevice interface { WritePacketData(data []byte) (err error) } +// EmptyNetworkDevice is an empty implementation of NetworkDevice. +type EmptyNetworkDevice struct{} + +// type check +var _ NetworkDevice = EmptyNetworkDevice{} + +// ReadPacketData implements the [gopacket.PacketDataSource] interface for +// [EmptyNetworkDevice]. It always returns no data, empty capture info and a +// nil error. +func (EmptyNetworkDevice) ReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) { + return nil, gopacket.CaptureInfo{}, nil +} + +// LinkType implements the [NetworkDevice] interface for [EmptyNetworkDevice]. +// It always returns [layers.LinkTypeNull]. +func (EmptyNetworkDevice) LinkType() (lt layers.LinkType) { + return layers.LinkTypeNull +} + +// WritePacketData implements the [NetworkDevice] interface for +// [EmptyNetworkDevice]. It always returns nil. +func (EmptyNetworkDevice) WritePacketData(_ []byte) (err error) { + return nil +} + // frameData stores the Ethernet and IPv4 layers of the incoming packet, and // the network device that the packet was received from. type frameData struct { diff --git a/internal/dhcpsvc/networkdevice_test.go b/internal/dhcpsvc/networkdevice_test.go new file mode 100644 index 00000000..a5e427d1 --- /dev/null +++ b/internal/dhcpsvc/networkdevice_test.go @@ -0,0 +1,118 @@ +package dhcpsvc_test + +import ( + "context" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" +) + +// testNetworkDeviceManager is a mock implementation of the +// [dhcpsvc.NetworkDeviceManager] interface. +// +// TODO(e.burkov): Move to aghtest. +type testNetworkDeviceManager struct { + onOpen func( + ctx context.Context, + conf *dhcpsvc.NetworkDeviceConfig, + ) (nd dhcpsvc.NetworkDevice, err error) +} + +// type check +var _ dhcpsvc.NetworkDeviceManager = (*testNetworkDeviceManager)(nil) + +// Open implements the [dhcpsvc.NetworkDeviceManager] interface for +// *testNetworkDeviceManager. +func (ndm *testNetworkDeviceManager) Open( + ctx context.Context, + conf *dhcpsvc.NetworkDeviceConfig, +) (dev dhcpsvc.NetworkDevice, err error) { + return ndm.onOpen(ctx, conf) +} + +// testNetworkDevice is a mock implementation of the [dhcpsvc.NetworkDevice] +// interface. +// +// TODO(e.burkov): Move to aghtest. +type testNetworkDevice struct { + onReadPacketData func() (data []byte, ci gopacket.CaptureInfo, err error) + onLinkType func() (lt layers.LinkType) + onWritePacketData func(data []byte) (err error) +} + +// type check +var _ dhcpsvc.NetworkDevice = (*testNetworkDevice)(nil) + +// ReadPacketData implements the [dhcpsvc.NetworkDevice] interface for +// *testNetworkDevice. +func (nd *testNetworkDevice) ReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) { + return nd.onReadPacketData() +} + +// WritePacketData implements the [dhcpsvc.NetworkDevice] interface for +// *testNetworkDevice. +func (nd *testNetworkDevice) WritePacketData(data []byte) (err error) { + return nd.onWritePacketData(data) +} + +// LinkType implements the [dhcpsvc.NetworkDevice] interface for +// *testNetworkDevice. +func (nd *testNetworkDevice) LinkType() (lt layers.LinkType) { + return nd.onLinkType() +} + +// newTestNetworkDeviceManager creates a network device manager for testing. It +// requires that device opened have a deviceName. The device itself has a link +// type [layers.LinkTypeEthernet]. Incoming packets are received from inCh and +// outgoing packets are sent to outCh. +func newTestNetworkDeviceManager( + tb testing.TB, + deviceName string, +) (ndMgr dhcpsvc.NetworkDeviceManager, inCh chan gopacket.Packet, outCh chan []byte) { + tb.Helper() + + inCh = make(chan gopacket.Packet) + outCh = make(chan []byte) + + pt := testutil.PanicT{} + + dev := &testNetworkDevice{ + onReadPacketData: func() (data []byte, ci gopacket.CaptureInfo, err error) { + pkt, ok := testutil.RequireReceive(pt, inCh, testTimeout) + require.True(pt, ok) + + data = pkt.Data() + ci = gopacket.CaptureInfo{ + Length: len(data), + CaptureLength: len(data), + } + + return data, ci, nil + }, + onLinkType: func() (lt layers.LinkType) { + return layers.LinkTypeEthernet + }, + onWritePacketData: func(data []byte) (err error) { + testutil.RequireSend(pt, outCh, data, testTimeout) + + return nil + }, + } + + ndMgr = &testNetworkDeviceManager{ + onOpen: func( + _ context.Context, + conf *dhcpsvc.NetworkDeviceConfig, + ) (nd dhcpsvc.NetworkDevice, err error) { + require.Equal(pt, deviceName, conf.Name) + + return dev, nil + }, + } + + return ndMgr, inCh, outCh +} diff --git a/internal/dhcpsvc/options4.go b/internal/dhcpsvc/options4.go index 6f3f07ad..e0f7fcc6 100644 --- a/internal/dhcpsvc/options4.go +++ b/internal/dhcpsvc/options4.go @@ -288,9 +288,17 @@ func (iface *dhcpInterfaceV4) updateOptions(req, resp *layers.DHCPv4) { } } -// appendLeaseTime appends the lease time option to the response. -func appendLeaseTime(resp *layers.DHCPv4, leaseTime time.Duration) { - leaseTimeData := binary.BigEndian.AppendUint32(nil, uint32(leaseTime.Seconds())) +// appendLeaseTime appends the lease time option to the response. lease must +// not be nil. +func (iface *dhcpInterfaceV4) appendLeaseTime(resp *layers.DHCPv4, lease *Lease) { + var dur time.Duration + if lease.IsStatic { + dur = iface.common.leaseTTL + } else { + dur = lease.Expiry.Sub(iface.clock.Now()) + } + + leaseTimeData := binary.BigEndian.AppendUint32(nil, uint32(dur.Seconds())) resp.Options = append( resp.Options, diff --git a/internal/dhcpsvc/options4_test.go b/internal/dhcpsvc/options4_test.go new file mode 100644 index 00000000..b5c73ebe --- /dev/null +++ b/internal/dhcpsvc/options4_test.go @@ -0,0 +1,42 @@ +package dhcpsvc_test + +import ( + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/google/gopacket/layers" +) + +// newOptHostname creates a DHCP hostname (12) option. +func newOptHostname(tb testing.TB, hostname string) (opt layers.DHCPOption) { + tb.Helper() + + return layers.NewDHCPOption(layers.DHCPOptHostname, []byte(hostname)) +} + +// newOptLeaseTime creates a DHCP lease time (51) option. +func newOptLeaseTime(tb testing.TB, dur time.Duration) (opt layers.DHCPOption) { + tb.Helper() + + secs := uint32(dur.Seconds()) + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], secs) + + return layers.NewDHCPOption(layers.DHCPOptLeaseTime, buf[:]) +} + +// newOptMessageType creates a DHCP message type (53) option. +func newOptMessageType(tb testing.TB, msgType layers.DHCPMsgType) (opt layers.DHCPOption) { + tb.Helper() + + return layers.NewDHCPOption(layers.DHCPOptMessageType, []byte{byte(msgType)}) +} + +// newOptServerID creates a DHCP server identifier (54) option. +func newOptServerID(tb testing.TB, serverIP netip.Addr) (opt layers.DHCPOption) { + tb.Helper() + + return layers.NewDHCPOption(layers.DHCPOptServerID, serverIP.AsSlice()) +} diff --git a/internal/dhcpsvc/server_test.go b/internal/dhcpsvc/server_test.go index 598f2d00..0d85e481 100644 --- a/internal/dhcpsvc/server_test.go +++ b/internal/dhcpsvc/server_test.go @@ -1,11 +1,8 @@ package dhcpsvc_test import ( - "io/fs" "net" "net/netip" - "os" - "path" "path/filepath" "strings" "testing" @@ -18,40 +15,12 @@ import ( "github.com/stretchr/testify/require" ) -// testdata is a filesystem containing data for tests. -var testdata = os.DirFS("testdata") - -// newTempDB copies the leases database file located in the testdata FS, under -// tb.Name()/leases.json, to a temporary directory and returns the path to the -// copied file. -func newTempDB(tb testing.TB) (dst string) { - tb.Helper() - - const filename = "leases.json" - - data, err := fs.ReadFile(testdata, path.Join(tb.Name(), filename)) - require.NoError(tb, err) - - dst = filepath.Join(tb.TempDir(), filename) - - err = os.WriteFile(dst, data, dhcpsvc.DatabasePerm) - require.NoError(tb, err) - - return dst -} - func TestDHCPServer_AddLease(t *testing.T) { - ctx := testutil.ContextWithTimeout(t, testTimeout) - leasesPath := filepath.Join(t.TempDir(), "leases.json") - srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{ - Enabled: true, - Logger: discardLog, - LocalDomainName: testLocalTLD, - Interfaces: testInterfaceConf, - DBFilePath: leasesPath, + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + DBFilePath: leasesPath, + Enabled: true, }) - require.NoError(t, err) const ( existHost = "host1" @@ -69,6 +38,7 @@ func TestDHCPServer_AddLease(t *testing.T) { ipv6MAC = errors.Must(net.ParseMAC("02:03:04:05:06:07")) ) + ctx := testutil.ContextWithTimeout(t, testTimeout) require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{ Hostname: existHost, IP: existIP, @@ -144,6 +114,7 @@ func TestDHCPServer_AddLease(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx = testutil.ContextWithTimeout(t, testTimeout) testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.AddLease(ctx, tc.lease)) }) } @@ -153,17 +124,11 @@ func TestDHCPServer_AddLease(t *testing.T) { } func TestDHCPServer_index(t *testing.T) { - ctx := testutil.ContextWithTimeout(t, testTimeout) - leasesPath := newTempDB(t) - srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{ - Enabled: true, - Logger: discardLog, - LocalDomainName: testLocalTLD, - Interfaces: testInterfaceConf, - DBFilePath: leasesPath, + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + DBFilePath: leasesPath, + Enabled: true, }) - require.NoError(t, err) const ( host1 = "host1" @@ -210,17 +175,11 @@ func TestDHCPServer_index(t *testing.T) { } func TestDHCPServer_UpdateStaticLease(t *testing.T) { - ctx := testutil.ContextWithTimeout(t, testTimeout) - leasesPath := newTempDB(t) - srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{ - Enabled: true, - Logger: discardLog, - LocalDomainName: testLocalTLD, - Interfaces: testInterfaceConf, - DBFilePath: leasesPath, + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + DBFilePath: leasesPath, + Enabled: true, }) - require.NoError(t, err) const ( host1 = "host1" @@ -309,6 +268,7 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx := testutil.ContextWithTimeout(t, testTimeout) testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.UpdateStaticLease(ctx, tc.lease)) }) } @@ -317,17 +277,11 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) { } func TestDHCPServer_RemoveLease(t *testing.T) { - ctx := testutil.ContextWithTimeout(t, testTimeout) - leasesPath := newTempDB(t) - srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{ - Enabled: true, - Logger: discardLog, - LocalDomainName: testLocalTLD, - Interfaces: testInterfaceConf, - DBFilePath: leasesPath, + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + DBFilePath: leasesPath, + Enabled: true, }) - require.NoError(t, err) const ( host1 = "host1" @@ -393,6 +347,7 @@ func TestDHCPServer_RemoveLease(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx := testutil.ContextWithTimeout(t, testTimeout) testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.RemoveLease(ctx, tc.lease)) }) } @@ -403,22 +358,16 @@ func TestDHCPServer_RemoveLease(t *testing.T) { func TestDHCPServer_Reset(t *testing.T) { leasesPath := newTempDB(t) - conf := &dhcpsvc.Config{ - Enabled: true, - Logger: discardLog, - LocalDomainName: testLocalTLD, - Interfaces: testInterfaceConf, - DBFilePath: leasesPath, - } - - ctx := testutil.ContextWithTimeout(t, testTimeout) - srv, err := dhcpsvc.New(ctx, conf) - require.NoError(t, err) + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + DBFilePath: leasesPath, + Enabled: true, + }) const leasesNum = 4 require.Len(t, srv.Leases(), leasesNum) + ctx := testutil.ContextWithTimeout(t, testTimeout) require.NoError(t, srv.Reset(ctx)) assert.FileExists(t, leasesPath) @@ -427,18 +376,10 @@ func TestDHCPServer_Reset(t *testing.T) { func TestServer_Leases(t *testing.T) { leasesPath := newTempDB(t) - conf := &dhcpsvc.Config{ - Enabled: true, - Logger: discardLog, - LocalDomainName: testLocalTLD, - Interfaces: testInterfaceConf, - DBFilePath: leasesPath, - } - - ctx := testutil.ContextWithTimeout(t, testTimeout) - - srv, err := dhcpsvc.New(ctx, conf) - require.NoError(t, err) + srv := newTestDHCPServer(t, &dhcpsvc.Config{ + DBFilePath: leasesPath, + Enabled: true, + }) expiry, err := time.Parse(time.RFC3339, "2042-01-02T03:04:05Z") require.NoError(t, err) diff --git a/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther4_discover/leases.json b/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther4_discover/leases.json new file mode 100644 index 00000000..5cb6ca54 --- /dev/null +++ b/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther4_discover/leases.json @@ -0,0 +1,26 @@ +{ + "leases": [ + { + "expires": "2025-01-01T10:01:01Z", + "ip": "192.168.0.102", + "hostname": "dynamic4", + "mac": "02:03:04:05:06:07", + "static": false + }, + { + "expires": "2025-01-01T01:01:01Z", + "ip": "192.168.0.103", + "hostname": "expired4", + "mac": "03:04:05:06:07:08", + "static": false + }, + { + "expires": "", + "ip": "192.168.0.101", + "hostname": "static4", + "mac": "01:02:03:04:05:06", + "static": true + } + ], + "version": 1 +} diff --git a/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther4_discoverExpired/leases.json b/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther4_discoverExpired/leases.json new file mode 100644 index 00000000..3fee9b4e --- /dev/null +++ b/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther4_discoverExpired/leases.json @@ -0,0 +1,12 @@ +{ + "leases": [ + { + "expires": "2025-01-01T01:01:00Z", + "ip": "192.168.0.100", + "hostname": "dynamic4", + "mac": "02:03:04:05:06:07", + "static": false + } + ], + "version": 1 +} diff --git a/internal/dhcpsvc/v4.go b/internal/dhcpsvc/v4.go index 37bb3637..bf763e5e 100644 --- a/internal/dhcpsvc/v4.go +++ b/internal/dhcpsvc/v4.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "slices" + "strings" "time" "github.com/AdguardTeam/golibs/errors" @@ -197,10 +198,13 @@ func (srv *DHCPServer) newDHCPInterfaceV4( return iface } -// commitLease updates the lease in database, using new hostname if it's valid. -// -// TODO(e.burkov): Implement. -func (iface *dhcpInterfaceV4) commitLease(ctx context.Context, l *Lease, hostname string) {} +// commitLease writes l into database. l must be valid and not expired. +func (iface *dhcpInterfaceV4) commitLease( + ctx context.Context, + l *Lease, +) (err error) { + return iface.common.index.add(ctx, iface.common.logger, l, iface.common) +} // respondOffer sends a DHCPOFFER message to the client. req, fd, and l must // not be nil. @@ -211,7 +215,9 @@ func (iface *dhcpInterfaceV4) respondOffer( l *Lease, ) { resp := iface.buildResponse(req, l, layers.DHCPMsgTypeOffer) - if err := respond4(fd, resp); err != nil { + + err := respond4(fd, resp) + if err != nil { iface.common.logger.ErrorContext(ctx, "writing offer", "error", err) } } @@ -288,10 +294,11 @@ func (iface *dhcpInterfaceV4) buildResponse( resp.Options = append( resp.Options, layers.NewDHCPOption(layers.DHCPOptMessageType, []byte{byte(msgType)}), + // TODO(e.burkov): Use network device address. layers.NewDHCPOption(layers.DHCPOptServerID, iface.gateway.AsSlice()), ) - appendLeaseTime(resp, iface.common.leaseTTL) + iface.appendLeaseTime(resp, l) iface.updateOptions(req, resp) // Add hostname option if the lease has a hostname. @@ -334,7 +341,7 @@ func (iface *dhcpInterfaceV4) allocateLease( for { l, err = iface.reserveLease(ctx, mac) if err != nil { - return nil, fmt.Errorf("reserving a lease: %w", err) + return nil, err } var ok bool @@ -353,65 +360,103 @@ func (iface *dhcpInterfaceV4) allocateLease( // reserveLease reserves a lease for a client by its MAC-address. l is nil if a // new lease can't be allocated. mac must be a valid according to -// [netutil.ValidateMAC]. +// [netutil.ValidateMAC]. index mutex must be locked. func (iface *dhcpInterfaceV4) reserveLease( ctx context.Context, mac net.HardwareAddr, ) (l *Lease, err error) { - iface.common.indexMu.Lock() - defer iface.common.indexMu.Unlock() - nextIP := iface.common.nextIP() - if nextIP == (netip.Addr{}) { - l = iface.common.findExpiredLease(iface.clock.Now()) - if l == nil { - return nil, nil + if nextIP != (netip.Addr{}) { + l = &Lease{ + HWAddr: slices.Clone(mac), + IP: nextIP, + Expiry: iface.clock.Now().Add(iface.common.leaseTTL), } - // TODO(e.burkov): Move validation from index methods into server's - // methods and use index here. - delete(iface.common.leases, macToKey(l.HWAddr)) - - l.HWAddr = slices.Clone(mac) - iface.common.leases[macToKey(mac)] = l - return l, nil } - l = &Lease{ - HWAddr: slices.Clone(mac), - IP: nextIP, + l = iface.common.findExpiredLease(iface.clock.Now()) + if l == nil { + return nil, errors.Error("no addresses available to lease") } - err = iface.common.index.add(ctx, iface.common.logger, l, iface.common) + // TODO(e.burkov): Move validation from index methods into server's + // methods and use index here. + delete(iface.common.leases, macToKey(l.HWAddr)) + + idx := iface.common.index + delete(idx.byAddr, l.IP) + delete(idx.byName, strings.ToLower(l.Hostname)) + + err = idx.dbStore(ctx, iface.common.logger) if err != nil { + // Don't wrap the error since it's informative enough as is. return nil, err } + l.HWAddr = slices.Clone(mac) + l.Hostname = "" + l.IsStatic = false + l.updateExpiry(iface.clock, iface.common.leaseTTL) + + iface.common.leases[macToKey(mac)] = l + return l, nil } +const ( + // IPv4DefaultTTL is the default Time to Live value in seconds as + // recommended by RFC 1700. + IPv4DefaultTTL = 64 + + // IPProtoVersion is the IP internetwork general protocol version number as + // defined by RFC 1700. + IPProtoVersion = 4 +) + +// Port numbers for DHCPv4. +// +// See RFC 2131 Section 4.1. +const ( + // ServerPortV4 is the standard DHCPv4 server port. + ServerPortV4 layers.UDPPort = 67 + + // ClientPortV4 is the standard DHCPv4 client port. + ClientPortV4 layers.UDPPort = 68 +) + // respond4 sends a DHCPv4 response. fd and resp must not be nil. func respond4(fd *frameData, resp *layers.DHCPv4) (err error) { + // TODO(e.burkov): Use pools for buffer and layers. buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } eth := &layers.Ethernet{ SrcMAC: fd.ether.SrcMAC, DstMAC: fd.ether.DstMAC, EthernetType: layers.EthernetTypeIPv4, } + ip := &layers.IPv4{ + Version: IPProtoVersion, + TTL: IPv4DefaultTTL, + SrcIP: net.IPv4zero.To4(), + DstIP: net.IPv4bcast.To4(), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: ServerPortV4, + DstPort: ClientPortV4, + } + _ = udp.SetNetworkLayerForChecksum(ip) - // TODO(e.burkov): Handle IP layer. + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } - all := []gopacket.SerializableLayer{eth, resp} - - err = gopacket.SerializeLayers(buf, opts, all...) + err = gopacket.SerializeLayers(buf, opts, eth, ip, udp, resp) if err != nil { - return err + return fmt.Errorf("constructing dhcp v4 response: %w", err) } return fd.device.WritePacketData(buf.Bytes())