DEVELOPMENT ENVIRONMENT

~liljamo/felu

3cd37b5d7df372ecbf494843ba642e004a54621a — Jonni Liljamo 16 days ago f9956b2
feat: custom TsigProvider implementation
2 files changed, 109 insertions(+), 0 deletions(-)

M internal/dns/server.go
A internal/dns/tsigprovider.go
M internal/dns/server.go => internal/dns/server.go +1 -0
@@ 21,6 21,7 @@ func Run(addr string, net string) error {
		Addr:          addr,
		Net:           net,
		MsgAcceptFunc: msgAcceptFunc,
		TsigProvider:  tsigProvider{},
	}

	return server.ListenAndServe()

A internal/dns/tsigprovider.go => internal/dns/tsigprovider.go +108 -0
@@ 0,0 1,108 @@
/*
 * Copyright (C) 2024 Jonni Liljamo <jonni@liljamo.com>
 *
 * This file is licensed under AGPL-3.0-or-later, see NOTICE and LICENSE for
 * more information.
 */

package dns

import (
	"crypto/hmac"
	"crypto/sha1"
	"crypto/sha256"
	"crypto/sha512"
	"encoding/base64"
	"encoding/hex"
	"hash"
	"strings"

	"git.src.quest/~liljamo/felu/internal/db"
	"github.com/miekg/dns"
)

// tsigProvider is an implementation of TsigProvider
type tsigProvider struct{}

func (ts tsigProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
	var key string
	if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 {
		var err error
		key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index])
		if err != nil {
			return nil, dns.ErrSecret
		}
	} else {
		return nil, dns.ErrSecret
	}

	return tsigHMACProvider(key).Generate(msg, t)
}

func (ts tsigProvider) Verify(msg []byte, t *dns.TSIG) error {
	var key string
	if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 {
		var err error
		key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index])
		if err != nil {
			return dns.ErrSecret
		}
	} else {
		return dns.ErrSecret
	}

	return tsigHMACProvider(key).Verify(msg, t)
}

// tsigHMACProvider is a carbon copy of tsigHMACProvider from here:
// https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/tsig.go#L35
// But that's not exported, so thats why we copy it.
type tsigHMACProvider string

func (key tsigHMACProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
	rawsecret, err := fromBase64([]byte(key))
	if err != nil {
		return nil, err
	}
	var h hash.Hash
	switch dns.CanonicalName(t.Algorithm) {
	case dns.HmacSHA1:
		h = hmac.New(sha1.New, rawsecret)
	case dns.HmacSHA224:
		h = hmac.New(sha256.New224, rawsecret)
	case dns.HmacSHA256:
		h = hmac.New(sha256.New, rawsecret)
	case dns.HmacSHA384:
		h = hmac.New(sha512.New384, rawsecret)
	case dns.HmacSHA512:
		h = hmac.New(sha512.New, rawsecret)
	default:
		return nil, dns.ErrKeyAlg
	}
	h.Write(msg)
	return h.Sum(nil), nil
}

func (key tsigHMACProvider) Verify(msg []byte, t *dns.TSIG) error {
	b, err := key.Generate(msg, t)
	if err != nil {
		return err
	}
	mac, err := hex.DecodeString(t.MAC)
	if err != nil {
		return err
	}
	if !hmac.Equal(b, mac) {
		return dns.ErrSig
	}
	return nil
}

// https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/msg_helpers.go#L161
func fromBase64(s []byte) (buf []byte, err error) {
	buflen := base64.StdEncoding.DecodedLen(len(s))
	buf = make([]byte, buflen)
	n, err := base64.StdEncoding.Decode(buf, s)
	buf = buf[:n]
	return
}