[PATCH] Create DC DNS entires at domain join

Andrew Bartlett abartlet at samba.org
Fri Jun 9 04:48:39 UTC 2017


On Fri, 2017-06-09 at 15:49 +1200, Andrew Bartlett via samba-technical
wrote:
> Attached is the patch to have Samba create DNS entries at domain
> join. 
> 
> Please review and push.
> 
> Thanks!
> 
> 
> Andrew Bartlett
> 
> http://git.catalyst.net.nz/gitweb?p=samba.git;a=shortlog;h=refs/heads
> /dns-at-domain-join-for-master

Attached is the patch set with Garming's review.  Garming has it in his
autobuild queue as I've already got one going.

Andrew Bartlett

-- 
Andrew Bartlett
https://samba.org/~abartlet/
Authentication Developer, Samba Team         https://samba.org
Samba Development and Support, Catalyst IT   
https://catalyst.net.nz/services/samba



-------------- next part --------------
From b4a503756155db72a3a2dfd2ab6b023760f309ef Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 23 May 2017 15:56:55 +1200
Subject: [PATCH 01/24] dsdb: Improve error messages when
 dsdb_set_schema_from_ldif() fails

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/dsdb/schema/schema_set.c | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/source4/dsdb/schema/schema_set.c b/source4/dsdb/schema/schema_set.c
index e6d5ce627ce..fd48d279af2 100644
--- a/source4/dsdb/schema/schema_set.c
+++ b/source4/dsdb/schema/schema_set.c
@@ -893,6 +893,8 @@ WERROR dsdb_set_schema_from_ldif(struct ldb_context *ldb,
 	ret = dsdb_set_schema(ldb, schema, true);
 	if (ret != LDB_SUCCESS) {
 		status = WERR_FOOBAR;
+		DEBUG(0,("ERROR: dsdb_set_schema() failed with %s / %s\n",
+			 ldb_strerror(ret), ldb_errstring(ldb)));
 		goto failed;
 	}
 
-- 
2.11.0


From e3f365d629b945e7afe84841cd19943c14398218 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Fri, 17 Feb 2017 18:24:27 +1300
Subject: [PATCH 02/24] samba_dnsupdate: Ensure we only force "server" under
 resolv_wrapper

This ensures that nsupdate can use a namserver in /etc/resolv.conf that is a
cache or forwarder, rather than the AD DC directly.

This avoids a regression from forcing the nameservers to the
/etc/resolv.conf nameservers in
e85ef1dbfef4b16c35cac80c0efc563d8cd1ba3e

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/scripting/bin/samba_dnsupdate | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index d382758168b..ba167da2876 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -430,8 +430,19 @@ def call_nsupdate(d, op="add"):
 
     (tmp_fd, tmpfile) = tempfile.mkstemp()
     f = os.fdopen(tmp_fd, 'w')
-    if d.nameservers != []:
+
+    # Getting this line right is really important.  When we are under
+    # resolv_wrapper, then we want to use RESOLV_CONF and the
+    # nameserver therein. The issue is that this parameter forces us
+    # to only ever use that server, and not some other server that the
+    # NS record may point to, even as we get a ticket to that other
+    # server.
+    #
+    # Therefore we must not set this in production.
+
+    if os.getenv('RESOLV_CONF') and d.nameservers != []:
         f.write('server %s\n' % d.nameservers[0])
+
     if d.type == "A":
         f.write("update %s %s %u A %s\n" % (op, normalised_name, default_ttl, d.ip))
     if d.type == "AAAA":
-- 
2.11.0


From 09b06b842c6d5520c29848390b673ba84b9dc25e Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 27 Feb 2017 16:51:45 +1300
Subject: [PATCH 03/24] pydns: Fix leak of talloc_stackframe() in python
 bindings

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/dns_server/pydns.c | 23 ++++++++++++++++++++---
 1 file changed, 20 insertions(+), 3 deletions(-)

diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 9842f24edfd..18c3c2953d9 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -124,12 +124,14 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 
 	status = dns_common_zones(samdb, frame, &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
+		talloc_free(frame);
 		PyErr_SetNTSTATUS(status);
 		return NULL;
 	}
 
 	werr = dns_common_name2dn(samdb, zones_list, frame, dns_name, &dn);
 	if (!W_ERROR_IS_OK(werr)) {
+		talloc_free(frame);
 		PyErr_SetWERROR(werr);
 		return NULL;
 	}
@@ -141,16 +143,19 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 				 &num_records,
 				 NULL);
 	if (!W_ERROR_IS_OK(werr)) {
+		talloc_free(frame);
 		PyErr_SetWERROR(werr);
 		return NULL;
 	}
 
-	return py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	ret = py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	talloc_free(frame);
+	return ret;
 }
 
 static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 {
-	PyObject *py_dns_el;
+	PyObject *py_dns_el, *ret;
 	TALLOC_CTX *frame;
 	WERROR werr;
 	struct ldb_message_element *dns_el;
@@ -175,11 +180,14 @@ static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 				  &records,
 				  &num_records);
 	if (!W_ERROR_IS_OK(werr)) {
+		talloc_free(frame);
 		PyErr_SetWERROR(werr);
 		return NULL;
 	}
 
-	return py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	ret = py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	talloc_free(frame);
+	return ret;
 }
 
 static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
@@ -213,12 +221,14 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 	status = dns_common_zones(samdb, frame, &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		PyErr_SetNTSTATUS(status);
+		talloc_free(frame);
 		return NULL;
 	}
 
 	werr = dns_common_name2dn(samdb, zones_list, frame, dns_name, &dn);
 	if (!W_ERROR_IS_OK(werr)) {
 		PyErr_SetWERROR(werr);
+		talloc_free(frame);
 		return NULL;
 	}
 
@@ -226,6 +236,7 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 						frame,
 						&records, &num_records);
 	if (ret != 0) {
+		talloc_free(frame);
 		return NULL;
 	}
 
@@ -238,9 +249,11 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 				  num_records);
 	if (!W_ERROR_IS_OK(werr)) {
 		PyErr_SetWERROR(werr);
+		talloc_free(frame);
 		return NULL;
 	}
 
+	talloc_free(frame);
 	Py_RETURN_NONE;
 }
 
@@ -275,6 +288,7 @@ static PyObject *py_dsdb_dns_replace_by_dn(PyObject *self, PyObject *args)
 						frame,
 						&records, &num_records);
 	if (ret != 0) {
+		talloc_free(frame);
 		return NULL;
 	}
 
@@ -287,9 +301,12 @@ static PyObject *py_dsdb_dns_replace_by_dn(PyObject *self, PyObject *args)
 				  num_records);
 	if (!W_ERROR_IS_OK(werr)) {
 		PyErr_SetWERROR(werr);
+		talloc_free(frame);
 		return NULL;
 	}
 
+	talloc_free(frame);
+
 	Py_RETURN_NONE;
 }
 
-- 
2.11.0


From 81bc3c16df6cb7f1dc76d088cac0fa59bcb6ffaa Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 27 Feb 2017 17:09:56 +1300
Subject: [PATCH 04/24] pydns: Also return the DN of the LDB object when
 finding a DNS record

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/remove_dc.py  | 4 ++--
 source4/dns_server/pydns.c | 5 +++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/python/samba/remove_dc.py b/python/samba/remove_dc.py
index 61b5937ba7a..4c8ee892464 100644
--- a/python/samba/remove_dc.py
+++ b/python/samba/remove_dc.py
@@ -97,7 +97,7 @@ def remove_dns_references(samdb, logger, dnsHostName):
     dnsHostNameUpper = dnsHostName.upper()
 
     try:
-        primary_recs = samdb.dns_lookup(dnsHostName)
+        (dn, primary_recs) = samdb.dns_lookup(dnsHostName)
     except RuntimeError as (enum, estr):
         if enum == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
               return
@@ -140,7 +140,7 @@ def remove_dns_references(samdb, logger, dnsHostName):
     for a_name in a_names_to_remove_from:
         try:
             logger.debug("checking for DNS records to remove on %s" % a_name)
-            a_recs = samdb.dns_lookup(a_name)
+            (a_rec_dn, a_recs) = samdb.dns_lookup(a_name)
         except RuntimeError as (enum, estr):
             if enum == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
                 return
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 18c3c2953d9..3de9739f1f1 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -105,7 +105,7 @@ static int py_dnsp_DnssrvRpcRecord_get_array(PyObject *value,
 static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 {
 	struct ldb_context *samdb;
-	PyObject *py_ldb;
+	PyObject *py_ldb, *ret, *pydn;
 	char *dns_name;
 	TALLOC_CTX *frame;
 	NTSTATUS status;
@@ -149,8 +149,9 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
 	}
 
 	ret = py_dnsp_DnssrvRpcRecord_get_list(records, num_records);
+	pydn = pyldb_Dn_FromDn(dn);
 	talloc_free(frame);
-	return ret;
+	return Py_BuildValue("(OO)", pydn, ret);
 }
 
 static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
-- 
2.11.0


