[PATCH] Cache SAMR QueryDisplayInfo, EnumDomainGroups and EnumDomainUsers

Gary Lockyer gary at catalyst.net.nz
Wed Nov 14 20:10:01 UTC 2018


Updated patch set, incorporating feedback from Matthias, and Andrews via
gitlab.

Gary.

On 4/11/18 07:53, Matthias Dieter Wallnöfer wrote:
> Nice work Gary, but this introduces some signed/unsigned integer
> mix-ups, to which we should pay more attention than in the past.
> 
> Please find the three occurrences which I would change into "unsigned"
> (I know it's confusing that the LDB API ldb_* works with "unsigned" and
> the SAMDB one with plain "int" since it returns also non-success codes).
> 
> Matthias
> 
> Gary Lockyer via samba-technical wrote:
>> static NTSTATUS load_guid_cache(
>> +	struct samr_guid_cache *cache,
>> +	struct samr_domain_state *d_state,
>> +	int ldb_cnt,
> ^^^
>> +	struct ldb_message **res)
>> +{
>> +	NTSTATUS status = NT_STATUS_OK;
>> +	int i;
> ^^^
>>  
>>  /*
>>    samr_Connect
>> @@ -384,6 +439,7 @@ static NTSTATUS dcesrv_samr_OpenDomain(struct dcesrv_call_state *dce_call, TALLO
>>  	const char * const dom_attrs[] = { "cn", NULL};
>>  	struct ldb_message **dom_msgs;
>>  	int ret;
>> +	int i;
> ^^^
> 
> 
-------------- next part --------------
From de9b6c3697a76bc2ada6dc7c1afb8f7bde72dc9c Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Tue, 9 Oct 2018 11:09:20 +1300
Subject: [PATCH 1/7] tests samr: Extra tests for samr_QueryDisplayInfo

Add extra tests to test the content returned by samr_QueryDisplayInfo,
which is not tested for the ADDC.  Also adds tests for the result
caching added in the following commit.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 python/samba/tests/dcerpc/sam.py | 388 ++++++++++++++++++++++++++++++-
 selftest/knownfail.d/samr        |  11 +
 2 files changed, 398 insertions(+), 1 deletion(-)
 create mode 100644 selftest/knownfail.d/samr

diff --git a/python/samba/tests/dcerpc/sam.py b/python/samba/tests/dcerpc/sam.py
index 9c432e3a37e..904af6b0d98 100644
--- a/python/samba/tests/dcerpc/sam.py
+++ b/python/samba/tests/dcerpc/sam.py
@@ -19,9 +19,20 @@
 
 """Tests for samba.dcerpc.sam."""
 
-from samba.dcerpc import samr, security
+from samba.dcerpc import samr, security, lsa
 from samba.tests import RpcInterfaceTestCase
+from samba.tests import env_loadparm, delete_force
 
+from samba.credentials import Credentials
+from samba.auth import system_session
+from samba.samdb import SamDB
+from samba.dsdb import (
+    ATYPE_NORMAL_ACCOUNT,
+    ATYPE_WORKSTATION_TRUST,
+    GTYPE_SECURITY_UNIVERSAL_GROUP,
+    GTYPE_SECURITY_GLOBAL_GROUP)
+from samba import generate_random_password
+import os
 
 # FIXME: Pidl should be doing this for us
 def toArray(handle, array, num_entries):
@@ -33,6 +44,32 @@ class SamrTests(RpcInterfaceTestCase):
     def setUp(self):
         super(SamrTests, self).setUp()
         self.conn = samr.samr("ncalrpc:", self.get_loadparm())
+        self.open_samdb()
+        self.open_domain_handle()
+
+    #
+    # Open the samba database
+    #
+    def open_samdb(self):
+        self.lp = env_loadparm()
+        self.domain = os.environ["DOMAIN"]
+        self.creds = Credentials()
+        self.creds.guess(self.lp)
+        self.session = system_session()
+        self.samdb = SamDB(
+            session_info=self.session, credentials=self.creds, lp=self.lp)
+
+    #
+    # Open a SAMR Domain handle
+    def open_domain_handle(self):
+        self.handle = self.conn.Connect2(
+            None, security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+        self.domain_sid = self.conn.LookupDomain(
+            self.handle, lsa.String(self.domain))
+
+        self.domain_handle = self.conn.OpenDomain(
+            self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
 
     def test_connect5(self):
         (level, info, handle) = self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
@@ -45,3 +82,352 @@ class SamrTests(RpcInterfaceTestCase):
         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
         domains = toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
         self.conn.Close(handle)
+
+    # Create groups based on the id list supplied, the id is used to
+    # form a unique name and description.
+    #
+    # returns a list of the created dn's, which can be passed to delete_dns
+    # to clean up after the test has run.
+    def create_groups(self, ids):
+        dns = []
+        for i in ids:
+            name = "SAMR_GRP%d" % i
+            dn = "cn=%s,cn=Users,%s" % (name, self.samdb.domain_dn())
+            delete_force(self.samdb, dn)
+
+            self.samdb.newgroup(name)
+            dns.append(dn)
+        return dns
+
+    # Create user accounts based on the id list supplied, the id is used to
+    # form a unique name and description.
+    #
+    # returns a list of the created dn's, which can be passed to delete_dns
+    # to clean up after the test has run.
+    def create_users(self, ids):
+        dns = []
+        for i in ids:
+            name = "SAMR_USER%d" % i
+            dn = "cn=%s,CN=USERS,%s" % (name, self.samdb.domain_dn())
+            delete_force(self.samdb, dn)
+            password = generate_random_password(32, 32)
+            self.samdb.newuser(
+                name,
+                password,
+                description="Description for " + name,
+                givenname="given%dname" % i,
+                surname="surname%d" % i)
+            dns.append(dn)
+        return dns
+
+    # Create computer accounts based on the id list supplied, the id is used to
+    # form a unique name and description.
+    #
+    # returns a list of the created dn's, which can be passed to delete_dns
+    # to clean up after the test has run.
+    def create_computers(self, ids):
+        dns = []
+        for i in ids:
+            name = "SAMR_CMP%d" % i
+            dn = "cn=%s,cn=COMPUTERS,%s" % (name, self.samdb.domain_dn())
+            delete_force(self.samdb, dn)
+
+            self.samdb.newcomputer(name, description="Description of " + name)
+            dns.append(dn)
+        return dns
+
+    # Delete the specified dn's.
+    #
+    # Used to clean up entries created by individual tests.
+    #
+    def delete_dns(self, dns):
+        for dn in dns:
+            delete_force(self.samdb, dn)
+
+    # Common tests for QueryDisplayInfo
+    #
+    def _test_QueryDisplayInfo(
+            self, level, check_results, select, attributes, add_elements):
+        #
+        # Get the expected results by querying the samdb database directly.
+        # We do this rather than use a list of expected results as this runs
+        # with other tests so we do not have a known fixed list of elements
+        expected = self.samdb.search(expression=select, attrs=attributes)
+        self.assertTrue(len(expected) > 0)
+
+        #
+        # Perform QueryDisplayInfo with max results greater than the expected
+        # number of results.
+        (ts, rs, actual) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 0, 1024, 4294967295)
+
+        self.assertEquals(len(expected), ts)
+        self.assertEquals(len(expected), rs)
+        check_results(expected, actual.entries)
+
+        #
+        # Perform QueryDisplayInfo with max results set to the number of
+        # results returned from the first query, should return the same results
+        (ts1, rs1, actual1) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 0, rs, 4294967295)
+        self.assertEquals(ts, ts1)
+        self.assertEquals(rs, rs1)
+        check_results(expected, actual1.entries)
+
+        #
+        # Perform QueryDisplayInfo and get the last two results.
+        # Note: We are assuming there are at least three entries
+        self.assertTrue(ts > 2)
+        (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, (ts - 2), 2, 4294967295)
+        self.assertEquals(ts, ts2)
+        self.assertEquals(2, rs2)
+        check_results(list(expected)[-2:], actual2.entries)
+
+        #
+        # Perform QueryDisplayInfo and get the first two results.
+        # Note: We are assuming there are at least three entries
+        self.assertTrue(ts > 2)
+        (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 0, 2, 4294967295)
+        self.assertEquals(ts, ts2)
+        self.assertEquals(2, rs2)
+        check_results(list(expected)[:2], actual2.entries)
+
+        #
+        # Perform QueryDisplayInfo and get two results in the middle of the
+        # list i.e. not the first or the last entry.
+        # Note: We are assuming there are at least four entries
+        self.assertTrue(ts > 3)
+        (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 1, 2, 4294967295)
+        self.assertEquals(ts, ts2)
+        self.assertEquals(2, rs2)
+        check_results(list(expected)[1:2], actual2.entries)
+
+        #
+        # To check that cached values are being returned rather than the
+        # results being re-read from disk we add elements, and request all
+        # but the first result.
+        #
+        dns = add_elements([1000, 1002, 1003, 1004])
+
+        #
+        # Perform QueryDisplayInfo and get all but the first result.
+        # We should be using the cached results so the entries we just added
+        # should not be present
+        (ts3, rs3, actual3) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 1, 1024, 4294967295)
+        self.assertEquals(ts, ts3)
+        self.assertEquals(len(expected) - 1, rs3)
+        check_results(list(expected)[1:], actual3.entries)
+
+        #
+        # Perform QueryDisplayInfo and get all the results.
+        # As the start index is zero we should reread the data from disk and
+        # the added entries should be there
+        new = self.samdb.search(expression=select, attrs=attributes)
+        (ts4, rs4, actual4) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 0, 1024, 4294967295)
+        self.assertEquals(len(expected) + len(dns), ts4)
+        self.assertEquals(len(expected) + len(dns), rs4)
+        check_results(new, actual4.entries)
+
+        # Delete the added DN's and query all but the first entry.
+        # This should ensure the cached results are used and that the
+        # missing entry code is triggered.
+        self.delete_dns(dns)
+        (ts5, rs5, actual5) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, 1, 1024, 4294967295)
+        self.assertEquals(len(expected) + len(dns), ts5)
+        # The deleted results will be filtered from the result set so should
+        # be missing from the returned results.
+        # Note: depending on the GUID order, the first result in the cache may
+        #       be a deleted entry, in which case the results will contain all
+        #       the expected elements, otherwise the first expected result will
+        #       be missing.
+        if rs5 == len(expected):
+            check_results(expected, actual5.entries)
+        elif rs5 == (len(expected) - 1):
+            check_results(list(expected)[1:], actual5.entries)
+        else:
+            self.fail("Incorrect number of entries {0}".format(rs5))
+
+        #
+        # Perform QueryDisplayInfo specifying an index past the end of the
+        # available data.
+        # Should return no data.
+        (ts6, rs6, actual6) = self.conn.QueryDisplayInfo(
+            self.domain_handle, level, ts5, 1, 4294967295)
+        self.assertEquals(ts5, ts6)
+        self.assertEquals(0, rs6)
+
+        self.conn.Close(self.handle)
+
+    # Test for QueryDisplayInfo, Level 1
+    # Returns the sAMAccountName, displayName and description for all
+    # the user accounts.
+    #
+    def test_QueryDisplayInfo_level_1(self):
+        def check_results(expected, actual):
+            # Assume the QueryDisplayInfo and ldb.search return their results
+            # in the same order
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.DispEntryGeneral))
+                self.assertEquals(str(e["sAMAccountName"]),
+                                  str(a.account_name))
+
+                # The displayName and description are optional.
+                # In the expected results they will be missing, in
+                # samr.DispEntryGeneral the corresponding attribute will have a
+                # length of zero.
+                #
+                if a.full_name.length == 0:
+                    self.assertFalse("displayName" in e)
+                else:
+                    self.assertEquals(str(e["displayName"]), str(a.full_name))
+
+                if a.description.length == 0:
+                    self.assertFalse("description" in e)
+                else:
+                    self.assertEquals(str(e["description"]),
+                                      str(a.description))
+        # Create four user accounts
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_users([1, 2, 3, 4])
+
+        select = "(&(objectclass=user)(sAMAccountType={0}))".format(
+            ATYPE_NORMAL_ACCOUNT)
+        attributes = ["sAMAccountName", "displayName", "description"]
+        self._test_QueryDisplayInfo(
+            1, check_results, select, attributes, self.create_users)
+
+        self.delete_dns(dns)
+
+    # Test for QueryDisplayInfo, Level 2
+    # Returns the sAMAccountName and description for all
+    # the computer accounts.
+    #
+    def test_QueryDisplayInfo_level_2(self):
+        def check_results(expected, actual):
+            # Assume the QueryDisplayInfo and ldb.search return their results
+            # in the same order
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.DispEntryFull))
+                self.assertEquals(str(e["sAMAccountName"]),
+                                  str(a.account_name))
+
+                # The description is optional.
+                # In the expected results they will be missing, in
+                # samr.DispEntryGeneral the corresponding attribute will have a
+                # length of zero.
+                #
+                if a.description.length == 0:
+                    self.assertFalse("description" in e)
+                else:
+                    self.assertEquals(str(e["description"]),
+                                      str(a.description))
+
+        # Create four computer accounts
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_computers([1, 2, 3, 4])
+
+        select = "(&(objectclass=user)(sAMAccountType={0}))".format(
+            ATYPE_WORKSTATION_TRUST)
+        attributes = ["sAMAccountName", "description"]
+        self._test_QueryDisplayInfo(
+            2, check_results, select, attributes, self.create_computers)
+
+        self.delete_dns(dns)
+
+    # Test for QueryDisplayInfo, Level 3
+    # Returns the sAMAccountName and description for all
+    # the groups.
+    #
+    def test_QueryDisplayInfo_level_3(self):
+        def check_results(expected, actual):
+            # Assume the QueryDisplayInfo and ldb.search return their results
+            # in the same order
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.DispEntryFullGroup))
+                self.assertEquals(str(e["sAMAccountName"]),
+                                  str(a.account_name))
+
+                # The description is optional.
+                # In the expected results they will be missing, in
+                # samr.DispEntryGeneral the corresponding attribute will have a
+                # length of zero.
+                #
+                if a.description.length == 0:
+                    self.assertFalse("description" in e)
+                else:
+                    self.assertEquals(str(e["description"]),
+                                      str(a.description))
+
+        # Create four groups
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_groups([1, 2, 3, 4])
+
+        select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
+            GTYPE_SECURITY_UNIVERSAL_GROUP,
+            GTYPE_SECURITY_GLOBAL_GROUP)
+        attributes = ["sAMAccountName", "description"]
+        self._test_QueryDisplayInfo(
+            3, check_results, select, attributes, self.create_groups)
+
+        self.delete_dns(dns)
+
+    # Test for QueryDisplayInfo, Level 4
+    # Returns the sAMAccountName (as an ASCII string)
+    # for all the user accounts.
+    #
+    def test_QueryDisplayInfo_level_4(self):
+        def check_results(expected, actual):
+            # Assume the QueryDisplayInfo and ldb.search return their results
+            # in the same order
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.DispEntryAscii))
+                self.assertTrue(
+                    isinstance(a.account_name, lsa.AsciiStringLarge))
+                self.assertEquals(
+                    str(e["sAMAccountName"]), str(a.account_name.string))
+
+        # Create four user accounts
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_users([1, 2, 3, 4])
+
+        select = "(&(objectclass=user)(sAMAccountType={0}))".format(
+            ATYPE_NORMAL_ACCOUNT)
+        attributes = ["sAMAccountName", "displayName", "description"]
+        self._test_QueryDisplayInfo(
+            4, check_results, select, attributes, self.create_users)
+
+        self.delete_dns(dns)
+
+    # Test for QueryDisplayInfo, Level 5
+    # Returns the sAMAccountName (as an ASCII string)
+    # for all the groups.
+    #
+    def test_QueryDisplayInfo_level_5(self):
+        def check_results(expected, actual):
+            # Assume the QueryDisplayInfo and ldb.search return their results
+            # in the same order
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.DispEntryAscii))
+                self.assertTrue(
+                    isinstance(a.account_name, lsa.AsciiStringLarge))
+                self.assertEquals(
+                    str(e["sAMAccountName"]), str(a.account_name.string))
+
+        # Create four groups
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_groups([1, 2, 3, 4])
+
+        select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
+            GTYPE_SECURITY_UNIVERSAL_GROUP,
+            GTYPE_SECURITY_GLOBAL_GROUP)
+        attributes = ["sAMAccountName", "description"]
+        self._test_QueryDisplayInfo(
+            5, check_results, select, attributes, self.create_groups)
+
+        self.delete_dns(dns)
diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
new file mode 100644
index 00000000000..310f44d1dd2
--- /dev/null
+++ b/selftest/knownfail.d/samr
@@ -0,0 +1,11 @@
+samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_1\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_2\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_3\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_4\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_5\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_1\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_2\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_3\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_4\(ad_dc_ntvfs:local\)
+samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_5\(ad_dc_ntvfs:local\)
+
-- 
2.17.1


From 3aca7a23782706102aa9abfdeca1774932003b58 Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Tue, 9 Oct 2018 11:11:12 +1300
Subject: [PATCH 2/7] tests samr: remove PEP8 warnings

Remove PEP8 warnings from the samr tests.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 python/samba/tests/dcerpc/sam.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/python/samba/tests/dcerpc/sam.py b/python/samba/tests/dcerpc/sam.py
index 904af6b0d98..67f62d4d2b5 100644
--- a/python/samba/tests/dcerpc/sam.py
+++ b/python/samba/tests/dcerpc/sam.py
@@ -34,6 +34,7 @@ from samba.dsdb import (
 from samba import generate_random_password
 import os
 
+
 # FIXME: Pidl should be doing this for us
 def toArray(handle, array, num_entries):
     return [(entry.idx, entry.name) for entry in array.entries[:num_entries]]
@@ -72,7 +73,8 @@ class SamrTests(RpcInterfaceTestCase):
             self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
 
     def test_connect5(self):
-        (level, info, handle) = self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
+        (level, info, handle) =\
+            self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
 
     def test_connect2(self):
         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
@@ -80,7 +82,7 @@ class SamrTests(RpcInterfaceTestCase):
 
     def test_EnumDomains(self):
         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
-        domains = toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
+        toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
         self.conn.Close(handle)
 
     # Create groups based on the id list supplied, the id is used to
-- 
2.17.1


From 5419d01d1c80376531cecae64d5ac7d24b8bacaf Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Wed, 10 Oct 2018 09:20:25 +1300
Subject: [PATCH 3/7] source4 samr: cache samr_QueryDisplayInfo results

Add a cache of GUID's that matched the last samr_QueryDisplayInfo made on a
domain handle.  The cache is cleared if the requested start index is
zero, or if the level does not match that in the cache.

The cache is maintained in the guid_caches array of the dcesrv_handle.

Note: that currently this cache exists for the lifetime of the RPC
      handle.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 selftest/knownfail.d/samr             |  11 -
 source4/rpc_server/samr/dcesrv_samr.c | 439 +++++++++++++++++---------
 source4/rpc_server/samr/dcesrv_samr.h |  15 +
 3 files changed, 310 insertions(+), 155 deletions(-)
 delete mode 100644 selftest/knownfail.d/samr

diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
deleted file mode 100644
index 310f44d1dd2..00000000000
--- a/selftest/knownfail.d/samr
+++ /dev/null
@@ -1,11 +0,0 @@
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_1\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_2\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_3\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_4\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_5\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_1\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_2\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_3\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_4\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_5\(ad_dc_ntvfs:local\)
-
diff --git a/source4/rpc_server/samr/dcesrv_samr.c b/source4/rpc_server/samr/dcesrv_samr.c
index 3df0c51dfee..3ae53d7cf87 100644
--- a/source4/rpc_server/samr/dcesrv_samr.c
+++ b/source4/rpc_server/samr/dcesrv_samr.c
@@ -145,7 +145,62 @@ static NTSTATUS dcesrv_interface_samr_bind(struct dcesrv_call_state *dce_call,
 	}								\
 } while (0)
 
+/*
+ * Clear a GUID cache
+ */
+static void clear_guid_cache(struct samr_guid_cache *cache)
+{
+	cache->handle = 0;
+	cache->size = 0;
+	TALLOC_FREE(cache->entries);
+}
 
+/*
+ * initialize a GUID cache
+ */
+static void initialize_guid_cache(struct samr_guid_cache *cache)
+{
+	cache->handle = 0;
+	cache->size = 0;
+	cache->entries = NULL;
+}
+
+static NTSTATUS load_guid_cache(
+	struct samr_guid_cache *cache,
+	struct samr_domain_state *d_state,
+	unsigned ldb_cnt,
+	struct ldb_message **res)
+{
+	NTSTATUS status = NT_STATUS_OK;
+	unsigned i;
+	TALLOC_CTX *frame = talloc_stackframe();
+
+	clear_guid_cache(cache);
+
+	/*
+	 * Store the GUID's in the cache.
+	 */
+	cache->handle = 0;
+	cache->size = ldb_cnt;
+	cache->entries = talloc_array(d_state, struct GUID, ldb_cnt);
+	if (cache->entries == NULL) {
+		clear_guid_cache(cache);
+		status = NT_STATUS_NO_MEMORY;
+		goto exit;
+	}
+
+	/*
+	 * Extract a list of the GUIDs for all the matching objects
+	 * we cache just the GUIDS to reduce the memory overhead of
+	 * the result cache.
+	 */
+	for (i = 0; i < ldb_cnt; i++) {
+		cache->entries[i] = samdb_result_guid(res[i], "objectGUID");
+	}
+exit:
+	TALLOC_FREE(frame);
+	return status;
+}
 
 /*
   samr_Connect
@@ -384,6 +439,7 @@ static NTSTATUS dcesrv_samr_OpenDomain(struct dcesrv_call_state *dce_call, TALLO
 	const char * const dom_attrs[] = { "cn", NULL};
 	struct ldb_message **dom_msgs;
 	int ret;
+	unsigned i;
 
 	ZERO_STRUCTP(r->out.domain_handle);
 
@@ -435,6 +491,10 @@ static NTSTATUS dcesrv_samr_OpenDomain(struct dcesrv_call_state *dce_call, TALLO
 
 	d_state->lp_ctx = dce_call->conn->dce_ctx->lp_ctx;
 
+	for (i = 0; i < SAMR_LAST_CACHE; i++) {
+		initialize_guid_cache(&d_state->guid_caches[i]);
+	}
+
 	h_domain = dcesrv_handle_new(dce_call->context, SAMR_HANDLE_DOMAIN);
 	if (!h_domain) {
 		talloc_free(d_state);
@@ -3707,10 +3767,13 @@ static NTSTATUS dcesrv_samr_GetGroupsForUser(struct dcesrv_call_state *dce_call,
 	return NT_STATUS_OK;
 }
 
-
 /*
-  samr_QueryDisplayInfo
-*/
+ * samr_QueryDisplayInfo
+ *
+ * A cache of the GUID's matching the last query is maintained
+ * in the SAMR_QUERY_DISPLAY_INFO_CACHE guid_cache maintained o
+ * n the dcesrv_handle.
+ */
 static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call, TALLOC_CTX *mem_ctx,
 		       struct samr_QueryDisplayInfo *r)
 {
@@ -3718,77 +3781,148 @@ static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call,
 	struct samr_domain_state *d_state;
 	struct ldb_result *res;
 	unsigned int i;
-	uint32_t count;
-	const char * const attrs[] = { "objectSid", "sAMAccountName",
-		"displayName", "description", "userAccountControl",
-		"pwdLastSet", NULL };
+	unsigned results = 0;
+	unsigned count = 0;
+	const char *const guid_attr[] = {"objectGUID", NULL};
+	const char *const attrs[] = {
+	    "objectSID", "sAMAccountName", "displayName", "description", NULL};
 	struct samr_DispEntryFull *entriesFull = NULL;
 	struct samr_DispEntryFullGroup *entriesFullGroup = NULL;
 	struct samr_DispEntryAscii *entriesAscii = NULL;
 	struct samr_DispEntryGeneral *entriesGeneral = NULL;
 	const char *filter;
 	int ret;
+	NTSTATUS status;
+	struct samr_guid_cache *cache = NULL;
 
 	DCESRV_PULL_HANDLE(h, r->in.domain_handle, SAMR_HANDLE_DOMAIN);
 
 	d_state = h->data;
 
-	switch (r->in.level) {
-	case 1:
-	case 4:
-		filter = talloc_asprintf(mem_ctx, "(&(objectclass=user)"
-					 "(sAMAccountType=%d))",
-					 ATYPE_NORMAL_ACCOUNT);
-		break;
-	case 2:
-		filter = talloc_asprintf(mem_ctx, "(&(objectclass=user)"
-					 "(sAMAccountType=%d))",
-					 ATYPE_WORKSTATION_TRUST);
-		break;
-	case 3:
-	case 5:
-		filter = talloc_asprintf(mem_ctx,
-					 "(&(|(groupType=%d)(groupType=%d))"
-					 "(objectClass=group))",
-					 GTYPE_SECURITY_UNIVERSAL_GROUP,
-					 GTYPE_SECURITY_GLOBAL_GROUP);
-		break;
-	default:
-		return NT_STATUS_INVALID_INFO_CLASS;
-	}
+	cache = &d_state->guid_caches[SAMR_QUERY_DISPLAY_INFO_CACHE];
+	/*
+	 * Can the cached results be used?
+	 * The cache is discarded if the start index is zero, or the requested
+	 * level is different from that in the cache.
+	 */
+	if ((r->in.start_idx == 0) || (r->in.level != cache->handle)) {
+		/*
+		 * The cached results can not be used, so will need to query
+		 * the database.
+		 */
 
-	/* search for all requested objects in all domains. This could
-	   possibly be cached and resumed based on resume_key */
-	ret = dsdb_search(d_state->sam_ctx, mem_ctx, &res, ldb_get_default_basedn(d_state->sam_ctx),
-			  LDB_SCOPE_SUBTREE, attrs, 0, "%s", filter);
-	if (ret != LDB_SUCCESS) {
-		return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		/*
+		 * Get the search filter for the current level
+		 */
+		switch (r->in.level) {
+		case 1:
+		case 4:
+			filter = talloc_asprintf(mem_ctx,
+						 "(&(objectclass=user)"
+						 "(sAMAccountType=%d))",
+						 ATYPE_NORMAL_ACCOUNT);
+			break;
+		case 2:
+			filter = talloc_asprintf(mem_ctx,
+						 "(&(objectclass=user)"
+						 "(sAMAccountType=%d))",
+						 ATYPE_WORKSTATION_TRUST);
+			break;
+		case 3:
+		case 5:
+			filter =
+			    talloc_asprintf(mem_ctx,
+					    "(&(|(groupType=%d)(groupType=%d))"
+					    "(objectClass=group))",
+					    GTYPE_SECURITY_UNIVERSAL_GROUP,
+					    GTYPE_SECURITY_GLOBAL_GROUP);
+			break;
+		default:
+			return NT_STATUS_INVALID_INFO_CLASS;
+		}
+		clear_guid_cache(cache);
+
+		/*
+		 * search for all requested objects in all domains.
+		 */
+		ret = dsdb_search(d_state->sam_ctx,
+				  mem_ctx,
+				  &res,
+				  ldb_get_default_basedn(d_state->sam_ctx),
+				  LDB_SCOPE_SUBTREE,
+				  guid_attr,
+				  0,
+				  "%s",
+				  filter);
+		if (ret != LDB_SUCCESS) {
+			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		}
+		if ((res->count == 0) || (r->in.max_entries == 0)) {
+			return NT_STATUS_OK;
+		}
+
+		status = load_guid_cache(cache, d_state, res->count, res->msgs);
+		TALLOC_FREE(res);
+		if (!NT_STATUS_IS_OK(status)) {
+			return status;
+		}
+		cache->handle = r->in.level;
 	}
-	if ((res->count == 0) || (r->in.max_entries == 0)) {
+	*r->out.total_size = cache->size;
+
+	/*
+	 * if there are no entries or the requested start index is greater
+	 * than the number of entries, we return an empty response.
+	 */
+	if (r->in.start_idx >= cache->size) {
+		*r->out.returned_size = 0;
+		switch(r->in.level) {
+		case 1:
+			r->out.info->info1.count = *r->out.returned_size;
+			r->out.info->info1.entries = NULL;
+			break;
+		case 2:
+			r->out.info->info2.count = *r->out.returned_size;
+			r->out.info->info2.entries = NULL;
+			break;
+		case 3:
+			r->out.info->info3.count = *r->out.returned_size;
+			r->out.info->info3.entries = NULL;
+			break;
+		case 4:
+			r->out.info->info4.count = *r->out.returned_size;
+			r->out.info->info4.entries = NULL;
+			break;
+		case 5:
+			r->out.info->info5.count = *r->out.returned_size;
+			r->out.info->info5.entries = NULL;
+			break;
+		}
 		return NT_STATUS_OK;
 	}
 
+	/*
+	 * Allocate an array of the appropriate result structures for the
+	 * current query level.
+	 */
+	results = MIN((cache->size - r->in.start_idx), r->in.max_entries);
 	switch (r->in.level) {
 	case 1:
-		entriesGeneral = talloc_array(mem_ctx,
-					      struct samr_DispEntryGeneral,
-					      res->count);
+		entriesGeneral = talloc_array(
+		    mem_ctx, struct samr_DispEntryGeneral, results);
 		break;
 	case 2:
-		entriesFull = talloc_array(mem_ctx,
-					   struct samr_DispEntryFull,
-					   res->count);
+		entriesFull =
+		    talloc_array(mem_ctx, struct samr_DispEntryFull, results);
 		break;
 	case 3:
-		entriesFullGroup = talloc_array(mem_ctx,
-						struct samr_DispEntryFullGroup,
-						res->count);
+		entriesFullGroup = talloc_array(
+		    mem_ctx, struct samr_DispEntryFullGroup, results);
 		break;
 	case 4:
 	case 5:
-		entriesAscii = talloc_array(mem_ctx,
-					    struct samr_DispEntryAscii,
-					    res->count);
+		entriesAscii =
+		    talloc_array(mem_ctx, struct samr_DispEntryAscii, results);
 		break;
 	}
 
@@ -3796,135 +3930,152 @@ static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call,
 	    (entriesAscii == NULL) && (entriesFullGroup == NULL))
 		return NT_STATUS_NO_MEMORY;
 
+	/*
+	 * Process the list of result GUID's.
+	 * Read the details of each object and populate the result structure
+	 * for the current level.
+	 */
 	count = 0;
-
-	for (i = 0; i < res->count; i++) {
+	for (i = 0; i < results; i++) {
 		struct dom_sid *objectsid;
+		struct ldb_result *rec;
+		const unsigned idx = r->in.start_idx + i;
 
-		objectsid = samdb_result_dom_sid(mem_ctx, res->msgs[i],
-						 "objectSid");
-		if (objectsid == NULL)
+		/*
+		 * Read an object from disk using the GUID as the key
+		 *
+		 * If the object can not be read, or it does not have a SID
+		 * it is ignored.  In this case the number of entries returned
+		 * will be less than the requested size, there will also be
+		 * a gap in the idx numbers in the returned elements e.g. if
+		 * there are 3 GUIDs a, b, c in the cache and b is deleted from
+		 * disk then details for a, and c will be returned with
+		 * idx values of 1 and 3 respectively.
+		 *
+		 */
+		ret = dsdb_search_by_dn_guid(d_state->sam_ctx,
+					     mem_ctx,
+					     &rec,
+					     &cache->entries[idx],
+					     attrs,
+					     0);
+		if (ret == LDB_ERR_NO_SUCH_OBJECT) {
+			char *guid_str =
+			    GUID_string(mem_ctx, &cache->entries[idx]);
+			DBG_WARNING("GUID [%s] not found\n", guid_str);
+			continue;
+		} else if (ret != LDB_SUCCESS) {
+			clear_guid_cache(cache);
+			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		}
+		objectsid =
+		    samdb_result_dom_sid(mem_ctx, rec->msgs[0], "objectSID");
+		if (objectsid == NULL) {
+			char *guid_str =
+			    GUID_string(mem_ctx, &cache->entries[idx]);
+			DBG_WARNING("objectSID for GUID [%s] not found\n",
+				    guid_str);
 			continue;
+		}
 
+		/*
+		 * Populate the result structure for the current object
+		 */
 		switch(r->in.level) {
 		case 1:
-			entriesGeneral[count].idx = count + 1;
+
+			entriesGeneral[count].idx = idx + 1;
 			entriesGeneral[count].rid =
-				objectsid->sub_auths[objectsid->num_auths-1];
+			    objectsid->sub_auths[objectsid->num_auths - 1];
 			entriesGeneral[count].acct_flags =
-				samdb_result_acct_flags(res->msgs[i], NULL);
+			    samdb_result_acct_flags(rec->msgs[0], NULL);
 			entriesGeneral[count].account_name.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "sAMAccountName", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "sAMAccountName", "");
 			entriesGeneral[count].full_name.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "displayName", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "displayName", "");
 			entriesGeneral[count].description.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "description", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "description", "");
 			break;
 		case 2:
-			entriesFull[count].idx = count + 1;
+			entriesFull[count].idx = idx + 1;
 			entriesFull[count].rid =
-				objectsid->sub_auths[objectsid->num_auths-1];
+			    objectsid->sub_auths[objectsid->num_auths - 1];
 
-			/* No idea why we need to or in ACB_NORMAL here, but this is what Win2k3 seems to do... */
+			/*
+			 * No idea why we need to or in ACB_NORMAL here,
+			 * but this is what Win2k3 seems to do...
+			 */
 			entriesFull[count].acct_flags =
-				samdb_result_acct_flags(res->msgs[i],
-							NULL) | ACB_NORMAL;
+			    samdb_result_acct_flags(rec->msgs[0], NULL) |
+			    ACB_NORMAL;
 			entriesFull[count].account_name.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "sAMAccountName", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "sAMAccountName", "");
 			entriesFull[count].description.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "description", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "description", "");
 			break;
 		case 3:
