/*
* Copyright (C) 2024 Jonni Liljamo <jonni@liljamo.com>
*
* This file is licensed under AGPL-3.0-or-later, see NOTICE and LICENSE for
* more information.
*/
package dns
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/hex"
"hash"
"strings"
"git.src.quest/~liljamo/felu/internal/db"
"github.com/miekg/dns"
)
// tsigProvider is an implementation of TsigProvider
type tsigProvider struct{}
func (ts tsigProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
var key string
if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 {
var err error
key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index])
if err != nil {
return nil, dns.ErrSecret
}
} else {
return nil, dns.ErrSecret
}
return tsigHMACProvider(key).Generate(msg, t)
}
func (ts tsigProvider) Verify(msg []byte, t *dns.TSIG) error {
var key string
if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 {
var err error
key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index])
if err != nil {
return dns.ErrSecret
}
} else {
return dns.ErrSecret
}
return tsigHMACProvider(key).Verify(msg, t)
}
// tsigHMACProvider is a carbon copy of tsigHMACProvider from here:
// https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/tsig.go#L35
// But that's not exported, so thats why we copy it.
type tsigHMACProvider string
func (key tsigHMACProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
rawsecret, err := fromBase64([]byte(key))
if err != nil {
return nil, err
}
var h hash.Hash
switch dns.CanonicalName(t.Algorithm) {
case dns.HmacSHA1:
h = hmac.New(sha1.New, rawsecret)
case dns.HmacSHA224:
h = hmac.New(sha256.New224, rawsecret)
case dns.HmacSHA256:
h = hmac.New(sha256.New, rawsecret)
case dns.HmacSHA384:
h = hmac.New(sha512.New384, rawsecret)
case dns.HmacSHA512:
h = hmac.New(sha512.New, rawsecret)
default:
return nil, dns.ErrKeyAlg
}
h.Write(msg)
return h.Sum(nil), nil
}
func (key tsigHMACProvider) Verify(msg []byte, t *dns.TSIG) error {
b, err := key.Generate(msg, t)
if err != nil {
return err
}
mac, err := hex.DecodeString(t.MAC)
if err != nil {
return err
}
if !hmac.Equal(b, mac) {
return dns.ErrSig
}
return nil
}
// https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/msg_helpers.go#L161
func fromBase64(s []byte) (buf []byte, err error) {
buflen := base64.StdEncoding.DecodedLen(len(s))
buf = make([]byte, buflen)
n, err := base64.StdEncoding.Decode(buf, s)
buf = buf[:n]
return
}