From 43b74239950373093905d2c53276ef5287e02760 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 28 Feb 2017 14:15:12 +1300
Subject: [PATCH 05/24] python: Allow sd_utils to take a Dn object, not just a
 string DN

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/sd_utils.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/python/samba/sd_utils.py b/python/samba/sd_utils.py
index 7592a2982a4..568829f9c36 100644
--- a/python/samba/sd_utils.py
+++ b/python/samba/sd_utils.py
@@ -37,7 +37,11 @@ class SDUtils(object):
             or security.descriptor object
         """
         m = Message()
-        m.dn = Dn(self.ldb, object_dn)
+        if isinstance(object_dn, Dn):
+            m.dn = object_dn
+        else:
+            m.dn = Dn(self.ldb, object_dn)
+
         assert(isinstance(sd, str) or isinstance(sd, security.descriptor))
         if isinstance(sd, str):
             tmp_desc = security.descriptor.from_sddl(sd, self.domain_sid)
-- 
2.11.0


From 6e62b8b06f43cd906d1d8090f917e0c8d0ae5687 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 16:06:13 +1200
Subject: [PATCH 06/24] pydsdb_dns: Use TypeError not LdbError for mismatched
 types

This avoids the samba-tool command handling code blowing up when trying to parse an LdbError

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/dns_server/pydns.c | 15 +++------------
 1 file changed, 3 insertions(+), 12 deletions(-)

diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 3de9739f1f1..7fc8f0c8811 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -32,27 +32,18 @@
 /* FIXME: These should be in a header file somewhere */
 #define PyErr_LDB_OR_RAISE(py_ldb, ldb) \
 	if (!py_check_dcerpc_type(py_ldb, "ldb", "Ldb")) { \
-		PyErr_SetString(py_ldb_get_exception(), "Ldb connection object required"); \
+		PyErr_SetString(PyExc_TypeError, "Ldb connection object required"); \
 		return NULL; \
 	} \
 	ldb = pyldb_Ldb_AsLdbContext(py_ldb);
 
 #define PyErr_LDB_DN_OR_RAISE(py_ldb_dn, dn) \
 	if (!py_check_dcerpc_type(py_ldb_dn, "ldb", "Dn")) { \
-		PyErr_SetString(py_ldb_get_exception(), "ldb Dn object required"); \
+		PyErr_SetString(PyExc_TypeError, "ldb Dn object required"); \
 		return NULL; \
 	} \
 	dn = pyldb_Dn_AsDn(py_ldb_dn);
 
-static PyObject *py_ldb_get_exception(void)
-{
-	PyObject *mod = PyImport_ImportModule("ldb");
-	if (mod == NULL)
-		return NULL;
-
-	return PyObject_GetAttrString(mod, "LdbError");
-}
-
 static PyObject *py_dnsp_DnssrvRpcRecord_get_list(struct dnsp_DnssrvRpcRecord *records,
 						  uint16_t num_records)
 {
@@ -168,7 +159,7 @@ static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 	}
 
 	if (!py_check_dcerpc_type(py_dns_el, "ldb", "MessageElement")) {
-		PyErr_SetString(py_ldb_get_exception(),
+		PyErr_SetString(PyExc_TypeError,
 				"ldb MessageElement object required");
 		return NULL;
 	}
-- 
2.11.0


From e3b8b0b4a064e7593746739afb4ae9d010b0c245 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Fri, 9 Jun 2017 16:05:31 +1200
Subject: [PATCH 07/24] pydsdb_dns: Allow the partition DN to be specified into
 py_dsdb_dns_lookup

This allows lookups to be confined to one partition, which in turn avoids issues
when running this against MS Windows, which does not match Samba behaviour
for dns_common_zones()

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/samdb.py                 |  8 ++++++--
 source4/dns_server/dns_server.c       |  2 +-
 source4/dns_server/dnsserver_common.c | 17 ++++++++++++++---
 source4/dns_server/dnsserver_common.h |  7 +++++++
 source4/dns_server/pydns.c            | 26 ++++++++++++++++++++------
 5 files changed, 48 insertions(+), 12 deletions(-)

diff --git a/python/samba/samdb.py b/python/samba/samdb.py
index 19dd8e9a6ad..b3a4b384926 100644
--- a/python/samba/samdb.py
+++ b/python/samba/samdb.py
@@ -927,9 +927,13 @@ accountExpires: %u
         res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["serverName"])
         return res[0]["serverName"][0]
 
-    def dns_lookup(self, dns_name):
+    def dns_lookup(self, dns_name, dns_partition=None):
         '''Do a DNS lookup in the database, returns the NDR database structures'''
-        return dsdb_dns.lookup(self, dns_name)
+        if dns_partition is None:
+            return dsdb_dns.lookup(self, dns_name)
+        else:
+            return dsdb_dns.lookup(self, dns_name,
+                                   dns_partition=dns_partition)
 
     def dns_extract(self, el):
         '''Return the NDR database structures from a dnsRecord element'''
diff --git a/source4/dns_server/dns_server.c b/source4/dns_server/dns_server.c
index 5e9527d1f72..d4f5f27d0bb 100644
--- a/source4/dns_server/dns_server.c
+++ b/source4/dns_server/dns_server.c
@@ -764,7 +764,7 @@ static NTSTATUS dns_server_reload_zones(struct dns_server *dns)
 	struct dns_server_zone *new_list = NULL;
 	struct dns_server_zone *old_list = NULL;
 	struct dns_server_zone *old_zone;
-	status = dns_common_zones(dns->samdb, dns, &new_list);
+	status = dns_common_zones(dns->samdb, dns, NULL, &new_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		return status;
 	}
diff --git a/source4/dns_server/dnsserver_common.c b/source4/dns_server/dnsserver_common.c
index 7aac7e22855..fbfa5fa4eae 100644
--- a/source4/dns_server/dnsserver_common.c
+++ b/source4/dns_server/dnsserver_common.c
@@ -560,6 +560,7 @@ static int dns_common_sort_zones(struct ldb_message **m1, struct ldb_message **m
 
 NTSTATUS dns_common_zones(struct ldb_context *samdb,
 			  TALLOC_CTX *mem_ctx,
+			  struct ldb_dn *base_dn,
 			  struct dns_server_zone **zones_ret)
 {
 	int ret;
@@ -569,9 +570,19 @@ NTSTATUS dns_common_zones(struct ldb_context *samdb,
 	struct dns_server_zone *new_list = NULL;
 	TALLOC_CTX *frame = talloc_stackframe();
 
-	/* TODO: this search does not work against windows */
-	ret = dsdb_search(samdb, frame, &res, NULL, LDB_SCOPE_SUBTREE,
-			  attrs, DSDB_SEARCH_SEARCH_ALL_PARTITIONS, "(objectClass=dnsZone)");
+	if (base_dn) {
+		/* This search will work against windows */
+		ret = dsdb_search(samdb, frame, &res,
+				  base_dn, LDB_SCOPE_SUBTREE,
+				  attrs, 0, "(objectClass=dnsZone)");
+	} else {
+		/* TODO: this search does not work against windows */
+		ret = dsdb_search(samdb, frame, &res, NULL,
+				  LDB_SCOPE_SUBTREE,
+				  attrs,
+				  DSDB_SEARCH_SEARCH_ALL_PARTITIONS,
+				  "(objectClass=dnsZone)");
+	}
 	if (ret != LDB_SUCCESS) {
 		TALLOC_FREE(frame);
 		return NT_STATUS_INTERNAL_DB_CORRUPTION;
diff --git a/source4/dns_server/dnsserver_common.h b/source4/dns_server/dnsserver_common.h
index 57d5d9f3c15..293831f0acb 100644
--- a/source4/dns_server/dnsserver_common.h
+++ b/source4/dns_server/dnsserver_common.h
@@ -62,7 +62,14 @@ WERROR dns_common_name2dn(struct ldb_context *samdb,
 			  TALLOC_CTX *mem_ctx,
 			  const char *name,
 			  struct ldb_dn **_dn);
+
+/*
+ * For this routine, base_dn is generally NULL.  The exception comes
+ * from the python bindings to support setting ACLs on DNS objects
+ * when joining Windows
+ */
 NTSTATUS dns_common_zones(struct ldb_context *samdb,
 			  TALLOC_CTX *mem_ctx,
+			  struct ldb_dn *base_dn,
 			  struct dns_server_zone **zones_ret);
 #endif /* __DNSSERVER_COMMON_H__ */
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index 7fc8f0c8811..cb41faa1441 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -93,27 +93,40 @@ static int py_dnsp_DnssrvRpcRecord_get_array(PyObject *value,
 	return 0;
 }
 
-static PyObject *py_dsdb_dns_lookup(PyObject *self, PyObject *args)
+static PyObject *py_dsdb_dns_lookup(PyObject *self,
+				    PyObject *args, PyObject *kwargs)
 {
 	struct ldb_context *samdb;
 	PyObject *py_ldb, *ret, *pydn;
+	PyObject *py_dns_partition = NULL;
 	char *dns_name;
 	TALLOC_CTX *frame;
 	NTSTATUS status;
 	WERROR werr;
 	struct dns_server_zone *zones_list;
-	struct ldb_dn *dn;
+	struct ldb_dn *dn, *dns_partition = NULL;
 	struct dnsp_DnssrvRpcRecord *records;
 	uint16_t num_records;
+	const char * const kwnames[] = { "ldb", "dns_name",
+					 "dns_partition", NULL };
 
-	if (!PyArg_ParseTuple(args, "Os", &py_ldb, &dns_name)) {
+	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Os|O",
+					 discard_const_p(char *, kwnames),
+					 &py_ldb, &dns_name,
+					 &py_dns_partition)) {
 		return NULL;
 	}
 	PyErr_LDB_OR_RAISE(py_ldb, samdb);
 
+	if (py_dns_partition) {
+		PyErr_LDB_DN_OR_RAISE(py_dns_partition,
+				      dns_partition);
+	}
+
 	frame = talloc_stackframe();
 
-	status = dns_common_zones(samdb, frame, &zones_list);
+	status = dns_common_zones(samdb, frame, dns_partition,
+				  &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		talloc_free(frame);
 		PyErr_SetNTSTATUS(status);
@@ -210,7 +223,7 @@ static PyObject *py_dsdb_dns_replace(PyObject *self, PyObject *args)
 
 	frame = talloc_stackframe();
 
-	status = dns_common_zones(samdb, frame, &zones_list);
+	status = dns_common_zones(samdb, frame, NULL, &zones_list);
 	if (!NT_STATUS_IS_OK(status)) {
 		PyErr_SetNTSTATUS(status);
 		talloc_free(frame);
@@ -305,7 +318,8 @@ static PyObject *py_dsdb_dns_replace_by_dn(PyObject *self, PyObject *args)
 static PyMethodDef py_dsdb_dns_methods[] = {
 
 	{ "lookup", (PyCFunction)py_dsdb_dns_lookup,
-		METH_VARARGS, "Get the DNS database entries for a DNS name"},
+	        METH_VARARGS|METH_KEYWORDS,
+	        "Get the DNS database entries for a DNS name"},
 	{ "replace", (PyCFunction)py_dsdb_dns_replace,
 		METH_VARARGS, "Replace the DNS database entries for a DNS name"},
 	{ "replace_by_dn", (PyCFunction)py_dsdb_dns_replace_by_dn,
-- 
2.11.0


From 81576ea4871eef6dd8cb0a2af3440809ba69fbf1 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 16:10:00 +1200
Subject: [PATCH 08/24] join.py: Do not expose the old machine password over
 NTLM if -k yes was set

This makes the test for a valid machine account stricter (as a kerberos error could
cause this to fail and so skip the validation), but we never wish to use NTLM
if the administrator disabled it on the command line

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/join.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/python/samba/join.py b/python/samba/join.py
index 6a924359407..3e70db08d2a 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -209,6 +209,7 @@ class dc_join(object):
         creds.guess(ctx.lp)
         try:
             creds.set_machine_account(ctx.lp)
+            creds.set_kerberos_state(ctx.creds.get_kerberos_state())
             machine_samdb = SamDB(url="ldap://%s" % ctx.server,
                                   session_info=system_session(),
                                 credentials=creds, lp=ctx.lp)
-- 
2.11.0


From 62935ff40791444a35454551435d512cc28a1eb8 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 17:10:27 +1200
Subject: [PATCH 09/24] samba_dnsupdate: Make nsupdate use the server given by
 the SOA record

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/scripting/bin/samba_dnsupdate | 19 ++++++++++++++++---
 1 file changed, 16 insertions(+), 3 deletions(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index ba167da2876..80a5a6f484d 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -237,7 +237,7 @@ def hostname_match(h1, h2):
     h2 = str(h2)
     return h1.lower().rstrip('.') == h2.lower().rstrip('.')
 
-def check_one_dns_name(name, name_type, d=None):
+def get_resolver(d=None):
     resolv_conf = os.getenv('RESOLV_CONF')
     if not resolv_conf:
         resolv_conf = '/etc/resolv.conf'
@@ -245,7 +245,12 @@ def check_one_dns_name(name, name_type, d=None):
 
     if d is not None and d.nameservers != []:
         resolver.nameservers = d.nameservers
-    elif d is not None:
+
+    return resolver
+
+def check_one_dns_name(name, name_type, d=None):
+    resolver = get_resolver(d)
+    if d is not None and len(d.nameservers) == 0:
         d.nameservers = resolver.nameservers
 
     ans = resolver.query(name, name_type)
@@ -438,10 +443,18 @@ def call_nsupdate(d, op="add"):
     # NS record may point to, even as we get a ticket to that other
     # server.
     #
-    # Therefore we must not set this in production.
+    # Therefore we must not set this in production, instead we want
+    # to find the name of a SOA for the zone and use that server.
 
     if os.getenv('RESOLV_CONF') and d.nameservers != []:
         f.write('server %s\n' % d.nameservers[0])
+    else:
+        resolver = get_resolver(d)
+        zone = dns.resolver.zone_for_name(normalised_name,
+                                          resolver=resolver)
+        soa = resolver.query(zone, "SOA")
+
+        f.write('server %s\n' % soa[0].mname)
 
     if d.type == "A":
         f.write("update %s %s %u A %s\n" % (op, normalised_name, default_ttl, d.ip))
-- 
2.11.0


From 5c2a503622cd0c20c534bc841fc6a57f98309961 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Mon, 10 Apr 2017 17:13:46 +1200
Subject: [PATCH 10/24] samba_dnsupate: Try to get ticket to the SOA, not the
 NS servers

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/scripting/bin/samba_dnsupdate | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index 80a5a6f484d..28343bf17d5 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -137,10 +137,12 @@ def get_credentials(lp):
         if opts.use_file is not None:
             return
 
-        # Now confirm we can get a ticket to a DNS server
-        ans = check_one_dns_name(sub_vars['DNSDOMAIN'] + '.', 'NS')
+        # Now confirm we can get a ticket to the DNS server
+        ans = check_one_dns_name(sub_vars['DNSDOMAIN'] + '.', 'SOA')
+
+        # Actually there is only one
         for i in range(len(ans)):
-            target_hostname = str(ans[i].target).rstrip('.')
+            target_hostname = str(ans[i].mname).rstrip('.')
             settings = {}
             settings["lp_ctx"] = lp
             settings["target_hostname"] = target_hostname
-- 
2.11.0


From d6458353e3533263d7887f8e02801caac1852f3a Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 8 Jun 2017 15:54:22 +1200
Subject: [PATCH 11/24] selftest: confirm we clobber the MNAME in the SOA query
 in the DNS server

All RW DCs should be their own master DNS server.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/dns.py | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index b8a2481ae36..93a7a7a2b32 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -235,6 +235,24 @@ class TestSimpleQueries(DNSTest):
         self.assertEquals(response.answers[0].rdata,
                           self.server_ip)
 
+    def test_one_SOA_query(self):
+        "create a query packet containing one query record for the SOA"
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "%s" % (self.get_dns_domain())
+        q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.mname.upper(),
+                          ("%s.%s" % (self.server, self.get_dns_domain())).upper())
+
     def test_one_a_query_tcp(self):
         "create a query packet containing one query record via TCP"
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-- 
2.11.0


From 1c697e8fea89ac4fd382e8b0e1cf59a9dccf901b Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 8 Jun 2017 16:20:42 +1200
Subject: [PATCH 12/24] selftest: run dns tests in multiple envs

This will let us check the negative behaviour: that updates against RODCs fail
and un-authenticated updates fail.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/dns.py | 154 ++++++++++++++++++++++++++++++----------------
 selftest/knownfail.d/dns  |  55 +++++++++++++++++
 source4/selftest/tests.py |   2 +
 3 files changed, 159 insertions(+), 52 deletions(-)
 create mode 100644 selftest/knownfail.d/dns

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index 93a7a7a2b32..43eccddd957 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -26,7 +26,7 @@ from samba.tests import TestCase
 from samba.dcerpc import dns, dnsp, dnsserver
 from samba.netcmd.dns import TXTRecord, dns_record_match, data_to_dns_record
 from samba.tests.subunitrun import SubunitOptions, TestProgram
-from samba import werror
+from samba import werror, WERRORError
 import samba.getopt as options
 import optparse
 
@@ -800,57 +800,70 @@ class TestComplexQueries(DNSTest):
 
     def setUp(self):
         super(TestComplexQueries, self).setUp()
-        name = "cname_test.%s" % self.get_dns_domain()
-        rdata = "%s.%s" % (self.server, self.get_dns_domain())
-        self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
 
     def tearDown(self):
         super(TestComplexQueries, self).tearDown()
-        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
-        updates = []
 
-        name = self.get_dns_domain()
+    def test_one_a_query(self):
+        "create a query packet containing one query record"
 
-        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
-        updates.append(u)
-        self.finish_name_packet(p, updates)
+        name = "cname_test.%s" % self.get_dns_domain()
+        rdata = "%s.%s" % (self.server, self.get_dns_domain())
+        self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
 
-        updates = []
-        r = dns.res_rec()
-        r.name = "cname_test.%s" % self.get_dns_domain()
-        r.rr_type = dns.DNS_QTYPE_CNAME
-        r.rr_class = dns.DNS_QCLASS_NONE
-        r.ttl = 0
-        r.length = 0xffff
-        r.rdata = "%s.%s" % (self.server, self.get_dns_domain())
-        updates.append(r)
-        p.nscount = len(updates)
-        p.nsrecs = updates
+        try:
 
-        response = self.dns_transaction_udp(p)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+            # Create the record
+            name = "cname_test.%s" % self.get_dns_domain()
+            rdata = "%s.%s" % (self.server, self.get_dns_domain())
+            self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
 
-    def test_one_a_query(self):
-        "create a query packet containing one query record"
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        questions = []
+            p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+            questions = []
 
-        name = "cname_test.%s" % self.get_dns_domain()
-        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
-        print "asking for ", q.name
-        questions.append(q)
+            # Check the record
+            name = "cname_test.%s" % self.get_dns_domain()
+            q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+            print "asking for ", q.name
+            questions.append(q)
 
-        self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
-        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
-        self.assertEquals(response.ancount, 2)
-        self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
-        self.assertEquals(response.answers[0].rdata, "%s.%s" %
-                          (self.server, self.get_dns_domain()))
-        self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
-        self.assertEquals(response.answers[1].rdata,
-                          self.server_ip)
+            self.finish_name_packet(p, questions)
+            response = self.dns_transaction_udp(p)
+            self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+            self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+            self.assertEquals(response.ancount, 2)
+            self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
+            self.assertEquals(response.answers[0].rdata, "%s.%s" %
+                              (self.server, self.get_dns_domain()))
+            self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
+            self.assertEquals(response.answers[1].rdata,
+                              self.server_ip)
+
+        finally:
+            # Delete the record
+            p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+            updates = []
+
+            name = self.get_dns_domain()
+
+            u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+            updates.append(u)
+            self.finish_name_packet(p, updates)
+
+            updates = []
+            r = dns.res_rec()
+            r.name = "cname_test.%s" % self.get_dns_domain()
+            r.rr_type = dns.DNS_QTYPE_CNAME
+            r.rr_class = dns.DNS_QCLASS_NONE
+            r.ttl = 0
+            r.length = 0xffff
+            r.rdata = "%s.%s" % (self.server, self.get_dns_domain())
+            updates.append(r)
+            p.nscount = len(updates)
+            p.nsrecs = updates
+
+            response = self.dns_transaction_udp(p)
+            self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_cname_two_chain(self):
         name0 = "cnamechain0.%s" % self.get_dns_domain()
@@ -1012,14 +1025,17 @@ class TestZones(DNSTest):
         zone_create.fAllowUpdate = dnsp.DNS_ZONE_UPDATE_SECURE
         zone_create.fAging = 0
         zone_create.dwDpFlags = dnsserver.DNS_DP_DOMAIN_DEFAULT
-        self.rpc_conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
-                                       0,
-                                       self.server_ip,
-                                       None,
-                                       0,
-                                       'ZoneCreate',
-                                       dnsserver.DNSSRV_TYPEID_ZONE_CREATE,
-                                       zone_create)
+        try:
+            self.rpc_conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
+                                           0,
+                                           self.server_ip,
+                                           None,
+                                           0,
+                                           'ZoneCreate',
+                                           dnsserver.DNSSRV_TYPEID_ZONE_CREATE,
+                                           zone_create)
+        except WERRORError as e:
+            self.fail(str(e))
 
     def delete_zone(self, zone):
         self.rpc_conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
@@ -1080,7 +1096,10 @@ class TestRPCRoundtrip(DNSTest):
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
+        except WERRORError as e:
+            self.fail(str(e))
 
+        try:
             self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
@@ -1132,6 +1151,10 @@ class TestRPCRoundtrip(DNSTest):
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
 
+        except WERRORError as e:
+            self.fail(str(e))
+
+        try:
             self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
@@ -1150,6 +1173,10 @@ class TestRPCRoundtrip(DNSTest):
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
 
+        except WERRORError as e:
+            self.fail(str(e))
+
+        try:
             self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
@@ -1167,7 +1194,10 @@ class TestRPCRoundtrip(DNSTest):
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
+        except WERRORError as e:
+            self.fail(str(e))
 
+        try:
             self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
@@ -1210,7 +1240,11 @@ class TestRPCRoundtrip(DNSTest):
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
 
-            self.check_query_txt(prefix, ['NULL'])
+        except WERRORError as e:
+            self.fail(str(e))
+
+        try:
+           self.check_query_txt(prefix, ['NULL'])
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                               0, self.server_ip, self.get_dns_domain(),
@@ -1241,7 +1275,11 @@ class TestRPCRoundtrip(DNSTest):
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
 
-            self.check_query_txt(prefix, txt)
+        except WERRORError as e:
+            self.fail(str(e))
+
+        try:
+           self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                               0, self.server_ip, self.get_dns_domain(),
@@ -1275,7 +1313,12 @@ class TestRPCRoundtrip(DNSTest):
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
 
+        except WERRORError as e:
+            self.fail(str(e))
+
+        try:
             self.check_query_txt(prefix, txt)
+
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                               0, self.server_ip, self.get_dns_domain(),
@@ -1311,6 +1354,10 @@ class TestRPCRoundtrip(DNSTest):
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
 
+        except WERRORError as e:
+            self.fail(str(e))
+
+        try:
             self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
@@ -1341,7 +1388,10 @@ class TestRPCRoundtrip(DNSTest):
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                      0, self.server_ip, self.get_dns_domain(),
                                      name, add_rec_buf, None)
+        except WERRORError as e:
+            self.fail(str(e))
 
+        try:
             self.check_query_txt(prefix, txt)
         finally:
             self.rpc_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
diff --git a/selftest/knownfail.d/dns b/selftest/knownfail.d/dns
new file mode 100644
index 00000000000..6553c1fffe0
--- /dev/null
+++ b/selftest/knownfail.d/dns
@@ -0,0 +1,55 @@
+# These tests are expected to fail because we want to ensure that
+# unauthenicated updates are not permitted against the default
+# configuration, nor against an RODC
+
+samba.tests.dns.__main__.TestDNSUpdates.test_delete_record\(rodc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_readd_record\(rodc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_update_add_mx_record\(rodc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_update_add_txt_record\(rodc:local\)
+samba.tests.dns.__main__.TestInvalidQueries.test_one_a_query\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_empty_txt_records\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_hex_char_txt_record\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_null_char_txt_record\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_null_padded_txt_record\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_slash_txt_record\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_two_txt_records\(rodc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_delete_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_readd_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_update_add_mx_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestDNSUpdates.test_update_add_txt_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestInvalidQueries.test_one_a_query\(vampire_dc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_empty_txt_records\(vampire_dc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_hex_char_txt_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_null_char_txt_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_null_padded_txt_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_slash_txt_record\(vampire_dc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_two_txt_records\(vampire_dc:local\)
+samba.tests.dns.__main__.TestComplexQueries.test_cname_two_chain\(rodc:local\)
+samba.tests.dns.__main__.TestComplexQueries.test_one_a_query\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_empty_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_hex_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_null_char_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_padding_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_slash_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_two_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_txt_rpc_to_dns\(rodc:local\)
+samba.tests.dns.__main__.TestZones.test_soa_query\(rodc:local\)
+samba.tests.dns.__main__.TestComplexQueries.test_cname_two_chain\(vampire_dc:local\)
+samba.tests.dns.__main__.TestComplexQueries.test_one_a_query\(vampire_dc:local\)
+
+# The SOA override should not pass against the RODC, it must not overstamp
+samba.tests.dns.__main__.TestSimpleQueries.test_one_SOA_query\(rodc:local\)
+
+# The very first DC will have DNS records, but subsequent DCs only get entries into
+# the dns_hosts_file in our selftest env
+samba.tests.dns.__main__.TestSimpleQueries.test_one_SOA_query\(vampire_dc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query\(vampire_dc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query_tcp\(vampire_dc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_one_mx_query\(vampire_dc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_qtype_all_query\(vampire_dc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_soa_hostname_query\(vampire_dc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query\(rodc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query_tcp\(rodc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_one_mx_query\(rodc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_qtype_all_query\(rodc:local\)
+samba.tests.dns.__main__.TestSimpleQueries.test_soa_hostname_query\(rodc:local\)
diff --git a/source4/selftest/tests.py b/source4/selftest/tests.py
index 7c601c35af4..071660bb418 100755
--- a/source4/selftest/tests.py
+++ b/source4/selftest/tests.py
@@ -361,6 +361,8 @@ for f in sorted(os.listdir(os.path.join(samba4srcdir, "../pidl/tests"))):
 
 # DNS tests
 plantestsuite_loadlist("samba.tests.dns", "fl2003dc:local", [python, os.path.join(srcdir(), "python/samba/tests/dns.py"), '$SERVER', '$SERVER_IP', '--machine-pass', '-U"$USERNAME%$PASSWORD"', '--workgroup=$DOMAIN', '$LOADLIST', '$LISTOPT'])
+plantestsuite_loadlist("samba.tests.dns", "rodc:local", [python, os.path.join(srcdir(), "python/samba/tests/dns.py"), '$SERVER', '$SERVER_IP', '--machine-pass', '-U"$USERNAME%$PASSWORD"', '--workgroup=$DOMAIN', '$LOADLIST', '$LISTOPT'])
+plantestsuite_loadlist("samba.tests.dns", "vampire_dc:local", [python, os.path.join(srcdir(), "python/samba/tests/dns.py"), '$SERVER', '$SERVER_IP', '--machine-pass', '-U"$USERNAME%$PASSWORD"', '--workgroup=$DOMAIN', '$LOADLIST', '$LISTOPT'])
 
 plantestsuite_loadlist("samba.tests.dns_forwarder", "fl2003dc:local", [python, os.path.join(srcdir(), "python/samba/tests/dns_forwarder.py"), '$SERVER', '$SERVER_IP', '$DNS_FORWARDER1', '$DNS_FORWARDER2', '--machine-pass', '-U"$USERNAME%$PASSWORD"', '--workgroup=$DOMAIN', '$LOADLIST', '$LISTOPT'])
 
-- 
2.11.0


From 3d48e3521aa0e732eda03d6fb5576c7dae7420bc Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 11 Apr 2017 12:43:22 +1200
Subject: [PATCH 13/24] dns_server: clobber MNAME in the SOA

Otherwise, we always report the first server we created/provisioned the AD domain on
which does not match AD behaviour.  AD is multi-master so all RW servers are a master.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/samdb.py                 |  2 +-
 selftest/knownfail.d/dns              |  1 -
 source4/dns_server/dlz_bind9.c        |  2 +-
 source4/dns_server/dnsserver_common.c | 53 +++++++++++++++++++++++++++++++++--
 source4/dns_server/dnsserver_common.h |  3 +-
 source4/dns_server/pydns.c            |  8 ++++--
 6 files changed, 60 insertions(+), 9 deletions(-)

diff --git a/python/samba/samdb.py b/python/samba/samdb.py
index b3a4b384926..e0021563a23 100644
--- a/python/samba/samdb.py
+++ b/python/samba/samdb.py
@@ -937,7 +937,7 @@ accountExpires: %u
 
     def dns_extract(self, el):
         '''Return the NDR database structures from a dnsRecord element'''
-        return dsdb_dns.extract(el)
+        return dsdb_dns.extract(self, el)
 
     def dns_replace(self, dns_name, new_records):
         '''Do a DNS modification on the database, sets the NDR database
diff --git a/selftest/knownfail.d/dns b/selftest/knownfail.d/dns
index 6553c1fffe0..c40041d1892 100644
--- a/selftest/knownfail.d/dns
+++ b/selftest/knownfail.d/dns
@@ -42,7 +42,6 @@ samba.tests.dns.__main__.TestSimpleQueries.test_one_SOA_query\(rodc:local\)
 
 # The very first DC will have DNS records, but subsequent DCs only get entries into
 # the dns_hosts_file in our selftest env
-samba.tests.dns.__main__.TestSimpleQueries.test_one_SOA_query\(vampire_dc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query\(vampire_dc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query_tcp\(vampire_dc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_one_mx_query\(vampire_dc:local\)
diff --git a/source4/dns_server/dlz_bind9.c b/source4/dns_server/dlz_bind9.c
index 897699a6317..7096f4749b2 100644
--- a/source4/dns_server/dlz_bind9.c
+++ b/source4/dns_server/dlz_bind9.c
@@ -997,7 +997,7 @@ _PUBLIC_ isc_result_t dlz_allnodes(const char *zone, void *dbdata,
 			return ISC_R_NOMEMORY;
 		}
 
-		werr = dns_common_extract(el, el_ctx, &recs, &num_recs);
+		werr = dns_common_extract(state->samdb, el, el_ctx, &recs, &num_recs);
 		if (!W_ERROR_IS_OK(werr)) {
 			state->log(ISC_LOG_ERROR, "samba_dlz: failed to parse dnsRecord for %s, %s",
 				   ldb_dn_get_linearized(dn), win_errstr(werr));
diff --git a/source4/dns_server/dnsserver_common.c b/source4/dns_server/dnsserver_common.c
index fbfa5fa4eae..d0c0a2fdbb4 100644
--- a/source4/dns_server/dnsserver_common.c
+++ b/source4/dns_server/dnsserver_common.c
@@ -69,7 +69,8 @@ uint8_t werr_to_dns_err(WERROR werr)
 	return DNS_RCODE_SERVFAIL;
 }
 
-WERROR dns_common_extract(const struct ldb_message_element *el,
+WERROR dns_common_extract(struct ldb_context *samdb,
+			  const struct ldb_message_element *el,
 			  TALLOC_CTX *mem_ctx,
 			  struct dnsp_DnssrvRpcRecord **records,
 			  uint16_t *num_records)
@@ -86,9 +87,13 @@ WERROR dns_common_extract(const struct ldb_message_element *el,
 		return WERR_NOT_ENOUGH_MEMORY;
 	}
 	for (ri = 0; ri < el->num_values; ri++) {
+		bool am_rodc;
+		int ret;
+		const char *attrs[] = { "dnsHostName", NULL };
+		const char *dnsHostName;
 		struct ldb_val *v = &el->values[ri];
 		enum ndr_err_code ndr_err;
-
+		struct ldb_result *res = NULL;
 		ndr_err = ndr_pull_struct_blob(v, recs, &recs[ri],
 				(ndr_pull_flags_fn_t)ndr_pull_dnsp_DnssrvRpcRecord);
 		if (!NDR_ERR_CODE_IS_SUCCESS(ndr_err)) {
@@ -96,7 +101,49 @@ WERROR dns_common_extract(const struct ldb_message_element *el,
 			DEBUG(0, ("Failed to grab dnsp_DnssrvRpcRecord\n"));
 			return DNS_ERR(SERVER_FAILURE);
 		}
+
+		/*
+		 * In AD, except on an RODC (where we should list a random RWDC,
+		 * we should over-stamp the MNAME with our own hostname
+		 */
+		if (recs[ri].wType != DNS_TYPE_SOA) {
+			continue;
+		}
+
+		ret = samdb_rodc(samdb, &am_rodc);
+		if (ret != LDB_SUCCESS) {
+			DEBUG(0, ("Failed to confirm we are not an RODC: %s\n",
+				  ldb_errstring(samdb)));
+			return DNS_ERR(SERVER_FAILURE);
+		}
+
+		if (am_rodc) {
+			continue;
+		}
+
+		ret = dsdb_search_dn(samdb, mem_ctx, &res, NULL,
+				     attrs, 0);
+
+		if (res->count != 1 || ret != LDB_SUCCESS) {
+			DEBUG(0, ("Failed to get rootDSE for dnsHostName: %s",
+				  ldb_errstring(samdb)));
+			return DNS_ERR(SERVER_FAILURE);
+		}
+
+		dnsHostName
+			= ldb_msg_find_attr_as_string(res->msgs[0],
+						      "dnsHostName",
+						      NULL);
+
+		if (dnsHostName == NULL) {
+			DEBUG(0, ("Failed to get dnsHostName from rootDSE"));
+			return DNS_ERR(SERVER_FAILURE);
+		}
+
+		recs[ri].data.soa.mname
+			= talloc_steal(recs, dnsHostName);
 	}
+
 	*records = recs;
 	*num_records = el->num_values;
 	return WERR_OK;
@@ -189,7 +236,7 @@ WERROR dns_common_lookup(struct ldb_context *samdb,
 		}
 	}
 
-	werr = dns_common_extract(el, mem_ctx, records, num_records);
+	werr = dns_common_extract(samdb, el, mem_ctx, records, num_records);
 	TALLOC_FREE(msg);
 	if (!W_ERROR_IS_OK(werr)) {
 		return werr;
diff --git a/source4/dns_server/dnsserver_common.h b/source4/dns_server/dnsserver_common.h
index 293831f0acb..b615e2dcfae 100644
--- a/source4/dns_server/dnsserver_common.h
+++ b/source4/dns_server/dnsserver_common.h
@@ -35,7 +35,8 @@ struct dns_server_zone {
 	struct ldb_dn *dn;
 };
 
-WERROR dns_common_extract(const struct ldb_message_element *el,
+WERROR dns_common_extract(struct ldb_context *samdb,
+			  const struct ldb_message_element *el,
 			  TALLOC_CTX *mem_ctx,
 			  struct dnsp_DnssrvRpcRecord **records,
 			  uint16_t *num_records);
diff --git a/source4/dns_server/pydns.c b/source4/dns_server/pydns.c
index cb41faa1441..63fa80e92b3 100644
--- a/source4/dns_server/pydns.c
+++ b/source4/dns_server/pydns.c
@@ -160,17 +160,21 @@ static PyObject *py_dsdb_dns_lookup(PyObject *self,
 
 static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 {
+	struct ldb_context *samdb;
 	PyObject *py_dns_el, *ret;
+	PyObject *py_ldb = NULL;
 	TALLOC_CTX *frame;
 	WERROR werr;
 	struct ldb_message_element *dns_el;
 	struct dnsp_DnssrvRpcRecord *records;
 	uint16_t num_records;
 
-	if (!PyArg_ParseTuple(args, "O", &py_dns_el)) {
+	if (!PyArg_ParseTuple(args, "OO", &py_ldb, &py_dns_el)) {
 		return NULL;
 	}
 
+	PyErr_LDB_OR_RAISE(py_ldb, samdb);
+
 	if (!py_check_dcerpc_type(py_dns_el, "ldb", "MessageElement")) {
 		PyErr_SetString(PyExc_TypeError,
 				"ldb MessageElement object required");
@@ -180,7 +184,7 @@ static PyObject *py_dsdb_dns_extract(PyObject *self, PyObject *args)
 
 	frame = talloc_stackframe();
 
-	werr = dns_common_extract(dns_el,
+	werr = dns_common_extract(samdb, dns_el,
 				  frame,
 				  &records,
 				  &num_records);
-- 
2.11.0


From 6a22362d12a8c4b0a412f9f948f2116b58ccd166 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 11 Apr 2017 14:14:15 +1200
Subject: [PATCH 14/24] samba_dnsupdate: Extend possible server list to all NS
 servers for the zone

This should eventually be removed, but for now this unblocks samba_dnsupdate operation
in existing domains that have lost the original Samba DC

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/scripting/bin/samba_dnsupdate | 98 ++++++++++++++++++++++++-----------
 1 file changed, 69 insertions(+), 29 deletions(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index 28343bf17d5..eb6d4c2ad86 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -121,6 +121,64 @@ for i in IPs:
 if opts.verbose:
     print "IPs: %s" % IPs
 
+def get_possible_rw_dns_server(creds, domain):
+    """Get a list of possible read-write DNS servers, starting with
+       the SOA.  The SOA is the correct answer, but old Samba domains
+       (4.6 and prior) do not maintain this value, so add NS servers
+       as well"""
+
+    hostnames = []
+    ans_soa = check_one_dns_name(domain, 'SOA')
+
+    # Actually there is only one
+    for i in range(len(ans_soa)):
+        hostnames.append(str(ans_soa[i].mname).rstrip('.'))
+
+    # This is not strictly legit, but old Samba domains may have an
+    # unmaintained SOA record, so go for any NS that we can get a
+    # ticket to.
+    ans_ns = check_one_dns_name(domain, 'NS')
+
+    # Actually there is only one
+    for i in range(len(ans_ns)):
+        hostnames.append(str(ans_ns[i].target).rstrip('.'))
+
+    return hostnames
+
+def get_krb5_rw_dns_server(creds, domain):
+    """Get a list of read-write DNS servers that we can obtain a ticket
+       for, starting with the SOA.  The SOA is the correct answer, but
+       old Samba domains (4.6 and prior) do not maintain this value,
+       so continue with the NS servers as well until we get one that
+       the KDC will issue a ticket to.
+    """
+
+    rw_dns_servers = get_possible_rw_dns_server(creds, domain)
+    # Actually there is only one
+    for i in range(len(rw_dns_servers)):
+        target_hostname = str(rw_dns_servers[i])
+        settings = {}
+        settings["lp_ctx"] = lp
+        settings["target_hostname"] = target_hostname
+
+        gensec_client = gensec.Security.start_client(settings)
+        gensec_client.set_credentials(creds)
+        gensec_client.set_target_service("DNS")
+        gensec_client.set_target_hostname(target_hostname)
+        gensec_client.want_feature(gensec.FEATURE_SEAL)
+        gensec_client.start_mech_by_sasl_name("GSSAPI")
+        server_to_client = ""
+        try:
+            (client_finished, client_to_server) = gensec_client.update(server_to_client)
+            if opts.verbose:
+                print "Successfully obtained Kerberos ticket to DNS/%s as %s" \
+                    % (target_hostname, creds.get_username())
+            return target_hostname
+        except RuntimeError:
+            # Only raise an exception if they all failed
+            if i != len(rw_dns_servers) - 1:
+                pass
+            raise
 
 def get_credentials(lp):
     """# get credentials if we haven't got them already."""