-			entriesFullGroup[count].idx = count + 1;
+			entriesFullGroup[count].idx = idx + 1;
 			entriesFullGroup[count].rid =
-				objectsid->sub_auths[objectsid->num_auths-1];
-			/* We get a "7" here for groups */
-			entriesFullGroup[count].acct_flags
-				= SE_GROUP_MANDATORY | SE_GROUP_ENABLED_BY_DEFAULT | SE_GROUP_ENABLED;
+			    objectsid->sub_auths[objectsid->num_auths - 1];
+			/*
+			 * We get a "7" here for groups
+			 */
+			entriesFullGroup[count].acct_flags =
+			    SE_GROUP_MANDATORY | SE_GROUP_ENABLED_BY_DEFAULT |
+			    SE_GROUP_ENABLED;
 			entriesFullGroup[count].account_name.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "sAMAccountName", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "sAMAccountName", "");
 			entriesFullGroup[count].description.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "description", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "description", "");
 			break;
 		case 4:
 		case 5:
-			entriesAscii[count].idx = count + 1;
+			entriesAscii[count].idx = idx + 1;
 			entriesAscii[count].account_name.string =
-				ldb_msg_find_attr_as_string(res->msgs[i],
-							    "sAMAccountName", "");
+			    ldb_msg_find_attr_as_string(
+				rec->msgs[0], "sAMAccountName", "");
 			break;
 		}
