[PATCH][WIP] Create DC DNS entires at domain join

Andrew Bartlett abartlet at samba.org
Thu Jun 8 10:07:00 UTC 2017


On Thu, 2017-06-08 at 21:13 +1200, Andrew Bartlett via samba-technical
wrote:
> 
> 
> Attached is the current patches, taking the approach as above, but now
> with tests to show that the entries are created.

As Gary pointed out, that was the wrong set of patches.  Here are the
referenced patches.

> Not here - the patches on my workstation (drat) - are also tests to
> assert that subsequent modification using DNS is possible, using the
> machine account. 
> 
> This also gives us a good framework for improvements here in the
> future. 
> 
> The only other thing blocking this from being put up for review is that
> Garming asked that I test the MNAME over-stamp in an environment where
> it would actually do something, and it is taking a little longer to get
> the tests and knownfail entries set up.  
> 
> Regardless, any further comments most welcome as I would hope to seek a
> formal review tomorrow.
> 
> http://git.catalyst.net.nz/gw?p=samba.git;a=shortlog;h=refs/heads/dns-at-domain-join
> 
> Thanks,
> 
> Andrew Bartlett
-- 
Andrew Bartlett                       http://samba.org/~abartlet/
Authentication Developer, Samba Team  http://samba.org
Samba Developer, Catalyst IT          http://catalyst.net.nz/services/samba
-------------- next part --------------
From 5a10cc0969956bfd15b34320a0c35188b98f7de4 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 23 May 2017 15:56:55 +1200
Subject: [PATCH 01/22] dsdb: Improve error messages when
 dsdb_set_schema_from_ldif() fails

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/dsdb/schema/schema_set.c | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/source4/dsdb/schema/schema_set.c b/source4/dsdb/schema/schema_set.c
index e6d5ce6..fd48d27 100644
--- a/source4/dsdb/schema/schema_set.c
+++ b/source4/dsdb/schema/schema_set.c
@@ -893,6 +893,8 @@ WERROR dsdb_set_schema_from_ldif(struct ldb_context *ldb,
 	ret = dsdb_set_schema(ldb, schema, true);
 	if (ret != LDB_SUCCESS) {
 		status = WERR_FOOBAR;
+		DEBUG(0,("ERROR: dsdb_set_schema() failed with %s / %s\n",
+			 ldb_strerror(ret), ldb_errstring(ldb)));
 		goto failed;
 	}
 
-- 
2.9.4


From ed196d93d312a51a4b7752d7f8f90c84c87f71c3 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Fri, 17 Feb 2017 18:24:27 +1300
Subject: [PATCH 02/22] samba_dnsupdate: Ensure we only force "server" under
 resolv_wrapper

This ensures that nsupdate can use a namserver in /etc/resolv.conf that is a
cache or forwarder, rather than the AD DC directly.

This avoids a regression from forcing the nameservers to the
/etc/resolv.conf nameservers in
e85ef1dbfef4b16c35cac80c0efc563d8cd1ba3e

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/scripting/bin/samba_dnsupdate | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index d382758..ba167da 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -430,8 +430,19 @@ def call_nsupdate(d, op="add"):
 
     (tmp_fd, tmpfile) = tempfile.mkstemp()
     f = os.fdopen(tmp_fd, 'w')
-    if d.nameservers != []:
+
+    # Getting this line right is really important.  When we are under
+    # resolv_wrapper, then we want to use RESOLV_CONF and the
+    # nameserver therein. The issue is that this parameter forces us
+    # to only ever use that server, and not some other server that the
+    # NS record may point to, even as we get a ticket to that other
+    # server.
+    #
+    # Therefore we must not set this in production.
+
+    if os.getenv('RESOLV_CONF') and d.nameservers != []:
         f.write('server %s\n' % d.nameservers[0])
+
     if d.type == "A":
         f.write("update %s %s %u A %s\n" % (op, normalised_name, default_ttl, d.ip))
     if d.type == "AAAA":
-- 
2.9.4


From 1a8c4fa0f34d3399ec45fff367b9d11e35bec970 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 27 Feb 2017 16:51:45 +1300
Subject: [PATCH 03/22] pydns: Fix leak of talloc_stackframe() in python
 bindings

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/dns_server/pydns.c | 23 ++++++++++++++++++++---
 1 file changed, 20 insertions(+), 3 deletions(-)

diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 9842f24..18c3c29 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -124,12 +124,14 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 
 	status = dns_common_zones(samdb, frame, &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
+		talloc_free(frame);
 		PyErr_SetNTSTATUS(status);
 		return NULL;
 	}
 
 	werr = dns_common_name2dn(samdb, zones_list, frame, dns_name, &dn);
 	if (!W_ERROR_IS_OK(werr)) {
+		talloc_free(frame);
 		PyErr_SetWERROR(werr);
 		return NULL;
 	}
@@ -141,16 +143,19 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 				 &num_records,
 				 NULL);
 	if (!W_ERROR_IS_OK(werr)) {
+		talloc_free(frame);
 		PyErr_SetWERROR(werr);
 		return NULL;
 	}
 
-	return py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	ret = py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	talloc_free(frame);
+	return ret;
 }
 
 static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 {
-	PyObject *py_dns_el;
+	PyObject *py_dns_el, *ret;
 	TALLOC_CTX *frame;
 	WERROR werr;
 	struct ldb_message_element *dns_el;
@@ -175,11 +180,14 @@ static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 				  &records,
 				  &num_records);
 	if (!W_ERROR_IS_OK(werr)) {
+		talloc_free(frame);
 		PyErr_SetWERROR(werr);
 		return NULL;
 	}
 
-	return py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	ret = py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	talloc_free(frame);
+	return ret;
 }
 
 static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
@@ -213,12 +221,14 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 	status = dns_common_zones(samdb, frame, &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		PyErr_SetNTSTATUS(status);
+		talloc_free(frame);
 		return NULL;
 	}
 
 	werr = dns_common_name2dn(samdb, zones_list, frame, dns_name, &dn);
 	if (!W_ERROR_IS_OK(werr)) {
 		PyErr_SetWERROR(werr);
+		talloc_free(frame);
 		return NULL;
 	}
 
@@ -226,6 +236,7 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 						frame,
 						&records, &num_records);
 	if (ret != 0) {
+		talloc_free(frame);
 		return NULL;
 	}
 
@@ -238,9 +249,11 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 				  num_records);
 	if (!W_ERROR_IS_OK(werr)) {
 		PyErr_SetWERROR(werr);
+		talloc_free(frame);
 		return NULL;
 	}
 
+	talloc_free(frame);
 	Py_RETURN_NONE;
 }
 
@@ -275,6 +288,7 @@ static PyObject *py_dsdb_dns_replace_by_dn(PyObject *self, PyObject *args)
 						frame,
 						&records, &num_records);
 	if (ret != 0) {
+		talloc_free(frame);
 		return NULL;
 	}
 
@@ -287,9 +301,12 @@ static PyObject *py_dsdb_dns_replace_by_dn(PyObject *self, PyObject *args)
 				  num_records);
 	if (!W_ERROR_IS_OK(werr)) {
 		PyErr_SetWERROR(werr);
+		talloc_free(frame);
 		return NULL;
 	}
 
+	talloc_free(frame);
+
 	Py_RETURN_NONE;
 }
 
-- 
2.9.4


From ab59eb658a90eb7564b1aa52e4af1938afecedb2 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 27 Feb 2017 17:09:56 +1300
Subject: [PATCH 04/22] pydns: Also return the DN of the LDB object when
 finding a DNS record

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/remove_dc.py  | 4 ++--
 source4/dns_server/pydns.c | 5 +++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/python/samba/remove_dc.py b/python/samba/remove_dc.py
index 61b5937..4c8ee89 100644
--- a/python/samba/remove_dc.py
+++ b/python/samba/remove_dc.py
@@ -97,7 +97,7 @@ def remove_dns_references(samdb, logger, dnsHostName):
     dnsHostNameUpper = dnsHostName.upper()
 
     try:
-        primary_recs = samdb.dns_lookup(dnsHostName)
+        (dn, primary_recs) = samdb.dns_lookup(dnsHostName)
     except RuntimeError as (enum, estr):
         if enum == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
               return
@@ -140,7 +140,7 @@ def remove_dns_references(samdb, logger, dnsHostName):
     for a_name in a_names_to_remove_from:
         try:
             logger.debug("checking for DNS records to remove on %s" % a_name)
-            a_recs = samdb.dns_lookup(a_name)
+            (a_rec_dn, a_recs) = samdb.dns_lookup(a_name)
         except RuntimeError as (enum, estr):
             if enum == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
                 return
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 18c3c29..3de9739 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -105,7 +105,7 @@ static int py_dnsp_DnssrvRpcRecord_get_array(PyObject *value,
 static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 {
 	struct ldb_context *samdb;
-	PyObject *py_ldb;
+	PyObject *py_ldb, *ret, *pydn;
 	char *dns_name;
 	TALLOC_CTX *frame;
 	NTSTATUS status;
@@ -149,8 +149,9 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 	}
 
 	ret = py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	pydn = pyldb_Dn_FromDn(dn);
 	talloc_free(frame);
-	return ret;
+	return Py_BuildValue("(OO)", pydn, ret);
 }
 
 static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
-- 
2.9.4


