DEVELOPMENT ENVIRONMENT

~liljamo/felu

ref: 9e41bc5486e2858c81a59c9302089d3a480ec4e3 felu/internal/dns/tsigprovider.go -rw-r--r-- 2.5 KiB
9e41bc54Jonni Liljamo feat: sane SOA ttls 16 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*
 * Copyright (C) 2024 Jonni Liljamo <jonni@liljamo.com>
 *
 * This file is licensed under AGPL-3.0-or-later, see NOTICE and LICENSE for
 * more information.
 */

package dns

import (
	"crypto/hmac"
	"crypto/sha1"
	"crypto/sha256"
	"crypto/sha512"
	"encoding/base64"
	"encoding/hex"
	"hash"
	"strings"

	"git.src.quest/~liljamo/felu/internal/db"
	"github.com/miekg/dns"
)

// tsigProvider is an implementation of TsigProvider
type tsigProvider struct{}

func (ts tsigProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
	var key string
	if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 {
		var err error
		key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index])
		if err != nil {
			return nil, dns.ErrSecret
		}
	} else {
		return nil, dns.ErrSecret
	}

	return tsigHMACProvider(key).Generate(msg, t)
}

func (ts tsigProvider) Verify(msg []byte, t *dns.TSIG) error {
	var key string
	if index := strings.IndexByte(t.Hdr.Name, '.'); index >= 0 {
		var err error
		key, err = db.FetchDomainTsigKey(t.Hdr.Name[:index])
		if err != nil {
			return dns.ErrSecret
		}
	} else {
		return dns.ErrSecret
	}

	return tsigHMACProvider(key).Verify(msg, t)
}

// tsigHMACProvider is a carbon copy of tsigHMACProvider from here:
// https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/tsig.go#L35
// But that's not exported, so thats why we copy it.
type tsigHMACProvider string

func (key tsigHMACProvider) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
	rawsecret, err := fromBase64([]byte(key))
	if err != nil {
		return nil, err
	}
	var h hash.Hash
	switch dns.CanonicalName(t.Algorithm) {
	case dns.HmacSHA1:
		h = hmac.New(sha1.New, rawsecret)
	case dns.HmacSHA224:
		h = hmac.New(sha256.New224, rawsecret)
	case dns.HmacSHA256:
		h = hmac.New(sha256.New, rawsecret)
	case dns.HmacSHA384:
		h = hmac.New(sha512.New384, rawsecret)
	case dns.HmacSHA512:
		h = hmac.New(sha512.New, rawsecret)
	default:
		return nil, dns.ErrKeyAlg
	}
	h.Write(msg)
	return h.Sum(nil), nil
}

func (key tsigHMACProvider) Verify(msg []byte, t *dns.TSIG) error {
	b, err := key.Generate(msg, t)
	if err != nil {
		return err
	}
	mac, err := hex.DecodeString(t.MAC)
	if err != nil {
		return err
	}
	if !hmac.Equal(b, mac) {
		return dns.ErrSig
	}
	return nil
}

// https://github.com/miekg/dns/blob/b77d1ed8e9282cadf21c4124f53a660fed55c8ca/msg_helpers.go#L161
func fromBase64(s []byte) (buf []byte, err error) {
	buflen := base64.StdEncoding.DecodedLen(len(s))
	buf = make([]byte, buflen)
	n, err := base64.StdEncoding.Decode(buf, s)
	buf = buf[:n]
	return
}