-
-		count += 1;
+		count++;
 	}
 
-	*r->out.total_size = count;
-
-	if (r->in.start_idx >= count) {
-		*r->out.returned_size = 0;
-		switch(r->in.level) {
-		case 1:
-			r->out.info->info1.count = *r->out.returned_size;
-			r->out.info->info1.entries = NULL;
-			break;
-		case 2:
-			r->out.info->info2.count = *r->out.returned_size;
-			r->out.info->info2.entries = NULL;
-			break;
-		case 3:
-			r->out.info->info3.count = *r->out.returned_size;
-			r->out.info->info3.entries = NULL;
-			break;
-		case 4:
-			r->out.info->info4.count = *r->out.returned_size;
-			r->out.info->info4.entries = NULL;
-			break;
-		case 5:
-			r->out.info->info5.count = *r->out.returned_size;
-			r->out.info->info5.entries = NULL;
-			break;
-		}
-	} else {
-		*r->out.returned_size = MIN(count - r->in.start_idx,
-					   r->in.max_entries);
-		switch(r->in.level) {
-		case 1:
-			r->out.info->info1.count = *r->out.returned_size;
-			r->out.info->info1.entries =
-				&(entriesGeneral[r->in.start_idx]);
-			break;
-		case 2:
-			r->out.info->info2.count = *r->out.returned_size;
-			r->out.info->info2.entries =
-				&(entriesFull[r->in.start_idx]);
-			break;
-		case 3:
-			r->out.info->info3.count = *r->out.returned_size;
-			r->out.info->info3.entries =
-				&(entriesFullGroup[r->in.start_idx]);
-			break;
-		case 4:
-			r->out.info->info4.count = *r->out.returned_size;
-			r->out.info->info4.entries =
-				&(entriesAscii[r->in.start_idx]);
-			break;
-		case 5:
-			r->out.info->info5.count = *r->out.returned_size;
-			r->out.info->info5.entries =
-				&(entriesAscii[r->in.start_idx]);
-			break;
-		}
+	/*
+	 * Build the response based on the request level.
+	 */
+	*r->out.returned_size = count;
+	switch(r->in.level) {
+	case 1:
+		r->out.info->info1.count = count;
+		r->out.info->info1.entries = entriesGeneral;
+		break;
+	case 2:
+		r->out.info->info2.count = count;
+		r->out.info->info2.entries = entriesFull;
+		break;
+	case 3:
+		r->out.info->info3.count = count;
+		r->out.info->info3.entries = entriesFullGroup;
+		break;
+	case 4:
+		r->out.info->info4.count = count;
+		r->out.info->info4.entries = entriesAscii;
+		break;
+	case 5:
+		r->out.info->info5.count = count;
+		r->out.info->info5.entries = entriesAscii;
+		break;
 	}
 
-	return (*r->out.returned_size < (count - r->in.start_idx)) ?
-		STATUS_MORE_ENTRIES : NT_STATUS_OK;
+	return ((r->in.start_idx + results) < cache->size)
+		   ? STATUS_MORE_ENTRIES
+		   : NT_STATUS_OK;
 }
 
 