From 690681a416385ea961292fa5f35db68bb063dd3c Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 28 Feb 2017 14:15:12 +1300
Subject: [PATCH 05/22] python: Allow sd_utils to take a Dn object, not just a
 string DN

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/sd_utils.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/python/samba/sd_utils.py b/python/samba/sd_utils.py
index 7592a29..568829f 100644
--- a/python/samba/sd_utils.py
+++ b/python/samba/sd_utils.py
@@ -37,7 +37,11 @@ class SDUtils(object):
             or security.descriptor object
         """
         m = Message()
-        m.dn = Dn(self.ldb, object_dn)
+        if isinstance(object_dn, Dn):
+            m.dn = object_dn
+        else:
+            m.dn = Dn(self.ldb, object_dn)
+
         assert(isinstance(sd, str) or isinstance(sd, security.descriptor))
         if isinstance(sd, str):
             tmp_desc = security.descriptor.from_sddl(sd, self.domain_sid)
-- 
2.9.4


From e0250f0ebaf687508cadbc924c177a0408ecf333 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Fri, 17 Feb 2017 18:23:23 +1300
Subject: [PATCH 06/22] join.py Add DNS records at domain join time

This avoids issues getting replication going after the DC first starts
as the rest of the domain does not have to wait for samba_dnsupdate to
run successfully

We do not just run samba_dnsupdate as we want to strictly
operate against the DC we just joined:
 - We do not want to query another DNS server
 - We do not want to obtain a Kerberos ticket for the new DC
   (as the KDC we select may not be the DC we just joined,
   and so may not be in sync with the password we just set)
 - We do not wish to set the _ldap records until we have started
 - We do not wish to use NTLM (the --use-samba-tool mode forces
   NTLM)

The downside to using DCE/RPC rather than DNS is that these will
be regarded as static entries, and (against windows) have a the ACL
assigned for static entries.  However this is still better than no
DNS at all.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/join.py | 200 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 197 insertions(+), 3 deletions(-)

diff --git a/python/samba/join.py b/python/samba/join.py
index 6a92435..f961d6b 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -22,8 +22,8 @@ from samba.auth import system_session
 from samba.samdb import SamDB
 from samba import gensec, Ldb, drs_utils, arcfour_encrypt, string_to_byte_array
 import ldb, samba, sys, uuid
-from samba.ndr import ndr_pack
-from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs
+from samba.ndr import ndr_pack, ndr_unpack
+from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs, dnsserver, dnsp
 from samba.dsdb import DS_DOMAIN_FUNCTION_2003
 from samba.credentials import Credentials, DONT_USE_KERBEROS
 from samba.provision import secretsdb_self_join, provision, provision_fill, FILL_DRS, FILL_SUBDOMAIN
@@ -35,6 +35,9 @@ from samba.provision.sambadns import setup_bind9_dns
 from samba import read_and_sub_file
 from samba import werror
 from base64 import b64encode
+from samba import WERRORError
+from samba.dnsserver import ARecord, AAAARecord, PTRRecord, CNameRecord, NSRecord, MXRecord, SOARecord, SRVRecord, TXTRecord
+from samba import sd_utils
 import logging
 import talloc
 import random
@@ -184,6 +187,12 @@ class dc_join(object):
         ctx.adminpass = None
         ctx.partition_dn = None
 
+        ctx.dns_a_dn = None
+        ctx.dns_cname_dn = None
+
+        # Do not normally register 127. addresses but allow override for selftest
+        ctx.force_all_ips = False
+
     def del_noerror(ctx, dn, recursive=False):
         if recursive:
             try:
@@ -289,6 +298,13 @@ class dc_join(object):
 
             lsaconn.DeleteTrustedDomain(pol_handle, info.info_ex.sid)
 
+        if ctx.dns_a_dn:
+            ctx.del_noerror(ctx.dns_a_dn)
+
+        if ctx.dns_cname_dn:
+            ctx.del_noerror(ctx.dns_cname_dn)
+
+
 
     def promote_possible(ctx):
         """confirm that the account is just a bare NT4 BDC or a member server, so can be safely promoted"""
@@ -687,12 +703,16 @@ class dc_join(object):
                                      newpassword=ctx.acct_pass.encode('utf-8'))
 
             res = ctx.samdb.search(base=ctx.acct_dn, scope=ldb.SCOPE_BASE,
-                                   attrs=["msDS-KeyVersionNumber"])
+                                   attrs=["msDS-KeyVersionNumber",
+                                          "objectSID"])
             if "msDS-KeyVersionNumber" in res[0]:
                 ctx.key_version_number = int(res[0]["msDS-KeyVersionNumber"][0])
             else:
                 ctx.key_version_number = None
 
+            ctx.new_dc_account_sid = ndr_unpack(security.dom_sid,
+                                                res[0]["objectSid"][0])
+
             print("Enabling account")
             m = ldb.Message()
             m.dn = ldb.Dn(ctx.samdb, ctx.acct_dn)
@@ -969,6 +989,175 @@ class dc_join(object):
 
         ctx.drsuapi.DsReplicaUpdateRefs(ctx.drsuapi_handle, 1, r)
 
+    def join_add_dns_records(ctx):
+        """Remotely Add a DNS record to the target DC.  We assume that if we
+           replicate DNS that the server holds the DNS roles and can accept
+           updates.
+
+           This avoids issues getting replication going after the DC
+           first starts as the rest of the domain does not have to
+           wait for samba_dnsupdate to run successfully.
+
+           Specifically, we add the records implied by the DsReplicaUpdateRefs
+           call above.
+
+           We do not just run samba_dnsupdate as we want to strictly
+           operate against the DC we just joined:
+            - We do not want to query another DNS server
+            - We do not want to obtain a Kerberos ticket
+              (as the KDC we select may not be the DC we just joined,
+              and so may not be in sync with the password we just set)
+            - We do not wish to set the _ldap records until we have started
+            - We do not wish to use NTLM (the --use-samba-tool mode forces
+              NTLM)
+
+        """
+
+        client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
+        record_type = dnsp.DNS_TYPE_A
+        select_flags = dnsserver.DNS_RPC_VIEW_AUTHORITY_DATA |\
+                       dnsserver.DNS_RPC_VIEW_NO_CHILDREN
+
+        zone = ctx.dnsdomain
+        msdcs_zone = "_msdcs.%s" % ctx.dnsforest
+        name = ctx.myname
+        msdcs_cname = str(ctx.ntds_guid)
+        cname_target = "%s.%s" % (name, zone)
+        IPs = samba.interface_ips(ctx.lp, ctx.force_all_ips)
+
+        ctx.logger.info("Adding %d remote DNS records for %s.%s" % \
+                        (len(IPs), name, zone))
+
+        binding_options = "sign"
+        dns_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
+                                      ctx.lp, ctx.creds)
+
+
+        name_found = True
+
+        sd_helper = samba.sd_utils.SDUtils(ctx.samdb)
+
+        change_owner_sd = security.descriptor()
+        change_owner_sd.owner_sid = ctx.new_dc_account_sid
+        change_owner_sd.group_sid = security.dom_sid("%s-%d" %
+                                                     (str(ctx.domsid),
+                                                      security.DOMAIN_RID_DCS))
+
+        # TODO: Remove any old records from the primary DNS name
+        try:
+            (buflen, res) \
+                = dns_conn.DnssrvEnumRecords2(client_version,
+                                              0,
+                                              ctx.server,
+                                              zone,
+                                              name,
+                                              None,
+                                              dnsp.DNS_TYPE_ALL,
+                                              select_flags,
+                                              None,
+                                              None)
+        except WERRORError as e:
+            if e.args[0] == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
+                name_found = False
+                pass
+
+        if name_found:
+            for rec in res.rec:
+                for record in rec.records:
+                    if record.wType == dnsp.DNS_TYPE_A or \
+                       record.wType == dnsp.DNS_TYPE_AAAA:
+                        # delete record
+                        del_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
+                        del_rec_buf.rec = record
+                        try:
+                            dns_conn.DnssrvUpdateRecord2(client_version,
+                                                         0,
+                                                         ctx.server,
+                                                         zone,
+                                                         name,
+                                                         None,
+                                                         del_rec_buf)
+                        except WERRORError as e:
+                            if e.args[0] == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
+                                pass
+                            else:
+                                raise
+
+        for IP in IPs:
+            if IP.find(':') != -1:
+                ctx.logger.info("Adding DNS AAAA record %s.%s for IPv6 IP: %s"
+                                % (name, zone, IP))
+                rec = AAAARecord(IP)
+            else:
+                ctx.logger.info("Adding DNS A record %s.%s for IPv4 IP: %s"
+                                % (name, zone, IP))
+                rec = ARecord(IP)
+
+            # Add record
+            add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
+            add_rec_buf.rec = rec
+            dns_conn.DnssrvUpdateRecord2(client_version,
+                                         0,
+                                         ctx.server,
+                                         zone,
+                                         name,
+                                         add_rec_buf,
+                                         None)
+
+        if (len(IPs) > 0):
+            domaindns_zone_dn = ldb.Dn(ctx.samdb, ctx.domaindns_zone)
+            (ctx.dns_a_dn, ldap_record) \
+                = ctx.samdb.dns_lookup("%s.%s" % (name, zone),
+                                       dns_partition=domaindns_zone_dn)
+
+            # Make the DC own the DNS record, not the administrator
+            sd_helper.modify_sd_on_dn(ctx.dns_a_dn, change_owner_sd,
+                                      controls=["sd_flags:1:%d"
+                                                % (security.SECINFO_OWNER
+                                                   | security.SECINFO_GROUP)])
+
+
+            # Add record
+            ctx.logger.info("Adding DNS CNAME record %s.%s for %s"
+                            % (msdcs_cname, msdcs_zone, cname_target))
+
+            add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
+            rec = CNameRecord(cname_target)
+            add_rec_buf.rec = rec
+            dns_conn.DnssrvUpdateRecord2(client_version,
+                                         0,
+                                         ctx.server,
+                                         msdcs_zone,
+                                         msdcs_cname,
+                                         add_rec_buf,
+                                         None)
+
+            forestdns_zone_dn = ldb.Dn(ctx.samdb, ctx.forestdns_zone)
+            (ctx.dns_cname_dn, ldap_record) \
+                = ctx.samdb.dns_lookup("%s.%s" % (msdcs_cname, msdcs_zone),
+                                       dns_partition=forestdns_zone_dn)
+
+            # Make the DC own the DNS record, not the administrator
+            sd_helper.modify_sd_on_dn(ctx.dns_cname_dn, change_owner_sd,
+                                      controls=["sd_flags:1:%d"
+                                                % (security.SECINFO_OWNER
+                                                   | security.SECINFO_GROUP)])
+
+        ctx.logger.info("All other DNS records (like _ldap SRV records) " +
+                        "will be created samba_dnsupdate on first startup")
+
+
+    def join_replicate_new_dns_records(ctx):
+        for nc in (ctx.domaindns_zone, ctx.forestdns_zone):
+            if nc in ctx.nc_list:
+                print "Replicating new DNS records in %s" % (str(nc))
+                ctx.repl.replicate(nc, ctx.source_dsa_invocation_id,
+                                   ctx.ntds_guid, rodc=ctx.RODC,
+                                   replica_flags=ctx.replica_flags,
+                                   full_sync=False)
+
+
+
     def join_finalise(ctx):
         """Finalise the join, mark us synchronised and setup secrets db."""
 
@@ -1185,6 +1374,11 @@ class dc_join(object):
                 ctx.join_add_objects2()
                 ctx.join_provision_own_domain()
                 ctx.join_setup_trusts()
+
+            if not ctx.clone_only and ctx.dns_backend != "NONE":
+                ctx.join_add_dns_records()
+                ctx.join_replicate_new_dns_records()
+
             ctx.join_finalise()
         except:
             try:
-- 
2.9.4


From b051873129184d7e91ba6e112b195864bdd19f06 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 16:06:13 +1200
Subject: [PATCH 07/22] pydsdb_dns: Use TypeError not LdbError for mismatched
 types

This avoids the samba-tool command handling code blowing up when trying to parse an LdbError

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/dns_server/dns_server.c       |  2 +-
 source4/dns_server/dnsserver_common.c | 17 ++++++++++++++---
 source4/dns_server/dnsserver_common.h |  7 +++++++
 source4/dns_server/pydns.c            | 15 +++------------
 4 files changed, 25 insertions(+), 16 deletions(-)

diff --git a/source4/dns_server/dns_server.c b/source4/dns_server/dns_server.c
index 5e9527d..d4f5f27 100644
--- a/source4/dns_server/dns_server.c
+++ b/source4/dns_server/dns_server.c
@@ -764,7 +764,7 @@ static NTSTATUS dns_server_reload_zones(struct dns_server *dns)
 	struct dns_server_zone *new_list = NULL;
 	struct dns_server_zone *old_list = NULL;
 	struct dns_server_zone *old_zone;
-	status = dns_common_zones(dns->samdb, dns, &new_list);
+	status = dns_common_zones(dns->samdb, dns, NULL, &new_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		return status;
 	}
diff --git a/source4/dns_server/dnsserver_common.c b/source4/dns_server/dnsserver_common.c
index 7aac7e2..fbfa5fa 100644
--- a/source4/dns_server/dnsserver_common.c
+++ b/source4/dns_server/dnsserver_common.c
@@ -560,6 +560,7 @@ static int dns_common_sort_zones(struct ldb_message **m1, struct ldb_message **m
 
 NTSTATUS dns_common_zones(struct ldb_context *samdb,
 			  TALLOC_CTX *mem_ctx,
+			  struct ldb_dn *base_dn,
 			  struct dns_server_zone **zones_ret)
 {
 	int ret;
@@ -569,9 +570,19 @@ NTSTATUS dns_common_zones(struct ldb_context *samdb,
 	struct dns_server_zone *new_list = NULL;
 	TALLOC_CTX *frame = talloc_stackframe();
 
-	/* TODO: this search does not work against windows */
-	ret = dsdb_search(samdb, frame, &res, NULL, LDB_SCOPE_SUBTREE,
-			  attrs, DSDB_SEARCH_SEARCH_ALL_PARTITIONS, "(objectClass=dnsZone)");
+	if (base_dn) {
+		/* This search will work against windows */
+		ret = dsdb_search(samdb, frame, &res,
+				  base_dn, LDB_SCOPE_SUBTREE,
+				  attrs, 0, "(objectClass=dnsZone)");
+	} else {
+		/* TODO: this search does not work against windows */
+		ret = dsdb_search(samdb, frame, &res, NULL,
+				  LDB_SCOPE_SUBTREE,
+				  attrs,
+				  DSDB_SEARCH_SEARCH_ALL_PARTITIONS,
+				  "(objectClass=dnsZone)");
+	}
 	if (ret != LDB_SUCCESS) {
 		TALLOC_FREE(frame);
 		return NT_STATUS_INTERNAL_DB_CORRUPTION;
diff --git a/source4/dns_server/dnsserver_common.h b/source4/dns_server/dnsserver_common.h
index 57d5d9f..293831f 100644
--- a/source4/dns_server/dnsserver_common.h
+++ b/source4/dns_server/dnsserver_common.h
@@ -62,7 +62,14 @@ WERROR dns_common_name2dn(struct ldb_context *samdb,
 			  TALLOC_CTX *mem_ctx,
 			  const char *name,
 			  struct ldb_dn **_dn);
+
+/*
+ * For this routine, base_dn is generally NULL.  The exception comes
+ * from the python bindings to support setting ACLs on DNS objects
+ * when joining Windows
+ */
 NTSTATUS dns_common_zones(struct ldb_context *samdb,
 			  TALLOC_CTX *mem_ctx,
+			  struct ldb_dn *base_dn,
 			  struct dns_server_zone **zones_ret);
 #endif /* __DNSSERVER_COMMON_H__ */
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 3de9739..7fc8f0c 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -32,27 +32,18 @@
 /* FIXME: These should be in a header file somewhere */
 #define PyErr_LDB_OR_RAISE(py_ldb, ldb) \
 	if (!py_check_dcerpc_type(py_ldb, "ldb", "Ldb")) { \
-		PyErr_SetString(py_ldb_get_exception(), "Ldb connection object required"); \
+		PyErr_SetString(PyExc_TypeError, "Ldb connection object required"); \
 		return NULL; \
 	} \
 	ldb = pyldb_Ldb_AsLdbContext(py_ldb);
 
 #define PyErr_LDB_DN_OR_RAISE(py_ldb_dn, dn) \
 	if (!py_check_dcerpc_type(py_ldb_dn, "ldb", "Dn")) { \
-		PyErr_SetString(py_ldb_get_exception(), "ldb Dn object required"); \
+		PyErr_SetString(PyExc_TypeError, "ldb Dn object required"); \
 		return NULL; \
 	} \
 	dn = pyldb_Dn_AsDn(py_ldb_dn);
 
-static PyObject *py_ldb_get_exception(void)
-{
-	PyObject *mod = PyImport_ImportModule("ldb");
-	if (mod == NULL)
-		return NULL;
-
-	return PyObject_GetAttrString(mod, "LdbError");
-}
-
 static PyObject *py_dnsp_DnssrvRpcRecord_get_list(struct dnsp_DnssrvRpcRecord *records,
 						  uint16_t num_records)
 {
@@ -168,7 +159,7 @@ static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 	}
 
 	if (!py_check_dcerpc_type(py_dns_el, "ldb", "MessageElement")) {
-		PyErr_SetString(py_ldb_get_exception(),
+		PyErr_SetString(PyExc_TypeError,
 				"ldb MessageElement object required");
 		return NULL;
 	}
-- 
2.9.4


From df47517641bbe3d80d530b33c9b2a8227acac5e6 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 16:08:39 +1200
Subject: [PATCH 08/22] pydsdb_dns: Allow the partition DN to be specified into
 py_dsdb_dns_lookup

This allows lookups to be confined to one partition, which in turn avoids issues
when running this against MS Windows, which does not match Samba behaviour
for dns_common_zones()

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/samdb.py      |  8 ++++++--
 source4/dns_server/pydns.c | 26 ++++++++++++++++++++------
 2 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/python/samba/samdb.py b/python/samba/samdb.py
index 19dd8e9..b3a4b38 100644
--- a/python/samba/samdb.py
+++ b/python/samba/samdb.py
@@ -927,9 +927,13 @@ accountExpires: %u
         res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["serverName"])
         return res[0]["serverName"][0]
 
-    def dns_lookup(self, dns_name):
+    def dns_lookup(self, dns_name, dns_partition=None):
         '''Do a DNS lookup in the database, returns the NDR database structures'''
-        return dsdb_dns.lookup(self, dns_name)
+        if dns_partition is None:
+            return dsdb_dns.lookup(self, dns_name)
+        else:
+            return dsdb_dns.lookup(self, dns_name,
+                                   dns_partition=dns_partition)
 
     def dns_extract(self, el):
         '''Return the NDR database structures from a dnsRecord element'''
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 7fc8f0c..cb41faa 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -93,27 +93,40 @@ static int py_dnsp_DnssrvRpcRecord_get_array(PyObject *value,
 	return 0;
 }
 
-static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
+static PyObject *py_dsdb_dns_lookup(PyObject *self,
+				    PyObject *args, PyObject *kwargs)
 {
 	struct ldb_context *samdb;
 	PyObject *py_ldb, *ret, *pydn;
+	PyObject *py_dns_partition = NULL;
 	char *dns_name;
 	TALLOC_CTX *frame;
 	NTSTATUS status;
 	WERROR werr;
 	struct dns_server_zone *zones_list;
-	struct ldb_dn *dn;
+	struct ldb_dn *dn, *dns_partition = NULL;
 	struct dnsp_DnssrvRpcRecord *records;
 	uint16_t num_records;
+	const char * const kwnames[] = { "ldb", "dns_name",
+					 "dns_partition", NULL };
 
-	if (!PyArg_ParseTuple(args, "Os", &py_ldb, &dns_name)) {
+	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Os|O",
+					 discard_const_p(char *, kwnames),
+					 &py_ldb, &dns_name,
+					 &py_dns_partition)) {
 		return NULL;
 	}
 	PyErr_LDB_OR_RAISE(py_ldb, samdb);
 
+	if (py_dns_partition) {
+		PyErr_LDB_DN_OR_RAISE(py_dns_partition,
+				      dns_partition);
+	}
+
 	frame = talloc_stackframe();
 
-	status = dns_common_zones(samdb, frame, &zones_list);
+	status = dns_common_zones(samdb, frame, dns_partition,
+				  &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		talloc_free(frame);
 		PyErr_SetNTSTATUS(status);
@@ -210,7 +223,7 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 
 	frame = talloc_stackframe();
 
-	status = dns_common_zones(samdb, frame, &zones_list);
+	status = dns_common_zones(samdb, frame, NULL, &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		PyErr_SetNTSTATUS(status);
 		talloc_free(frame);
@@ -305,7 +318,8 @@ static PyObject *py_dsdb_dns_replace_by_dn(PyObject *self, PyObject *args)
 static PyMethodDef py_dsdb_dns_methods[] = {
 
 	{ "lookup", (PyCFunction)py_dsdb_dns_lookup,
-		METH_VARARGS, "Get the DNS database entries for a DNS name"},
+	        METH_VARARGS|METH_KEYWORDS,
+	        "Get the DNS database entries for a DNS name"},
 	{ "replace", (PyCFunction)py_dsdb_dns_replace,
 		METH_VARARGS, "Replace the DNS database entries for a DNS name"},
 	{ "replace_by_dn", (PyCFunction)py_dsdb_dns_replace_by_dn,
-- 
2.9.4


From f2b3919be590fd093501a3b428127ee46ec3a5f5 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 16:10:00 +1200
Subject: [PATCH 09/22] join.py: Do not expose the old machine password over
 NTLM if -k yes was set

This makes the test for a valid machine account stricter (as a kerberos error could
cause this to fail and so skip the validation), but we never wish to use NTLM
if the administrator disabled it on the command line

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/join.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/python/samba/join.py b/python/samba/join.py
index f961d6b..1ceb954 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -218,6 +218,7 @@ class dc_join(object):
         creds.guess(ctx.lp)
         try:
             creds.set_machine_account(ctx.lp)
+            creds.set_kerberos_state(ctx.creds.get_kerberos_state())
             machine_samdb = SamDB(url="ldap://%s" % ctx.server,
                                   session_info=system_session(),
                                 credentials=creds, lp=ctx.lp)
-- 
2.9.4


From 1e03d31ef5b14074e2866a6e0aedbd985a949322 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 17:10:27 +1200
Subject: [PATCH 10/22] samba_dnsupdate: Make nsupdate use the server given by
 the SOA record

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/scripting/bin/samba_dnsupdate | 19 ++++++++++++++++---
 1 file changed, 16 insertions(+), 3 deletions(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index ba167da..80a5a6f 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -237,7 +237,7 @@ def hostname_match(h1, h2):
     h2 = str(h2)
     return h1.lower().rstrip('.') == h2.lower().rstrip('.')
 
-def check_one_dns_name(name, name_type, d=None):
+def get_resolver(d=None):
     resolv_conf = os.getenv('RESOLV_CONF')
     if not resolv_conf:
         resolv_conf = '/etc/resolv.conf'
@@ -245,7 +245,12 @@ def check_one_dns_name(name, name_type, d=None):
 
     if d is not None and d.nameservers != []:
         resolver.nameservers = d.nameservers
-    elif d is not None:
+
+    return resolver
+
+def check_one_dns_name(name, name_type, d=None):
+    resolver = get_resolver(d)
+    if d is not None and len(d.nameservers) == 0:
         d.nameservers = resolver.nameservers
 
     ans = resolver.query(name, name_type)
@@ -438,10 +443,18 @@ def call_nsupdate(d, op="add"):
     # NS record may point to, even as we get a ticket to that other
     # server.
     #
-    # Therefore we must not set this in production.
+    # Therefore we must not set this in production, instead we want
+    # to find the name of a SOA for the zone and use that server.
 
     if os.getenv('RESOLV_CONF') and d.nameservers != []:
         f.write('server %s\n' % d.nameservers[0])
+    else:
+        resolver = get_resolver(d)
+        zone = dns.resolver.zone_for_name(normalised_name,
+                                          resolver=resolver)
+        soa = resolver.query(zone, "SOA")
+
+        f.write('server %s\n' % soa[0].mname)
 
     if d.type == "A":
         f.write("update %s %s %u A %s\n" % (op, normalised_name, default_ttl, d.ip))
-- 
2.9.4


From eaa9047a84f89381bf3b46c3fcb3dc315c00981c Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 17:13:46 +1200
Subject: [PATCH 11/22] samba_dnsupate: Try to get ticket to the SOA, not the
 NS servers

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/scripting/bin/samba_dnsupdate | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index 80a5a6f..28343bf 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -137,10 +137,12 @@ def get_credentials(lp):
         if opts.use_file is not None:
             return
 
-        # Now confirm we can get a ticket to a DNS server
-        ans = check_one_dns_name(sub_vars['DNSDOMAIN'] + '.', 'NS')
+        # Now confirm we can get a ticket to the DNS server
+        ans = check_one_dns_name(sub_vars['DNSDOMAIN'] + '.', 'SOA')
+
+        # Actually there is only one
         for i in range(len(ans)):
-            target_hostname = str(ans[i].target).rstrip('.')
+            target_hostname = str(ans[i].mname).rstrip('.')
             settings = {}
             settings["lp_ctx"] = lp
             settings["target_hostname"] = target_hostname
-- 
2.9.4


From 9b71b55b893ca5fc03864ec9e963d2bba1835952 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 11 Apr 2017 12:43:22 +1200
Subject: [PATCH 12/22] dns_server: clobber MNAME in the SOA

Otherwise, we always report the first server we created/provisioned the AD domain on
which does not match AD behaviour.  AD is multi-master so all RW servers are a master.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/samdb.py                 |  2 +-
 python/samba/tests/dns.py             | 18 ++++++++++++
 source4/dns_server/dlz_bind9.c        |  2 +-
 source4/dns_server/dnsserver_common.c | 53 +++++++++++++++++++++++++++++++++--
 source4/dns_server/dnsserver_common.h |  3 +-
 source4/dns_server/pydns.c            |  8 ++++--
 6 files changed, 78 insertions(+), 8 deletions(-)

diff --git a/python/samba/samdb.py b/python/samba/samdb.py
index b3a4b38..e002156 100644
--- a/python/samba/samdb.py
+++ b/python/samba/samdb.py
@@ -937,7 +937,7 @@ accountExpires: %u
 
     def dns_extract(self, el):
         '''Return the NDR database structures from a dnsRecord element'''
-        return dsdb_dns.extract(el)
+        return dsdb_dns.extract(self, el)
 
     def dns_replace(self, dns_name, new_records):
         '''Do a DNS modification on the database, sets the NDR database
diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index b8a2481..93a7a7a 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -235,6 +235,24 @@ class TestSimpleQueries(DNSTest):
         self.assertEquals(response.answers[0].rdata,
                           self.server_ip)
 
+    def test_one_SOA_query(self):
+        "create a query packet containing one query record for the SOA"
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "%s" % (self.get_dns_domain())
+        q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.mname.upper(),
+                          ("%s.%s" % (self.server, self.get_dns_domain())).upper())
+
     def test_one_a_query_tcp(self):
         "create a query packet containing one query record via TCP"
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
diff --git a/source4/dns_server/dlz_bind9.c b/source4/dns_server/dlz_bind9.c
index 897699a..7096f47 100644
--- a/source4/dns_server/dlz_bind9.c
+++ b/source4/dns_server/dlz_bind9.c
@@ -997,7 +997,7 @@ _PUBLIC_ isc_result_t dlz_allnodes(const char *zone, void *dbdata,
 			return ISC_R_NOMEMORY;
 		}
 
-		werr = dns_common_extract(el, el_ctx, &recs, &num_recs);
+		werr = dns_common_extract(state->samdb, el, el_ctx, &recs, &num_recs);
 		if (!W_ERROR_IS_OK(werr)) {
 			state->log(ISC_LOG_ERROR, "samba_dlz: failed to parse dnsRecord for %s, %s",
 				   ldb_dn_get_linearized(dn), win_errstr(werr));
diff --git a/source4/dns_server/dnsserver_common.c b/source4/dns_server/dnsserver_common.c
index fbfa5fa..d0c0a2f 100644
--- a/source4/dns_server/dnsserver_common.c
+++ b/source4/dns_server/dnsserver_common.c
@@ -69,7 +69,8 @@ uint8_t werr_to_dns_err(WERROR werr)
 	return DNS_RCODE_SERVFAIL;
 }
 
-WERROR dns_common_extract(const struct ldb_message_element *el,
+WERROR dns_common_extract(struct ldb_context *samdb,
+			  const struct ldb_message_element *el,
 			  TALLOC_CTX *mem_ctx,
 			  struct dnsp_DnssrvRpcRecord **records,
 			  uint16_t *num_records)
@@ -86,9 +87,13 @@ WERROR dns_common_extract(const struct ldb_message_element *el,
 		return WERR_NOT_ENOUGH_MEMORY;
 	}
 	for (ri = 0; ri < el->num_values; ri++) {
+		bool am_rodc;
+		int ret;
+		const char *attrs[] = { "dnsHostName", NULL };
+		const char *dnsHostName;
 		struct ldb_val *v = &el->values[ri];
 		enum ndr_err_code ndr_err;
-
+		struct ldb_result *res = NULL;
 		ndr_err = ndr_pull_struct_blob(v, recs, &recs[ri],
 				(ndr_pull_flags_fn_t)ndr_pull_dnsp_DnssrvRpcRecord);
 		if (!NDR_ERR_CODE_IS_SUCCESS(ndr_err)) {
@@ -96,7 +101,49 @@ WERROR dns_common_extract(const struct ldb_message_element *el,
 			DEBUG(0, ("Failed to grab dnsp_DnssrvRpcRecord\n"));
 			return DNS_ERR(SERVER_FAILURE);
 		}
+
+		/*
+		 * In AD, except on an RODC (where we should list a random RWDC,
+		 * we should over-stamp the MNAME with our own hostname
+		 */
+		if (recs[ri].wType != DNS_TYPE_SOA) {
+			continue;
+		}
+
+		ret = samdb_rodc(samdb, &am_rodc);
+		if (ret != LDB_SUCCESS) {
+			DEBUG(0, ("Failed to confirm we are not an RODC: %s\n",
+				  ldb_errstring(samdb)));
+			return DNS_ERR(SERVER_FAILURE);
+		}
+
+		if (am_rodc) {
+			continue;
+		}
+
+		ret = dsdb_search_dn(samdb, mem_ctx, &res, NULL,
+				     attrs, 0);
+
+		if (res->count != 1 || ret != LDB_SUCCESS) {
+			DEBUG(0, ("Failed to get rootDSE for dnsHostName: %s",
+				  ldb_errstring(samdb)));
+			return DNS_ERR(SERVER_FAILURE);
+		}
+
+		dnsHostName
+			= ldb_msg_find_attr_as_string(res->msgs[0],
+						      "dnsHostName",
+						      NULL);
+
+		if (dnsHostName == NULL) {
+			DEBUG(0, ("Failed to get dnsHostName from rootDSE"));
+			return DNS_ERR(SERVER_FAILURE);
+		}
+
+		recs[ri].data.soa.mname
+			= talloc_steal(recs, dnsHostName);
 	}
+
 	*records = recs;
 	*num_records = el->num_values;
 	return WERR_OK;
@@ -189,7 +236,7 @@ WERROR dns_common_lookup(struct ldb_context *samdb,
 		}
 	}
 
-	werr = dns_common_extract(el, mem_ctx, records, num_records);
+	werr = dns_common_extract(samdb, el, mem_ctx, records, num_records);
 	TALLOC_FREE(msg);
 	if (!W_ERROR_IS_OK(werr)) {
 		return werr;
diff --git a/source4/dns_server/dnsserver_common.h b/source4/dns_server/dnsserver_common.h
index 293831f..b615e2d 100644
--- a/source4/dns_server/dnsserver_common.h
+++ b/source4/dns_server/dnsserver_common.h
@@ -35,7 +35,8 @@ struct dns_server_zone {
 	struct ldb_dn *dn;
 };
 
-WERROR dns_common_extract(const struct ldb_message_element *el,
+WERROR dns_common_extract(struct ldb_context *samdb,
+			  const struct ldb_message_element *el,
 			  TALLOC_CTX *mem_ctx,
 			  struct dnsp_DnssrvRpcRecord **records,
 			  uint16_t *num_records);
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index cb41faa..63fa80e 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -160,17 +160,21 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self,
 
 static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 {
+	struct ldb_context *samdb;
 	PyObject *py_dns_el, *ret;
+	PyObject *py_ldb = NULL;
 	TALLOC_CTX *frame;
 	WERROR werr;
 	struct ldb_message_element *dns_el;
 	struct dnsp_DnssrvRpcRecord *records;
 	uint16_t num_records;
 
-	if (!PyArg_ParseTuple(args, "O", &py_dns_el)) {
+	if (!PyArg_ParseTuple(args, "OO", &py_ldb, &py_dns_el)) {
 		return NULL;
 	}
 
+	PyErr_LDB_OR_RAISE(py_ldb, samdb);
+
 	if (!py_check_dcerpc_type(py_dns_el, "ldb", "MessageElement")) {
 		PyErr_SetString(PyExc_TypeError,
 				"ldb MessageElement object required");
@@ -180,7 +184,7 @@ static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 
 	frame = talloc_stackframe();
 
-	werr = dns_common_extract(dns_el,
+	werr = dns_common_extract(samdb, dns_el,
 				  frame,
 				  &records,
 				  &num_records);
-- 
2.9.4


From ad8a292f062eb86e5143af264cb7e43543b2236c Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 11 Apr 2017 14:14:15 +1200
Subject: [PATCH 13/22] samba_dnsupdate: Extend possible server list to all NS
 servers for the zone

This should eventually be removed, but for now this unblocks samba_dnsupdate operation
in existing domains that have lost the original Samba DC

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/scripting/bin/samba_dnsupdate | 98 ++++++++++++++++++++++++-----------
 1 file changed, 69 insertions(+), 29 deletions(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index 28343bf..eb6d4c2 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -121,6 +121,64 @@ for i in IPs:
 if opts.verbose:
     print "IPs: %s" % IPs
 
+def get_possible_rw_dns_server(creds, domain):
+    """Get a list of possible read-write DNS servers, starting with
+       the SOA.  The SOA is the correct answer, but old Samba domains
+       (4.6 and prior) do not maintain this value, so add NS servers
+       as well"""
+
+    hostnames = []
+    ans_soa = check_one_dns_name(domain, 'SOA')
+
+    # Actually there is only one
+    for i in range(len(ans_soa)):
+        hostnames.append(str(ans_soa[i].mname).rstrip('.'))
+
+    # This is not strictly legit, but old Samba domains may have an
+    # unmaintained SOA record, so go for any NS that we can get a
+    # ticket to.
+    ans_ns = check_one_dns_name(domain, 'NS')
+
+    # Actually there is only one
+    for i in range(len(ans_ns)):
+        hostnames.append(str(ans_ns[i].target).rstrip('.'))
+
+    return hostnames
+
+def get_krb5_rw_dns_server(creds, domain):
+    """Get a list of read-write DNS servers that we can obtain a ticket
+       for, starting with the SOA.  The SOA is the correct answer, but
+       old Samba domains (4.6 and prior) do not maintain this value,
+       so continue with the NS servers as well until we get one that
+       the KDC will issue a ticket to.
+    """
+
+    rw_dns_servers = get_possible_rw_dns_server(creds, domain)
+    # Actually there is only one
+    for i in range(len(rw_dns_servers)):
+        target_hostname = str(rw_dns_servers[i])
+        settings = {}
+        settings["lp_ctx"] = lp
+        settings["target_hostname"] = target_hostname
+
+        gensec_client = gensec.Security.start_client(settings)
+        gensec_client.set_credentials(creds)
+        gensec_client.set_target_service("DNS")
+        gensec_client.set_target_hostname(target_hostname)
+        gensec_client.want_feature(gensec.FEATURE_SEAL)
+        gensec_client.start_mech_by_sasl_name("GSSAPI")
+        server_to_client = ""
+        try:
+            (client_finished, client_to_server) = gensec_client.update(server_to_client)
+            if opts.verbose:
+                print "Successfully obtained Kerberos ticket to DNS/%s as %s" \
+                    % (target_hostname, creds.get_username())
+            return target_hostname
+        except RuntimeError:
+            # Only raise an exception if they all failed
+            if i != len(rw_dns_servers) - 1:
+                pass
+            raise
 
 def get_credentials(lp):
     """# get credentials if we haven't got them already."""