@@ -138,33 +196,8 @@ def get_credentials(lp):
             return
 
         # Now confirm we can get a ticket to the DNS server
-        ans = check_one_dns_name(sub_vars['DNSDOMAIN'] + '.', 'SOA')
-
-        # Actually there is only one
-        for i in range(len(ans)):
-            target_hostname = str(ans[i].mname).rstrip('.')
-            settings = {}
-            settings["lp_ctx"] = lp
-            settings["target_hostname"] = target_hostname
-
-            gensec_client = gensec.Security.start_client(settings)
-            gensec_client.set_credentials(creds)
-            gensec_client.set_target_service("DNS")
-            gensec_client.set_target_hostname(target_hostname)
-            gensec_client.want_feature(gensec.FEATURE_SEAL)
-            gensec_client.start_mech_by_sasl_name("GSSAPI")
-            server_to_client = ""
-            try:
-                (client_finished, client_to_server) = gensec_client.update(server_to_client)
-                if opts.verbose:
-                    print "Successfully obtained Kerberos ticket to DNS/%s as %s" \
-                            % (target_hostname, creds.get_username())
-                return
-            except RuntimeError:
-                # Only raise an exception if they all failed
-                if i != len(ans) - 1:
-                    pass
-                raise
+        get_krb5_rw_dns_server(creds, sub_vars['DNSDOMAIN'] + '.')
+        return creds
 
     except RuntimeError as e:
         os.unlink(ccachename)