diff --git a/source4/rpc_server/samr/dcesrv_samr.h b/source4/rpc_server/samr/dcesrv_samr.h
index 261bd052efe..f08bac053c8 100644
--- a/source4/rpc_server/samr/dcesrv_samr.h
+++ b/source4/rpc_server/samr/dcesrv_samr.h
@@ -42,6 +42,20 @@ struct samr_connect_state {
 	uint32_t access_mask;
 };
 
+/*
+ * Cache of object GUIDS
+ */
+struct samr_guid_cache {
+	unsigned handle;
+	unsigned size;
+	struct GUID *entries;
+};
+
+enum samr_guid_cache_id {
+	SAMR_QUERY_DISPLAY_INFO_CACHE,
+	SAMR_LAST_CACHE
+};
+
 /*
   state associated with a samr_OpenDomain() operation
 */
@@ -55,6 +69,7 @@ struct samr_domain_state {
 	enum server_role role;
 	bool builtin;
 	struct loadparm_context *lp_ctx;
+	struct samr_guid_cache guid_caches[SAMR_LAST_CACHE];
 };
 
 /*
-- 
2.17.1


From 240775f002a0bc8da6081b1183fe631d02091448 Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Fri, 12 Oct 2018 11:21:10 +1300
Subject: [PATCH 4/7] test samr: Extra tests for samr_EnumDomainGroups

Add extra tests to test the content returned by samr_EnumDomainGroups,
and tests for the result caching added in the following commit.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 python/samba/tests/dcerpc/sam.py | 169 +++++++++++++++++++++++++++++++
 selftest/knownfail.d/samr        |   3 +
 2 files changed, 172 insertions(+)
 create mode 100644 selftest/knownfail.d/samr

diff --git a/python/samba/tests/dcerpc/sam.py b/python/samba/tests/dcerpc/sam.py
index 67f62d4d2b5..e7fe21e9c30 100644
--- a/python/samba/tests/dcerpc/sam.py
+++ b/python/samba/tests/dcerpc/sam.py
@@ -32,6 +32,7 @@ from samba.dsdb import (
     GTYPE_SECURITY_UNIVERSAL_GROUP,
     GTYPE_SECURITY_GLOBAL_GROUP)
 from samba import generate_random_password
+from samba.ndr import ndr_unpack
 import os
 
 
@@ -40,6 +41,24 @@ def toArray(handle, array, num_entries):
     return [(entry.idx, entry.name) for entry in array.entries[:num_entries]]
 
 
+# Extract the rid from an ldb message, assumes that the message has a
+# objectSID attribute
+#
+def rid(msg):
+    sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
+    (_, rid) = sid.split()
+    return rid
+
+
+# Calculate the request size for EnumDomainUsers and EnumDomainGroups calls
+# to hold the specified number of entries.
+# We use the w2k3 element size value of 54, code under test
+# rounds this up i.e. (1+(max_size/SAMR_ENUM_USERS_MULTIPLIER))
+#
+def calc_max_size(num_entries):
+    return (num_entries - 1) * 54
+
+
 class SamrTests(RpcInterfaceTestCase):
 
     def setUp(self):
@@ -72,6 +91,18 @@ class SamrTests(RpcInterfaceTestCase):
         self.domain_handle = self.conn.OpenDomain(
             self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
 
+    # Filter a list of records, removing those that are not part of the
+    # current domain.
+    #
+    def filter_domain(self, unfiltered):
+        def sid(msg):
+            sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
+            (x, _) = sid.split()
+            return x
+
+        dom_sid = security.dom_sid(self.samdb.get_domain_sid())
+        return [x for x in unfiltered if sid(x) == dom_sid]
+
     def test_connect5(self):
         (level, info, handle) =\
             self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
@@ -433,3 +464,141 @@ class SamrTests(RpcInterfaceTestCase):
             5, check_results, select, attributes, self.create_groups)
 
         self.delete_dns(dns)
+
+    def test_EnumDomainUsers(self):
+        def check_results(expected, actual):
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.SamEntry))
+                self.assertEquals(
+                    str(e["sAMAccountName"]), str(a.name.string))
+
+        # Create four groups
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_groups([1, 2, 3, 4])
+
+        #
+        # Get the expected results by querying the samdb database directly.
+        # We do this rather than use a list of expected results as this runs
+        # with other tests so we do not have a known fixed list of elements
+        select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
+            GTYPE_SECURITY_UNIVERSAL_GROUP,
+            GTYPE_SECURITY_GLOBAL_GROUP)
+        attributes = ["sAMAccountName", "objectSID"]
+        unfiltered = self.samdb.search(expression=select, attrs=attributes)
+        filtered = self.filter_domain(unfiltered)
+        self.assertTrue(len(filtered) > 4)
+
+        # Sort the expected results by rid
+        expected = sorted(list(filtered), key=rid)
+
+        #
+        # Perform EnumDomainGroups with max size greater than the expected
+        # number of results. Allow for an extra 10 entries
+        #
+        max_size = calc_max_size(len(expected) + 10)
+        (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, 0, max_size)
+        self.assertEquals(len(expected), num_entries)
+        check_results(expected, actual.entries)
+
+        #
+        # Perform EnumDomainGroups with size set to so that it contains
+        # 4 entries.
+        #
+        max_size = calc_max_size(4)
+        (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, 0, max_size)
+        self.assertEquals(4, num_entries)
+        check_results(expected[:4], actual.entries)
+
+        #
+        # Try calling with resume_handle greater than number of entries
+        # Should return no results and a resume handle of 0
+        max_size = calc_max_size(1)
+        rh = len(expected)
+        self.conn.Close(self.handle)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, rh, max_size)
+
+        self.assertEquals(0, num_entries)
+        self.assertEquals(0, resume_handle)
+
+        #
+        # Enumerate through the domain groups one element at a time.
+        #
+        max_size = calc_max_size(1)
+        actual = []
+        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, 0, max_size)
+        while resume_handle:
+            self.assertEquals(1, num_entries)
+            actual.append(a.entries[0])
+            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+                self.domain_handle, resume_handle, max_size)
+        if num_entries:
+            actual.append(a.entries[0])
+
+        #
+        # Check that the cached results are being returned.
+        # Obtain a new resume_handle and insert new entries into the
+        # into the DB
+        #
+        actual = []
+        max_size = calc_max_size(1)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, 0, max_size)
+        extra_dns = self.create_groups([1000, 1002, 1003, 1004])
+        while resume_handle:
+            self.assertEquals(1, num_entries)
+            actual.append(a.entries[0])
+            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+                self.domain_handle, resume_handle, max_size)
+        if num_entries:
+            actual.append(a.entries[0])
+
+        self.assertEquals(len(expected), len(actual))
+        check_results(expected, actual)
+
+        #
+        # Perform EnumDomainGroups, we should read the newly added domains
+        #
+        max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
+        (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, 0, max_size)
+        self.assertEquals(len(expected) + len(extra_dns), num_entries)
+
+        #
+        # Get a new expected result set by querying the database directly
+        unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
+        filtered01 = self.filter_domain(unfiltered01)
+        self.assertTrue(len(filtered01) > len(expected))
+
+        # Sort the expected results by rid
+        expected01 = sorted(list(filtered01), key=rid)
+
+        #
+        # Now check that we read the new entries.
+        #
+        check_results(expected01, actual.entries)
+
+        #
+        # Check that deleted results are handled correctly.
+        # Obtain a new resume_handle and delete entries from the DB.
+        #
+        actual = []
+        max_size = calc_max_size(1)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+            self.domain_handle, 0, max_size)
+        self.delete_dns(extra_dns)
+        while resume_handle and num_entries:
+            self.assertEquals(1, num_entries)
+            actual.append(a.entries[0])
+            (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
+                self.domain_handle, resume_handle, max_size)
+        if num_entries:
+            actual.append(a.entries[0])
+
+        self.assertEquals(len(expected), len(actual))
+        check_results(expected, actual)
+
+        self.delete_dns(dns)
diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
new file mode 100644
index 00000000000..4c1d62aac16
--- /dev/null
+++ b/selftest/knownfail.d/samr
@@ -0,0 +1,3 @@
+^samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers\(ad_dc_ntvfs:local\)
+^samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers\(ad_dc_ntvfs:local\)
+
-- 
2.17.1


From 3364400ee4b17f1b7fefe1d537da4f065b7919ae Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Thu, 18 Oct 2018 10:16:24 +1300
Subject: [PATCH 5/7] source4 samr: cache samr_EnumDomainGroups results

Add a cache of GUID's that matched the last samr_EnunDomainGroups made on a
domain handle.  The cache is cleared if resume_handle is zero, and when the
final results are returned to the caller.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 selftest/knownfail.d/samr             |   3 -
 source4/rpc_server/samr/dcesrv_samr.c | 265 ++++++++++++++++++++------
 source4/rpc_server/samr/dcesrv_samr.h |   1 +
 3 files changed, 213 insertions(+), 56 deletions(-)
 delete mode 100644 selftest/knownfail.d/samr

diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
deleted file mode 100644
index 4c1d62aac16..00000000000
--- a/selftest/knownfail.d/samr
+++ /dev/null
@@ -1,3 +0,0 @@
-^samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers\(ad_dc_ntvfs:local\)
-^samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers\(ad_dc_ntvfs:local\)
-
diff --git a/source4/rpc_server/samr/dcesrv_samr.c b/source4/rpc_server/samr/dcesrv_samr.c
index 3ae53d7cf87..bfc900535a4 100644
--- a/source4/rpc_server/samr/dcesrv_samr.c
+++ b/source4/rpc_server/samr/dcesrv_samr.c
@@ -1125,6 +1125,63 @@ static int compare_SamEntry(struct samr_SamEntry *e1, struct samr_SamEntry *e2)
 	return e1->idx - e2->idx;
 }
 
+static int compare_msgRid(struct ldb_message **m1, struct ldb_message **m2) {
+	struct dom_sid *sid1 = NULL;
+	struct dom_sid *sid2 = NULL;
+	uint32_t rid1;
+	uint32_t rid2;
+	int res = 0;
+	NTSTATUS status;
+	TALLOC_CTX *frame = talloc_stackframe();
+
+	sid1 = samdb_result_dom_sid(frame, *m1, "objectSid");
+	sid2 = samdb_result_dom_sid(frame, *m2, "objectSid");
+
+	/*
+	 * If entries don't have a SID we want to sort them to the end of
+	 * the list.
+	 */
+	if (sid1 == NULL && sid2 == NULL) {
+		res = 0;
+		goto exit;
+	} else if (sid2 == NULL) {
+		res = 1;
+		goto exit;
+	} else if (sid1 == NULL) {
+		res = -1;
+		goto exit;
+	}
+
+	/*
+	 * Get and compare the rids, if we fail to extract a rid treat it as a
+	 * missing SID and sort to the end of the list
+	 */
+	status = dom_sid_split_rid(NULL, sid1, NULL, &rid1);
+	if (!NT_STATUS_IS_OK(status)) {
+		res = 1;
+		goto exit;
+	}
+
+	status = dom_sid_split_rid(NULL, sid2, NULL, &rid2);
+	if (!NT_STATUS_IS_OK(status)) {
+		res = -1;
+		goto exit;
+	}
+
+	if (rid1 == rid2) {
+		res = 0;
+	}
+	else if (rid1 > rid2) {
+		res = 1;
+	}
+	else {
+		res = -1;
+	}
+exit:
+	TALLOC_FREE(frame);
+	return res;
+}
+
 /*
   samr_EnumDomainGroups
 */