@@ -138,33 +196,8 @@ def get_credentials(lp):
             return
 
         # Now confirm we can get a ticket to the DNS server
-        ans = check_one_dns_name(sub_vars['DNSDOMAIN'] + '.', 'SOA')
-
-        # Actually there is only one
-        for i in range(len(ans)):
-            target_hostname = str(ans[i].mname).rstrip('.')
-            settings = {}
-            settings["lp_ctx"] = lp
-            settings["target_hostname"] = target_hostname
-
-            gensec_client = gensec.Security.start_client(settings)
-            gensec_client.set_credentials(creds)
-            gensec_client.set_target_service("DNS")
-            gensec_client.set_target_hostname(target_hostname)
-            gensec_client.want_feature(gensec.FEATURE_SEAL)
-            gensec_client.start_mech_by_sasl_name("GSSAPI")
-            server_to_client = ""
-            try:
-                (client_finished, client_to_server) = gensec_client.update(server_to_client)
-                if opts.verbose:
-                    print "Successfully obtained Kerberos ticket to DNS/%s as %s" \
-                            % (target_hostname, creds.get_username())
-                return
-            except RuntimeError:
-                # Only raise an exception if they all failed
-                if i != len(ans) - 1:
-                    pass
-                raise
+        get_krb5_rw_dns_server(creds, sub_vars['DNSDOMAIN'] + '.')
+        return creds
 
     except RuntimeError as e:
         os.unlink(ccachename)