@@ -452,11 +485,18 @@ def call_nsupdate(d, op="add"):
         f.write('server %s\n' % d.nameservers[0])
     else:
         resolver = get_resolver(d)
+
+        # Local the zone for this name
         zone = dns.resolver.zone_for_name(normalised_name,
                                           resolver=resolver)
-        soa = resolver.query(zone, "SOA")
 
-        f.write('server %s\n' % soa[0].mname)
+        # Now find the SOA, or if we can't get a ticket to the SOA,
+        # any server with an NS record we can get a ticket for.
+        #
+        # Thanks to the Kerberos Crednetials cache this is not
+        # expensive inside the loop
+        server = get_krb5_rw_dns_server(creds, zone)
+        f.write('server %s\n' % server)
 
     if d.type == "A":
         f.write("update %s %s %u A %s\n" % (op, normalised_name, default_ttl, d.ip))
-- 
2.11.0


From 26c9ea7eaf23ba18ebf43c80110c8f3f2552aac6 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 11 Apr 2017 14:23:49 +1200
Subject: [PATCH 15/24] samba_dnsupdate: fix "samba-tool" fallback error
 handling

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 source4/scripting/bin/samba_dnsupdate | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/source4/scripting/bin/samba_dnsupdate b/source4/scripting/bin/samba_dnsupdate
index eb6d4c2ad86..d9948a6f9b8 100755
--- a/source4/scripting/bin/samba_dnsupdate
+++ b/source4/scripting/bin/samba_dnsupdate
@@ -626,7 +626,7 @@ def call_samba_tool(d, op="add", zone=None):
                 sys.exit(1)
             error_count = error_count + 1
             if opts.verbose:
-                print("Failed 'samba-tool dns' based update: %s" % (str(d)))
+                print("Failed 'samba-tool dns' based update of %s" % (str(d)))
     except Exception, estr:
         if opts.fail_immediately:
             sys.exit(1)
-- 
2.11.0


From 239196378bb4a66dec0d7e5fb5c3c5bc8ea41b4b Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Wed, 31 May 2017 13:57:25 +1200
Subject: [PATCH 16/24] selftest: move make_txt_record() onto self in
 samba.tests.dns

This will help unifying dns.py and dns_tkey.py to use common subclasses

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/dns.py | 28 ++++++++++++++--------------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index 43eccddd957..50d7fb60c2f 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -60,14 +60,6 @@ server_name = args[0]
 server_ip = args[1]
 creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 
-def make_txt_record(records):
-    rdata_txt = dns.txt_record()
-    s_list = dnsp.string_list()
-    s_list.count = len(records)
-    s_list.str = records
-    rdata_txt.txt = s_list
-    return rdata_txt
-
 class DNSTest(TestCase):
 
     def setUp(self):
@@ -131,6 +123,14 @@ class DNSTest(TestCase):
         q.question_class = qclass
         return q
 
+    def make_txt_record(self, records):
+        rdata_txt = dns.txt_record()
+        s_list = dnsp.string_list()
+        s_list.count = len(records)
+        s_list.str = records
+        rdata_txt.txt = s_list
+        return rdata_txt
+
     def get_dns_domain(self):
         "Helper to get dns domain"
         return self.creds.get_realm().lower()
@@ -193,7 +193,7 @@ class DNSTest(TestCase):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(txt_array)
+        rdata = self.make_txt_record(txt_array)
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -560,7 +560,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -596,7 +596,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_NONE
         r.ttl = 0
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -638,7 +638,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -674,7 +674,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_NONE
         r.ttl = 0
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
@@ -711,7 +711,7 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        rdata = make_txt_record(['"This is a test"'])
+        rdata = self.make_txt_record(['"This is a test"'])
         r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
-- 
2.11.0


From 870c51bd87d19a8613f9536a03b296167bf275c2 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Fri, 9 Jun 2017 10:00:09 +1200
Subject: [PATCH 17/24] selftest: merge DNSTest boilerplate

This will help unifying dns.py and dns_tkey.py to use common subclasses

The code was originally copied, but has since divereged.  This handles
that divergence.

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/dns.py      | 202 ++++++++++++++++++++++++++---------------
 python/samba/tests/dns_tkey.py |  82 +++++++++++++----
 2 files changed, 196 insertions(+), 88 deletions(-)

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index 50d7fb60c2f..ac82d5051aa 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -63,12 +63,8 @@ creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 class DNSTest(TestCase):
 
     def setUp(self):
-        global server, server_ip, lp, creds
         super(DNSTest, self).setUp()
-        self.server = server_name
-        self.server_ip = server_ip
-        self.lp = lp
-        self.creds = creds
+        self.timeout = None
 
     def errstr(self, errcode):
         "Return a readable error code"
@@ -84,30 +80,42 @@ class DNSTest(TestCase):
             "NXRRSET",
             "NOTAUTH",
             "NOTZONE",
+            "0x0B",
+            "0x0C",
+            "0x0D",
+            "0x0E",
+            "0x0F",
+            "BADSIG",
+            "BADKEY"
         ]
 
         return string_codes[errcode]
 
+    def assert_rcode_equals(self, rcode, expected):
+        "Helper function to check return code"
+        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
+                          (self.errstr(expected), self.errstr(rcode)))
 
     def assert_dns_rcode_equals(self, packet, rcode):
         "Helper function to check return code"
         p_errcode = packet.operation & 0x000F
         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                            (self.errstr(rcode), self.errstr(p_errcode)))
+                          (self.errstr(rcode), self.errstr(p_errcode)))
 
     def assert_dns_opcode_equals(self, packet, opcode):
         "Helper function to check opcode"
         p_opcode = packet.operation & 0x7800
         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                            (opcode, p_opcode))
+                          (opcode, p_opcode))
 
     def make_name_packet(self, opcode, qid=None):
         "Helper creating a dns.name_packet"
         p = dns.name_packet()
         if qid is None:
-            p.id = random.randint(0x0, 0xffff)
+            p.id = random.randint(0x0, 0xff00)
         p.operation = opcode
         p.questions = []
+        p.additional = []
         return p
 
     def finish_name_packet(self, packet, questions):
@@ -135,10 +143,12 @@ class DNSTest(TestCase):
         "Helper to get dns domain"
         return self.creds.get_realm().lower()
 
-    def dns_transaction_udp(self, packet, host=server_ip,
-                            dump=False, timeout=timeout):
+    def dns_transaction_udp(self, packet, host,
+                            dump=False, timeout=None):
         "send a DNS query and read the reply"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -146,19 +156,22 @@ class DNSTest(TestCase):
             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
             s.settimeout(timeout)
             s.connect((host, 53))
-            s.send(send_packet, 0)
+            s.sendall(send_packet, 0)
             recv_packet = s.recv(2048, 0)
             if dump:
                 print self.hexdump(recv_packet)
