[PATCH 2/9] source4/scripting/python/samba/samba3: handle ntdb files.

Rusty Russell rusty at rustcorp.com.au
Thu Apr 11 01:42:09 MDT 2013


Upgrading old Samba 3 instances seems like a place where we don't have
to read ntdb files, but Andrew Bartlett points out that you can run a
Samba 4.0 and even a 4.1 'classic' domain and desire to migrate that
to the AD DC.

So make this upgrade code generic: if it finds an ntdb file, read
that, otherwise read the tdb file.

Cc: Jelmer Vernooij <jelmer at samba.org>
Signed-off-by: Rusty Russell <rusty at rustcorp.com.au>
---
 python/samba/samba3/__init__.py |   98 ++++++++++++++++++++-------------------
 python/samba/tests/samba3.py    |    4 +-
 2 files changed, 53 insertions(+), 49 deletions(-)

diff --git a/python/samba/samba3/__init__.py b/python/samba/samba3/__init__.py
index acccff4..0165909 100644
--- a/python/samba/samba3/__init__.py
+++ b/python/samba/samba3/__init__.py
@@ -25,37 +25,41 @@ REGISTRY_DB_VERSION = 1
 import os
 import struct
 import tdb
+import ntdb
 
 import passdb
 import param as s3param
 
 
-def fetch_uint32(tdb, key):
+def fetch_uint32(db, key):
     try:
-        data = tdb[key]
+        data = db[key]
     except KeyError:
         return None
     assert len(data) == 4
     return struct.unpack("<L", data)[0]
 
 
-def fetch_int32(tdb, key):
+def fetch_int32(db, key):
     try:
-        data = tdb[key]
+        data = db[key]
     except KeyError:
         return None
     assert len(data) == 4
     return struct.unpack("<l", data)[0]
 
 
-class TdbDatabase(object):
-    """Simple Samba 3 TDB database reader."""
+class DbDatabase(object):
+    """Simple Samba 3 TDB/NTDB database reader."""
     def __init__(self, file):
         """Open a file.
 
-        :param file: Path of the file to open.
+        :param file: Path of the file to open, appending .tdb or .ntdb.
         """
-        self.tdb = tdb.Tdb(file, flags=os.O_RDONLY)
+        if os.path.exists(file + ".ntdb"):
+            self.db = ntdb.Ntdb(file + ".ntdb", flags=os.O_RDONLY)
+        else:
+            self.db = tdb.Tdb(file + ".tdb", flags=os.O_RDONLY)
         self._check_version()
 
     def _check_version(self):
@@ -63,10 +67,10 @@ class TdbDatabase(object):
 
     def close(self):
         """Close resources associated with this object."""
-        self.tdb.close()
+        self.db.close()
 
 
-class Registry(TdbDatabase):
+class Registry(DbDatabase):
     """Simple read-only support for reading the Samba3 registry.
 
     :note: This object uses the same syntax for registry key paths as
@@ -80,7 +84,7 @@ class Registry(TdbDatabase):
 
     def keys(self):
         """Return list with all the keys."""
-        return [k.rstrip("\x00") for k in self.tdb.iterkeys() if not k.startswith(REGISTRY_VALUE_PREFIX)]
+        return [k.rstrip("\x00") for k in self.db.iterkeys() if not k.startswith(REGISTRY_VALUE_PREFIX)]
 
     def subkeys(self, key):
         """Retrieve the subkeys for the specified key.
@@ -88,7 +92,7 @@ class Registry(TdbDatabase):
         :param key: Key path.
         :return: list with key names
         """
-        data = self.tdb.get("%s\x00" % key)
+        data = self.db.get("%s\x00" % key)
         if data is None:
             return []
         (num, ) = struct.unpack("<L", data[0:4])
@@ -104,7 +108,7 @@ class Registry(TdbDatabase):
         :param key: Key to retrieve values for.
         :return: Dictionary with value names as key, tuple with type and
             data as value."""