@@ -452,11 +485,18 @@ def call_nsupdate(d, op="add"):
         f.write('server %s\n' % d.nameservers[0])
     else:
         resolver = get_resolver(d)
+
+        # Local the zone for this name
         zone = dns.resolver.zone_for_name(normalised_name,
                                           resolver=resolver)
-        soa = resolver.query(zone, "SOA")
 
-        f.write('server %s\n' % soa[0].mname)
+        # Now find the SOA, or if we can't get a ticket to the SOA,
+        # any server with an NS record we can get a ticket for.
+        #
+        # Thanks to the Kerberos Crednetials cache this is not
+        # expensive inside the loop
+        server = get_krb5_rw_dns_server(creds, zone)
+        f.write('server %s\n' % server)
 
     if d.type == "A":
         f.write("update %s %s %u A %s\n" % (op, normalised_name, default_ttl, d.ip))
-- 
2.9.4


From ac382c01338bcced5c794b429c4796bcac0b75f5 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 11 Apr 2017 14:23:49 +1200
Subject: [PATCH 14/22] samba_dnsupdate: fix "samba-tool" fallback error
 handling

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 source4/scripting/bin/samba_dnsupdate | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index eb6d4c2..d9948a6 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -626,7 +626,7 @@ def call_samba_tool(d, op="add", zone=None):
                 sys.exit(1)
             error_count = error_count + 1
             if opts.verbose:
-                print("Failed 'samba-tool dns' based update: %s" % (str(d)))
+                print("Failed 'samba-tool dns' based update of %s" % (str(d)))
     except Exception, estr:
         if opts.fail_immediately:
             sys.exit(1)
-- 
2.9.4


From 324b4a98c39be3e020a51c66b1bec79ccae93982 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Wed, 31 May 2017 13:57:25 +1200
Subject: [PATCH 15/22] selftest: move make_txt_record() onto self in
 samba.tests.dns

This will help unifying dns.py and dns_tkey.py to use common subclasses

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/tests/dns.py | 28 ++++++++++++++--------------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index 93a7a7a..edcf2f0 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -60,14 +60,6 @@ server_name = args[0]
 server_ip = args[1]
 creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 
-def make_txt_record(records):
-    rdata_txt = dns.txt_record()
-    s_list = dnsp.string_list()
-    s_list.count = len(records)
-    s_list.str = records
-    rdata_txt.txt = s_list
-    return rdata_txt
-
 class DNSTest(TestCase):
 
     def setUp(self):
@@ -131,6 +123,14 @@ class DNSTest(TestCase):
         q.question_class = qclass
         return q
 
+    def make_txt_record(self, records):
+        rdata_txt = dns.txt_record()
+        s_list = dnsp.string_list()
+        s_list.count = len(records)
+        s_list.str = records
+        rdata_txt.txt = s_list
+        return rdata_txt
+
     def get_dns_domain(self):
         "Helper to get dns domain"
         return self.creds.get_realm().lower()
@@ -193,7 +193,7 @@ class DNSTest(TestCase):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(txt_array)
+        rdata = self.make_txt_record(txt_array)
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -560,7 +560,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -596,7 +596,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_NONE
         r.ttl = 0
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -638,7 +638,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -674,7 +674,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_NONE
         r.ttl = 0
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -711,7 +711,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
-- 
2.9.4


From 3b4420cbbeb2a4f8f1628082e1f7637249be8f6f Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Wed, 31 May 2017 13:56:18 +1200
Subject: [PATCH 16/22] selftest: merge DNSTest boilerplate

This will help unifying dns.py and dns_tkey.py to use common subclasses

The code was originally copied, but has since divereged.  This handles
that divergence.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/tests/dns.py      | 198 +++++++++++++++++++++++++++--------------
 python/samba/tests/dns_tkey.py |  82 +++++++++++++----
 2 files changed, 198 insertions(+), 82 deletions(-)

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index edcf2f0..106e64f 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -63,12 +63,8 @@ creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 class DNSTest(TestCase):
 
     def setUp(self):
-        global server, server_ip, lp, creds
         super(DNSTest, self).setUp()
-        self.server = server_name
-        self.server_ip = server_ip
-        self.lp = lp
-        self.creds = creds
+        self.timeout = None
 
     def errstr(self, errcode):
         "Return a readable error code"
@@ -84,30 +80,42 @@ class DNSTest(TestCase):
             "NXRRSET",
             "NOTAUTH",
             "NOTZONE",
+            "0x0B",
+            "0x0C",
+            "0x0D",
+            "0x0E",
+            "0x0F",
+            "BADSIG",
+            "BADKEY"
         ]
 
         return string_codes[errcode]
 
+    def assert_rcode_equals(self, rcode, expected):
+        "Helper function to check return code"
+        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
+                          (self.errstr(expected), self.errstr(rcode)))
 
     def assert_dns_rcode_equals(self, packet, rcode):
         "Helper function to check return code"
         p_errcode = packet.operation & 0x000F
         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                            (self.errstr(rcode), self.errstr(p_errcode)))
+                          (self.errstr(rcode), self.errstr(p_errcode)))
 
     def assert_dns_opcode_equals(self, packet, opcode):
         "Helper function to check opcode"
         p_opcode = packet.operation & 0x7800
         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                            (opcode, p_opcode))
+                          (opcode, p_opcode))
 
     def make_name_packet(self, opcode, qid=None):
         "Helper creating a dns.name_packet"
         p = dns.name_packet()
         if qid is None:
-            p.id = random.randint(0x0, 0xffff)
+            p.id = random.randint(0x0, 0xff00)
         p.operation = opcode
         p.questions = []
+        p.additional = []
         return p
 
     def finish_name_packet(self, packet, questions):
@@ -135,10 +143,12 @@ class DNSTest(TestCase):
         "Helper to get dns domain"
         return self.creds.get_realm().lower()
 
-    def dns_transaction_udp(self, packet, host=server_ip,
-                            dump=False, timeout=timeout):
+    def dns_transaction_udp(self, packet, host,
+                            dump=False, timeout=None):
         "send a DNS query and read the reply"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -146,19 +156,22 @@ class DNSTest(TestCase):
             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
             s.settimeout(timeout)
             s.connect((host, 53))
-            s.send(send_packet, 0)
+            s.sendall(send_packet, 0)
             recv_packet = s.recv(2048, 0)
             if dump:
                 print self.hexdump(recv_packet)