-            return ndr.ndr_unpack(dns.name_packet, recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
+            return (response, recv_packet)
         finally:
             if s is not None:
                 s.close()
 
-    def dns_transaction_tcp(self, packet, host=server_ip,
-                            dump=False, timeout=timeout):
-        "send a DNS query and read the reply"
+    def dns_transaction_tcp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply, also return the raw packet"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -168,14 +181,22 @@ class DNSTest(TestCase):
             s.connect((host, 53))
             tcp_packet = struct.pack('!H', len(send_packet))
             tcp_packet += send_packet
-            s.send(tcp_packet, 0)
+            s.sendall(tcp_packet)
+
             recv_packet = s.recv(0xffff + 2, 0)
             if dump:
                 print self.hexdump(recv_packet)
-            return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+
         finally:
-                if s is not None:
-                    s.close()
+            if s is not None:
+                s.close()
+
+        # unpacking and packing again should produce same bytestream
+        my_packet = ndr.ndr_pack(response)
+        self.assertEquals(my_packet, recv_packet[2:])
+
+        return (response, recv_packet[2:])
 
     def make_txt_update(self, prefix, txt_array):
         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
@@ -210,12 +231,21 @@ class DNSTest(TestCase):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEquals(response.ancount, 1)
         self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
 
+
 class TestSimpleQueries(DNSTest):
+    def setUp(self):
+        super(TestSimpleQueries, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "create a query packet containing one query record"
@@ -228,7 +258,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -246,7 +276,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -264,7 +294,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_tcp(p)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -282,7 +312,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -296,7 +326,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -316,7 +346,7 @@ class TestSimpleQueries(DNSTest):
 
         self.finish_name_packet(p, questions)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -336,7 +366,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
 
         num_answers = 1
         dc_ipv6 = os.getenv('SERVER_IPV6')
@@ -362,7 +392,7 @@ class TestSimpleQueries(DNSTest):
 
         self.finish_name_packet(p, questions)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -381,7 +411,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         # We don't get SOA records for single hosts
@@ -399,7 +429,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -407,6 +437,14 @@ class TestSimpleQueries(DNSTest):
 
 
 class TestDNSUpdates(DNSTest):
+    def setUp(self):
+        super(TestDNSUpdates, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_two_updates(self):
         "create two update requests"
@@ -423,7 +461,7 @@ class TestDNSUpdates(DNSTest):
 
         self.finish_name_packet(p, updates)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -442,7 +480,7 @@ class TestDNSUpdates(DNSTest):
         updates.append(u)
 
         self.finish_name_packet(p, updates)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
 
     def test_update_prereq_with_non_null_ttl(self):
@@ -469,7 +507,7 @@ class TestDNSUpdates(DNSTest):
         p.answers = prereqs
 
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -501,7 +539,7 @@ class TestDNSUpdates(DNSTest):
         p.ancount = len(prereqs)
         p.answers = prereqs
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
 
     def test_update_prereq_nonexisting_name(self):
@@ -527,14 +565,14 @@ class TestDNSUpdates(DNSTest):
         p.ancount = len(prereqs)
         p.answers = prereqs
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
 
     def test_update_add_txt_record(self):
         "test adding records works"
         prefix, txt = 'textrec', ['"This is a test"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
 
@@ -566,7 +604,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -576,7 +614,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now delete the record
@@ -602,7 +640,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # And finally check it's gone
@@ -613,7 +651,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
     def test_readd_record(self):
@@ -644,7 +682,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -654,7 +692,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now delete the record
@@ -680,7 +718,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # check it's gone
@@ -691,7 +729,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
         # recreate the record
@@ -717,7 +755,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -727,7 +765,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_update_add_mx_record(self):
@@ -756,7 +794,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@@ -767,7 +805,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEqual(response.ancount, 1)
         ans = response.answers[0]
@@ -795,22 +833,22 @@ class TestComplexQueries(DNSTest):
         updates = [r]
         p.nscount = 1
         p.nsrecs = updates
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def setUp(self):
         super(TestComplexQueries, self).setUp()
 
-    def tearDown(self):
-        super(TestComplexQueries, self).tearDown()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "create a query packet containing one query record"
 
-        name = "cname_test.%s" % self.get_dns_domain()
-        rdata = "%s.%s" % (self.server, self.get_dns_domain())
-        self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
-
         try:
 
             # Create the record
@@ -828,7 +866,7 @@ class TestComplexQueries(DNSTest):
             questions.append(q)
 
             self.finish_name_packet(p, questions)
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
             self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
             self.assertEquals(response.ancount, 2)
@@ -862,7 +900,7 @@ class TestComplexQueries(DNSTest):
             p.nscount = len(updates)
             p.nsrecs = updates
 
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_cname_two_chain(self):
@@ -880,7 +918,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 3)
@@ -921,7 +959,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
 
@@ -938,6 +976,14 @@ class TestComplexQueries(DNSTest):
         self.assertEquals(response.answers[1].rdata, name0)
 
 class TestInvalidQueries(DNSTest):
+    def setUp(self):
+        super(TestInvalidQueries, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "send 0 bytes follows by create a query packet containing one query record"
@@ -960,7 +1006,7 @@ class TestInvalidQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -1006,6 +1052,13 @@ class TestInvalidQueries(DNSTest):
 class TestZones(DNSTest):
     def setUp(self):
         super(TestZones, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
+
         self.zone = "test.lan"
         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
                                             self.lp, self.creds)
@@ -1056,21 +1109,21 @@ class TestZones(DNSTest):
         questions.append(q)
         self.finish_name_packet(p, questions)
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         # Windows returns OK while BIND logically seems to return NXDOMAIN
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
 
         self.create_zone(zone)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_SOA)
 
         self.delete_zone(zone)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -1078,6 +1131,11 @@ class TestZones(DNSTest):
 class TestRPCRoundtrip(DNSTest):
     def setUp(self):
         super(TestRPCRoundtrip, self).setUp()
+        global server, server_ip, lp, creds
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
                                             self.lp, self.creds)
 
@@ -1110,7 +1168,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'pad1textrec', ['"This is a test"', '', '']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1120,7 +1178,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'pad2textrec', ['"This is a test"', '', '', 'more text']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1130,7 +1188,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'pad3textrec', ['', '', '"This is a test"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1209,7 +1267,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'nulltextrec', ['NULL\x00BYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, ['NULL'])
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1219,7 +1277,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'nulltextrec2', ['NULL\x00BYTE', 'NULL\x00BYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, ['NULL', 'NULL'])
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1254,7 +1312,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'hextextrec', ['HIGH\xFFBYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1289,7 +1347,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'slashtextrec', ['Th\\=is=is a test']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1329,7 +1387,7 @@ class TestRPCRoundtrip(DNSTest):
         prefix, txt = 'textrec2', ['"This is a test"',
                                    '"and this is a test, too"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1368,7 +1426,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding two txt records works"
         prefix, txt = 'emptytextrec', []
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
diff --git a/python/samba/tests/dns_tkey.py b/python/samba/tests/dns_tkey.py
index f424e079eb8..21f5d560dd2 100644
--- a/python/samba/tests/dns_tkey.py
+++ b/python/samba/tests/dns_tkey.py
@@ -29,6 +29,7 @@ from samba import credentials
 from samba.dcerpc import dns, dnsp
 from samba.tests.subunitrun import SubunitOptions, TestProgram
 from samba import gensec, tests
+from samba.tests import TestCase
 
 parser = optparse.OptionParser("dns.py <server name> <server ip> [options]")
 sambaopts = options.SambaOptions(parser)
@@ -56,21 +57,11 @@ server_name = args[0]
 server_ip = args[1]
 
 
-class DNSTest(tests.TestCase):
+class DNSTest(TestCase):
+
     def setUp(self):
         super(DNSTest, self).setUp()
-        self.server = server_name
-        self.server_ip = server_ip
-        self.settings = {}
-        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
-        self.settings["target_hostname"] = self.server
-
-        self.creds = credentials.Credentials()
-        self.creds.guess(self.lp_ctx)
-        self.creds.set_username(tests.env_get_var_value('USERNAME'))
-        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
-        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
-        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+        self.timeout = None
 
     def errstr(self, errcode):
         "Return a readable error code"
@@ -150,9 +141,11 @@ class DNSTest(tests.TestCase):
         return self.creds.get_realm().lower()
 
     def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=timeout):
+                            dump=False, timeout=None):
         "send a DNS query and read the reply"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -171,9 +164,11 @@ class DNSTest(tests.TestCase):
                 s.close()
 
     def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=timeout):
+                            dump=False, timeout=None):
         "send a DNS query and read the reply, also return the raw packet"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -200,6 +195,61 @@ class DNSTest(tests.TestCase):
 
         return (response, recv_packet[2:])
 
+    def make_txt_update(self, prefix, txt_array):
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = "%s.%s" % (prefix, self.get_dns_domain())
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = self.make_txt_record(txt_array)
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        return p
+
+    def check_query_txt(self, prefix, txt_array):
+        name = "%s.%s" % (prefix, self.get_dns_domain())
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
+
+
+class DNSTKeyTest(DNSTest):
+    def setUp(self):
+        super(DNSTKeyTest, self).setUp()
+        self.server = server_name
+        self.server_ip = server_ip
+        self.settings = {}
+        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
+        self.settings["target_hostname"] = self.server
+
+        self.creds = credentials.Credentials()
+        self.creds.guess(self.lp_ctx)
+        self.creds.set_username(tests.env_get_var_value('USERNAME'))
+        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
+        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
+        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+
     def tkey_trans(self):
         "Do a TKEY transaction and establish a gensec context"
 
@@ -410,7 +460,7 @@ class DNSTest(tests.TestCase):
         return p
 
 
-class TestDNSUpdates(DNSTest):
+class TestDNSUpdates(DNSTKeyTest):
     def test_tkey(self):
         "test DNS TKEY handshake"
 
-- 
2.11.0


From cc3540cda498b46aafbba7e98b3739e38046f06f Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 1 Jun 2017 13:26:37 +1200
Subject: [PATCH 18/24] selftest: Create new common base class for dns.py and
 dns_tkey.py

This will allow more DNS tests to be written in the future with less
code duplication.
---
 python/samba/tests/dns.py      | 179 +----------------
 python/samba/tests/dns_base.py | 431 +++++++++++++++++++++++++++++++++++++++++
 python/samba/tests/dns_tkey.py | 405 +-------------------------------------
 3 files changed, 435 insertions(+), 580 deletions(-)
 create mode 100644 python/samba/tests/dns_base.py

diff --git a/python/samba/tests/dns.py b/python/samba/tests/dns.py
index ac82d5051aa..1b5b64da3a4 100644
--- a/python/samba/tests/dns.py
+++ b/python/samba/tests/dns.py
@@ -22,11 +22,11 @@ import random
 import socket
 import samba.ndr as ndr
 from samba import credentials, param
-from samba.tests import TestCase
 from samba.dcerpc import dns, dnsp, dnsserver
 from samba.netcmd.dns import TXTRecord, dns_record_match, data_to_dns_record
 from samba.tests.subunitrun import SubunitOptions, TestProgram
 from samba import werror, WERRORError
+from samba.tests.dns_base import DNSTest
 import samba.getopt as options
 import optparse
 
@@ -60,183 +60,6 @@ server_name = args[0]
 server_ip = args[1]
 creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 
-class DNSTest(TestCase):
-
-    def setUp(self):
-        super(DNSTest, self).setUp()
-        self.timeout = None
-
-    def errstr(self, errcode):
-        "Return a readable error code"
-        string_codes = [
-            "OK",
-            "FORMERR",
-            "SERVFAIL",
-            "NXDOMAIN",
-            "NOTIMP",
-            "REFUSED",
-            "YXDOMAIN",
-            "YXRRSET",
-            "NXRRSET",
-            "NOTAUTH",
-            "NOTZONE",
-            "0x0B",
-            "0x0C",
-            "0x0D",
-            "0x0E",
-            "0x0F",
-            "BADSIG",
-            "BADKEY"
-        ]
-
-        return string_codes[errcode]
-
-    def assert_rcode_equals(self, rcode, expected):
-        "Helper function to check return code"
-        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
-                          (self.errstr(expected), self.errstr(rcode)))
-
-    def assert_dns_rcode_equals(self, packet, rcode):
-        "Helper function to check return code"
-        p_errcode = packet.operation & 0x000F
-        self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                          (self.errstr(rcode), self.errstr(p_errcode)))
-
-    def assert_dns_opcode_equals(self, packet, opcode):
-        "Helper function to check opcode"
-        p_opcode = packet.operation & 0x7800
-        self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                          (opcode, p_opcode))
-
-    def make_name_packet(self, opcode, qid=None):
-        "Helper creating a dns.name_packet"
-        p = dns.name_packet()
-        if qid is None:
-            p.id = random.randint(0x0, 0xff00)
-        p.operation = opcode
-        p.questions = []
-        p.additional = []
-        return p
-
-    def finish_name_packet(self, packet, questions):
-        "Helper to finalize a dns.name_packet"
-        packet.qdcount = len(questions)
-        packet.questions = questions
-
-    def make_name_question(self, name, qtype, qclass):
-        "Helper creating a dns.name_question"
-        q = dns.name_question()
-        q.name = name
-        q.question_type = qtype
-        q.question_class = qclass
-        return q
-
-    def make_txt_record(self, records):
-        rdata_txt = dns.txt_record()
-        s_list = dnsp.string_list()
-        s_list.count = len(records)
-        s_list.str = records
-        rdata_txt.txt = s_list
-        return rdata_txt
-
-    def get_dns_domain(self):
-        "Helper to get dns domain"
-        return self.creds.get_realm().lower()
-
-    def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            s.sendall(send_packet, 0)
-            recv_packet = s.recv(2048, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
-            return (response, recv_packet)
-        finally:
-            if s is not None:
-                s.close()
-
-    def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply, also return the raw packet"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            tcp_packet = struct.pack('!H', len(send_packet))
-            tcp_packet += send_packet
-            s.sendall(tcp_packet)
-
-            recv_packet = s.recv(0xffff + 2, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
-
-        finally:
-            if s is not None:
-                s.close()
-
-        # unpacking and packing again should produce same bytestream
-        my_packet = ndr.ndr_pack(response)
-        self.assertEquals(my_packet, recv_packet[2:])
-
-        return (response, recv_packet[2:])
-
-    def make_txt_update(self, prefix, txt_array):
-        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
-        updates = []
-
-        name = self.get_dns_domain()
-        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
-        updates.append(u)
-        self.finish_name_packet(p, updates)
-
-        updates = []
-        r = dns.res_rec()
-        r.name = "%s.%s" % (prefix, self.get_dns_domain())
-        r.rr_type = dns.DNS_QTYPE_TXT
-        r.rr_class = dns.DNS_QCLASS_IN
-        r.ttl = 900
-        r.length = 0xffff
-        rdata = self.make_txt_record(txt_array)
-        r.rdata = rdata
-        updates.append(r)
-        p.nscount = len(updates)
-        p.nsrecs = updates
-
-        return p
-
-    def check_query_txt(self, prefix, txt_array):
-        name = "%s.%s" % (prefix, self.get_dns_domain())
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        questions = []
-
-        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
-        questions.append(q)
-
-        self.finish_name_packet(p, questions)
-        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
-        self.assertEquals(response.ancount, 1)
-        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
-
-
 class TestSimpleQueries(DNSTest):
     def setUp(self):
         super(TestSimpleQueries, self).setUp()
diff --git a/python/samba/tests/dns_base.py b/python/samba/tests/dns_base.py
new file mode 100644
index 00000000000..3d5aa8e25b0
--- /dev/null
+++ b/python/samba/tests/dns_base.py
@@ -0,0 +1,431 @@
+# Unix SMB/CIFS implementation.
+# Copyright (C) Kai Blin  <kai at samba.org> 2011
+# Copyright (C) Ralph Boehme  <slow at samba.org> 2016
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+
+from samba.tests import TestCase
+from samba.dcerpc import dns, dnsp
+from samba import gensec, tests
+from samba import credentials
+import struct
+import samba.ndr as ndr
+import random
+import socket
+import uuid
+import time
+
+class DNSTest(TestCase):
+
+    def setUp(self):
+        super(DNSTest, self).setUp()
+        self.timeout = None
+
+    def errstr(self, errcode):
+        "Return a readable error code"
+        string_codes = [
+            "OK",
+            "FORMERR",
+            "SERVFAIL",
+            "NXDOMAIN",
+            "NOTIMP",
+            "REFUSED",
+            "YXDOMAIN",
+            "YXRRSET",
+            "NXRRSET",
+            "NOTAUTH",
+            "NOTZONE",
+            "0x0B",
+            "0x0C",
+            "0x0D",
+            "0x0E",
+            "0x0F",
+            "BADSIG",
+            "BADKEY"
+        ]
+
+        return string_codes[errcode]
+
+    def assert_rcode_equals(self, rcode, expected):
+        "Helper function to check return code"
+        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
+                          (self.errstr(expected), self.errstr(rcode)))
+
+    def assert_dns_rcode_equals(self, packet, rcode):
+        "Helper function to check return code"
+        p_errcode = packet.operation & 0x000F
+        self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
+                          (self.errstr(rcode), self.errstr(p_errcode)))
+
+    def assert_dns_opcode_equals(self, packet, opcode):
+        "Helper function to check opcode"
+        p_opcode = packet.operation & 0x7800
+        self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
+                          (opcode, p_opcode))
+
+    def make_name_packet(self, opcode, qid=None):
+        "Helper creating a dns.name_packet"
+        p = dns.name_packet()
+        if qid is None:
+            p.id = random.randint(0x0, 0xff00)
+        p.operation = opcode
+        p.questions = []
+        p.additional = []
+        return p
+
+    def finish_name_packet(self, packet, questions):
+        "Helper to finalize a dns.name_packet"
+        packet.qdcount = len(questions)
+        packet.questions = questions
+
+    def make_name_question(self, name, qtype, qclass):
+        "Helper creating a dns.name_question"
+        q = dns.name_question()
+        q.name = name
+        q.question_type = qtype
+        q.question_class = qclass
+        return q
+
+    def make_txt_record(self, records):
+        rdata_txt = dns.txt_record()
+        s_list = dnsp.string_list()
+        s_list.count = len(records)
+        s_list.str = records
+        rdata_txt.txt = s_list
+        return rdata_txt
+
+    def get_dns_domain(self):
+        "Helper to get dns domain"
+        return self.creds.get_realm().lower()
+
+    def dns_transaction_udp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply"
+        s = None
+        if timeout is None:
+            timeout = self.timeout
+        try:
+            send_packet = ndr.ndr_pack(packet)
+            if dump:
+                print self.hexdump(send_packet)
+            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+            s.settimeout(timeout)
+            s.connect((host, 53))
+            s.sendall(send_packet, 0)
+            recv_packet = s.recv(2048, 0)
+            if dump:
+                print self.hexdump(recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
+            return (response, recv_packet)
+        finally:
+            if s is not None:
+                s.close()
+
+    def dns_transaction_tcp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply, also return the raw packet"
+        s = None
+        if timeout is None:
+            timeout = self.timeout
+        try:
+            send_packet = ndr.ndr_pack(packet)
+            if dump:
+                print self.hexdump(send_packet)
+            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            s.settimeout(timeout)
+            s.connect((host, 53))
+            tcp_packet = struct.pack('!H', len(send_packet))
+            tcp_packet += send_packet
+            s.sendall(tcp_packet)
+
+            recv_packet = s.recv(0xffff + 2, 0)
+            if dump:
+                print self.hexdump(recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+
+        finally:
+            if s is not None:
+                s.close()
+
+        # unpacking and packing again should produce same bytestream
+        my_packet = ndr.ndr_pack(response)
+        self.assertEquals(my_packet, recv_packet[2:])
+
+        return (response, recv_packet[2:])
+
+    def make_txt_update(self, prefix, txt_array):
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = "%s.%s" % (prefix, self.get_dns_domain())
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = self.make_txt_record(txt_array)
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        return p
+
+    def check_query_txt(self, prefix, txt_array):
+        name = "%s.%s" % (prefix, self.get_dns_domain())
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
+
+
+class DNSTKeyTest(DNSTest):
+    def setUp(self):
+        super(DNSTKeyTest, self).setUp()
+        self.settings = {}
+        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
+        self.settings["target_hostname"] = self.server
+
+        self.creds = credentials.Credentials()
+        self.creds.guess(self.lp_ctx)
+        self.creds.set_username(tests.env_get_var_value('USERNAME'))
+        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
+        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
+        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+
+    def tkey_trans(self, creds=None):
+        "Do a TKEY transaction and establish a gensec context"
+
+        if creds is None:
+            creds = self.creds
+
+        self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        q = self.make_name_question(self.key_name,
+                                    dns.DNS_QTYPE_TKEY,
+                                    dns.DNS_QCLASS_IN)
+        questions = []
+        questions.append(q)
+        self.finish_name_packet(p, questions)
+
+        r = dns.res_rec()
+        r.name = self.key_name
+        r.rr_type = dns.DNS_QTYPE_TKEY
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 0
+        r.length = 0xffff
+        rdata = dns.tkey_record()
+        rdata.algorithm = "gss-tsig"
+        rdata.inception = int(time.time())
+        rdata.expiration = int(time.time()) + 60*60
+        rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
+        rdata.error = 0
+        rdata.other_size = 0
+
+        self.g = gensec.Security.start_client(self.settings)
+        self.g.set_credentials(creds)
+        self.g.set_target_service("dns")
+        self.g.set_target_hostname(self.server)
+        self.g.want_feature(gensec.FEATURE_SIGN)
+        self.g.start_mech_by_name("spnego")
+
+        finished = False
+        client_to_server = ""
+
+        (finished, server_to_client) = self.g.update(client_to_server)
+        self.assertFalse(finished)
+
+        data = [ord(x) for x in list(server_to_client)]
+        rdata.key_data = data
+        rdata.key_size = len(data)
+        r.rdata = rdata
+
+        additional = [r]
+        p.arcount = 1
+        p.additional = additional
+
+        (response, response_packet) = self.dns_transaction_tcp(p, self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+        tkey_record = response.answers[0].rdata
+        data = [chr(x) for x in tkey_record.key_data]
+        server_to_client = ''.join(data)
+        (finished, client_to_server) = self.g.update(server_to_client)
+        self.assertTrue(finished)
+
+        self.verify_packet(response, response_packet)
+
+    def verify_packet(self, response, response_packet, request_mac=""):
+        self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
+
+        tsig_record = response.additional[0].rdata
+        mac = ''.join([chr(x) for x in tsig_record.mac])
+
+        # Cut off tsig record from dns response packet for MAC verification
+        # and reset additional record count.
+        key_name_len = len(self.key_name) + 2
+        tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
+
+        response_packet_list = list(response_packet)
+        del response_packet_list[-tsig_record_len:]
+        response_packet_list[11] = chr(0)
+        response_packet_wo_tsig = ''.join(response_packet_list)
+
+        fake_tsig = dns.fake_tsig_rec()
+        fake_tsig.name = self.key_name
+        fake_tsig.rr_class = dns.DNS_QCLASS_ANY
+        fake_tsig.ttl = 0
+        fake_tsig.time_prefix = tsig_record.time_prefix
+        fake_tsig.time = tsig_record.time
+        fake_tsig.algorithm_name = tsig_record.algorithm_name
+        fake_tsig.fudge = tsig_record.fudge
+        fake_tsig.error = 0
+        fake_tsig.other_size = 0
+        fake_tsig_packet = ndr.ndr_pack(fake_tsig)
+
+        data = request_mac + response_packet_wo_tsig + fake_tsig_packet
+        self.g.check_packet(data, data, mac)
+
+    def sign_packet(self, packet, key_name):
+        "Sign a packet, calculate a MAC and add TSIG record"
+        packet_data = ndr.ndr_pack(packet)
+
+        fake_tsig = dns.fake_tsig_rec()
+        fake_tsig.name = key_name
+        fake_tsig.rr_class = dns.DNS_QCLASS_ANY
+        fake_tsig.ttl = 0
+        fake_tsig.time_prefix = 0
+        fake_tsig.time = int(time.time())
+        fake_tsig.algorithm_name = "gss-tsig"
+        fake_tsig.fudge = 300
+        fake_tsig.error = 0
+        fake_tsig.other_size = 0
+        fake_tsig_packet = ndr.ndr_pack(fake_tsig)
+
+        data = packet_data + fake_tsig_packet
+        mac = self.g.sign_packet(data, data)
+        mac_list = [ord(x) for x in list(mac)]
+
+        rdata = dns.tsig_record()
+        rdata.algorithm_name = "gss-tsig"
+        rdata.time_prefix = 0
+        rdata.time = fake_tsig.time
+        rdata.fudge = 300
+        rdata.original_id = packet.id
+        rdata.error = 0
+        rdata.other_size = 0
+        rdata.mac = mac_list
+        rdata.mac_size = len(mac_list)
+
+        r = dns.res_rec()
+        r.name = key_name
+        r.rr_type = dns.DNS_QTYPE_TSIG
+        r.rr_class = dns.DNS_QCLASS_ANY
+        r.ttl = 0
+        r.length = 0xffff
+        r.rdata = rdata
+
+        additional = [r]
+        packet.additional = additional
+        packet.arcount = 1
+
+        return mac
+
+    def bad_sign_packet(self, packet, key_name):
+        '''Add bad signature for a packet by bitflipping
+        the final byte in the MAC'''
+
+        mac_list = [ord(x) for x in list("badmac")]
+
+        rdata = dns.tsig_record()
+        rdata.algorithm_name = "gss-tsig"
+        rdata.time_prefix = 0
+        rdata.time = int(time.time())
+        rdata.fudge = 300
+        rdata.original_id = packet.id
+        rdata.error = 0
+        rdata.other_size = 0
+        rdata.mac = mac_list
+        rdata.mac_size = len(mac_list)
+
+        r = dns.res_rec()
+        r.name = key_name
+        r.rr_type = dns.DNS_QTYPE_TSIG
+        r.rr_class = dns.DNS_QCLASS_ANY
+        r.ttl = 0
+        r.length = 0xffff
+        r.rdata = rdata
+
+        additional = [r]
+        packet.additional = additional
+        packet.arcount = 1
+
+    def search_record(self, name):
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_udp(p, self.server_ip)
+        return response.operation & 0x000F
+
+    def make_update_request(self, delete=False):
+        "Create a DNS update request"
+
+        rr_class = dns.DNS_QCLASS_IN
+        ttl = 900
+
+        if delete:
+            rr_class = dns.DNS_QCLASS_NONE
+            ttl = 0
+
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        q = self.make_name_question(self.get_dns_domain(),
+                                    dns.DNS_QTYPE_SOA,
+                                    dns.DNS_QCLASS_IN)
+        questions = []
+        questions.append(q)
+        self.finish_name_packet(p, questions)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = self.newrecname
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = rr_class
+        r.ttl = ttl
+        r.length = 0xffff
+        rdata = self.make_txt_record(['"This is a test"'])
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        return p
diff --git a/python/samba/tests/dns_tkey.py b/python/samba/tests/dns_tkey.py
index 21f5d560dd2..4afa7fefba5 100644
--- a/python/samba/tests/dns_tkey.py
+++ b/python/samba/tests/dns_tkey.py
@@ -28,8 +28,7 @@ import samba.getopt as options
 from samba import credentials
 from samba.dcerpc import dns, dnsp
 from samba.tests.subunitrun import SubunitOptions, TestProgram
-from samba import gensec, tests
-from samba.tests import TestCase
+from samba.tests.dns_base import DNSTKeyTest
 
 parser = optparse.OptionParser("dns.py <server name> <server ip> [options]")
 sambaopts = options.SambaOptions(parser)
@@ -57,410 +56,12 @@ server_name = args[0]
 server_ip = args[1]
 
 
-class DNSTest(TestCase):
-
-    def setUp(self):
-        super(DNSTest, self).setUp()
-        self.timeout = None
-
-    def errstr(self, errcode):
-        "Return a readable error code"
-        string_codes = [
-            "OK",
-            "FORMERR",
-            "SERVFAIL",
-            "NXDOMAIN",
-            "NOTIMP",
-            "REFUSED",
-            "YXDOMAIN",
-            "YXRRSET",
-            "NXRRSET",
-            "NOTAUTH",
-            "NOTZONE",
-            "0x0B",
-            "0x0C",
-            "0x0D",
-            "0x0E",
-            "0x0F",
-            "BADSIG",
-            "BADKEY"
-        ]
-
-        return string_codes[errcode]
-
-    def assert_rcode_equals(self, rcode, expected):
-        "Helper function to check return code"
-        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
-                          (self.errstr(expected), self.errstr(rcode)))
-
-    def assert_dns_rcode_equals(self, packet, rcode):
-        "Helper function to check return code"
-        p_errcode = packet.operation & 0x000F
-        self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                          (self.errstr(rcode), self.errstr(p_errcode)))
-
-    def assert_dns_opcode_equals(self, packet, opcode):
-        "Helper function to check opcode"
-        p_opcode = packet.operation & 0x7800
-        self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                          (opcode, p_opcode))
-
-    def make_name_packet(self, opcode, qid=None):
-        "Helper creating a dns.name_packet"
-        p = dns.name_packet()
-        if qid is None:
-            p.id = random.randint(0x0, 0xff00)
-        p.operation = opcode
-        p.questions = []
-        p.additional = []
-        return p
-
-    def finish_name_packet(self, packet, questions):
-        "Helper to finalize a dns.name_packet"
-        packet.qdcount = len(questions)
-        packet.questions = questions
-
-    def make_name_question(self, name, qtype, qclass):
-        "Helper creating a dns.name_question"
-        q = dns.name_question()
-        q.name = name
-        q.question_type = qtype
-        q.question_class = qclass
-        return q
-
-    def make_txt_record(self, records):
-        rdata_txt = dns.txt_record()
-        s_list = dnsp.string_list()
-        s_list.count = len(records)
-        s_list.str = records
-        rdata_txt.txt = s_list
-        return rdata_txt
-
-    def get_dns_domain(self):
-        "Helper to get dns domain"
-        return self.creds.get_realm().lower()
-
-    def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            s.sendall(send_packet, 0)
-            recv_packet = s.recv(2048, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
-            return (response, recv_packet)
-        finally:
-            if s is not None:
-                s.close()
-
-    def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=None):
-        "send a DNS query and read the reply, also return the raw packet"
-        s = None
-        if timeout is None:
-            timeout = self.timeout
-        try:
-            send_packet = ndr.ndr_pack(packet)
-            if dump:
-                print self.hexdump(send_packet)
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-            s.settimeout(timeout)
-            s.connect((host, 53))
-            tcp_packet = struct.pack('!H', len(send_packet))
-            tcp_packet += send_packet
-            s.sendall(tcp_packet)
-
-            recv_packet = s.recv(0xffff + 2, 0)
-            if dump:
-                print self.hexdump(recv_packet)
-            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
-
-        finally:
-            if s is not None:
-                s.close()
-
-        # unpacking and packing again should produce same bytestream
-        my_packet = ndr.ndr_pack(response)
-        self.assertEquals(my_packet, recv_packet[2:])
-
-        return (response, recv_packet[2:])
-
-    def make_txt_update(self, prefix, txt_array):
-        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
-        updates = []
-
-        name = self.get_dns_domain()
-        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
-        updates.append(u)
-        self.finish_name_packet(p, updates)
-
-        updates = []
-        r = dns.res_rec()
-        r.name = "%s.%s" % (prefix, self.get_dns_domain())
-        r.rr_type = dns.DNS_QTYPE_TXT
-        r.rr_class = dns.DNS_QCLASS_IN
-        r.ttl = 900
-        r.length = 0xffff
-        rdata = self.make_txt_record(txt_array)
-        r.rdata = rdata
-        updates.append(r)
-        p.nscount = len(updates)
-        p.nsrecs = updates
-
-        return p
-
-    def check_query_txt(self, prefix, txt_array):
-        name = "%s.%s" % (prefix, self.get_dns_domain())
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        questions = []
-
-        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
-        questions.append(q)
-
-        self.finish_name_packet(p, questions)
-        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
-        self.assertEquals(response.ancount, 1)
-        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
-
-
-class DNSTKeyTest(DNSTest):
+class TestDNSUpdates(DNSTKeyTest):
     def setUp(self):
