Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 11 additions & 38 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,72 +199,45 @@ func (p *Provider) AppendRecords(ctx context.Context, zone string, records []lib

// SetRecords sets the records in the zone, either by updating existing records or creating new ones.
// It returns the updated records.
//
// Caveat: This method will fail if there are more than 500 RRsets in the zone. See package
// documentation for more detail.
func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
err := acquireWriteToken(ctx)
if err != nil {
return nil, fmt.Errorf("waiting for inflight requests to finish: %v", err)
}
defer releaseWriteToken()

// Build the desired state
rrsets := make(map[rrKey]*rrSet)
// Group records by rrKey (name, type)
rrsetMap := make(map[rrKey]*rrSet)
for _, r := range records {
rr := r.RR()
key := rrKey{rrSetSubname(rr), rr.Type}
rrset := rrsets[key]
rrset := rrsetMap[key]
if rrset == nil {
rrset = &rrSet{
Subname: key.Subname,
Type: key.Type,
Records: nil,
TTL: rrSetTTL(rr),
}
rrsets[key] = rrset
rrsetMap[key] = rrset
}
rrset.Records = append(rrset.Records, rrSetRecord(rr))
}

// Fetch existing rrSets and compare to desired state
existing, err := p.listRRSets(ctx, zone)
if err != nil {
return nil, fmt.Errorf("listing RRSets: %v", err)
}
for _, g := range existing {
key := rrKey{g.Subname, g.Type}
w := rrsets[key]
switch {
case w == nil:
// rrset exists, but not in the input, delete it by adding it to rrsets and set
// records to an empty slice to represent the deletion.
// See https://desec.readthedocs.io/en/latest/dns/rrsets.html#deleting-an-rrset
w0 := g
w0.Records = []string{}
rrsets[key] = &w0
case g.equal(w):
// rrset exists and is equal to the one we want; skip it in the update.
delete(rrsets, key)
}
}

// Generate updates to arrive at desired state.
update := make([]rrSet, 0, len(rrsets))
var ret []libdns.Record
for _, rrset := range rrsets {
update = append(update, *rrset)

// Add all records being set here. This ignores records that are being deleted, because
// those are represented as an rrset without any records.
// Build list of RRSets to pass to the API and list of libdns records
// to return from the function
rrsetList := make([]rrSet, 0, len(rrsetMap))
ret := make([]libdns.Record, 0, len(records))
for _, rrset := range rrsetMap {
rrsetList = append(rrsetList, *rrset)
records0, err := libdnsRecords(*rrset)
if err != nil {
return nil, fmt.Errorf("parsing RRSet: %v", err)
}
ret = append(ret, records0...)
}

if err := p.putRRSets(ctx, zone, update); err != nil {
if err := p.putRRSets(ctx, zone, rrsetList); err != nil {
return nil, fmt.Errorf("writing RRSets: %v", err)
}
return ret, nil
Expand Down
139 changes: 133 additions & 6 deletions provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,12 @@ func TestSetRecords(t *testing.T) {
ctx := setup(t, `[
{"subname": "www", "type": "A", "ttl": 3600, "records": ["127.0.1.1", "127.0.1.2"]},
{"subname": "", "type": "TXT", "ttl": 3600, "records": ["\"will be overridden\""]},
{"subname": "sub", "type": "TXT", "ttl": 3600, "records": ["\"will stay the same\""]},
{"subname": "www", "type": "HTTPS", "ttl": 3600, "records": ["1 . alpn=\"h2\""]},
{"subname": "_sip._tcp.sub", "type": "SRV", "ttl": 3600, "records": ["1 100 5061 sip.example.com."]},
{"subname": "_ftp._tcp", "type": "URI", "ttl": 3600, "records": ["1 2 \"ftp://example.com/arst\""]},
{"subname": "", "type": "MX", "ttl": 3600, "records": ["0 mx0.example.com.", "10 mx1.example.com."]}
{"subname": "", "type": "MX", "ttl": 3600, "records": ["0 mx0.example.com.", "10 mx1.example.com."]},
{"subname": "www", "type": "NS", "ttl": 3600, "records": ["ns0.example.com.", "ns1.example.com."]}
]`)

p := &desec.Provider{
Expand Down Expand Up @@ -369,13 +371,14 @@ func TestSetRecords(t *testing.T) {
},
}

created, err := p.SetRecords(ctx, *domain+".", records)
recordsSet, err := p.SetRecords(ctx, *domain+".", records)
if err != nil {
t.Fatal(err)
}

// The records that already existed are not returned by SetRecords.
wantCreated := []libdns.Record{
// All set records, including the ones that already existed, should be returned
// by SetRecords.
wantSet := []libdns.Record{
libdns.Address{
Name: "@",
IP: netip.MustParseAddr("127.0.0.3"),
Expand All @@ -391,6 +394,16 @@ func TestSetRecords(t *testing.T) {
IP: netip.MustParseAddr("127.0.0.1"),
TTL: 3600 * time.Second,
},
libdns.ServiceBinding{
Scheme: "https",
Name: "www",
Target: ".",
Params: libdns.SvcParams{
"alpn": []string{"h2"},
},
Priority: 1,
TTL: 3600 * time.Second,
},
libdns.Address{
Name: "www",
IP: netip.MustParseAddr("127.0.0.2"),
Expand All @@ -401,16 +414,130 @@ func TestSetRecords(t *testing.T) {
IP: netip.MustParseAddr("127.0.0.5"),
TTL: 3600 * time.Second,
},
libdns.MX{
Name: "@",
Target: "mx0.example.com.",
TTL: 3600 * time.Second,
Preference: 0,
},
libdns.MX{
Name: "@",
Target: "mx1.example.com.",
TTL: 3600 * time.Second,
Preference: 10,
},
libdns.SRV{
Service: "sip",
Transport: "tcp",
Name: "sub",
Target: "sip.example.com.",
TTL: 3600 * time.Second,
Priority: 1,
Weight: 100,
Port: 5061,
},
libdns.RR{
Type: "URI",
Name: "_ftp._tcp",
Data: `1 2 "ftp://example.com/arst"`,
TTL: 3600 * time.Second,
},
}
if diff := cmp.Diff(wantCreated, created, cmpRecord); diff != "" {

if diff := cmp.Diff(wantSet, recordsSet, cmpRecord); diff != "" {
t.Fatalf("p.SetRecords() unexpected diff [-want +got]: %s", diff)
}

wantCurrent := []libdns.Record{
// Records for (name, type) pairs which were not present in the SetRecords input
// should be unaffected.
libdns.TXT{
Name: "sub",
Text: `will stay the same`,
TTL: time.Second * 3600,
},
libdns.NS{
Name: "www",
Target: "ns0.example.com.",
TTL: time.Second * 3600,
},
libdns.NS{
Name: "www",
Target: "ns1.example.com.",
TTL: time.Second * 3600,
},
// Records for (name, type) pairs which were present in the SetRecords input
// should match the output of SetRecords.
libdns.Address{
Name: "@",
IP: netip.MustParseAddr("127.0.0.3"),
TTL: time.Second * 3601,
},
libdns.TXT{
Name: "@",
Text: `hello dns!`,
TTL: time.Second * 3600,
},
libdns.Address{
Name: "www",
IP: netip.MustParseAddr("127.0.0.1"),
TTL: 3600 * time.Second,
},
libdns.ServiceBinding{
Scheme: "https",
Name: "www",
Target: ".",
Params: libdns.SvcParams{
"alpn": []string{"h2"},
},
Priority: 1,
TTL: 3600 * time.Second,
},
libdns.Address{
Name: "www",
IP: netip.MustParseAddr("127.0.0.2"),
TTL: 3600 * time.Second,
},
libdns.Address{
Name: "subsub.sub",
IP: netip.MustParseAddr("127.0.0.5"),
TTL: 3600 * time.Second,
},
libdns.MX{
Name: "@",
Target: "mx0.example.com.",
TTL: 3600 * time.Second,
Preference: 0,
},
libdns.MX{
Name: "@",
Target: "mx1.example.com.",
TTL: 3600 * time.Second,
Preference: 10,
},
libdns.SRV{
Service: "sip",
Transport: "tcp",
Name: "sub",
Target: "sip.example.com.",
TTL: 3600 * time.Second,
Priority: 1,
Weight: 100,
Port: 5061,
},
libdns.RR{
Type: "URI",
Name: "_ftp._tcp",
Data: `1 2 "ftp://example.com/arst"`,
TTL: 3600 * time.Second,
},
}

got, err := p.GetRecords(ctx, *domain+".")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(records, got, cmpRecord); diff != "" {
if diff := cmp.Diff(wantCurrent, got, cmpRecord); diff != "" {
t.Fatalf("p.GetRecords() unexpected diff [-want +got]: %s", diff)
}
}
Expand Down