-        data = self.tdb.get("%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key))
+        data = self.db.get("%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key))
         if data is None:
             return {}
         ret = {}
@@ -135,15 +139,15 @@ IDMAP_USER_PREFIX = "UID "
 # idmap version determines auto-conversion
 IDMAP_VERSION_V2 = 2
 
-class IdmapDatabase(TdbDatabase):
+class IdmapDatabase(DbDatabase):
     """Samba 3 ID map database reader."""
 
     def _check_version(self):
-        assert fetch_int32(self.tdb, "IDMAP_VERSION\0") == IDMAP_VERSION_V2
+        assert fetch_int32(self.db, "IDMAP_VERSION\0") == IDMAP_VERSION_V2
 
     def ids(self):
         """Retrieve a list of all ids in this database."""
-        for k in self.tdb.iterkeys():
+        for k in self.db.iterkeys():
             if k.startswith(IDMAP_USER_PREFIX):
                 yield k.rstrip("\0").split(" ")
             if k.startswith(IDMAP_GROUP_PREFIX):
@@ -151,13 +155,13 @@ class IdmapDatabase(TdbDatabase):
 
     def uids(self):
         """Retrieve a list of all uids in this database."""
-        for k in self.tdb.iterkeys():
+        for k in self.db.iterkeys():
             if k.startswith(IDMAP_USER_PREFIX):
                 yield int(k[len(IDMAP_USER_PREFIX):].rstrip("\0"))
 
     def gids(self):
         """Retrieve a list of all gids in this database."""
-        for k in self.tdb.iterkeys():
+        for k in self.db.iterkeys():
             if k.startswith(IDMAP_GROUP_PREFIX):
                 yield int(k[len(IDMAP_GROUP_PREFIX):].rstrip("\0"))
 
@@ -167,7 +171,7 @@ class IdmapDatabase(TdbDatabase):
         :param xid: UID or GID to retrive SID for.
         :param id_type: Type of id specified - 'UID' or 'GID'
         """
-        data = self.tdb.get("%s %s\0" % (id_type, str(xid)))
+        data = self.db.get("%s %s\0" % (id_type, str(xid)))
         if data is None:
             return data
         return data.rstrip("\0")
@@ -178,43 +182,43 @@ class IdmapDatabase(TdbDatabase):
         :param uid: UID to retrieve SID for.
         :return: A SID or None if no mapping was found.
         """
-        data = self.tdb.get("%s%d\0" % (IDMAP_USER_PREFIX, uid))
+        data = self.db.get("%s%d\0" % (IDMAP_USER_PREFIX, uid))
         if data is None:
             return data
         return data.rstrip("\0")
 
     def get_group_sid(self, gid):
-        data = self.tdb.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid))
+        data = self.db.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid))
         if data is None:
             return data
         return data.rstrip("\0")
 
     def get_user_hwm(self):
         """Obtain the user high-water mark."""
-        return fetch_uint32(self.tdb, IDMAP_HWM_USER)
+        return fetch_uint32(self.db, IDMAP_HWM_USER)
 
     def get_group_hwm(self):
         """Obtain the group high-water mark."""
-        return fetch_uint32(self.tdb, IDMAP_HWM_GROUP)
+        return fetch_uint32(self.db, IDMAP_HWM_GROUP)
 
 
-class SecretsDatabase(TdbDatabase):
+class SecretsDatabase(DbDatabase):
     """Samba 3 Secrets database reader."""
 
     def get_auth_password(self):
-        return self.tdb.get("SECRETS/AUTH_PASSWORD")
+        return self.db.get("SECRETS/AUTH_PASSWORD")
 
     def get_auth_domain(self):
-        return self.tdb.get("SECRETS/AUTH_DOMAIN")
+        return self.db.get("SECRETS/AUTH_DOMAIN")
 
     def get_auth_user(self):
-        return self.tdb.get("SECRETS/AUTH_USER")
+        return self.db.get("SECRETS/AUTH_USER")
 
     def get_domain_guid(self, host):
-        return self.tdb.get("SECRETS/DOMGUID/%s" % host)
+        return self.db.get("SECRETS/DOMGUID/%s" % host)
 
     def ldap_dns(self):
-        for k in self.tdb.iterkeys():
+        for k in self.db.iterkeys():
             if k.startswith("SECRETS/LDAP_BIND_PW/"):
                 yield k[len("SECRETS/LDAP_BIND_PW/"):].rstrip("\0")
 