-        super(DNSTKeyTest, self).setUp()
         self.server = server_name
         self.server_ip = server_ip
-        self.settings = {}
-        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
-        self.settings["target_hostname"] = self.server
-
-        self.creds = credentials.Credentials()
-        self.creds.guess(self.lp_ctx)
-        self.creds.set_username(tests.env_get_var_value('USERNAME'))
-        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
-        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
-        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
-
-    def tkey_trans(self):
-        "Do a TKEY transaction and establish a gensec context"
-
-        self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
-
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        q = self.make_name_question(self.key_name,
-                                    dns.DNS_QTYPE_TKEY,
-                                    dns.DNS_QCLASS_IN)
-        questions = []
-        questions.append(q)
-        self.finish_name_packet(p, questions)
-
-        r = dns.res_rec()
-        r.name = self.key_name
-        r.rr_type = dns.DNS_QTYPE_TKEY
-        r.rr_class = dns.DNS_QCLASS_IN
-        r.ttl = 0
-        r.length = 0xffff
-        rdata = dns.tkey_record()
-        rdata.algorithm = "gss-tsig"
-        rdata.inception = int(time.time())
-        rdata.expiration = int(time.time()) + 60*60
-        rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
-        rdata.error = 0
-        rdata.other_size = 0
-
-        self.g = gensec.Security.start_client(self.settings)
-        self.g.set_credentials(self.creds)
-        self.g.set_target_service("dns")
-        self.g.set_target_hostname(self.server)
-        self.g.want_feature(gensec.FEATURE_SIGN)
-        self.g.start_mech_by_name("spnego")
-
-        finished = False
-        client_to_server = ""
-
-        (finished, server_to_client) = self.g.update(client_to_server)
-        self.assertFalse(finished)
-
-        data = [ord(x) for x in list(server_to_client)]
-        rdata.key_data = data
-        rdata.key_size = len(data)
-        r.rdata = rdata
-
-        additional = [r]
-        p.arcount = 1
-        p.additional = additional
-
-        (response, response_packet) = self.dns_transaction_tcp(p, self.server_ip)
-        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
-
-        tkey_record = response.answers[0].rdata
-        data = [chr(x) for x in tkey_record.key_data]
-        server_to_client = ''.join(data)
-        (finished, client_to_server) = self.g.update(server_to_client)
-        self.assertTrue(finished)
-
-        self.verify_packet(response, response_packet)
-
-    def verify_packet(self, response, response_packet, request_mac=""):
-        self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
+        super(TestDNSUpdates, self).setUp()
 
-        tsig_record = response.additional[0].rdata
-        mac = ''.join([chr(x) for x in tsig_record.mac])
-
-        # Cut off tsig record from dns response packet for MAC verification
-        # and reset additional record count.
-        key_name_len = len(self.key_name) + 2
-        tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
-
-        response_packet_list = list(response_packet)
-        del response_packet_list[-tsig_record_len:]
-        response_packet_list[11] = chr(0)
-        response_packet_wo_tsig = ''.join(response_packet_list)
-
-        fake_tsig = dns.fake_tsig_rec()
-        fake_tsig.name = self.key_name
-        fake_tsig.rr_class = dns.DNS_QCLASS_ANY
-        fake_tsig.ttl = 0
-        fake_tsig.time_prefix = tsig_record.time_prefix
-        fake_tsig.time = tsig_record.time
-        fake_tsig.algorithm_name = tsig_record.algorithm_name
-        fake_tsig.fudge = tsig_record.fudge
-        fake_tsig.error = 0
-        fake_tsig.other_size = 0
-        fake_tsig_packet = ndr.ndr_pack(fake_tsig)
-
-        data = request_mac + response_packet_wo_tsig + fake_tsig_packet
-        self.g.check_packet(data, data, mac)
-
-    def sign_packet(self, packet, key_name):
-        "Sign a packet, calculate a MAC and add TSIG record"
-        packet_data = ndr.ndr_pack(packet)
-
-        fake_tsig = dns.fake_tsig_rec()
-        fake_tsig.name = key_name
-        fake_tsig.rr_class = dns.DNS_QCLASS_ANY
-        fake_tsig.ttl = 0
-        fake_tsig.time_prefix = 0
-        fake_tsig.time = int(time.time())
-        fake_tsig.algorithm_name = "gss-tsig"
-        fake_tsig.fudge = 300
-        fake_tsig.error = 0
-        fake_tsig.other_size = 0
-        fake_tsig_packet = ndr.ndr_pack(fake_tsig)
-
-        data = packet_data + fake_tsig_packet
-        mac = self.g.sign_packet(data, data)
-        mac_list = [ord(x) for x in list(mac)]
-
-        rdata = dns.tsig_record()
-        rdata.algorithm_name = "gss-tsig"
-        rdata.time_prefix = 0
-        rdata.time = fake_tsig.time
-        rdata.fudge = 300
-        rdata.original_id = packet.id
-        rdata.error = 0
-        rdata.other_size = 0
-        rdata.mac = mac_list
-        rdata.mac_size = len(mac_list)
-
-        r = dns.res_rec()
-        r.name = key_name
-        r.rr_type = dns.DNS_QTYPE_TSIG
-        r.rr_class = dns.DNS_QCLASS_ANY
-        r.ttl = 0
-        r.length = 0xffff
-        r.rdata = rdata
-
-        additional = [r]
-        packet.additional = additional
-        packet.arcount = 1
-
-        return mac
-
-    def bad_sign_packet(self, packet, key_name):
-        '''Add bad signature for a packet by bitflipping
-        the final byte in the MAC'''
-
-        mac_list = [ord(x) for x in list("badmac")]
-
-        rdata = dns.tsig_record()
-        rdata.algorithm_name = "gss-tsig"
-        rdata.time_prefix = 0
-        rdata.time = int(time.time())
-        rdata.fudge = 300
-        rdata.original_id = packet.id
-        rdata.error = 0
-        rdata.other_size = 0
-        rdata.mac = mac_list
-        rdata.mac_size = len(mac_list)
-
-        r = dns.res_rec()
-        r.name = key_name
-        r.rr_type = dns.DNS_QTYPE_TSIG
-        r.rr_class = dns.DNS_QCLASS_ANY
-        r.ttl = 0
-        r.length = 0xffff
-        r.rdata = rdata
-
-        additional = [r]
-        packet.additional = additional
-        packet.arcount = 1
-
-    def search_record(self, name):
-        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
-        questions = []
-
-        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
-        questions.append(q)
-
-        self.finish_name_packet(p, questions)
-        (response, response_packet) = self.dns_transaction_udp(p, self.server_ip)
-        return response.operation & 0x000F
-
-    def make_update_request(self, delete=False):
-        "Create a DNS update request"
-
-        rr_class = dns.DNS_QCLASS_IN
-        ttl = 900
-
-        if delete:
-            rr_class = dns.DNS_QCLASS_NONE
-            ttl = 0
-
-        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
-        q = self.make_name_question(self.get_dns_domain(),
-                                    dns.DNS_QTYPE_SOA,
-                                    dns.DNS_QCLASS_IN)
-        questions = []
-        questions.append(q)
-        self.finish_name_packet(p, questions)
-
-        updates = []
-        r = dns.res_rec()
-        r.name = self.newrecname
-        r.rr_type = dns.DNS_QTYPE_TXT
-        r.rr_class = rr_class
-        r.ttl = ttl
-        r.length = 0xffff
-        rdata = self.make_txt_record(['"This is a test"'])
-        r.rdata = rdata
-        updates.append(r)
-        p.nscount = len(updates)
-        p.nsrecs = updates
-
-        return p
-
-
-class TestDNSUpdates(DNSTKeyTest):
     def test_tkey(self):
         "test DNS TKEY handshake"
 