@@ -1134,11 +1191,17 @@ static NTSTATUS dcesrv_samr_EnumDomainGroups(struct dcesrv_call_state *dce_call,
 	struct dcesrv_handle *h;
 	struct samr_domain_state *d_state;
 	struct ldb_message **res;
-	int i, ldb_cnt;
-	uint32_t first, count;
+	int i;
+	unsigned count;
+	unsigned results;
+	unsigned max_entries;
+	unsigned remaining_entries;
+	uint32_t resume_handle;
 	struct samr_SamEntry *entries;
 	const char * const attrs[] = { "objectSid", "sAMAccountName", NULL };
+	const char * const cache_attrs[] = { "objectSid", "objectGUID", NULL };
 	struct samr_SamArray *sam;
+	struct samr_guid_cache *cache = NULL;
 
 	*r->out.resume_handle = 0;
 	*r->out.sam = NULL;
@@ -1147,77 +1210,173 @@ static NTSTATUS dcesrv_samr_EnumDomainGroups(struct dcesrv_call_state *dce_call,
 	DCESRV_PULL_HANDLE(h, r->in.domain_handle, SAMR_HANDLE_DOMAIN);
 
 	d_state = h->data;
+	cache = &d_state->guid_caches[SAMR_ENUM_DOMAIN_GROUPS_CACHE];
 
-	/* search for all domain groups in this domain. This could possibly be
-	   cached and resumed based on resume_key */
-	ldb_cnt = samdb_search_domain(d_state->sam_ctx, mem_ctx,
-				      d_state->domain_dn, &res, attrs,
-				      d_state->domain_sid,
-				      "(&(|(groupType=%d)(groupType=%d))(objectClass=group))",
-				      GTYPE_SECURITY_UNIVERSAL_GROUP,
-				      GTYPE_SECURITY_GLOBAL_GROUP);
-	if (ldb_cnt < 0) {
-		return NT_STATUS_INTERNAL_DB_CORRUPTION;
-	}
+	/*
+	 * If the resume_handle is zero, query the database and cache the
+	 * matching GUID's
+	 */
+	if (*r->in.resume_handle == 0) {
+		NTSTATUS status;
+		int ldb_cnt;
+		clear_guid_cache(cache);
+		/*
+		 * search for all domain groups in this domain.
+		 */
+		ldb_cnt = samdb_search_domain(
+		    d_state->sam_ctx,
+		    mem_ctx,
+		    d_state->domain_dn,
+		    &res,
+		    cache_attrs,
+		    d_state->domain_sid,
+		    "(&(|(groupType=%d)(groupType=%d))(objectClass=group))",
+		    GTYPE_SECURITY_UNIVERSAL_GROUP,
+		    GTYPE_SECURITY_GLOBAL_GROUP);
+		if (ldb_cnt < 0) {
+			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		}
+		/*
+		 * Sort the results into RID order, while the spec states there
+		 * is no order, Windows appears to sort the results by RID and
+		 * so it is possible that there are clients that depend on
+		 * this ordering
+		 */
+		TYPESAFE_QSORT(res, ldb_cnt, compare_msgRid);
 
-	/* convert to SamEntry format */
-	entries = talloc_array(mem_ctx, struct samr_SamEntry, ldb_cnt);
-	if (!entries) {
-		return NT_STATUS_NO_MEMORY;
+		/*
+		 * cache the sorted GUID's
+		 */
+		status = load_guid_cache(cache, d_state, ldb_cnt, res);
+		TALLOC_FREE(res);
+		if (!NT_STATUS_IS_OK(status)) {
+			return status;
+		}
+		cache->handle = 0;
 	}
 
-	count = 0;
-
-	for (i=0;i<ldb_cnt;i++) {
-		struct dom_sid *group_sid;
 
-		group_sid = samdb_result_dom_sid(mem_ctx, res[i],
-						 "objectSid");
-		if (group_sid == NULL) {
-			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+	/*
+	 * If the resume handle is out of range we return an empty response
+	 * and invalidate the cache.
+	 *
+	 * From the specification:
+	 * Servers SHOULD validate that EnumerationContext is an expected
+	 * value for the server's implementation. Windows does NOT validate
+	 * the input, though the result of malformed information merely results
+	 * in inconsistent output to the client.
+	 */
+	if (*r->in.resume_handle >= cache->size) {
+		clear_guid_cache(cache);
+		sam = talloc(mem_ctx, struct samr_SamArray);
+		if (!sam) {
+			return NT_STATUS_NO_MEMORY;
 		}
+		sam->entries = NULL;
+		sam->count = 0;
 
-		entries[count].idx =
-			group_sid->sub_auths[group_sid->num_auths-1];
-		entries[count].name.string =
-			ldb_msg_find_attr_as_string(res[i], "sAMAccountName", "");
-		count += 1;
+		*r->out.sam = sam;
+		*r->out.resume_handle = 0;
+		return NT_STATUS_OK;
 	}
 
-	/* sort the results by rid */
-	TYPESAFE_QSORT(entries, count, compare_SamEntry);
 
-	/* find the first entry to return */
-	for (first=0;
-	     first<count && entries[first].idx <= *r->in.resume_handle;
-	     first++) ;
+	/*
+	 * Calculate the number of entries to return limit by max_size.
+	 * Note that we use the w2k3 element size value of 54
+	 */
+	max_entries = 1 + (r->in.max_size/SAMR_ENUM_USERS_MULTIPLIER);
+	remaining_entries = cache->size - *r->in.resume_handle;
+	results = MIN(remaining_entries, max_entries);
 
-	/* return the rest, limit by max_size. Note that we
-	   use the w2k3 element size value of 54 */
-	*r->out.num_entries = count - first;
-	*r->out.num_entries = MIN(*r->out.num_entries,
-				 1+(r->in.max_size/SAMR_ENUM_USERS_MULTIPLIER));
+	/*
+	 * Process the list of result GUID's.
+	 * Read the details of each object and populate the Entries
+	 * for the current level.
+	 */
+	count = 0;
+	resume_handle = *r->in.resume_handle;
+	entries = talloc_array(mem_ctx, struct samr_SamEntry, results);
+	if (entries == NULL) {
+		clear_guid_cache(cache);
+		return NT_STATUS_NO_MEMORY;
+	}
+	for (i = 0; i < results; i++) {
+		struct dom_sid *sid;
+		struct ldb_result *rec;
+		const unsigned idx = *r->in.resume_handle + i;
+		int ret;
+		const char *name = NULL;
+
+		resume_handle++;
+		/*
+		 * Read an object from disk using the GUID as the key
+		 *
+		 * If the object can not be read, or it does not have a SID
+		 * it is ignored.
+		 *
+		 * As a consequence of this, if all the remaining GUID's
+		 * have been deleted an empty result will be returned.
+		 * i.e. even if the previous call returned a non zero
+		 * resume_handle it is possible for no results to be returned.
+		 *
+		 */
+		ret = dsdb_search_by_dn_guid(d_state->sam_ctx,
+					     mem_ctx,
+					     &rec,
+					     &cache->entries[idx],
+					     attrs,
+					     0);
+		if (ret == LDB_ERR_NO_SUCH_OBJECT) {
+			char *guid_str =
+			    GUID_string(mem_ctx, &cache->entries[idx]);
+			DBG_WARNING("GUID [%s] not found\n", guid_str);
+			continue;
+		} else if (ret != LDB_SUCCESS) {
+			clear_guid_cache(cache);
+			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		}
+		sid = samdb_result_dom_sid(mem_ctx, rec->msgs[0], "objectSID");
+		if (sid == NULL) {
+			char *guid_str =
+			    GUID_string(mem_ctx, &cache->entries[idx]);
+			DBG_WARNING("objectSID for GUID [%s] not found\n",
+				    guid_str);
+			continue;
+		}
+		entries[count].idx = sid->sub_auths[sid->num_auths - 1];
+		name = ldb_msg_find_attr_as_string(
+		    rec->msgs[0], "sAMAccountName", "");
+		entries[count].name.string = talloc_strdup(entries, name);
+		count++;
+	}
 
 	sam = talloc(mem_ctx, struct samr_SamArray);
 	if (!sam) {
+		clear_guid_cache(cache);
 		return NT_STATUS_NO_MEMORY;
 	}
 
-	sam->entries = entries+first;
-	sam->count = *r->out.num_entries;
+	sam->entries = entries;
+	sam->count = count;
 
 	*r->out.sam = sam;
+	*r->out.resume_handle = resume_handle;
+	*r->out.num_entries = count;
 
-	if (first == count) {
+	/*
+	 * Signal no more results by returning zero resume handle,
+	 * the cache is also cleared at this point
+	 */
+	if (*r->out.resume_handle >= cache->size) {
+		*r->out.resume_handle = 0;
+		clear_guid_cache(cache);
 		return NT_STATUS_OK;
 	}
-
-	if (*r->out.num_entries < count - first) {
-		*r->out.resume_handle = entries[first+*r->out.num_entries-1].idx;
-		return STATUS_MORE_ENTRIES;
-	}
-
-	return NT_STATUS_OK;
+	/*
+	 * There are more results to be returned.
+	 */
+	return STATUS_MORE_ENTRIES;
 }
 
 
@@ -3783,7 +3942,7 @@ static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call,
 	unsigned int i;
 	unsigned results = 0;
 	unsigned count = 0;
-	const char *const guid_attr[] = {"objectGUID", NULL};
+	const char *const cache_attrs[] = {"objectGUID", NULL};
 	const char *const attrs[] = {
 	    "objectSID", "sAMAccountName", "displayName", "description", NULL};
 	struct samr_DispEntryFull *entriesFull = NULL;
@@ -3850,7 +4009,7 @@ static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call,
 				  &res,
 				  ldb_get_default_basedn(d_state->sam_ctx),
 				  LDB_SCOPE_SUBTREE,
-				  guid_attr,
+				  cache_attrs,
 				  0,
 				  "%s",
 				  filter);
diff --git a/source4/rpc_server/samr/dcesrv_samr.h b/source4/rpc_server/samr/dcesrv_samr.h
index f08bac053c8..04866aca757 100644
--- a/source4/rpc_server/samr/dcesrv_samr.h
+++ b/source4/rpc_server/samr/dcesrv_samr.h
@@ -53,6 +53,7 @@ struct samr_guid_cache {
 
 enum samr_guid_cache_id {
 	SAMR_QUERY_DISPLAY_INFO_CACHE,
+	SAMR_ENUM_DOMAIN_GROUPS_CACHE,
 	SAMR_LAST_CACHE
 };
 
-- 
2.17.1


From d3ba4b49fc8e6764b2cc7b813d9a8b36a9e3b5ef Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Thu, 18 Oct 2018 13:53:55 +1300
Subject: [PATCH 6/7] tests samr: Extra tests for samr_EnumDomainUserss

Add extra tests to test the content returned by samr_EnumDomainUsers,
and tests for the result caching added in the following commit.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 python/samba/tests/dcerpc/sam.py | 146 ++++++++++++++++++++++++++++++-
 selftest/knownfail.d/samr        |   4 +
 2 files changed, 149 insertions(+), 1 deletion(-)
 create mode 100644 selftest/knownfail.d/samr

diff --git a/python/samba/tests/dcerpc/sam.py b/python/samba/tests/dcerpc/sam.py
index e7fe21e9c30..401cb7f07a5 100644
--- a/python/samba/tests/dcerpc/sam.py
+++ b/python/samba/tests/dcerpc/sam.py
@@ -465,7 +465,7 @@ class SamrTests(RpcInterfaceTestCase):
 
         self.delete_dns(dns)
 
-    def test_EnumDomainUsers(self):
+    def test_EnumDomainGroups(self):
         def check_results(expected, actual):
             for (e, a) in zip(expected, actual):
                 self.assertTrue(isinstance(a, samr.SamEntry))
@@ -602,3 +602,147 @@ class SamrTests(RpcInterfaceTestCase):
         check_results(expected, actual)
 
         self.delete_dns(dns)
+
+    def test_EnumDomainUsers(self):
+        def check_results(expected, actual):
+            for (e, a) in zip(expected, actual):
+                self.assertTrue(isinstance(a, samr.SamEntry))
+                self.assertEquals(
+                    str(e["sAMAccountName"]), str(a.name.string))
+
+        # Create four users
+        # to ensure that we have the minimum needed for the tests.
+        dns = self.create_users([1, 2, 3, 4])
+
+        #
+        # Get the expected results by querying the samdb database directly.
+        # We do this rather than use a list of expected results as this runs
+        # with other tests so we do not have a known fixed list of elements
+        select = "(objectClass=user)"
+        attributes = ["sAMAccountName", "objectSID", "userAccountConrol"]
+        unfiltered = self.samdb.search(expression=select, attrs=attributes)
+        filtered = self.filter_domain(unfiltered)
+        self.assertTrue(len(filtered) > 4)
+
+        # Sort the expected results by rid
+        expected = sorted(list(filtered), key=rid)
+
+        #
+        # Perform EnumDomainUsers with max_size greater than required for the
+        # expected number of results. We should get all the results.
+        #
+        max_size = calc_max_size(len(expected) + 10)
+        (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, 0, 0, max_size)
+        self.assertEquals(len(expected), num_entries)
+        check_results(expected, actual.entries)
+
+        #
+        # Perform EnumDomainUsers with size set to so that it contains
+        # 4 entries.
+        max_size = calc_max_size(4)
+        (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, 0, 0, max_size)
+        self.assertEquals(4, num_entries)
+        check_results(expected[:4], actual.entries)
+
+        #
+        # Try calling with resume_handle greater than number of entries
+        # Should return no results and a resume handle of 0
+        rh = len(expected)
+        max_size = calc_max_size(1)
+        self.conn.Close(self.handle)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, rh, 0, max_size)
+
+        self.assertEquals(0, num_entries)
+        self.assertEquals(0, resume_handle)
+
+        #
+        # Enumerate through the domain users one element at a time.
+        # We should get all the results.
+        #
+        actual = []
+        max_size = calc_max_size(1)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, 0, 0, max_size)
+        while resume_handle:
+            self.assertEquals(1, num_entries)
+            actual.append(a.entries[0])
+            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+                self.domain_handle, resume_handle, 0, max_size)
+        if num_entries:
+            actual.append(a.entries[0])
+
+        self.assertEquals(len(expected), len(actual))
+        check_results(expected, actual)
+
+        #
+        # Check that the cached results are being returned.
+        # Obtain a new resume_handle and insert new entries into the
+        # into the DB. As the entries were added after the results were cached
+        # they should not show up in the returned results.
+        #
+        actual = []
+        max_size = calc_max_size(1)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, 0, 0, max_size)
+        extra_dns = self.create_users([1000, 1002, 1003, 1004])
+        while resume_handle:
+            self.assertEquals(1, num_entries)
+            actual.append(a.entries[0])
+            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+                self.domain_handle, resume_handle, 0, max_size)
+        if num_entries:
+            actual.append(a.entries[0])
+
+        self.assertEquals(len(expected), len(actual))
+        check_results(expected, actual)
+
+        #
+        # Perform EnumDomainUsers, we should read the newly added groups
+        # As resume_handle is zero, the results will be read from disk.
+        #
+        max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
+        (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, 0, 0, max_size)
+        self.assertEquals(len(expected) + len(extra_dns), num_entries)
+
+        #
+        # Get a new expected result set by querying the database directly
+        unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
+        filtered01 = self.filter_domain(unfiltered01)
+        self.assertTrue(len(filtered01) > len(expected))
+
+        # Sort the expected results by rid
+        expected01 = sorted(list(filtered01), key=rid)
+
+        #
+        # Now check that we read the new entries.
+        #
+        self.assertEquals(len(expected01), num_entries)
+        check_results(expected01, actual.entries)
+
+        #
+        # Check that deleted results are handled correctly.
+        # Obtain a new resume_handle and delete entries from the DB.
+        # We will not see the deleted entries in the result set, as details
+        # need to be read from disk. Only the object GUID's are cached.
+        #
+        actual = []
+        max_size = calc_max_size(1)
+        (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+            self.domain_handle, 0, 0, max_size)
+        self.delete_dns(extra_dns)
+        while resume_handle and num_entries:
+            self.assertEquals(1, num_entries)
+            actual.append(a.entries[0])
+            (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
+                self.domain_handle, resume_handle, 0, max_size)
+        if num_entries:
+            actual.append(a.entries[0])
+
+        self.assertEquals(len(expected), len(actual))
+        check_results(expected, actual)
+
+        self.delete_dns(dns)
diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
new file mode 100644
index 00000000000..e4a5a8bc7eb
--- /dev/null
+++ b/selftest/knownfail.d/samr
@@ -0,0 +1,4 @@
+^samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers
+^samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers
+
+
-- 
2.17.1


