/* * Copyright (C) 2024 Jonni Liljamo * * This file is licensed under AGPL-3.0-or-later, see NOTICE and LICENSE for * more information. */ package dns import ( "crypto/hmac" "crypto/sha1" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/hex" "hash" "strings" "git.src.quest/~liljamo/felu/internal/db" "github.com/miekg/dns" ) // tsigProvider is an implementation of TsigProvider type tsigProvider struct{} func (ts tsigProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { var key string if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 { var err error key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index]) if err != nil { return nil, dns.ErrSecret } } else { return nil, dns.ErrSecret } return tsigHMACProvider(key).Generate(msg, t) } func (ts tsigProvider) Verify(msg []byte, t *dns.TSIG) error { var key string if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 { var err error key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index]) if err != nil { return dns.ErrSecret } } else { return dns.ErrSecret } return tsigHMACProvider(key).Verify(msg, t) } // tsigHMACProvider is a carbon copy of tsigHMACProvider from here: // https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/tsig.go#L35 // But that's not exported, so thats why we copy it. type tsigHMACProvider string func (key tsigHMACProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { rawsecret, err := fromBase64([]byte(key)) if err != nil { return nil, err } var h hash.Hash switch dns.CanonicalName(t.Algorithm) { case dns.HmacSHA1: h = hmac.New(sha1.New, rawsecret) case dns.HmacSHA224: h = hmac.New(sha256.New224, rawsecret) case dns.HmacSHA256: h = hmac.New(sha256.New, rawsecret) case dns.HmacSHA384: h = hmac.New(sha512.New384, rawsecret) case dns.HmacSHA512: h = hmac.New(sha512.New, rawsecret) default: return nil, dns.ErrKeyAlg } h.Write(msg) return h.Sum(nil), nil } func (key tsigHMACProvider) Verify(msg []byte, t *dns.TSIG) error { b, err := key.Generate(msg, t) if err != nil { return err } mac, err := hex.DecodeString(t.MAC) if err != nil { return err } if !hmac.Equal(b, mac) { return dns.ErrSig } return nil } // https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/msg_helpers.go#L161 func fromBase64(s []byte) (buf []byte, err error) { buflen := base64.StdEncoding.DecodedLen(len(s)) buf = make([]byte, buflen) n, err := base64.StdEncoding.Decode(buf, s) buf = buf[:n] return }