-- 
2.11.0


From 2e7a151226bff08020a5a67b4249d891d9b13b1a Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 1 Jun 2017 15:15:25 +1200
Subject: [PATCH 19/24] selftest: Use TestCaseInTempDir as base class in dns
 tests

This will help when we add a new join test based on this code

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/dns_base.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/samba/tests/dns_base.py b/python/samba/tests/dns_base.py
index 3d5aa8e25b0..2a40d999c36 100644
--- a/python/samba/tests/dns_base.py
+++ b/python/samba/tests/dns_base.py
@@ -16,7 +16,7 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 
-from samba.tests import TestCase
+from samba.tests import TestCaseInTempDir
 from samba.dcerpc import dns, dnsp
 from samba import gensec, tests
 from samba import credentials
@@ -27,7 +27,7 @@ import socket
 import uuid
 import time
 
-class DNSTest(TestCase):
+class DNSTest(TestCaseInTempDir):
 
     def setUp(self):
         super(DNSTest, self).setUp()
-- 
2.11.0


From b0f40752ae89754813da2501fb27f250cf647e0e Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 6 Jun 2017 15:21:50 +1200
Subject: [PATCH 20/24] provision: Move default handler for site=None down into
 dc_join object creation

This makes this code easier to call from a test script

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/join.py          | 3 +++
 python/samba/netcmd/domain.py | 3 ---
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/python/samba/join.py b/python/samba/join.py
index 3e70db08d2a..681275cd72d 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -53,6 +53,9 @@ class dc_join(object):
                  netbios_name=None, targetdir=None, domain=None,
                  machinepass=None, use_ntvfs=False, dns_backend=None,
                  promote_existing=False, clone_only=False):
+        if site is None:
+            site = "Default-First-Site-Name"
+
         ctx.clone_only=clone_only
 
         ctx.logger = logger
diff --git a/python/samba/netcmd/domain.py b/python/samba/netcmd/domain.py
index 4bd99ba6ff5..e3a0e4921f2 100644
--- a/python/samba/netcmd/domain.py
+++ b/python/samba/netcmd/domain.py
@@ -551,9 +551,6 @@ class cmd_domain_dcpromo(Command):
         creds = credopts.get_credentials(lp)
         net = Net(creds, lp, server=credopts.ipaddress)
 
-        if site is None:
-            site = "Default-First-Site-Name"
-
         logger = self.get_logger()
         if verbose:
             logger.setLevel(logging.DEBUG)
-- 
2.11.0


From adcb68dca484b50852ed97892d3c9b834ff91f63 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Tue, 6 Jun 2017 15:22:35 +1200
Subject: [PATCH 21/24] provision: Allow removing an existing account when
 force=True is set

This allows a practical override for use in test scripts

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/join.py | 45 +++++++++++++++++++++++----------------------
 1 file changed, 23 insertions(+), 22 deletions(-)

diff --git a/python/samba/join.py b/python/samba/join.py
index 681275cd72d..a76772a5b0f 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -201,32 +201,33 @@ class dc_join(object):
         except Exception:
             pass
 
-    def cleanup_old_accounts(ctx):
+    def cleanup_old_accounts(ctx, force=False):
         res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
                                expression='sAMAccountName=%s' % ldb.binary_encode(ctx.samname),
                                attrs=["msDS-krbTgtLink", "objectSID"])
         if len(res) == 0:
             return
 
-        creds = Credentials()
-        creds.guess(ctx.lp)
-        try:
-            creds.set_machine_account(ctx.lp)
-            creds.set_kerberos_state(ctx.creds.get_kerberos_state())
-            machine_samdb = SamDB(url="ldap://%s" % ctx.server,
-                                  session_info=system_session(),
-                                credentials=creds, lp=ctx.lp)
-        except:
-            pass
-        else:
-            token_res = machine_samdb.search(scope=ldb.SCOPE_BASE, base="", attrs=["tokenGroups"])
-            if token_res[0]["tokenGroups"][0] \
-               == res[0]["objectSID"][0]:
-                raise DCJoinException("Not removing account %s which "
-                                   "looks like a Samba DC account "
-                                   "maching the password we already have.  "
-                                   "To override, remove secrets.ldb and secrets.tdb"
-                                % ctx.samname)
+        if not force:
+            creds = Credentials()
+            creds.guess(ctx.lp)
+            try:
+                creds.set_machine_account(ctx.lp)
+                creds.set_kerberos_state(ctx.creds.get_kerberos_state())
+                machine_samdb = SamDB(url="ldap://%s" % ctx.server,
+                                      session_info=system_session(),
+                                    credentials=creds, lp=ctx.lp)
+            except:
+                pass
+            else:
+                token_res = machine_samdb.search(scope=ldb.SCOPE_BASE, base="", attrs=["tokenGroups"])
+                if token_res[0]["tokenGroups"][0] \
+                   == res[0]["objectSID"][0]:
+                    raise DCJoinException("Not removing account %s which "
+                                       "looks like a Samba DC account "
+                                       "maching the password we already have.  "
+                                       "To override, remove secrets.ldb and secrets.tdb"
+                                    % ctx.samname)
 
         ctx.del_noerror(res[0].dn, recursive=True)
 
@@ -253,11 +254,11 @@ class dc_join(object):
                                 ldb.binary_encode("dns/%s" % ctx.dnshostname)))
 
 
-    def cleanup_old_join(ctx):
+    def cleanup_old_join(ctx, force=False):
         """Remove any DNs from a previous join."""
         # find the krbtgt link
         if not ctx.subdomain:
-            ctx.cleanup_old_accounts()
+            ctx.cleanup_old_accounts(force=force)
 
         if ctx.connection_dn is not None:
             ctx.del_noerror(ctx.connection_dn)
-- 
2.11.0


From cdb0d22334cdd7b2a58b944be6ab19572760f190 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 1 Jun 2017 17:11:57 +1200
Subject: [PATCH 22/24] selftest: Test join.py and confirm that the DNS record
 is created

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/join.py       | 113 +++++++++++++++++++++++++++++++++++++++
 selftest/knownfail.d/dns-at-join |   1 +
 source4/selftest/tests.py        |   3 ++
 3 files changed, 117 insertions(+)
 create mode 100644 python/samba/tests/join.py
 create mode 100644 selftest/knownfail.d/dns-at-join

diff --git a/python/samba/tests/join.py b/python/samba/tests/join.py
new file mode 100644
index 00000000000..f18c9fd95ce
--- /dev/null
+++ b/python/samba/tests/join.py
@@ -0,0 +1,113 @@
+# Test joining as a DC and check the join was done right
+#
+# Copyright (C) Andrew Bartlett <abartlet at samba.org> 2017
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+
+import samba
+import sys
+import shutil
+import os
+from samba.tests.dns_base import DNSTest
+from samba.join import dc_join
+from samba.dcerpc import drsuapi, misc, dns
+
+def get_logger(name="subunit"):
+    """Get a logger object."""
+    import logging
+    logger = logging.getLogger(name)
+    logger.addHandler(logging.StreamHandler(sys.stderr))
+    return logger
+
+class JoinTestCase(DNSTest):
+    def setUp(self):
+        super(JoinTestCase, self).setUp()
+        self.server = samba.tests.env_get_var_value("SERVER")
+        self.server_ip = samba.tests.env_get_var_value("SERVER_IP")
+        self.lp = samba.tests.env_loadparm()
+        self.creds = self.get_credentials()
+        self.netbios_name = "jointest1"
+        logger = get_logger()
+
+        self.join_ctx = dc_join(server=self.server, creds=self.creds, lp=self.get_loadparm(),
+                                netbios_name=self.netbios_name,
+                                targetdir=self.tempdir,
+                                domain=None, logger=logger,
+                                dns_backend="SAMBA_INTERNAL")
+        self.join_ctx.userAccountControl = (samba.dsdb.UF_SERVER_TRUST_ACCOUNT |
+                                            samba.dsdb.UF_TRUSTED_FOR_DELEGATION)
+
+        self.join_ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_WRIT_REP |
+                                        drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS)
+        self.join_ctx.domain_replica_flags = self.join_ctx.replica_flags
+        self.join_ctx.secure_channel_type = misc.SEC_CHAN_BDC
+
+        self.join_ctx.cleanup_old_join()
+
+        self.join_ctx.force_all_ips = True
+
+    def tearDown(self):
+        try:
+            paths = self.join_ctx.paths
+        except AttributeError:
+            paths = None
+
+        if paths is not None:
+            shutil.rmtree(paths.private_dir)
+            shutil.rmtree(paths.state_dir)
+            shutil.rmtree(os.path.join(self.tempdir, "etc"))
+            shutil.rmtree(os.path.join(self.tempdir, "msg.lock"))
+            os.unlink(os.path.join(self.tempdir, "names.tdb"))
+
+        self.join_ctx.cleanup_old_join(force=True)
+
+        super(JoinTestCase, self).tearDown()
+
+
+    def test_join(self):
+
+        self.join_ctx.do_join()
+
+        "create a query packet containing one query record via TCP"
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = self.join_ctx.dnshostname
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        # Get expected IPs
+        IPs = samba.interface_ips(self.lp)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, len(IPs))
+
+        questions = []
+        name = "%s._msdcs.%s" % (self.join_ctx.ntds_guid, self.join_ctx.dnsforest)
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+
+        self.assertEquals(response.ancount, 1 + len(IPs))
+        self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
+        self.assertEquals(response.answers[0].rdata, self.join_ctx.dnshostname)
+        self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
diff --git a/selftest/knownfail.d/dns-at-join b/selftest/knownfail.d/dns-at-join
new file mode 100644
index 00000000000..57072e7f6c8
--- /dev/null
+++ b/selftest/knownfail.d/dns-at-join
@@ -0,0 +1 @@
+samba.tests.join.python\(ad_dc_ntvfs\).samba.tests.join.JoinTestCase.test_join\(ad_dc_ntvfs\)
\ No newline at end of file
diff --git a/source4/selftest/tests.py b/source4/selftest/tests.py
index 071660bb418..43c218de87e 100755
--- a/source4/selftest/tests.py
+++ b/source4/selftest/tests.py
@@ -823,6 +823,9 @@ for env in ['ad_dc_ntvfs']:
                            name="samba4.drs.repl_rodc.python(%s)" % env,
                            environ={'DC1': "$DC_SERVER", 'DC2': '$DC_SERVER'},
                            extra_args=['-U$DOMAIN/$DC_USERNAME%$DC_PASSWORD'])
+    planoldpythontestsuite(env, "samba.tests.join",
+                           name="samba.tests.join.python(%s)" % env,
+                           extra_args=['-U$DOMAIN/$DC_USERNAME%$DC_PASSWORD'])
 
 planoldpythontestsuite("chgdcpass:local", "samba.tests.blackbox.samba_dnsupdate",
                        environ={'DNS_SERVER_IP': '$SERVER_IP'})
-- 
2.11.0


From a25cf5f047c5cd02f18ab6706eb4f4659915028b Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Thu, 8 Jun 2017 15:25:23 +1200
Subject: [PATCH 23/24] selftest: Add test confirming join-created DNS entries
 can be modified as the DC

This ensures that samba_dnsupdate can run in the long term against the new DNS entries

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/tests/join.py       | 74 ++++++++++++++++++++++++++++++++++++----
 selftest/knownfail.d/dns-at-join |  3 +-
 2 files changed, 70 insertions(+), 7 deletions(-)

diff --git a/python/samba/tests/join.py b/python/samba/tests/join.py
index f18c9fd95ce..1f9fab1d72a 100644
--- a/python/samba/tests/join.py
+++ b/python/samba/tests/join.py
@@ -20,9 +20,10 @@ import samba
 import sys
 import shutil
 import os
-from samba.tests.dns_base import DNSTest
+from samba.tests.dns_base import DNSTKeyTest
 from samba.join import dc_join
 from samba.dcerpc import drsuapi, misc, dns
+from samba.credentials import Credentials
 
 def get_logger(name="subunit"):
     """Get a logger object."""
@@ -31,11 +32,11 @@ def get_logger(name="subunit"):
     logger.addHandler(logging.StreamHandler(sys.stderr))
     return logger
 
-class JoinTestCase(DNSTest):
+class JoinTestCase(DNSTKeyTest):
     def setUp(self):
-        super(JoinTestCase, self).setUp()
         self.server = samba.tests.env_get_var_value("SERVER")
         self.server_ip = samba.tests.env_get_var_value("SERVER_IP")
+        super(JoinTestCase, self).setUp()
         self.lp = samba.tests.env_loadparm()
         self.creds = self.get_credentials()
         self.netbios_name = "jointest1"
@@ -58,6 +59,8 @@ class JoinTestCase(DNSTest):
 
         self.join_ctx.force_all_ips = True
 
+        self.join_ctx.do_join()
+
     def tearDown(self):
         try:
             paths = self.join_ctx.paths
@@ -76,9 +79,7 @@ class JoinTestCase(DNSTest):
         super(JoinTestCase, self).tearDown()
 
 
-    def test_join(self):
-
-        self.join_ctx.do_join()
+    def test_join_makes_records(self):
 
         "create a query packet containing one query record via TCP"
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@@ -111,3 +112,64 @@ class JoinTestCase(DNSTest):
         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
         self.assertEquals(response.answers[0].rdata, self.join_ctx.dnshostname)
         self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