@@ -223,59 +227,59 @@ class SecretsDatabase(TdbDatabase):
 
         :return: Iterator over the names of domains in this database.
         """
-        for k in self.tdb.iterkeys():
+        for k in self.db.iterkeys():
             if k.startswith("SECRETS/SID/"):
                 yield k[len("SECRETS/SID/"):].rstrip("\0")
 
     def get_ldap_bind_pw(self, host):
-        return self.tdb.get("SECRETS/LDAP_BIND_PW/%s" % host)
+        return self.db.get("SECRETS/LDAP_BIND_PW/%s" % host)
 
     def get_afs_keyfile(self, host):
-        return self.tdb.get("SECRETS/AFS_KEYFILE/%s" % host)
+        return self.db.get("SECRETS/AFS_KEYFILE/%s" % host)
 
     def get_machine_sec_channel_type(self, host):
-        return fetch_uint32(self.tdb, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host)
+        return fetch_uint32(self.db, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host)
 
     def get_machine_last_change_time(self, host):
-        return fetch_uint32(self.tdb, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host)
+        return fetch_uint32(self.db, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host)
 
     def get_machine_password(self, host):
-        return self.tdb.get("SECRETS/MACHINE_PASSWORD/%s" % host)
+        return self.db.get("SECRETS/MACHINE_PASSWORD/%s" % host)
 
     def get_machine_acc(self, host):
-        return self.tdb.get("SECRETS/$MACHINE.ACC/%s" % host)
+        return self.db.get("SECRETS/$MACHINE.ACC/%s" % host)
 
     def get_domtrust_acc(self, host):
-        return self.tdb.get("SECRETS/$DOMTRUST.ACC/%s" % host)
+        return self.db.get("SECRETS/$DOMTRUST.ACC/%s" % host)
 
     def trusted_domains(self):
-        for k in self.tdb.iterkeys():
+        for k in self.db.iterkeys():
             if k.startswith("SECRETS/$DOMTRUST.ACC/"):
                 yield k[len("SECRETS/$DOMTRUST.ACC/"):].rstrip("\0")
 
     def get_random_seed(self):
-        return self.tdb.get("INFO/random_seed")
+        return self.db.get("INFO/random_seed")
 
     def get_sid(self, host):
-        return self.tdb.get("SECRETS/SID/%s" % host.upper())
+        return self.db.get("SECRETS/SID/%s" % host.upper())
 
 
 SHARE_DATABASE_VERSION_V1 = 1
 SHARE_DATABASE_VERSION_V2 = 2
 
 
-class ShareInfoDatabase(TdbDatabase):
+class ShareInfoDatabase(DbDatabase):
     """Samba 3 Share Info database reader."""
 
     def _check_version(self):
-        assert fetch_int32(self.tdb, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2)
+        assert fetch_int32(self.db, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2)
 
     def get_secdesc(self, name):
         """Obtain the security descriptor on a particular share.
 
         :param name: Name of the share
         """
-        secdesc = self.tdb.get("SECDESC/%s" % name)
+        secdesc = self.db.get("SECDESC/%s" % name)
         # FIXME: Run ndr_pull_security_descriptor
         return secdesc
 
@@ -390,16 +394,16 @@ class Samba3(object):
         return passdb.PDB(self.lp.get('passdb backend'))
 
     def get_registry(self):
-        return Registry(self.statedir_path("registry.tdb"))
+        return Registry(self.statedir_path("registry"))
 
     def get_secrets_db(self):
-        return SecretsDatabase(self.privatedir_path("secrets.tdb"))
+        return SecretsDatabase(self.privatedir_path("secrets"))
 
     def get_shareinfo_db(self):
-        return ShareInfoDatabase(self.statedir_path("share_info.tdb"))
+        return ShareInfoDatabase(self.statedir_path("share_info"))
 
     def get_idmap_db(self):
-        return IdmapDatabase(self.statedir_path("winbindd_idmap.tdb"))
+        return IdmapDatabase(self.statedir_path("winbindd_idmap"))
 
     def get_wins_db(self):
         return WinsDatabase(self.statedir_path("wins.dat"))
diff --git a/python/samba/tests/samba3.py b/python/samba/tests/samba3.py
index 0a7f13c..51d76dd 100644
--- a/python/samba/tests/samba3.py
+++ b/python/samba/tests/samba3.py
@@ -39,7 +39,7 @@ class RegistryTestCase(TestCase):
 
     def setUp(self):
         super(RegistryTestCase, self).setUp()
-        self.registry = Registry(os.path.join(DATADIR, "registry.tdb"))
+        self.registry = Registry(os.path.join(DATADIR, "registry"))
 
     def tearDown(self):
         self.registry.close()
@@ -194,7 +194,7 @@ class IdmapDbTestCase(TestCase):
     def setUp(self):
         super(IdmapDbTestCase, self).setUp()
         self.idmapdb = IdmapDatabase(os.path.join(DATADIR,
-            "winbindd_idmap.tdb"))
+            "winbindd_idmap"))
 
     def test_user_hwm(self):
         self.assertEquals(10000, self.idmapdb.get_user_hwm())
-- 
1.7.10.4



More information about the samba-technical mailing list