From 069171b8569b28ed1f83e63c98d3d4f5c8f7920d Mon Sep 17 00:00:00 2001
From: Gary Lockyer <gary at catalyst.net.nz>
Date: Thu, 18 Oct 2018 13:54:31 +1300
Subject: [PATCH 7/7] source4 samr: cache samr_EnumDomainUsers results

Add a cache of GUID's that matched the last samr_EnunDomainUsers made on a
domain handle.  The cache is cleared if resume_handle is zero, and when the
final results are returned to the caller.

The existing code repeated the database query for each chunk requested.

Signed-off-by: Gary Lockyer <gary at catalyst.net.nz>
---
 selftest/knownfail.d/samr             |   4 -
 source4/rpc_server/samr/dcesrv_samr.c | 203 ++++++++++++++++++++------
 source4/rpc_server/samr/dcesrv_samr.h |   1 +
 3 files changed, 156 insertions(+), 52 deletions(-)
 delete mode 100644 selftest/knownfail.d/samr

diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
deleted file mode 100644
index e4a5a8bc7eb..00000000000
--- a/selftest/knownfail.d/samr
+++ /dev/null
@@ -1,4 +0,0 @@
-^samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers
-^samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_EnumDomainUsers
-
-
diff --git a/source4/rpc_server/samr/dcesrv_samr.c b/source4/rpc_server/samr/dcesrv_samr.c
index bfc900535a4..6cb3c64518a 100644
--- a/source4/rpc_server/samr/dcesrv_samr.c
+++ b/source4/rpc_server/samr/dcesrv_samr.c
@@ -1490,12 +1490,18 @@ static NTSTATUS dcesrv_samr_EnumDomainUsers(struct dcesrv_call_state *dce_call,
 	struct dcesrv_handle *h;
 	struct samr_domain_state *d_state;
 	struct ldb_message **res;
-	int i, ldb_cnt;
-	uint32_t first, count;
+	int i;
+	unsigned count;
+	unsigned results;
+	unsigned max_entries;
+	unsigned remaining_entries;
+	uint32_t resume_handle;
 	struct samr_SamEntry *entries;
 	const char * const attrs[] = { "objectSid", "sAMAccountName",
 		"userAccountControl", NULL };
+	const char *const cache_attrs[] = {"objectSid", "objectGUID", NULL};
 	struct samr_SamArray *sam;
+	struct samr_guid_cache *cache = NULL;
 
 	*r->out.resume_handle = 0;
 	*r->out.sam = NULL;
@@ -1504,73 +1510,174 @@ static NTSTATUS dcesrv_samr_EnumDomainUsers(struct dcesrv_call_state *dce_call,
 	DCESRV_PULL_HANDLE(h, r->in.domain_handle, SAMR_HANDLE_DOMAIN);
 
 	d_state = h->data;
+	cache = &d_state->guid_caches[SAMR_ENUM_DOMAIN_USERS_CACHE];
 
-	/* search for all domain users in this domain. This could possibly be
-	   cached and resumed on resume_key */
-	ldb_cnt = samdb_search_domain(d_state->sam_ctx, mem_ctx,
-				      d_state->domain_dn,
-				      &res, attrs,
-				      d_state->domain_sid,
-				      "(objectClass=user)");
-	if (ldb_cnt < 0) {
-		return NT_STATUS_INTERNAL_DB_CORRUPTION;
+	/*
+	 * If the resume_handle is zero, query the database and cache the
+	 * matching GUID's
+	 */
+	if (*r->in.resume_handle == 0) {
+		NTSTATUS status;
+		int ldb_cnt;
+		clear_guid_cache(cache);
+		/*
+		 * search for all domain users in this domain.
+		 */
+		ldb_cnt = samdb_search_domain(d_state->sam_ctx,
+					      mem_ctx,
+					      d_state->domain_dn,
+					      &res,
+					      cache_attrs,
+					      d_state->domain_sid,
+					      "(objectClass=user)");
+		if (ldb_cnt < 0) {
+			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		}
+		/*
+		 * Sort the results into RID order, while the spec states there
+		 * is no order, Windows appears to sort the results by RID and
+		 * so it is possible that there are clients that depend on
+		 * this ordering
+		 */
+		TYPESAFE_QSORT(res, ldb_cnt, compare_msgRid);
+
+		/*
+		 * cache the sorted GUID's
+		 */
+		status = load_guid_cache(cache, d_state, ldb_cnt, res);
+		TALLOC_FREE(res);
+		if (!NT_STATUS_IS_OK(status)) {
+			return status;
+		}
+		cache->handle = 0;
 	}
 
-	/* convert to SamEntry format */
-	entries = talloc_array(mem_ctx, struct samr_SamEntry, ldb_cnt);
-	if (!entries) {
-		return NT_STATUS_NO_MEMORY;
+	/*
+	 * If the resume handle is out of range we return an empty response
+	 * and invalidate the cache.
+	 *
+	 * From the specification:
+	 * Servers SHOULD validate that EnumerationContext is an expected
+	 * value for the server's implementation. Windows does NOT validate
+	 * the input, though the result of malformed information merely results
+	 * in inconsistent output to the client.
+	 */
+	if (*r->in.resume_handle >= cache->size) {
+		clear_guid_cache(cache);
+		sam = talloc(mem_ctx, struct samr_SamArray);
+		if (!sam) {
+			return NT_STATUS_NO_MEMORY;
+		}
+		sam->entries = NULL;
+		sam->count = 0;
+
+		*r->out.sam = sam;
+		*r->out.resume_handle = 0;
+		return NT_STATUS_OK;
 	}
 
+	/*
+	 * Calculate the number of entries to return limit by max_size.
+	 * Note that we use the w2k3 element size value of 54
+	 */
+	max_entries = 1 + (r->in.max_size / SAMR_ENUM_USERS_MULTIPLIER);
+	remaining_entries = cache->size - *r->in.resume_handle;
+	results = MIN(remaining_entries, max_entries);
+
+	/*
+	 * Process the list of result GUID's.
+	 * Read the details of each object and populate the Entries
+	 * for the current level.
+	 */
 	count = 0;
+	resume_handle = *r->in.resume_handle;
+	entries = talloc_array(mem_ctx, struct samr_SamEntry, results);
+	if (entries == NULL) {
+		clear_guid_cache(cache);
+		return NT_STATUS_NO_MEMORY;
+	}
+	for (i = 0; i < results; i++) {
+		struct dom_sid *sid;
+		struct ldb_result *rec;
+		const unsigned idx = *r->in.resume_handle + i;
+		int ret;
+		const char *name = NULL;
 
-	for (i=0;i<ldb_cnt;i++) {
-		/* Check if a mask has been requested */
-		if (r->in.acct_flags
-		    && ((samdb_result_acct_flags(res[i], NULL) & r->in.acct_flags) == 0)) {
+		resume_handle++;
+		/*
+		 * Read an object from disk using the GUID as the key
+		 *
+		 * If the object can not be read, or it does not have a SID
+		 * it is ignored.
+		 *
+		 * As a consequence of this, if all the remaining GUID's
+		 * have been deleted an empty result will be returned.
+		 * i.e. even if the previous call returned a non zero
+		 * resume_handle it is possible for no results to be returned.
+		 *
+		 */
+		ret = dsdb_search_by_dn_guid(d_state->sam_ctx,
+					     mem_ctx,
+					     &rec,
+					     &cache->entries[idx],
+					     attrs,
+					     0);
+		if (ret == LDB_ERR_NO_SUCH_OBJECT) {
+			char *guid_str =
+			    GUID_string(mem_ctx, &cache->entries[idx]);
+			DBG_WARNING("GUID [%s] not found\n", guid_str);
+			continue;
+		} else if (ret != LDB_SUCCESS) {
+			clear_guid_cache(cache);
+			return NT_STATUS_INTERNAL_DB_CORRUPTION;
+		}
+		sid = samdb_result_dom_sid(mem_ctx, rec->msgs[0], "objectSID");
+		if (sid == NULL) {
+			char *guid_str =
+			    GUID_string(mem_ctx, &cache->entries[idx]);
+			DBG_WARNING("objectSID for GUID [%s] not found\n",
+				    guid_str);
 			continue;
 		}
-		entries[count].idx = samdb_result_rid_from_sid(mem_ctx, res[i],
-							       "objectSid", 0);
-		entries[count].name.string = ldb_msg_find_attr_as_string(res[i],
-								 "sAMAccountName", "");
-		count += 1;
+		if (r->in.acct_flags &&
+		    ((samdb_result_acct_flags(rec->msgs[0], NULL) &
+		      r->in.acct_flags) == 0)) {
+			continue;
+		}
+		entries[count].idx = samdb_result_rid_from_sid(
+		    mem_ctx, rec->msgs[0], "objectSid", 0);
+		name = ldb_msg_find_attr_as_string(
+		    rec->msgs[0], "sAMAccountName", "");
+		entries[count].name.string = talloc_strdup(entries, name);
+		count++;
 	}
 
-	/* sort the results by rid */
-	TYPESAFE_QSORT(entries, count, compare_SamEntry);
-
-	/* find the first entry to return */
-	for (first=0;
-	     first<count && entries[first].idx <= *r->in.resume_handle;
-	     first++) ;
-
-	/* return the rest, limit by max_size. Note that we
-	   use the w2k3 element size value of 54 */
-	*r->out.num_entries = count - first;
-	*r->out.num_entries = MIN(*r->out.num_entries,
-				 1+(r->in.max_size/SAMR_ENUM_USERS_MULTIPLIER));
-
 	sam = talloc(mem_ctx, struct samr_SamArray);
 	if (!sam) {
+		clear_guid_cache(cache);
 		return NT_STATUS_NO_MEMORY;
 	}
 
-	sam->entries = entries+first;
-	sam->count = *r->out.num_entries;
+	sam->entries = entries;
+	sam->count = count;
 
 	*r->out.sam = sam;
+	*r->out.resume_handle = resume_handle;
+	*r->out.num_entries = count;
 
-	if (first == count) {
+	/*
+	 * Signal no more results by returning zero resume handle,
+	 * the cache is also cleared at this point
+	 */
+	if (*r->out.resume_handle >= cache->size) {
+		*r->out.resume_handle = 0;
+		clear_guid_cache(cache);
 		return NT_STATUS_OK;
 	}
-
-	if (*r->out.num_entries < count - first) {
-		*r->out.resume_handle = entries[first+*r->out.num_entries-1].idx;
-		return STATUS_MORE_ENTRIES;
-	}
-
-	return NT_STATUS_OK;
+	/*
+	 * There are more results to be returned.
+	 */
+	return STATUS_MORE_ENTRIES;
 }
 
 
diff --git a/source4/rpc_server/samr/dcesrv_samr.h b/source4/rpc_server/samr/dcesrv_samr.h
index 04866aca757..d530fc67bd9 100644
--- a/source4/rpc_server/samr/dcesrv_samr.h
+++ b/source4/rpc_server/samr/dcesrv_samr.h
@@ -54,6 +54,7 @@ struct samr_guid_cache {
 enum samr_guid_cache_id {
 	SAMR_QUERY_DISPLAY_INFO_CACHE,
 	SAMR_ENUM_DOMAIN_GROUPS_CACHE,
+	SAMR_ENUM_DOMAIN_USERS_CACHE,
 	SAMR_LAST_CACHE
 };
 
-- 
2.17.1

-------------- next part --------------
A non-text attachment was scrubbed...
Name: signature.asc
Type: application/pgp-signature
Size: 488 bytes
Desc: OpenPGP digital signature
URL: <http://lists.samba.org/pipermail/samba-technical/attachments/20181115/e9aa165c/signature.sig>


More information about the samba-technical mailing list