-            return ndr.ndr_unpack(dns.name_packet, recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
+            return (response, recv_packet)
         finally:
             if s is not None:
                 s.close()
 
-    def dns_transaction_tcp(self, packet, host=server_ip,
-                            dump=False, timeout=timeout):
-        "send a DNS query and read the reply"
+    def dns_transaction_tcp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply, also return the raw packet"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -168,14 +181,22 @@ class DNSTest(TestCase):
             s.connect((host, 53))
             tcp_packet = struct.pack('!H', len(send_packet))
             tcp_packet += send_packet
-            s.send(tcp_packet, 0)
+            s.sendall(tcp_packet)
+
             recv_packet = s.recv(0xffff + 2, 0)
             if dump:
                 print self.hexdump(recv_packet)
-            return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+
         finally:
-                if s is not None:
-                    s.close()
+            if s is not None:
+                s.close()
+
+        # unpacking and packing again should produce same bytestream
+        my_packet = ndr.ndr_pack(response)
+        self.assertEquals(my_packet, recv_packet[2:])
+
+        return (response, recv_packet[2:])
 
     def make_txt_update(self, prefix, txt_array):
         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
@@ -210,12 +231,21 @@ class DNSTest(TestCase):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEquals(response.ancount, 1)
         self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
 
+
 class TestSimpleQueries(DNSTest):
+    def setUp(self):
+        super(TestSimpleQueries, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "create a query packet containing one query record"
@@ -228,7 +258,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -246,7 +276,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -264,7 +294,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_tcp(p)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -282,7 +312,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -296,7 +326,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -316,7 +346,7 @@ class TestSimpleQueries(DNSTest):
 
         self.finish_name_packet(p, questions)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -336,7 +366,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
 
         num_answers = 1
         dc_ipv6 = os.getenv('SERVER_IPV6')
@@ -362,7 +392,7 @@ class TestSimpleQueries(DNSTest):
 
         self.finish_name_packet(p, questions)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -381,7 +411,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         # We don't get SOA records for single hosts
@@ -399,7 +429,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -407,6 +437,14 @@ class TestSimpleQueries(DNSTest):
 
 
 class TestDNSUpdates(DNSTest):
+    def setUp(self):
+        super(TestDNSUpdates, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_two_updates(self):
         "create two update requests"
@@ -423,7 +461,7 @@ class TestDNSUpdates(DNSTest):
 
         self.finish_name_packet(p, updates)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -442,7 +480,7 @@ class TestDNSUpdates(DNSTest):
         updates.append(u)
 
         self.finish_name_packet(p, updates)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
 
     def test_update_prereq_with_non_null_ttl(self):
@@ -469,7 +507,7 @@ class TestDNSUpdates(DNSTest):
         p.answers = prereqs
 
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -501,7 +539,7 @@ class TestDNSUpdates(DNSTest):
         p.ancount = len(prereqs)
         p.answers = prereqs
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
 
     def test_update_prereq_nonexisting_name(self):
@@ -527,14 +565,14 @@ class TestDNSUpdates(DNSTest):
         p.ancount = len(prereqs)
         p.answers = prereqs
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
 
     def test_update_add_txt_record(self):
         "test adding records works"
         prefix, txt = 'textrec', ['"This is a test"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
 
@@ -566,7 +604,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -576,7 +614,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now delete the record
@@ -602,7 +640,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # And finally check it's gone
@@ -613,7 +651,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
     def test_readd_record(self):
@@ -644,7 +682,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -654,7 +692,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now delete the record
@@ -680,7 +718,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # check it's gone
@@ -691,7 +729,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
         # recreate the record
@@ -717,7 +755,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -727,7 +765,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_update_add_mx_record(self):
@@ -756,7 +794,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@@ -767,7 +805,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEqual(response.ancount, 1)
         ans = response.answers[0]
@@ -795,11 +833,19 @@ class TestComplexQueries(DNSTest):
         updates = [r]
         p.nscount = 1
         p.nsrecs = updates
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def setUp(self):
         super(TestComplexQueries, self).setUp()
+
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
+
         name = "cname_test.%s" % self.get_dns_domain()
         rdata = "%s.%s" % (self.server, self.get_dns_domain())
         self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
@@ -827,7 +873,7 @@ class TestComplexQueries(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_one_a_query(self):
@@ -841,7 +887,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 2)
@@ -867,7 +913,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 3)
@@ -908,7 +954,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
 
@@ -925,6 +971,14 @@ class TestComplexQueries(DNSTest):
         self.assertEquals(response.answers[1].rdata, name0)
 
 class TestInvalidQueries(DNSTest):
+    def setUp(self):
+        super(TestInvalidQueries, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "send 0 bytes follows by create a query packet containing one query record"
@@ -947,7 +1001,7 @@ class TestInvalidQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -993,6 +1047,13 @@ class TestInvalidQueries(DNSTest):
 class TestZones(DNSTest):
     def setUp(self):
         super(TestZones, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
+
         self.zone = "test.lan"
         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
                                             self.lp, self.creds)
@@ -1040,21 +1101,21 @@ class TestZones(DNSTest):
         questions.append(q)
         self.finish_name_packet(p, questions)
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         # Windows returns OK while BIND logically seems to return NXDOMAIN
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
 
         self.create_zone(zone)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_SOA)
 
         self.delete_zone(zone)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -1062,6 +1123,11 @@ class TestZones(DNSTest):
 class TestRPCRoundtrip(DNSTest):
     def setUp(self):
         super(TestRPCRoundtrip, self).setUp()
+        global server, server_ip, lp, creds
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
                                             self.lp, self.creds)
 
@@ -1091,7 +1157,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'pad1textrec', ['"This is a test"', '', '']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1101,7 +1167,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'pad2textrec', ['"This is a test"', '', '', 'more text']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1111,7 +1177,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'pad3textrec', ['', '', '"This is a test"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1179,7 +1245,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'nulltextrec', ['NULL\x00BYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, ['NULL'])
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1189,7 +1255,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'nulltextrec2', ['NULL\x00BYTE', 'NULL\x00BYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, ['NULL', 'NULL'])
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1220,7 +1286,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'hextextrec', ['HIGH\xFFBYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1251,7 +1317,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'slashtextrec', ['Th\\=is=is a test']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1286,7 +1352,7 @@ class TestRPCRoundtrip(DNSTest):
         prefix, txt = 'textrec2', ['"This is a test"',
                                    '"and this is a test, too"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1321,7 +1387,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding two txt records works"
         prefix, txt = 'emptytextrec', []
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
diff --git a/python/samba/tests/dns_tkey.py b/python/samba/tests/dns_tkey.py
index f424e07..21f5d56 100644
--- a/python/samba/tests/dns_tkey.py
+++ b/python/samba/tests/dns_tkey.py
@@ -29,6 +29,7 @@ from samba import credentials
 from samba.dcerpc import dns, dnsp
 from samba.tests.subunitrun import SubunitOptions, TestProgram
 from samba import gensec, tests
+from samba.tests import TestCase
 
 parser = optparse.OptionParser("dns.py <server name> <server ip> [options]")
 sambaopts = options.SambaOptions(parser)
@@ -56,21 +57,11 @@ server_name = args[0]
 server_ip = args[1]
 
 
-class DNSTest(tests.TestCase):
+class DNSTest(TestCase):
+
     def setUp(self):
         super(DNSTest, self).setUp()
-        self.server = server_name
-        self.server_ip = server_ip
-        self.settings = {}
-        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
-        self.settings["target_hostname"] = self.server
-
-        self.creds = credentials.Credentials()
-        self.creds.guess(self.lp_ctx)
-        self.creds.set_username(tests.env_get_var_value('USERNAME'))
-        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
-        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
-        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+        self.timeout = None
 
     def errstr(self, errcode):
         "Return a readable error code"
@@ -150,9 +141,11 @@ class DNSTest(tests.TestCase):
         return self.creds.get_realm().lower()
 
     def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=timeout):
+                            dump=False, timeout=None):
         "send a DNS query and read the reply"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -171,9 +164,11 @@ class DNSTest(tests.TestCase):
                 s.close()
 
     def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=timeout):
+                            dump=False, timeout=None):
         "send a DNS query and read the reply, also return the raw packet"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -200,6 +195,61 @@ class DNSTest(tests.TestCase):
 
         return (response, recv_packet[2:])
 
+    def make_txt_update(self, prefix, txt_array):
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = "%s.%s" % (prefix, self.get_dns_domain())
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = self.make_txt_record(txt_array)
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        return p
+
+    def check_query_txt(self, prefix, txt_array):
+        name = "%s.%s" % (prefix, self.get_dns_domain())
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
+
+
+class DNSTKeyTest(DNSTest):
+    def setUp(self):
+        super(DNSTKeyTest, self).setUp()
+        self.server = server_name
+        self.server_ip = server_ip
+        self.settings = {}
+        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
+        self.settings["target_hostname"] = self.server
+
+        self.creds = credentials.Credentials()
+        self.creds.guess(self.lp_ctx)
+        self.creds.set_username(tests.env_get_var_value('USERNAME'))
+        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
+        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
+        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+
     def tkey_trans(self):
         "Do a TKEY transaction and establish a gensec context"
 
@@ -410,7 +460,7 @@ class DNSTest(tests.TestCase):
         return p
 
 
-class TestDNSUpdates(DNSTest):
+class TestDNSUpdates(DNSTKeyTest):
     def test_tkey(self):
         "test DNS TKEY handshake"
 
-- 
2.9.4


From 831f5cb413589474cd1b44227713a812aae53b0d Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 1 Jun 2017 13:26:37 +1200
Subject: [PATCH 17/22] selftest: Create new common base class for dns.py and
 dns_tkey.py

This will allow more DNS tests to be written in the future with less
code duplication.
---
 python/samba/tests/dns.py      | 179 +-----------------------------------
 python/samba/tests/dns_base.py | 200 +++++++++++++++++++++++++++++++++++++++++
 python/samba/tests/dns_tkey.py | 179 +-----------------------------------
 3 files changed, 202 insertions(+), 356 deletions(-)
 create mode 100644 python/samba/tests/dns_base.py

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index 106e64f..6f59d9b 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -22,11 +22,11 @@ import random
 import socket
 import samba.ndr as ndr
 from samba import credentials, param
-from samba.tests import TestCase
 from samba.dcerpc import dns, dnsp, dnsserver
 from samba.netcmd.dns import TXTRecord, dns_record_match, data_to_dns_record
 from samba.tests.subunitrun import SubunitOptions, TestProgram
 from samba import werror
+from samba.tests.dns_base import DNSTest
 import samba.getopt as options
 import optparse
 
@@ -60,183 +60,6 @@ server_name = args[0]
 server_ip = args[1]
 creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 
-class DNSTest(TestCase):
-
-    def setUp(self):
-        super(DNSTest, self).setUp()
-        self.timeout = None
-
-    def errstr(self, errcode):
-        "Return a readable error code"
-        string_codes = [
-            "OK",
-            "FORMERR",
-            "SERVFAIL",
-            "NXDOMAIN",
-            "NOTIMP",
-            "REFUSED",
-            "YXDOMAIN",
-            "YXRRSET",
-            "NXRRSET",
-            "NOTAUTH",
-            "NOTZONE",
-            "0x0B",
-            "0x0C",
-            "0x0D",
-            "0x0E",
-            "0x0F",
-            "BADSIG",
-            "BADKEY"
-        ]
-
-        return string_codes[errcode]
-
-    def assert_rcode_equals(self, rcode, expected):
-        "Helper function to check return code"
-        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
-                          (self.errstr(expected), self.errstr(rcode)))
-
-    def assert_dns_rcode_equals(self, packet, rcode):
-        "Helper function to check return code"
-        p_errcode = packet.operation & 0x000F
-        self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                          (self.errstr(rcode), self.errstr(p_errcode)))
-
-    def assert_dns_opcode_equals(self, packet, opcode):
-        "Helper function to check opcode"
-        p_opcode = packet.operation & 0x7800
-        self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                          (opcode, p_opcode))
-
-    def make_name_packet(self, opcode, qid=None):
-        "Helper creating a dns.name_packet"
-        p = dns.name_packet()
-        if qid is None:
-            p.id = random.randint(0x0, 0xff00)
-        p.operation = opcode
-        p.questions = []
-        p.additional = []
-        return p
-
-    def finish_name_packet(self, packet, questions):
-        "Helper to finalize a dns.name_packet"
-        packet.qdcount = len(questions)
-        packet.questions = questions
-
-    def make_name_question(self, name, qtype, qclass):
-        "Helper creating a dns.name_question"
-        q = dns.name_question()
-        q.name = name
-        q.question_type = qtype
-        q.question_class = qclass
-        return q
-
-    def make_txt_record(self, records):
-        rdata_txt = dns.txt_record()
-        s_list = dnsp.string_list()
-        s_list.count = len(records)
-        s_list.str = records
-        rdata_txt.txt = s_list
-        return rdata_txt
-
-    def get_dns_domain(self):
-        "Helper to get dns domain"
-        return self.creds.get_realm().lower()
-
-    def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            s.sendall(send_packet, 0)
-            recv_packet = s.recv(2048, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
-            return (response, recv_packet)
-        finally:
-            if s is not None:
-                s.close()
-
-    def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply, also return the raw packet"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            tcp_packet = struct.pack('!H', len(send_packet))
-            tcp_packet += send_packet
-            s.sendall(tcp_packet)
-
-            recv_packet = s.recv(0xffff + 2, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
-
-        finally:
-            if s is not None:
-                s.close()
-
-        # unpacking and packing again should produce same bytestream
-        my_packet = ndr.ndr_pack(response)
-        self.assertEquals(my_packet, recv_packet[2:])
-
-        return (response, recv_packet[2:])
-
-    def make_txt_update(self, prefix, txt_array):
-        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
-        updates = []
-
-        name = self.get_dns_domain()
-        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
-        updates.append(u)
-        self.finish_name_packet(p, updates)
-
-        updates = []
-        r = dns.res_rec()
-        r.name = "%s.%s" % (prefix, self.get_dns_domain())
-        r.rr_type = dns.DNS_QTYPE_TXT
-        r.rr_class = dns.DNS_QCLASS_IN
-        r.ttl = 900
-        r.length = 0xffff
-        rdata = self.make_txt_record(txt_array)
-        r.rdata = rdata
-        updates.append(r)
-        p.nscount = len(updates)
-        p.nsrecs = updates
-
-        return p
-
-    def check_query_txt(self, prefix, txt_array):
-        name = "%s.%s" % (prefix, self.get_dns_domain())
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        questions = []
-
-        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
-        questions.append(q)
-
-        self.finish_name_packet(p, questions)
-        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
-        self.assertEquals(response.ancount, 1)
-        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
-
-
 class TestSimpleQueries(DNSTest):
     def setUp(self):
         super(TestSimpleQueries, self).setUp()
diff --git a/python/samba/tests/dns_base.py b/python/samba/tests/dns_base.py
new file mode 100644
index 0000000..25696c2
--- /dev/null
+++ b/python/samba/tests/dns_base.py
@@ -0,0 +1,200 @@
+# Unix SMB/CIFS implementation.
+# Copyright (C) Kai Blin  <kai at samba.org> 2011
+# Copyright (C) Ralph Boehme  <slow at samba.org> 2016
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+
+from samba.tests import TestCase
+from samba.dcerpc import dns, dnsp
+import struct
+import samba.ndr as ndr
+import random
+import socket
+
+class DNSTest(TestCase):
+
+    def setUp(self):
+        super(DNSTest, self).setUp()
+        self.timeout = None
+
+    def errstr(self, errcode):
+        "Return a readable error code"
+        string_codes = [
+            "OK",
+            "FORMERR",
+            "SERVFAIL",
+            "NXDOMAIN",
+            "NOTIMP",
+            "REFUSED",
+            "YXDOMAIN",
+            "YXRRSET",
+            "NXRRSET",
+            "NOTAUTH",
+            "NOTZONE",
+            "0x0B",
+            "0x0C",
+            "0x0D",
+            "0x0E",
+            "0x0F",
+            "BADSIG",
+            "BADKEY"
+        ]
+
+        return string_codes[errcode]
+
+    def assert_rcode_equals(self, rcode, expected):
+        "Helper function to check return code"
+        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
+                          (self.errstr(expected), self.errstr(rcode)))
+
+    def assert_dns_rcode_equals(self, packet, rcode):
+        "Helper function to check return code"
+        p_errcode = packet.operation & 0x000F
+        self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
+                          (self.errstr(rcode), self.errstr(p_errcode)))
+
+    def assert_dns_opcode_equals(self, packet, opcode):
+        "Helper function to check opcode"
+        p_opcode = packet.operation & 0x7800
+        self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
+                          (opcode, p_opcode))
+
+    def make_name_packet(self, opcode, qid=None):
+        "Helper creating a dns.name_packet"
+        p = dns.name_packet()
+        if qid is None:
+            p.id = random.randint(0x0, 0xff00)
+        p.operation = opcode
+        p.questions = []
+        p.additional = []
+        return p
+
+    def finish_name_packet(self, packet, questions):
+        "Helper to finalize a dns.name_packet"
+        packet.qdcount = len(questions)
+        packet.questions = questions
+
+    def make_name_question(self, name, qtype, qclass):
+        "Helper creating a dns.name_question"
+        q = dns.name_question()
+        q.name = name
+        q.question_type = qtype
+        q.question_class = qclass
+        return q
+
+    def make_txt_record(self, records):
+        rdata_txt = dns.txt_record()
+        s_list = dnsp.string_list()
+        s_list.count = len(records)
+        s_list.str = records
+        rdata_txt.txt = s_list
+        return rdata_txt
+
+    def get_dns_domain(self):
+        "Helper to get dns domain"
+        return self.creds.get_realm().lower()
+
+    def dns_transaction_udp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply"
+        s = None
+        if timeout is None:
+            timeout = self.timeout
+        try:
+            send_packet = ndr.ndr_pack(packet)
+            if dump:
+                print self.hexdump(send_packet)
+            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+            s.settimeout(timeout)
+            s.connect((host, 53))
+            s.sendall(send_packet, 0)
+            recv_packet = s.recv(2048, 0)
+            if dump:
+                print self.hexdump(recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
+            return (response, recv_packet)
+        finally:
+            if s is not None:
+                s.close()
+
+    def dns_transaction_tcp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply, also return the raw packet"
+        s = None
+        if timeout is None:
+            timeout = self.timeout
+        try:
+            send_packet = ndr.ndr_pack(packet)
+            if dump:
+                print self.hexdump(send_packet)
+            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            s.settimeout(timeout)
+            s.connect((host, 53))
+            tcp_packet = struct.pack('!H', len(send_packet))
+            tcp_packet += send_packet
+            s.sendall(tcp_packet)
+
+            recv_packet = s.recv(0xffff + 2, 0)
+            if dump:
+                print self.hexdump(recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+
+        finally:
+            if s is not None:
+                s.close()
+
+        # unpacking and packing again should produce same bytestream
+        my_packet = ndr.ndr_pack(response)
+        self.assertEquals(my_packet, recv_packet[2:])
+
+        return (response, recv_packet[2:])
+
+    def make_txt_update(self, prefix, txt_array):
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = "%s.%s" % (prefix, self.get_dns_domain())
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = self.make_txt_record(txt_array)
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        return p
+
+    def check_query_txt(self, prefix, txt_array):
+        name = "%s.%s" % (prefix, self.get_dns_domain())
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
diff --git a/python/samba/tests/dns_tkey.py b/python/samba/tests/dns_tkey.py
index 21f5d56..130e42c 100644
--- a/python/samba/tests/dns_tkey.py
+++ b/python/samba/tests/dns_tkey.py
@@ -29,7 +29,7 @@ from samba import credentials
 from samba.dcerpc import dns, dnsp
 from samba.tests.subunitrun import SubunitOptions, TestProgram
 from samba import gensec, tests
-from samba.tests import TestCase
+from samba.tests.dns_base import DNSTest
 
 parser = optparse.OptionParser("dns.py <server name> <server ip> [options]")
 sambaopts = options.SambaOptions(parser)
@@ -57,183 +57,6 @@ server_name = args[0]
 server_ip = args[1]
 
 
-class DNSTest(TestCase):
-
-    def setUp(self):
-        super(DNSTest, self).setUp()
-        self.timeout = None
-
-    def errstr(self, errcode):
-        "Return a readable error code"
-        string_codes = [
-            "OK",
-            "FORMERR",
-            "SERVFAIL",
-            "NXDOMAIN",
-            "NOTIMP",
-            "REFUSED",
-            "YXDOMAIN",
-            "YXRRSET",
-            "NXRRSET",
-            "NOTAUTH",
-            "NOTZONE",
-            "0x0B",
-            "0x0C",
-            "0x0D",
-            "0x0E",
-            "0x0F",
-            "BADSIG",
-            "BADKEY"
-        ]
-
-        return string_codes[errcode]
-
-    def assert_rcode_equals(self, rcode, expected):
-        "Helper function to check return code"
-        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
-                          (self.errstr(expected), self.errstr(rcode)))
-
-    def assert_dns_rcode_equals(self, packet, rcode):
-        "Helper function to check return code"
-        p_errcode = packet.operation & 0x000F
-        self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                          (self.errstr(rcode), self.errstr(p_errcode)))
-
-    def assert_dns_opcode_equals(self, packet, opcode):
-        "Helper function to check opcode"
-        p_opcode = packet.operation & 0x7800
-        self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                          (opcode, p_opcode))
-
-    def make_name_packet(self, opcode, qid=None):
-        "Helper creating a dns.name_packet"
-        p = dns.name_packet()
-        if qid is None:
-            p.id = random.randint(0x0, 0xff00)
-        p.operation = opcode
-        p.questions = []
-        p.additional = []
-        return p
-
-    def finish_name_packet(self, packet, questions):
-        "Helper to finalize a dns.name_packet"
-        packet.qdcount = len(questions)
-        packet.questions = questions
-
-    def make_name_question(self, name, qtype, qclass):
-        "Helper creating a dns.name_question"
-        q = dns.name_question()
-        q.name = name
-        q.question_type = qtype
-        q.question_class = qclass
-        return q
-
-    def make_txt_record(self, records):
-        rdata_txt = dns.txt_record()
-        s_list = dnsp.string_list()
-        s_list.count = len(records)
-        s_list.str = records
-        rdata_txt.txt = s_list
-        return rdata_txt
-
-    def get_dns_domain(self):
-        "Helper to get dns domain"
-        return self.creds.get_realm().lower()
-
-    def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            s.sendall(send_packet, 0)
-            recv_packet = s.recv(2048, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
-            return (response, recv_packet)
-        finally:
-            if s is not None:
-                s.close()
-
-    def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply, also return the raw packet"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            tcp_packet = struct.pack('!H', len(send_packet))
-            tcp_packet += send_packet
-            s.sendall(tcp_packet)
-
-            recv_packet = s.recv(0xffff + 2, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
-
-        finally:
-            if s is not None:
-                s.close()
-
-        # unpacking and packing again should produce same bytestream
-        my_packet = ndr.ndr_pack(response)
-        self.assertEquals(my_packet, recv_packet[2:])
-
-        return (response, recv_packet[2:])
-
-    def make_txt_update(self, prefix, txt_array):
-        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
-        updates = []
-
-        name = self.get_dns_domain()
-        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
-        updates.append(u)
-        self.finish_name_packet(p, updates)
-
-        updates = []
-        r = dns.res_rec()
-        r.name = "%s.%s" % (prefix, self.get_dns_domain())
-        r.rr_type = dns.DNS_QTYPE_TXT
-        r.rr_class = dns.DNS_QCLASS_IN
-        r.ttl = 900
-        r.length = 0xffff
-        rdata = self.make_txt_record(txt_array)
-        r.rdata = rdata
-        updates.append(r)
-        p.nscount = len(updates)
-        p.nsrecs = updates
-
-        return p
-
-    def check_query_txt(self, prefix, txt_array):
-        name = "%s.%s" % (prefix, self.get_dns_domain())
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        questions = []
-
-        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
-        questions.append(q)
-
-        self.finish_name_packet(p, questions)
-        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
-        self.assertEquals(response.ancount, 1)
-        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
-
-
 class DNSTKeyTest(DNSTest):
     def setUp(self):
         super(DNSTKeyTest, self).setUp()
-- 
2.9.4


From 85e727e77859ca8a6aae3330e376fdb68b88b567 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 1 Jun 2017 15:15:25 +1200
Subject: [PATCH 18/22] selftest: Use TestCaseInTempDir as base class in dns
 tests

This will help when we add a new join test based on this code

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/tests/dns_base.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/samba/tests/dns_base.py b/python/samba/tests/dns_base.py
index 25696c2..c986932 100644
--- a/python/samba/tests/dns_base.py
+++ b/python/samba/tests/dns_base.py
@@ -16,14 +16,14 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 
-from samba.tests import TestCase
+from samba.tests import TestCaseInTempDir
 from samba.dcerpc import dns, dnsp
 import struct
 import samba.ndr as ndr
 import random
 import socket
 
-class DNSTest(TestCase):
+class DNSTest(TestCaseInTempDir):
 
     def setUp(self):
         super(DNSTest, self).setUp()
-- 
2.9.4


From 740dcadd5b2f666e912e455eb935a9746f79d120 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 6 Jun 2017 15:21:50 +1200
Subject: [PATCH 19/22] provision: Move default handler for site=None down into
 dc_join object creation

This makes this code easier to call from a test script

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/join.py          | 3 +++
 python/samba/netcmd/domain.py | 3 ---
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/python/samba/join.py b/python/samba/join.py
index 1ceb954..53cd602 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -56,6 +56,9 @@ class dc_join(object):
                  netbios_name=None, targetdir=None, domain=None,
                  machinepass=None, use_ntvfs=False, dns_backend=None,
                  promote_existing=False, clone_only=False):
+        if site is None:
+            site = "Default-First-Site-Name"
+
         ctx.clone_only=clone_only
 
         ctx.logger = logger
diff --git a/python/samba/netcmd/domain.py b/python/samba/netcmd/domain.py
index 4bd99ba..e3a0e49 100644
--- a/python/samba/netcmd/domain.py
+++ b/python/samba/netcmd/domain.py
@@ -551,9 +551,6 @@ class cmd_domain_dcpromo(Command):
         creds = credopts.get_credentials(lp)
         net = Net(creds, lp, server=credopts.ipaddress)
 
-        if site is None:
-            site = "Default-First-Site-Name"
-
         logger = self.get_logger()
         if verbose:
             logger.setLevel(logging.DEBUG)
-- 
2.9.4


From c49cde21397ebc69324dffb353c7f35b49d54aa2 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 6 Jun 2017 15:22:35 +1200
Subject: [PATCH 20/22] provision: Allow removing an existing account when
 force=True is set

This allows a practical override for use in test scripts

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/join.py | 45 +++++++++++++++++++++++----------------------
 1 file changed, 23 insertions(+), 22 deletions(-)

diff --git a/python/samba/join.py b/python/samba/join.py
index 53cd602..eed535b 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -210,32 +210,33 @@ class dc_join(object):
         except Exception:
             pass
 
-    def cleanup_old_accounts(ctx):
+    def cleanup_old_accounts(ctx, force=False):
         res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
                                expression='sAMAccountName=%s' % ldb.binary_encode(ctx.samname),
                                attrs=["msDS-krbTgtLink", "objectSID"])
         if len(res) == 0:
             return
 