+
+
+    def test_join_records_can_update(self):
+        dc_creds = Credentials()
+        dc_creds.guess(self.join_ctx.lp)
+        dc_creds.set_machine_account(self.join_ctx.lp)
+
+        self.tkey_trans(creds=dc_creds)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        q = self.make_name_question(self.join_ctx.dnsdomain,
+                                    dns.DNS_QTYPE_SOA,
+                                    dns.DNS_QCLASS_IN)
+        questions = []
+        questions.append(q)
+        self.finish_name_packet(p, questions)
+
+        updates = []
+        # Delete the old expected IPs
+        IPs = samba.interface_ips(self.lp)
+        for IP in IPs[1:]:
+            if ":" in IP:
+                r = dns.res_rec()
+                r.name = self.join_ctx.dnshostname
+                r.rr_type = dns.DNS_QTYPE_AAAA
+                r.rr_class = dns.DNS_QCLASS_NONE
+                r.ttl = 0
+                r.length = 0xffff
+                rdata = IP
+            else:
+                r = dns.res_rec()
+                r.name = self.join_ctx.dnshostname
+                r.rr_type = dns.DNS_QTYPE_A
+                r.rr_class = dns.DNS_QCLASS_NONE
+                r.ttl = 0
+                r.length = 0xffff
+                rdata = IP
+
+            r.rdata = rdata
+            updates.append(r)
+
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        mac = self.sign_packet(p, self.key_name)
+        (response, response_p) = self.dns_transaction_udp(p, self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.verify_packet(response, response_p, mac)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = self.join_ctx.dnshostname
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, 1)
diff --git a/selftest/knownfail.d/dns-at-join b/selftest/knownfail.d/dns-at-join
index 57072e7f6c8..3737f1fb239 100644
--- a/selftest/knownfail.d/dns-at-join
+++ b/selftest/knownfail.d/dns-at-join
@@ -1 +1,2 @@
-samba.tests.join.python\(ad_dc_ntvfs\).samba.tests.join.JoinTestCase.test_join\(ad_dc_ntvfs\)
\ No newline at end of file
+samba.tests.join.python\(ad_dc_ntvfs\).samba.tests.join.JoinTestCase.test_join_makes_records\(ad_dc_ntvfs\)
+samba.tests.join.python\(ad_dc_ntvfs\).samba.tests.join.JoinTestCase.test_join_records_can_update\(ad_dc_ntvfs\)
-- 
2.11.0


From cec6f99e2902fbcb4815e42ef2c8dc080cf8db26 Mon Sep 17 00:00:00 2001
From: Andrew Bartlett <abartlet at samba.org>
Date: Fri, 17 Feb 2017 18:23:23 +1300
Subject: [PATCH 24/24] join.py Add DNS records at domain join time

This avoids issues getting replication going after the DC first starts
as the rest of the domain does not have to wait for samba_dnsupdate to
run successfully

We do not just run samba_dnsupdate as we want to strictly
operate against the DC we just joined:
 - We do not want to query another DNS server
 - We do not want to obtain a Kerberos ticket for the new DC
   (as the KDC we select may not be the DC we just joined,
   and so may not be in sync with the password we just set)
 - We do not wish to set the _ldap records until we have started
 - We do not wish to use NTLM (the --use-samba-tool mode forces
   NTLM)

The downside to using DCE/RPC rather than DNS is that these will
be regarded as static entries, and (against windows) have a the ACL
assigned for static entries.  However this is still better than no
DNS at all.

Because some tests want a DNS record matching their own name
this fixes some tests and removes entires from knownfail

Signed-off-by: Andrew Bartlett <abartlet at samba.org>
Reviewed-by: Garming Sam <garming at catalyst.net.nz>
---
 python/samba/join.py             | 200 ++++++++++++++++++++++++++++++++++++++-
 selftest/knownfail.d/dns         |  13 +--
 selftest/knownfail.d/dns-at-join |   2 -
 3 files changed, 200 insertions(+), 15 deletions(-)
 delete mode 100644 selftest/knownfail.d/dns-at-join

diff --git a/python/samba/join.py b/python/samba/join.py
index a76772a5b0f..fa87f0bb32f 100644
--- a/python/samba/join.py
+++ b/python/samba/join.py
@@ -22,8 +22,8 @@ from samba.auth import system_session
 from samba.samdb import SamDB
 from samba import gensec, Ldb, drs_utils, arcfour_encrypt, string_to_byte_array
 import ldb, samba, sys, uuid
-from samba.ndr import ndr_pack
-from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs
+from samba.ndr import ndr_pack, ndr_unpack
+from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs, dnsserver, dnsp
 from samba.dsdb import DS_DOMAIN_FUNCTION_2003
 from samba.credentials import Credentials, DONT_USE_KERBEROS
 from samba.provision import secretsdb_self_join, provision, provision_fill, FILL_DRS, FILL_SUBDOMAIN
@@ -35,6 +35,9 @@ from samba.provision.sambadns import setup_bind9_dns
 from samba import read_and_sub_file
 from samba import werror
 from base64 import b64encode
+from samba import WERRORError
+from samba.dnsserver import ARecord, AAAARecord, PTRRecord, CNameRecord, NSRecord, MXRecord, SOARecord, SRVRecord, TXTRecord
+from samba import sd_utils
 import logging
 import talloc
 import random
@@ -187,6 +190,12 @@ class dc_join(object):
         ctx.adminpass = None
         ctx.partition_dn = None
 
+        ctx.dns_a_dn = None
+        ctx.dns_cname_dn = None
+
+        # Do not normally register 127. addresses but allow override for selftest
+        ctx.force_all_ips = False
+
     def del_noerror(ctx, dn, recursive=False):
         if recursive:
             try:
@@ -294,6 +303,13 @@ class dc_join(object):
 
             lsaconn.DeleteTrustedDomain(pol_handle, info.info_ex.sid)
 
+        if ctx.dns_a_dn:
+            ctx.del_noerror(ctx.dns_a_dn)
+
+        if ctx.dns_cname_dn:
+            ctx.del_noerror(ctx.dns_cname_dn)
+
+
 
     def promote_possible(ctx):
         """confirm that the account is just a bare NT4 BDC or a member server, so can be safely promoted"""
@@ -692,12 +708,16 @@ class dc_join(object):
                                      newpassword=ctx.acct_pass.encode('utf-8'))
 
             res = ctx.samdb.search(base=ctx.acct_dn, scope=ldb.SCOPE_BASE,
-                                   attrs=["msDS-KeyVersionNumber"])
+                                   attrs=["msDS-KeyVersionNumber",
+                                          "objectSID"])
             if "msDS-KeyVersionNumber" in res[0]:
                 ctx.key_version_number = int(res[0]["msDS-KeyVersionNumber"][0])
             else:
                 ctx.key_version_number = None
 
+            ctx.new_dc_account_sid = ndr_unpack(security.dom_sid,
+                                                res[0]["objectSid"][0])
+
             print("Enabling account")
             m = ldb.Message()
             m.dn = ldb.Dn(ctx.samdb, ctx.acct_dn)
@@ -974,6 +994,175 @@ class dc_join(object):
 
         ctx.drsuapi.DsReplicaUpdateRefs(ctx.drsuapi_handle, 1, r)
 
+    def join_add_dns_records(ctx):
+        """Remotely Add a DNS record to the target DC.  We assume that if we
+           replicate DNS that the server holds the DNS roles and can accept
+           updates.
+
+           This avoids issues getting replication going after the DC
+           first starts as the rest of the domain does not have to
+           wait for samba_dnsupdate to run successfully.
+
+           Specifically, we add the records implied by the DsReplicaUpdateRefs
+           call above.
+
+           We do not just run samba_dnsupdate as we want to strictly
+           operate against the DC we just joined:
+            - We do not want to query another DNS server
+            - We do not want to obtain a Kerberos ticket
+              (as the KDC we select may not be the DC we just joined,
+              and so may not be in sync with the password we just set)
+            - We do not wish to set the _ldap records until we have started
+            - We do not wish to use NTLM (the --use-samba-tool mode forces
+              NTLM)
+
+        """
+
+        client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
+        record_type = dnsp.DNS_TYPE_A
+        select_flags = dnsserver.DNS_RPC_VIEW_AUTHORITY_DATA |\
+                       dnsserver.DNS_RPC_VIEW_NO_CHILDREN
+
+        zone = ctx.dnsdomain
+        msdcs_zone = "_msdcs.%s" % ctx.dnsforest
+        name = ctx.myname
+        msdcs_cname = str(ctx.ntds_guid)
+        cname_target = "%s.%s" % (name, zone)
+        IPs = samba.interface_ips(ctx.lp, ctx.force_all_ips)
+
+        ctx.logger.info("Adding %d remote DNS records for %s.%s" % \
+                        (len(IPs), name, zone))
+
+        binding_options = "sign"
+        dns_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
+                                      ctx.lp, ctx.creds)
+
+
+        name_found = True
+
+        sd_helper = samba.sd_utils.SDUtils(ctx.samdb)
+
+        change_owner_sd = security.descriptor()
+        change_owner_sd.owner_sid = ctx.new_dc_account_sid
+        change_owner_sd.group_sid = security.dom_sid("%s-%d" %
+                                                     (str(ctx.domsid),
+                                                      security.DOMAIN_RID_DCS))
+
+        # TODO: Remove any old records from the primary DNS name
+        try:
+            (buflen, res) \
+                = dns_conn.DnssrvEnumRecords2(client_version,
+                                              0,
+                                              ctx.server,
+                                              zone,
+                                              name,
+                                              None,
+                                              dnsp.DNS_TYPE_ALL,
+                                              select_flags,
+                                              None,
+                                              None)
+        except WERRORError as e:
+            if e.args[0] == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
+                name_found = False
+                pass
+
+        if name_found:
+            for rec in res.rec:
+                for record in rec.records:
+                    if record.wType == dnsp.DNS_TYPE_A or \
+                       record.wType == dnsp.DNS_TYPE_AAAA:
+                        # delete record
+                        del_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
+                        del_rec_buf.rec = record
+                        try:
+                            dns_conn.DnssrvUpdateRecord2(client_version,
+                                                         0,
+                                                         ctx.server,
+                                                         zone,
+                                                         name,
+                                                         None,
+                                                         del_rec_buf)
+                        except WERRORError as e:
+                            if e.args[0] == werror.WERR_DNS_ERROR_NAME_DOES_NOT_EXIST:
+                                pass
+                            else:
+                                raise
+
+        for IP in IPs:
+            if IP.find(':') != -1:
+                ctx.logger.info("Adding DNS AAAA record %s.%s for IPv6 IP: %s"
+                                % (name, zone, IP))
+                rec = AAAARecord(IP)
+            else:
+                ctx.logger.info("Adding DNS A record %s.%s for IPv4 IP: %s"
+                                % (name, zone, IP))
+                rec = ARecord(IP)
+
+            # Add record
+            add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
+            add_rec_buf.rec = rec
+            dns_conn.DnssrvUpdateRecord2(client_version,
+                                         0,
+                                         ctx.server,
+                                         zone,
+                                         name,
+                                         add_rec_buf,
+                                         None)
+
+        if (len(IPs) > 0):
+            domaindns_zone_dn = ldb.Dn(ctx.samdb, ctx.domaindns_zone)
+            (ctx.dns_a_dn, ldap_record) \
+                = ctx.samdb.dns_lookup("%s.%s" % (name, zone),
+                                       dns_partition=domaindns_zone_dn)
+
+            # Make the DC own the DNS record, not the administrator
+            sd_helper.modify_sd_on_dn(ctx.dns_a_dn, change_owner_sd,
+                                      controls=["sd_flags:1:%d"
+                                                % (security.SECINFO_OWNER
+                                                   | security.SECINFO_GROUP)])
+
+
+            # Add record
+            ctx.logger.info("Adding DNS CNAME record %s.%s for %s"
+                            % (msdcs_cname, msdcs_zone, cname_target))
+
+            add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
+            rec = CNameRecord(cname_target)
+            add_rec_buf.rec = rec
+            dns_conn.DnssrvUpdateRecord2(client_version,
+                                         0,
+                                         ctx.server,
+                                         msdcs_zone,
+                                         msdcs_cname,
+                                         add_rec_buf,
+                                         None)
+
+            forestdns_zone_dn = ldb.Dn(ctx.samdb, ctx.forestdns_zone)
+            (ctx.dns_cname_dn, ldap_record) \
+                = ctx.samdb.dns_lookup("%s.%s" % (msdcs_cname, msdcs_zone),
+                                       dns_partition=forestdns_zone_dn)
+
+            # Make the DC own the DNS record, not the administrator
+            sd_helper.modify_sd_on_dn(ctx.dns_cname_dn, change_owner_sd,
+                                      controls=["sd_flags:1:%d"
+                                                % (security.SECINFO_OWNER
+                                                   | security.SECINFO_GROUP)])
+
+        ctx.logger.info("All other DNS records (like _ldap SRV records) " +
+                        "will be created samba_dnsupdate on first startup")
+
+
+    def join_replicate_new_dns_records(ctx):
+        for nc in (ctx.domaindns_zone, ctx.forestdns_zone):
+            if nc in ctx.nc_list:
+                ctx.logger.info("Replicating new DNS records in %s" % (str(nc)))
+                ctx.repl.replicate(nc, ctx.source_dsa_invocation_id,
+                                   ctx.ntds_guid, rodc=ctx.RODC,
+                                   replica_flags=ctx.replica_flags,
+                                   full_sync=False)
+
+
+
     def join_finalise(ctx):
         """Finalise the join, mark us synchronised and setup secrets db."""
 
@@ -1190,6 +1379,11 @@ class dc_join(object):
                 ctx.join_add_objects2()
                 ctx.join_provision_own_domain()
                 ctx.join_setup_trusts()
+
+            if not ctx.clone_only and ctx.dns_backend != "NONE":
+                ctx.join_add_dns_records()
+                ctx.join_replicate_new_dns_records()
+
             ctx.join_finalise()
         except:
             try:
diff --git a/selftest/knownfail.d/dns b/selftest/knownfail.d/dns
index c40041d1892..cb3003240ea 100644
--- a/selftest/knownfail.d/dns
+++ b/selftest/knownfail.d/dns
@@ -36,19 +36,12 @@ samba.tests.dns.__main__.TestRPCRoundtrip.test_update_add_txt_rpc_to_dns\(rodc:l
 samba.tests.dns.__main__.TestZones.test_soa_query\(rodc:local\)
 samba.tests.dns.__main__.TestComplexQueries.test_cname_two_chain\(vampire_dc:local\)
 samba.tests.dns.__main__.TestComplexQueries.test_one_a_query\(vampire_dc:local\)
-
-# The SOA override should not pass against the RODC, it must not overstamp
-samba.tests.dns.__main__.TestSimpleQueries.test_one_SOA_query\(rodc:local\)
-
-# The very first DC will have DNS records, but subsequent DCs only get entries into
-# the dns_hosts_file in our selftest env
 samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query\(vampire_dc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query_tcp\(vampire_dc:local\)
-samba.tests.dns.__main__.TestSimpleQueries.test_one_mx_query\(vampire_dc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_qtype_all_query\(vampire_dc:local\)
-samba.tests.dns.__main__.TestSimpleQueries.test_soa_hostname_query\(vampire_dc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query\(rodc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_one_a_query_tcp\(rodc:local\)
-samba.tests.dns.__main__.TestSimpleQueries.test_one_mx_query\(rodc:local\)
 samba.tests.dns.__main__.TestSimpleQueries.test_qtype_all_query\(rodc:local\)
-samba.tests.dns.__main__.TestSimpleQueries.test_soa_hostname_query\(rodc:local\)
+
+# The SOA override should not pass against the RODC, it must not overstamp
+samba.tests.dns.__main__.TestSimpleQueries.test_one_SOA_query\(rodc:local\)
diff --git a/selftest/knownfail.d/dns-at-join b/selftest/knownfail.d/dns-at-join
deleted file mode 100644
index 3737f1fb239..00000000000
--- a/selftest/knownfail.d/dns-at-join
+++ /dev/null
@@ -1,2 +0,0 @@
-samba.tests.join.python\(ad_dc_ntvfs\).samba.tests.join.JoinTestCase.test_join_makes_records\(ad_dc_ntvfs\)
-samba.tests.join.python\(ad_dc_ntvfs\).samba.tests.join.JoinTestCase.test_join_records_can_update\(ad_dc_ntvfs\)
-- 
2.11.0



More information about the samba-technical mailing list