-        creds = Credentials()
-        creds.guess(ctx.lp)
-        try:
-            creds.set_machine_account(ctx.lp)
-            creds.set_kerberos_state(ctx.creds.get_kerberos_state())
-            machine_samdb = SamDB(url="ldap://%s" % ctx.server,
-                                  session_info=system_session(),
-                                credentials=creds, lp=ctx.lp)
-        except:
-            pass
-        else:
-            token_res = machine_samdb.search(scope=ldb.SCOPE_BASE, base="", attrs=["tokenGroups"])
-            if token_res[0]["tokenGroups"][0] \
-               == res[0]["objectSID"][0]:
-                raise DCJoinException("Not removing account %s which "
-                                   "looks like a Samba DC account "
-                                   "maching the password we already have.  "
-                                   "To override, remove secrets.ldb and secrets.tdb"
-                                % ctx.samname)
+        if not force:
+            creds = Credentials()
+            creds.guess(ctx.lp)
+            try:
+                creds.set_machine_account(ctx.lp)
+                creds.set_kerberos_state(ctx.creds.get_kerberos_state())
+                machine_samdb = SamDB(url="ldap://%s" % ctx.server,
+                                      session_info=system_session(),
+                                    credentials=creds, lp=ctx.lp)
+            except:
+                pass
+            else:
+                token_res = machine_samdb.search(scope=ldb.SCOPE_BASE, base="", attrs=["tokenGroups"])
+                if token_res[0]["tokenGroups"][0] \
+                   == res[0]["objectSID"][0]:
+                    raise DCJoinException("Not removing account %s which "
+                                       "looks like a Samba DC account "
+                                       "maching the password we already have.  "
+                                       "To override, remove secrets.ldb and secrets.tdb"
+                                    % ctx.samname)
 
         ctx.del_noerror(res[0].dn, recursive=True)
 
@@ -262,11 +263,11 @@ class dc_join(object):
                                 ldb.binary_encode("dns/%s" % ctx.dnshostname)))
 
 
-    def cleanup_old_join(ctx):
+    def cleanup_old_join(ctx, force=False):
         """Remove any DNs from a previous join."""
         # find the krbtgt link
         if not ctx.subdomain:
-            ctx.cleanup_old_accounts()
+            ctx.cleanup_old_accounts(force=force)
 
         if ctx.connection_dn is not None:
             ctx.del_noerror(ctx.connection_dn)
-- 
2.9.4


From 56842433da73289935289926d113b8c95df57e01 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 6 Jun 2017 15:22:59 +1200
Subject: [PATCH 21/22] join.py remove direct print to the console: use the
 logging class

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/join.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/samba/join.py b/python/samba/join.py
index eed535b..fa87f0b 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -1155,7 +1155,7 @@ class dc_join(object):
     def join_replicate_new_dns_records(ctx):
         for nc in (ctx.domaindns_zone, ctx.forestdns_zone):
             if nc in ctx.nc_list:
-                print "Replicating new DNS records in %s" % (str(nc))
+                ctx.logger.info("Replicating new DNS records in %s" % (str(nc)))
                 ctx.repl.replicate(nc, ctx.source_dsa_invocation_id,
                                    ctx.ntds_guid, rodc=ctx.RODC,
                                    replica_flags=ctx.replica_flags,
-- 
2.9.4


From b69ab7debf82270b44ff9faed8202828f0e939d2 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 1 Jun 2017 17:11:57 +1200
Subject: [PATCH 22/22] selftest: Test join.py and confirm that the DNS record
 is created

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
---
 python/samba/tests/join.py | 114 +++++++++++++++++++++++++++++++++++++++++++++
 source4/selftest/tests.py  |   3 ++
 2 files changed, 117 insertions(+)
 create mode 100644 python/samba/tests/join.py

diff --git a/python/samba/tests/join.py b/python/samba/tests/join.py
new file mode 100644
index 0000000..8a97a9f
--- /dev/null
+++ b/python/samba/tests/join.py
@@ -0,0 +1,114 @@
+# Test joining as a DC and check the join was done right
+#
+# Copyright (C) Andrew Bartlett <abartlet at samba.org> 2017
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+
+import samba
+import sys
+import shutil
+import os
+from samba.tests.dns_base import DNSTest
+from samba.join import dc_join
+from samba.dcerpc import drsuapi, misc, dns
+
+def get_logger(name="subunit"):
+    """Get a logger object."""
+    import logging
+    logger = logging.getLogger(name)
+    logger.addHandler(logging.StreamHandler(sys.stderr))
+    return logger
+
+class JoinTestCase(DNSTest):
+    def setUp(self):
+        super(JoinTestCase, self).setUp()
+        self.server = samba.tests.env_get_var_value("SERVER")
+        self.server_ip = samba.tests.env_get_var_value("SERVER_IP")
+        self.lp = samba.tests.env_loadparm()
+        self.creds = self.get_credentials()
+        self.netbios_name = "jointest1"
+        logger = get_logger()
+
+        self.join_ctx = dc_join(server=self.server, creds=self.creds, lp=self.get_loadparm(),
+                                netbios_name=self.netbios_name,
+                                targetdir=self.tempdir,
+                                domain=None, logger=logger,
+                                dns_backend="SAMBA_INTERNAL")
+        self.join_ctx.userAccountControl = (samba.dsdb.UF_SERVER_TRUST_ACCOUNT |
+                                            samba.dsdb.UF_TRUSTED_FOR_DELEGATION)
+
+        self.join_ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_WRIT_REP |
+                                        drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS)
+        self.join_ctx.domain_replica_flags = self.join_ctx.replica_flags
+        self.join_ctx.secure_channel_type = misc.SEC_CHAN_BDC
+
+        self.join_ctx.cleanup_old_join()
+
+        self.join_ctx.force_all_ips = True
+
+    def tearDown(self):
+        try:
+            paths = self.join_ctx.paths
+        except AttributeError:
+            paths = None
+
+        if paths is not None:
+            shutil.rmtree(paths.private_dir)
+            shutil.rmtree(paths.state_dir)
+            shutil.rmtree(os.path.join(self.tempdir, "etc"))
+            shutil.rmtree(os.path.join(self.tempdir, "msg.lock"))
+            os.unlink(os.path.join(self.tempdir, "names.tdb"))
+
+        self.join_ctx.cleanup_old_join(force=True)
+
+        super(JoinTestCase, self).tearDown()
+
+
+    def test_join(self):
+
+        self.join_ctx.do_join()
+
+        "create a query packet containing one query record via TCP"
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = self.join_ctx.dnshostname
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        # Get expected IPs
+        IPs = samba.interface_ips(self.lp)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, len(IPs))
+
+        questions = []
+        name = "%s._msdcs.%s" % (self.join_ctx.ntds_guid, self.join_ctx.dnsforest)
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+
+        self.assertEquals(response.ancount, 1 + len(IPs))
+        self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
+        self.assertEquals(response.answers[0].rdata, self.join_ctx.dnshostname)
+        self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
diff --git a/source4/selftest/tests.py b/source4/selftest/tests.py
index b9cdee9..f30eaa7 100755
--- a/source4/selftest/tests.py
+++ b/source4/selftest/tests.py
@@ -819,6 +819,9 @@ for env in ['ad_dc_ntvfs']:
                            name="samba4.drs.repl_rodc.python(%s)" % env,
                            environ={'DC1': "$DC_SERVER", 'DC2': '$DC_SERVER'},
                            extra_args=['-U$DOMAIN/$DC_USERNAME%$DC_PASSWORD'])
+    planoldpythontestsuite(env, "samba.tests.join",
+                           name="samba.tests.join.python(%s)" % env,
+                           extra_args=['-U$DOMAIN/$DC_USERNAME%$DC_PASSWORD'])
 
 planoldpythontestsuite("chgdcpass:local", "samba.tests.blackbox.samba_dnsupdate",
                        environ={'DNS_SERVER_IP': '$SERVER_IP'})
-- 
2.9.4



More information about the samba-technical mailing list