Use FilePath everywhere and create new CacheManager module.
authorCameron Dale <camrdale@gmail.com>
Mon, 14 Jan 2008 06:49:16 +0000 (22:49 -0800)
committerCameron Dale <camrdale@gmail.com>
Mon, 14 Jan 2008 06:49:16 +0000 (22:49 -0800)
All cache related things moved from MirrorManager to the new
CacheManager.

apt_dht/AptPackages.py
apt_dht/CacheManager.py [new file with mode: 0644]
apt_dht/HTTPServer.py
apt_dht/MirrorManager.py
apt_dht/apt_dht.py
apt_dht/db.py
apt_dht_Khashmir/DHT.py
apt_dht_Khashmir/khashmir.py

index 1987f8c5fe7e506f0f4d29a07f11f0196201d1ba..11a1c78bd8304ab3297dd9fd8ac45482b9140a21 100644 (file)
@@ -10,6 +10,7 @@ from UserDict import DictMixin
 
 from twisted.internet import threads, defer
 from twisted.python import log
 
 from twisted.internet import threads, defer
 from twisted.python import log
+from twisted.python.filepath import FilePath
 from twisted.trial import unittest
 
 import apt_pkg, apt_inst
 from twisted.trial import unittest
 
 import apt_pkg, apt_inst
@@ -30,15 +31,16 @@ class PackageFileList(DictMixin):
     
     def __init__(self, cache_dir):
         self.cache_dir = cache_dir
     
     def __init__(self, cache_dir):
         self.cache_dir = cache_dir
-        if not os.path.exists(self.cache_dir):
-            os.makedirs(self.cache_dir)
+        self.cache_dir.restat(False)
+        if not self.cache_dir.exists():
+            self.cache_dir.makedirs()
         self.packages = None
         self.open()
 
     def open(self):
         """Open the persistent dictionary of files in this backend."""
         if self.packages is None:
         self.packages = None
         self.open()
 
     def open(self):
         """Open the persistent dictionary of files in this backend."""
         if self.packages is None:
-            self.packages = shelve.open(self.cache_dir+'/packages.db')
+            self.packages = shelve.open(self.cache_dir.child('packages.db').path)
 
     def close(self):
         """Close the persistent dictionary."""
 
     def close(self):
         """Close the persistent dictionary."""
@@ -62,7 +64,8 @@ class PackageFileList(DictMixin):
         """Check all files in the database to make sure they exist."""
         files = self.packages.keys()
         for f in files:
         """Check all files in the database to make sure they exist."""
         files = self.packages.keys()
         for f in files:
-            if not os.path.exists(self.packages[f]):
+            self.packages[f].restat(False)
+            if not self.packages[f].exists():
                 log.msg("File in packages database has been deleted: "+f)
                 del self.packages[f]
 
                 log.msg("File in packages database has been deleted: "+f)
                 del self.packages[f]
 
@@ -124,19 +127,16 @@ class AptPackages:
         self.apt_config = deepcopy(self.DEFAULT_APT_CONFIG)
 
         for dir in self.essential_dirs:
         self.apt_config = deepcopy(self.DEFAULT_APT_CONFIG)
 
         for dir in self.essential_dirs:
-            path = os.path.join(self.cache_dir, dir)
-            if not os.path.exists(path):
-                os.makedirs(path)
+            path = self.cache_dir.preauthChild(dir)
+            if not path.exists():
+                path.makedirs()
         for file in self.essential_files:
         for file in self.essential_files:
-            path = os.path.join(self.cache_dir, file)
-            if not os.path.exists(path):
-                f = open(path,'w')
-                f.close()
-                del f
+            path = self.cache_dir.preauthChild(file)
+            if not path.exists():
+                path.touch()
                 
                 
-        self.apt_config['Dir'] = self.cache_dir
-        self.apt_config['Dir::State::status'] = os.path.join(self.cache_dir, 
-                      self.apt_config['Dir::State'], self.apt_config['Dir::State::status'])
+        self.apt_config['Dir'] = self.cache_dir.path
+        self.apt_config['Dir::State::status'] = self.cache_dir.preauthChild(self.apt_config['Dir::State']).preauthChild(self.apt_config['Dir::State::status']).path
         self.packages = PackageFileList(cache_dir)
         self.loaded = 0
         self.loading = None
         self.packages = PackageFileList(cache_dir)
         self.loaded = 0
         self.loading = None
@@ -152,7 +152,7 @@ class AptPackages:
         self.indexrecords[cache_path] = {}
 
         read_packages = False
         self.indexrecords[cache_path] = {}
 
         read_packages = False
-        f = open(file_path, 'r')
+        f = file_path.open('r')
         
         for line in f:
             line = line.rstrip()
         
         for line in f:
             line = line.rstrip()
@@ -204,13 +204,14 @@ class AptPackages:
         """Regenerates the fake configuration and load the packages cache."""
         if self.loaded: return True
         apt_pkg.InitSystem()
         """Regenerates the fake configuration and load the packages cache."""
         if self.loaded: return True
         apt_pkg.InitSystem()
-        rmtree(os.path.join(self.cache_dir, self.apt_config['Dir::State'], 
-                            self.apt_config['Dir::State::Lists']))
-        os.makedirs(os.path.join(self.cache_dir, self.apt_config['Dir::State'], 
-                                 self.apt_config['Dir::State::Lists'], 'partial'))
-        sources_filename = os.path.join(self.cache_dir, self.apt_config['Dir::Etc'], 
-                                        self.apt_config['Dir::Etc::sourcelist'])
-        sources = open(sources_filename, 'w')
+        self.cache_dir.preauthChild(self.apt_config['Dir::State']
+                     ).preauthChild(self.apt_config['Dir::State::Lists']).remove()
+        self.cache_dir.preauthChild(self.apt_config['Dir::State']
+                     ).preauthChild(self.apt_config['Dir::State::Lists']
+                     ).child('partial').makedirs()
+        sources_file = self.cache_dir.preauthChild(self.apt_config['Dir::Etc']
+                               ).preauthChild(self.apt_config['Dir::Etc::sourcelist'])
+        sources = sources_file.open('w')
         sources_count = 0
         deb_src_added = False
         self.packages.check_files()
         sources_count = 0
         deb_src_added = False
         self.packages.check_files()
@@ -218,9 +219,9 @@ class AptPackages:
         for f in self.packages:
             # we should probably clear old entries from self.packages and
             # take into account the recorded mtime as optimization
         for f in self.packages:
             # we should probably clear old entries from self.packages and
             # take into account the recorded mtime as optimization
-            filepath = self.packages[f]
+            file = self.packages[f]
             if f.split('/')[-1] == "Release":
             if f.split('/')[-1] == "Release":
-                self.addRelease(f, filepath)
+                self.addRelease(f, file)
             fake_uri='http://apt-dht'+f
             fake_dirname = '/'.join(fake_uri.split('/')[:-1])
             if f.endswith('Sources'):
             fake_uri='http://apt-dht'+f
             fake_dirname = '/'.join(fake_uri.split('/')[:-1])
             if f.endswith('Sources'):
@@ -228,26 +229,24 @@ class AptPackages:
                 source_line='deb-src '+fake_dirname+'/ /'
             else:
                 source_line='deb '+fake_dirname+'/ /'
                 source_line='deb-src '+fake_dirname+'/ /'
             else:
                 source_line='deb '+fake_dirname+'/ /'
-            listpath=(os.path.join(self.cache_dir, self.apt_config['Dir::State'], 
-                                   self.apt_config['Dir::State::Lists'], 
-                                   apt_pkg.URItoFileName(fake_uri)))
+            listpath = self.cache_dir.preauthChild(self.apt_config['Dir::State']
+                                    ).preauthChild(self.apt_config['Dir::State::Lists']
+                                    ).child(apt_pkg.URItoFileName(fake_uri))
             sources.write(source_line+'\n')
             log.msg("Sources line: " + source_line)
             sources_count = sources_count + 1
 
             sources.write(source_line+'\n')
             log.msg("Sources line: " + source_line)
             sources_count = sources_count + 1
 
-            try:
+            if listpath.exists():
                 #we should empty the directory instead
                 #we should empty the directory instead
-                os.unlink(listpath)
-            except:
-                pass
-            os.symlink(filepath, listpath)
+                listpath.remove()
+            os.symlink(file.path, listpath.path)
         sources.close()
 
         if sources_count == 0:
         sources.close()
 
         if sources_count == 0:
-            log.msg("No Packages files available for %s backend"%(self.cache_dir))
+            log.msg("No Packages files available for %s backend"%(self.cache_dir.path))
             return False
 
             return False
 
-        log.msg("Loading Packages database for "+self.cache_dir)
+        log.msg("Loading Packages database for "+self.cache_dir.path)
         for key, value in self.apt_config.items():
             apt_pkg.Config[key] = value
 
         for key, value in self.apt_config.items():
             apt_pkg.Config[key] = value
 
@@ -355,7 +354,7 @@ class TestAptPackages(unittest.TestCase):
     releaseFile = ''
     
     def setUp(self):
     releaseFile = ''
     
     def setUp(self):
-        self.client = AptPackages('/tmp/.apt-dht')
+        self.client = AptPackages(FilePath('/tmp/.apt-dht'))
     
         self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
         self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
     
         self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
         self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
@@ -365,11 +364,11 @@ class TestAptPackages(unittest.TestCase):
                 break
         
         self.client.file_updated(self.releaseFile[self.releaseFile.find('_dists_'):].replace('_','/'), 
                 break
         
         self.client.file_updated(self.releaseFile[self.releaseFile.find('_dists_'):].replace('_','/'), 
-                                 '/var/lib/apt/lists/' + self.releaseFile)
+                                 FilePath('/var/lib/apt/lists/' + self.releaseFile))
         self.client.file_updated(self.packagesFile[self.packagesFile.find('_dists_'):].replace('_','/'), 
         self.client.file_updated(self.packagesFile[self.packagesFile.find('_dists_'):].replace('_','/'), 
-                                 '/var/lib/apt/lists/' + self.packagesFile)
+                                 FilePath('/var/lib/apt/lists/' + self.packagesFile))
         self.client.file_updated(self.sourcesFile[self.sourcesFile.find('_dists_'):].replace('_','/'), 
         self.client.file_updated(self.sourcesFile[self.sourcesFile.find('_dists_'):].replace('_','/'), 
-                                 '/var/lib/apt/lists/' + self.sourcesFile)
+                                 FilePath('/var/lib/apt/lists/' + self.sourcesFile))
     
     def test_pkg_hash(self):
         self.client._load()
     
     def test_pkg_hash(self):
         self.client._load()
diff --git a/apt_dht/CacheManager.py b/apt_dht/CacheManager.py
new file mode 100644 (file)
index 0000000..3601619
--- /dev/null
@@ -0,0 +1,241 @@
+
+from bz2 import BZ2Decompressor
+from zlib import decompressobj, MAX_WBITS
+from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
+from urlparse import urlparse
+import os
+
+from twisted.python import log
+from twisted.python.filepath import FilePath
+from twisted.internet import defer
+from twisted.trial import unittest
+from twisted.web2 import stream
+from twisted.web2.http import splitHostPort
+
+from AptPackages import AptPackages
+
+aptpkg_dir='apt-packages'
+
+DECOMPRESS_EXTS = ['.gz', '.bz2']
+DECOMPRESS_FILES = ['release', 'sources', 'packages']
+
+class ProxyFileStream(stream.SimpleStream):
+    """Saves a stream to a file while providing a new stream."""
+    
+    def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
+        """Initializes the proxy.
+        
+        @type stream: C{twisted.web2.stream.IByteStream}
+        @param stream: the input stream to read from
+        @type outFile: C{twisted.python.FilePath}
+        @param outFile: the file to write to
+        @type hash: L{Hash.HashObject}
+        @param hash: the hash object to use for the file
+        @type decompress: C{string}
+        @param decompress: also decompress the file as this type
+            (currently only '.gz' and '.bz2' are supported)
+        @type decFile: C{twisted.python.FilePath}
+        @param decFile: the file to write the decompressed data to
+        """
+        self.stream = stream
+        self.outFile = outFile.open('w')
+        self.hash = hash
+        self.hash.new()
+        self.gzfile = None
+        self.bz2file = None
+        if decompress == ".gz":
+            self.gzheader = True
+            self.gzfile = decFile.open('w')
+            self.gzdec = decompressobj(-MAX_WBITS)
+        elif decompress == ".bz2":
+            self.bz2file = decFile.open('w')
+            self.bz2dec = BZ2Decompressor()
+        self.length = self.stream.length
+        self.start = 0
+        self.doneDefer = defer.Deferred()
+
+    def _done(self):
+        """Close the output file."""
+        if not self.outFile.closed:
+            self.outFile.close()
+            self.hash.digest()
+            if self.gzfile:
+                data_dec = self.gzdec.flush()
+                self.gzfile.write(data_dec)
+                self.gzfile.close()
+                self.gzfile = None
+            if self.bz2file:
+                self.bz2file.close()
+                self.bz2file = None
+                
+            self.doneDefer.callback(self.hash)
+    
+    def read(self):
+        """Read some data from the stream."""
+        if self.outFile.closed:
+            return None
+        
+        data = self.stream.read()
+        if isinstance(data, defer.Deferred):
+            data.addCallbacks(self._write, self._done)
+            return data
+        
+        self._write(data)
+        return data
+    
+    def _write(self, data):
+        """Write the stream data to the file and return it for others to use."""
+        if data is None:
+            self._done()
+            return data
+        
+        self.outFile.write(data)
+        self.hash.update(data)
+        if self.gzfile:
+            if self.gzheader:
+                self.gzheader = False
+                new_data = self._remove_gzip_header(data)
+                dec_data = self.gzdec.decompress(new_data)
+            else:
+                dec_data = self.gzdec.decompress(data)
+            self.gzfile.write(dec_data)
+        if self.bz2file:
+            dec_data = self.bz2dec.decompress(data)
+            self.bz2file.write(dec_data)
+        return data
+    
+    def _remove_gzip_header(self, data):
+        if data[:2] != '\037\213':
+            raise IOError, 'Not a gzipped file'
+        if ord(data[2]) != 8:
+            raise IOError, 'Unknown compression method'
+        flag = ord(data[3])
+        # modtime = self.fileobj.read(4)
+        # extraflag = self.fileobj.read(1)
+        # os = self.fileobj.read(1)
+
+        skip = 10
+        if flag & FEXTRA:
+            # Read & discard the extra field, if present
+            xlen = ord(data[10])
+            xlen = xlen + 256*ord(data[11])
+            skip = skip + 2 + xlen
+        if flag & FNAME:
+            # Read and discard a null-terminated string containing the filename
+            while True:
+                if not data[skip] or data[skip] == '\000':
+                    break
+                skip += 1
+            skip += 1
+        if flag & FCOMMENT:
+            # Read and discard a null-terminated string containing a comment
+            while True:
+                if not data[skip] or data[skip] == '\000':
+                    break
+                skip += 1
+            skip += 1
+        if flag & FHCRC:
+            skip += 2     # Read & discard the 16-bit header CRC
+        return data[skip:]
+
+    def close(self):
+        """Clean everything up and return None to future reads."""
+        self.length = 0
+        self._done()
+        self.stream.close()
+
+class CacheManager:
+    """Manages all requests for cached objects."""
+    
+    def __init__(self, cache_dir, db, manager = None):
+        self.cache_dir = cache_dir
+        self.db = db
+        self.manager = manager
+    
+    def save_file(self, response, hash, url):
+        """Save a downloaded file to the cache and stream it."""
+        if response.code != 200:
+            log.msg('File was not found (%r): %s' % (response, url))
+            return response
+        
+        log.msg('Returning file: %s' % url)
+        
+        parsed = urlparse(url)
+        destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
+        log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
+        
+        if destFile.exists():
+            log.msg('File already exists, removing: %s' % destFile.path)
+            destFile.remove()
+        elif not destFile.parent().exists():
+            destFile.parent().makedirs()
+            
+        root, ext = os.path.splitext(destFile.basename())
+        if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
+            ext = ext.lower()
+            decFile = destFile.sibling(root)
+            log.msg('Decompressing to: %s' % decFile.path)
+            if decFile.exists():
+                log.msg('File already exists, removing: %s' % decFile.path)
+                decFile.remove()
+        else:
+            ext = None
+            decFile = None
+            
+        orig_stream = response.stream
+        response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
+        response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
+                                              response.headers.getHeader('Last-Modified'),
+                                              ext, decFile)
+        response.stream.doneDefer.addErrback(self.save_error, url)
+        return response
+
+    def _save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
+        """Update the modification time and AptPackages."""
+        if modtime:
+            os.utime(destFile.path, (modtime, modtime))
+            if ext:
+                os.utime(decFile.path, (modtime, modtime))
+        
+        result = hash.verify()
+        if result or result is None:
+            if result:
+                log.msg('Hashes match: %s' % url)
+            else:
+                log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
+                
+            urlpath, newdir = self.db.storeFile(destFile, hash.digest(), self.cache_dir)
+            log.msg('now avaliable at %s: %s' % (urlpath, url))
+
+            if self.manager:
+                self.manager.new_cached_file(url, destFile, hash, urlpath)
+                if ext:
+                    self.manager.new_cached_file(url[:-len(ext)], decFile, None, urlpath)
+        else:
+            log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
+            destFile.remove()
+            if ext:
+                decFile.remove()
+
+    def save_error(self, failure, url):
+        """An error has occurred in downloadign or saving the file."""
+        log.msg('Error occurred downloading %s' % url)
+        log.err(failure)
+        return failure
+
+class TestMirrorManager(unittest.TestCase):
+    """Unit tests for the mirror manager."""
+    
+    timeout = 20
+    pending_calls = []
+    client = None
+    
+    def setUp(self):
+        self.client = CacheManager(FilePath('/tmp/.apt-dht'))
+        
+    def tearDown(self):
+        for p in self.pending_calls:
+            if p.active():
+                p.cancel()
+        self.client = None
+        
\ No newline at end of file
index 7e2ac68ac7419355805e7a96dbd6d87247168851..0e53c237ef175bdca4934bf335b05300ca8e337b 100644 (file)
@@ -1,4 +1,3 @@
-import os.path, time
 
 from twisted.python import log
 from twisted.internet import defer
 
 from twisted.python import log
 from twisted.internet import defer
@@ -65,14 +64,14 @@ class TopLevel(resource.Resource):
         name = segments[0]
         if name in self.subdirs:
             log.msg('Sharing %s with %s' % (request.uri, request.remoteAddr))
         name = segments[0]
         if name in self.subdirs:
             log.msg('Sharing %s with %s' % (request.uri, request.remoteAddr))
-            return static.File(self.subdirs[name]), segments[1:]
+            return static.File(self.subdirs[name].path), segments[1:]
         
         if request.remoteAddr.host != "127.0.0.1":
             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
             return None, ()
             
         if len(name) > 1:
         
         if request.remoteAddr.host != "127.0.0.1":
             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
             return None, ()
             
         if len(name) > 1:
-            return FileDownloader(self.directory, self.manager), segments[0:]
+            return FileDownloader(self.directory.path, self.manager), segments[0:]
         else:
             return self, ()
         
         else:
             return self, ()
         
index c41dbe23c8439bda7314794046564c7a1a28940d..84065c9477c7c7afba5447ffde7262b94b7715b7 100644 (file)
 
 
-from bz2 import BZ2Decompressor
-from zlib import decompressobj, MAX_WBITS
-from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
 from urlparse import urlparse
 import os
 
 from urlparse import urlparse
 import os
 
-from twisted.python import log, filepath
+from twisted.python import log
+from twisted.python.filepath import FilePath
 from twisted.internet import defer
 from twisted.trial import unittest
 from twisted.internet import defer
 from twisted.trial import unittest
-from twisted.web2 import stream
 from twisted.web2.http import splitHostPort
 
 from AptPackages import AptPackages
 
 from twisted.web2.http import splitHostPort
 
 from AptPackages import AptPackages
 
-aptpkg_dir='.apt-dht'
-
-DECOMPRESS_EXTS = ['.gz', '.bz2']
-DECOMPRESS_FILES = ['release', 'sources', 'packages']
+aptpkg_dir='apt-packages'
 
 class MirrorError(Exception):
     """Exception raised when there's a problem with the mirror."""
 
 
 class MirrorError(Exception):
     """Exception raised when there's a problem with the mirror."""
 
-class ProxyFileStream(stream.SimpleStream):
-    """Saves a stream to a file while providing a new stream."""
-    
-    def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
-        """Initializes the proxy.
-        
-        @type stream: C{twisted.web2.stream.IByteStream}
-        @param stream: the input stream to read from
-        @type outFile: C{twisted.python.filepath.FilePath}
-        @param outFile: the file to write to
-        @type hash: L{Hash.HashObject}
-        @param hash: the hash object to use for the file
-        @type decompress: C{string}
-        @param decompress: also decompress the file as this type
-            (currently only '.gz' and '.bz2' are supported)
-        @type decFile: C{twisted.python.filepath.FilePath}
-        @param decFile: the file to write the decompressed data to
-        """
-        self.stream = stream
-        self.outFile = outFile.open('w')
-        self.hash = hash
-        self.hash.new()
-        self.gzfile = None
-        self.bz2file = None
-        if decompress == ".gz":
-            self.gzheader = True
-            self.gzfile = decFile.open('w')
-            self.gzdec = decompressobj(-MAX_WBITS)
-        elif decompress == ".bz2":
-            self.bz2file = decFile.open('w')
-            self.bz2dec = BZ2Decompressor()
-        self.length = self.stream.length
-        self.start = 0
-        self.doneDefer = defer.Deferred()
-
-    def _done(self):
-        """Close the output file."""
-        if not self.outFile.closed:
-            self.outFile.close()
-            self.hash.digest()
-            if self.gzfile:
-                data_dec = self.gzdec.flush()
-                self.gzfile.write(data_dec)
-                self.gzfile.close()
-                self.gzfile = None
-            if self.bz2file:
-                self.bz2file.close()
-                self.bz2file = None
-                
-            self.doneDefer.callback(self.hash)
-    
-    def read(self):
-        """Read some data from the stream."""
-        if self.outFile.closed:
-            return None
-        
-        data = self.stream.read()
-        if isinstance(data, defer.Deferred):
-            data.addCallbacks(self._write, self._done)
-            return data
-        
-        self._write(data)
-        return data
-    
-    def _write(self, data):
-        """Write the stream data to the file and return it for others to use."""
-        if data is None:
-            self._done()
-            return data
-        
-        self.outFile.write(data)
-        self.hash.update(data)
-        if self.gzfile:
-            if self.gzheader:
-                self.gzheader = False
-                new_data = self._remove_gzip_header(data)
-                dec_data = self.gzdec.decompress(new_data)
-            else:
-                dec_data = self.gzdec.decompress(data)
-            self.gzfile.write(dec_data)
-        if self.bz2file:
-            dec_data = self.bz2dec.decompress(data)
-            self.bz2file.write(dec_data)
-        return data
-    
-    def _remove_gzip_header(self, data):
-        if data[:2] != '\037\213':
-            raise IOError, 'Not a gzipped file'
-        if ord(data[2]) != 8:
-            raise IOError, 'Unknown compression method'
-        flag = ord(data[3])
-        # modtime = self.fileobj.read(4)
-        # extraflag = self.fileobj.read(1)
-        # os = self.fileobj.read(1)
-
-        skip = 10
-        if flag & FEXTRA:
-            # Read & discard the extra field, if present
-            xlen = ord(data[10])
-            xlen = xlen + 256*ord(data[11])
-            skip = skip + 2 + xlen
-        if flag & FNAME:
-            # Read and discard a null-terminated string containing the filename
-            while True:
-                if not data[skip] or data[skip] == '\000':
-                    break
-                skip += 1
-            skip += 1
-        if flag & FCOMMENT:
-            # Read and discard a null-terminated string containing a comment
-            while True:
-                if not data[skip] or data[skip] == '\000':
-                    break
-                skip += 1
-            skip += 1
-        if flag & FHCRC:
-            skip += 2     # Read & discard the 16-bit header CRC
-        return data[skip:]
-
-    def close(self):
-        """Clean everything up and return None to future reads."""
-        self.length = 0
-        self._done()
-        self.stream.close()
-
 class MirrorManager:
     """Manages all requests for mirror objects."""
     
 class MirrorManager:
     """Manages all requests for mirror objects."""
     
-    def __init__(self, cache_dir, manager = None):
-        self.manager = manager
+    def __init__(self, cache_dir):
         self.cache_dir = cache_dir
         self.cache_dir = cache_dir
-        self.cache = filepath.FilePath(self.cache_dir)
         self.apt_caches = {}
     
     def extractPath(self, url):
         self.apt_caches = {}
     
     def extractPath(self, url):
@@ -190,7 +57,8 @@ class MirrorManager:
             self.apt_caches[site] = {}
             
         if baseDir not in self.apt_caches[site]:
             self.apt_caches[site] = {}
             
         if baseDir not in self.apt_caches[site]:
-            site_cache = os.path.join(self.cache_dir, aptpkg_dir, 'mirrors', site + baseDir.replace('/', '_'))
+            site_cache = self.cache_dir.child(aptpkg_dir).child('mirrors').child(site + baseDir.replace('/', '_'))
+            site_cache.makedirs
             self.apt_caches[site][baseDir] = AptPackages(site_cache)
     
     def updatedFile(self, url, file_path):
             self.apt_caches[site][baseDir] = AptPackages(site_cache)
     
     def updatedFile(self, url, file_path):
@@ -206,73 +74,6 @@ class MirrorManager:
         d.errback(MirrorError("Site Not Found"))
         return d
     
         d.errback(MirrorError("Site Not Found"))
         return d
     
-    def save_file(self, response, hash, url):
-        """Save a downloaded file to the cache and stream it."""
-        if response.code != 200:
-            log.msg('File was not found (%r): %s' % (response, url))
-            return response
-        
-        log.msg('Returning file: %s' % url)
-        
-        parsed = urlparse(url)
-        destFile = self.cache.preauthChild(parsed[1] + parsed[2])
-        log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
-        
-        if destFile.exists():
-            log.msg('File already exists, removing: %s' % destFile.path)
-            destFile.remove()
-        else:
-            destFile.parent().makedirs()
-            
-        root, ext = os.path.splitext(destFile.basename())
-        if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
-            ext = ext.lower()
-            decFile = destFile.sibling(root)
-            log.msg('Decompressing to: %s' % decFile.path)
-            if decFile.exists():
-                log.msg('File already exists, removing: %s' % decFile.path)
-                decFile.remove()
-        else:
-            ext = None
-            decFile = None
-            
-        orig_stream = response.stream
-        response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
-        response.stream.doneDefer.addCallback(self.save_complete, url, destFile,
-                                              response.headers.getHeader('Last-Modified'),
-                                              ext, decFile)
-        response.stream.doneDefer.addErrback(self.save_error, url)
-        return response
-
-    def save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
-        """Update the modification time and AptPackages."""
-        if modtime:
-            os.utime(destFile.path, (modtime, modtime))
-            if ext:
-                os.utime(decFile.path, (modtime, modtime))
-        
-        result = hash.verify()
-        if result or result is None:
-            if result:
-                log.msg('Hashes match: %s' % url)
-            else:
-                log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
-                
-            self.updatedFile(url, destFile.path)
-            if ext:
-                self.updatedFile(url[:-len(ext)], decFile.path)
-            
-            if self.manager:
-                self.manager.cached_file(hash, url, destFile.path)
-        else:
-            log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
-
-    def save_error(self, failure, url):
-        """An error has occurred in downloadign or saving the file."""
-        log.msg('Error occurred downloading %s' % url)
-        log.err(failure)
-        return failure
-
 class TestMirrorManager(unittest.TestCase):
     """Unit tests for the mirror manager."""
     
 class TestMirrorManager(unittest.TestCase):
     """Unit tests for the mirror manager."""
     
@@ -281,7 +82,7 @@ class TestMirrorManager(unittest.TestCase):
     client = None
     
     def setUp(self):
     client = None
     
     def setUp(self):
-        self.client = MirrorManager('/tmp')
+        self.client = MirrorManager(FilePath('/tmp/.apt-dht'))
         
     def test_extractPath(self):
         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org/debian/dists/unstable/Release')
         
     def test_extractPath(self):
         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org/debian/dists/unstable/Release')
@@ -312,13 +113,13 @@ class TestMirrorManager(unittest.TestCase):
                 break
         
         self.client.updatedFile('http://' + self.releaseFile.replace('_','/'), 
                 break
         
         self.client.updatedFile('http://' + self.releaseFile.replace('_','/'), 
-                                '/var/lib/apt/lists/' + self.releaseFile)
+                                FilePath('/var/lib/apt/lists/' + self.releaseFile))
         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
                                 self.packagesFile[self.packagesFile.find('_dists_')+1:].replace('_','/'), 
         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
                                 self.packagesFile[self.packagesFile.find('_dists_')+1:].replace('_','/'), 
-                                '/var/lib/apt/lists/' + self.packagesFile)
+                                FilePath('/var/lib/apt/lists/' + self.packagesFile))
         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
                                 self.sourcesFile[self.sourcesFile.find('_dists_')+1:].replace('_','/'), 
         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
                                 self.sourcesFile[self.sourcesFile.find('_dists_')+1:].replace('_','/'), 
-                                '/var/lib/apt/lists/' + self.sourcesFile)
+                                FilePath('/var/lib/apt/lists/' + self.sourcesFile))
 
         lastDefer = defer.Deferred()
         
 
         lastDefer = defer.Deferred()
         
index f25e2347a39609169f9a855f555646119887be4b..3679685e5f70ad873b2076bad04f21265b6786e8 100644 (file)
@@ -6,26 +6,34 @@ import os, re
 from twisted.internet import defer
 from twisted.web2 import server, http, http_headers
 from twisted.python import log
 from twisted.internet import defer
 from twisted.web2 import server, http, http_headers
 from twisted.python import log
+from twisted.python.filepath import FilePath
 
 from apt_dht_conf import config
 from PeerManager import PeerManager
 from HTTPServer import TopLevel
 from MirrorManager import MirrorManager
 
 from apt_dht_conf import config
 from PeerManager import PeerManager
 from HTTPServer import TopLevel
 from MirrorManager import MirrorManager
+from CacheManager import CacheManager
 from Hash import HashObject
 from db import DB
 from util import findMyIPAddr
 
 from Hash import HashObject
 from db import DB
 from util import findMyIPAddr
 
+download_dir = 'cache'
+
 class AptDHT:
     def __init__(self, dht):
         log.msg('Initializing the main apt_dht application')
 class AptDHT:
     def __init__(self, dht):
         log.msg('Initializing the main apt_dht application')
-        self.db = DB(config.get('DEFAULT', 'cache_dir') + '/.apt-dht.db')
+        self.cache_dir = FilePath(config.get('DEFAULT', 'cache_dir'))
+        if not self.cache_dir.child(download_dir).exists():
+            self.cache_dir.child(download_dir).makedirs()
+        self.db = DB(self.cache_dir.child('apt-dht.db'))
         self.dht = dht
         self.dht.loadConfig(config, config.get('DEFAULT', 'DHT'))
         self.dht.join().addCallbacks(self.joinComplete, self.joinError)
         self.dht = dht
         self.dht.loadConfig(config, config.get('DEFAULT', 'DHT'))
         self.dht.join().addCallbacks(self.joinComplete, self.joinError)
-        self.http_server = TopLevel(config.get('DEFAULT', 'cache_dir'), self)
+        self.http_server = TopLevel(self.cache_dir.child(download_dir), self)
         self.http_site = server.Site(self.http_server)
         self.peers = PeerManager()
         self.http_site = server.Site(self.http_server)
         self.peers = PeerManager()
-        self.mirrors = MirrorManager(config.get('DEFAULT', 'cache_dir'), self)
+        self.mirrors = MirrorManager(self.cache_dir)
+        self.cache = CacheManager(self.cache_dir.child(download_dir), self.db, self)
         self.my_addr = None
     
     def getSite(self):
         self.my_addr = None
     
     def getSite(self):
@@ -39,6 +47,7 @@ class AptDHT:
     def joinError(self, failure):
         log.msg("joining DHT failed miserably")
         log.err(failure)
     def joinError(self, failure):
         log.msg("joining DHT failed miserably")
         log.err(failure)
+        raise RuntimeError, "IP address for this machine could not be found"
     
     def check_freshness(self, path, modtime, resp):
         log.msg('Checking if %s is still fresh' % path)
     
     def check_freshness(self, path, modtime, resp):
         log.msg('Checking if %s is still fresh' % path)
@@ -84,16 +93,16 @@ class AptDHT:
         if not locations:
             log.msg('Peers for %s were not found' % path)
             getDefer = self.peers.get([path])
         if not locations:
             log.msg('Peers for %s were not found' % path)
             getDefer = self.peers.get([path])
-            getDefer.addCallback(self.mirrors.save_file, hash, path)
-            getDefer.addErrback(self.mirrors.save_error, path)
+            getDefer.addCallback(self.cache.save_file, hash, path)
+            getDefer.addErrback(self.cache.save_error, path)
             getDefer.addCallbacks(d.callback, d.errback)
         else:
             log.msg('Found peers for %s: %r' % (path, locations))
             # Download from the found peers
             getDefer = self.peers.get(locations)
             getDefer.addCallback(self.check_response, hash, path)
             getDefer.addCallbacks(d.callback, d.errback)
         else:
             log.msg('Found peers for %s: %r' % (path, locations))
             # Download from the found peers
             getDefer = self.peers.get(locations)
             getDefer.addCallback(self.check_response, hash, path)
-            getDefer.addCallback(self.mirrors.save_file, hash, path)
-            getDefer.addErrback(self.mirrors.save_error, path)
+            getDefer.addCallback(self.cache.save_file, hash, path)
+            getDefer.addErrback(self.cache.save_error, path)
             getDefer.addCallbacks(d.callback, d.errback)
             
     def check_response(self, response, hash, path):
             getDefer.addCallbacks(d.callback, d.errback)
             
     def check_response(self, response, hash, path):
@@ -103,12 +112,10 @@ class AptDHT:
             return getDefer
         return response
         
             return getDefer
         return response
         
-    def cached_file(self, hash, url, file_path):
-        assert file_path.startswith(config.get('DEFAULT', 'cache_dir'))
-        urlpath, newdir = self.db.storeFile(file_path, hash.digest(), config.get('DEFAULT', 'cache_dir'))
-        log.msg('now avaliable at %s: %s' % (urlpath, url))
-
-        if self.my_addr:
+    def new_cached_file(self, url, file_path, hash, urlpath):
+        self.mirrors.updatedFile(url, file_path)
+        
+        if self.my_addr and hash:
             site = self.my_addr + ':' + str(config.getint('DEFAULT', 'PORT'))
             full_path = urlunparse(('http', site, urlpath, None, None, None))
             key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
             site = self.my_addr + ':' + str(config.getint('DEFAULT', 'PORT'))
             full_path = urlunparse(('http', site, urlpath, None, None, None))
             key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
index 9725aa88f10c36e0bd137de1cb5051850d7cde91..1d2e34273ad98161e365ce54ab7b4d5ecb7f76f0 100644 (file)
@@ -5,6 +5,7 @@ from binascii import a2b_base64, b2a_base64
 from time import sleep
 import os
 
 from time import sleep
 import os
 
+from twisted.python.filepath import FilePath
 from twisted.trial import unittest
 
 assert sqlite.version_info >= (2, 1)
 from twisted.trial import unittest
 
 assert sqlite.version_info >= (2, 1)
@@ -25,24 +26,25 @@ class DB:
     
     def __init__(self, db):
         self.db = db
     
     def __init__(self, db):
         self.db = db
-        try:
-            os.stat(db)
-        except OSError:
-            self._createNewDB(db)
+        self.db.restat(False)
+        if self.db.exists():
+            self._loadDB()
         else:
         else:
-            self._loadDB(db)
+            self._createNewDB()
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
-    def _loadDB(self, db):
+    def _loadDB(self):
         try:
         try:
-            self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
+            self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         except:
             import traceback
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
         except:
             import traceback
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
-    def _createNewDB(self, db):
-        self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
+    def _createNewDB(self):
+        if not self.db.parent().exists():
+            self.db.parent().makedirs()
+        self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
         c.execute("CREATE INDEX files_urldir ON files(urldir)")
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
         c.execute("CREATE INDEX files_urldir ON files(urldir)")
@@ -52,49 +54,43 @@ class DB:
         c.close()
         self.conn.commit()
 
         c.close()
         self.conn.commit()
 
-    def _removeChanged(self, path, row):
+    def _removeChanged(self, file, row):
         res = None
         if row:
         res = None
         if row:
-            try:
-                stat = os.stat(path)
-            except:
-                stat = None
-            if stat:
-                res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
+            file.restat(False)
+            if file.exists():
+                res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
             if not res:
                 c = self.conn.cursor()
             if not res:
                 c = self.conn.cursor()
-                c.execute("DELETE FROM files WHERE path = ?", (path, ))
+                c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
                 self.conn.commit()
                 c.close()
         return res
         
                 self.conn.commit()
                 c.close()
         return res
         
-    def storeFile(self, path, hash, directory):
+    def storeFile(self, file, hash, directory):
         """Store or update a file in the database.
         
         @return: the urlpath to access the file, and whether a
             new url top-level directory was needed
         """
         """Store or update a file in the database.
         
         @return: the urlpath to access the file, and whether a
             new url top-level directory was needed
         """
-        path = os.path.abspath(path)
-        directory = os.path.abspath(directory)
-        assert path.startswith(directory)
-        stat = os.stat(path)
+        file.restat()
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (path, ))
+        c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (file.path, ))
         row = c.fetchone()
         if row and directory == row['directory']:
             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
         row = c.fetchone()
         if row and directory == row['directory']:
             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
-                      (khash(hash), stat.st_size, stat.st_mtime, datetime.now()))
+                      (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
             newdir = False
             urldir = row['urldir']
         else:
             urldir, newdir = self.findDirectory(directory)
             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
             newdir = False
             urldir = row['urldir']
         else:
             urldir, newdir = self.findDirectory(directory)
             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
-                      (path, khash(hash), urldir, len(directory), stat.st_size, stat.st_mtime, datetime.now()))
+                      (file.path, khash(hash), urldir, len(directory.path), file.getsize(), file.getmtime(), datetime.now()))
         self.conn.commit()
         c.close()
         self.conn.commit()
         c.close()
-        return '/~' + str(urldir) + path[len(directory):], newdir
+        return '/~' + str(urldir) + file.path[len(directory.path):], newdir
         
         
-    def getFile(self, path):
+    def getFile(self, file):
         """Get a file from the database.
         
         If it has changed or is missing, it is removed from the database.
         """Get a file from the database.
         
         If it has changed or is missing, it is removed from the database.
@@ -102,45 +98,43 @@ class DB:
         @return: dictionary of info for the file, False if changed, or
             None if not in database or missing
         """
         @return: dictionary of info for the file, False if changed, or
             None if not in database or missing
         """
-        path = os.path.abspath(path)
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (path, ))
+        c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (file.path, ))
         row = c.fetchone()
         row = c.fetchone()
-        res = self._removeChanged(path, row)
+        res = self._removeChanged(file, row)
         if res:
             res = {}
             res['hash'] = row['hash']
         if res:
             res = {}
             res['hash'] = row['hash']
-            res['urlpath'] = '/~' + str(row['urldir']) + path[row['dirlength']:]
+            res['size'] = row['size']
+            res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
         c.close()
         return res
         
         c.close()
         return res
         
-    def isUnchanged(self, path):
+    def isUnchanged(self, file):
         """Check if a file in the file system has changed.
         
         If it has changed, it is removed from the table.
         
         @return: True if unchanged, False if changed, None if not in database
         """
         """Check if a file in the file system has changed.
         
         If it has changed, it is removed from the table.
         
         @return: True if unchanged, False if changed, None if not in database
         """
-        path = os.path.abspath(path)
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
+        c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
         row = c.fetchone()
         row = c.fetchone()
-        return self._removeChanged(path, row)
+        return self._removeChanged(file, row)
 
 
-    def refreshFile(self, path):
+    def refreshFile(self, file):
         """Refresh the publishing time of a file.
         
         If it has changed or is missing, it is removed from the table.
         
         @return: True if unchanged, False if changed, None if not in database
         """
         """Refresh the publishing time of a file.
         
         If it has changed or is missing, it is removed from the table.
         
         @return: True if unchanged, False if changed, None if not in database
         """
-        path = os.path.abspath(path)
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
+        c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
         row = c.fetchone()
         row = c.fetchone()
-        res = self._removeChanged(path, row)
+        res = self._removeChanged(file, row)
         if res:
         if res:
-            c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
+            c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
         return res
     
     def expiredFiles(self, expireAfter):
         return res
     
     def expiredFiles(self, expireAfter):
@@ -156,7 +150,7 @@ class DB:
         row = c.fetchone()
         expired = {}
         while row:
         row = c.fetchone()
         expired = {}
         while row:
-            res = self._removeChanged(row['path'], row)
+            res = self._removeChanged(FilePath(row['path']), row)
             if res:
                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
             row = c.fetchone()
             if res:
                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
             row = c.fetchone()
@@ -174,7 +168,7 @@ class DB:
         newdirs = []
         sql = "WHERE"
         for dir in dirs:
         newdirs = []
         sql = "WHERE"
         for dir in dirs:
-            newdirs.append(os.path.abspath(dir) + os.sep + '*')
+            newdirs.append(dir.child('*').path)
             sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
             sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
@@ -183,7 +177,7 @@ class DB:
         row = c.fetchone()
         removed = []
         while row:
         row = c.fetchone()
         removed = []
         while row:
-            removed.append(row['path'])
+            removed.append(FilePath(row['path']))
             row = c.fetchone()
 
         if removed:
             row = c.fetchone()
 
         if removed:
@@ -196,9 +190,8 @@ class DB:
         
         @return: the index of the url directory, and whether it is new or not
         """
         
         @return: the index of the url directory, and whether it is new or not
         """
-        directory = os.path.abspath(directory)
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory, ))
+        c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory.path, ))
         row = c.fetchone()
         c.close()
         if row['urldir']:
         row = c.fetchone()
         c.close()
         if row['urldir']:
@@ -206,7 +199,7 @@ class DB:
 
         # Not found, need to add a new one
         c = self.conn.cursor()
 
         # Not found, need to add a new one
         c = self.conn.cursor()
-        c.execute("INSERT INTO dirs (path) VALUES (?)", (directory, ))
+        c.execute("INSERT INTO dirs (path) VALUES (?)", (directory.path, ))
         self.conn.commit()
         urldir = c.lastrowid
         c.close()
         self.conn.commit()
         urldir = c.lastrowid
         c.close()
@@ -219,7 +212,7 @@ class DB:
         row = c.fetchone()
         dirs = {}
         while row:
         row = c.fetchone()
         dirs = {}
         while row:
-            dirs['~' + str(row['urldir'])] = row['path']
+            dirs['~' + str(row['urldir'])] = FilePath(row['path'])
             row = c.fetchone()
         c.close()
         return dirs
             row = c.fetchone()
         c.close()
         return dirs
@@ -238,23 +231,34 @@ class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
     timeout = 5
     """Tests for the khashmir database."""
     
     timeout = 5
-    db = '/tmp/khashmir.db'
-    path = '/tmp/khashmir.test'
+    db = FilePath('/tmp/khashmir.db')
+    file = FilePath('/tmp/apt-dht/khashmir.test')
     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
-    directory = '/tmp/'
+    directory = FilePath('/tmp/apt-dht/')
     urlpath = '/~1/khashmir.test'
     urlpath = '/~1/khashmir.test'
-    dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
+    testfile = 'tmp/khashmir.test'
+    dirs = [FilePath('/tmp/apt-dht/top1'),
+            FilePath('/tmp/apt-dht/top2/sub1'),
+            FilePath('/tmp/apt-dht/top2/sub2/')]
 
     def setUp(self):
 
     def setUp(self):
-        f = open(self.path, 'w')
-        f.write('fgfhds')
-        f.close()
-        os.utime(self.path, None)
+        if not self.file.parent().exists():
+            self.file.parent().makedirs()
+        self.file.setContent('fgfhds')
+        self.file.touch()
         self.store = DB(self.db)
         self.store = DB(self.db)
-        self.store.storeFile(self.path, self.hash, self.directory)
+        self.store.storeFile(self.file, self.hash, self.directory)
+
+    def test_openExistsingDB(self):
+        self.store.close()
+        self.store = None
+        sleep(1)
+        self.store = DB(self.db)
+        res = self.store.isUnchanged(self.file)
+        self.failUnless(res)
 
     def test_getFile(self):
 
     def test_getFile(self):
-        res = self.store.getFile(self.path)
+        res = self.store.getFile(self.file)
         self.failUnless(res)
         self.failUnlessEqual(res['hash'], self.hash)
         self.failUnlessEqual(res['urlpath'], self.urlpath)
         self.failUnless(res)
         self.failUnlessEqual(res['hash'], self.hash)
         self.failUnlessEqual(res['urlpath'], self.urlpath)
@@ -264,17 +268,17 @@ class TestDB(unittest.TestCase):
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], '~1')
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], '~1')
-        self.failUnlessEqual(res['~1'], os.path.abspath(self.directory))
+        self.failUnlessEqual(res['~1'], self.directory)
         
     def test_isUnchanged(self):
         
     def test_isUnchanged(self):
-        res = self.store.isUnchanged(self.path)
+        res = self.store.isUnchanged(self.file)
         self.failUnless(res)
         sleep(2)
         self.failUnless(res)
         sleep(2)
-        os.utime(self.path, None)
-        res = self.store.isUnchanged(self.path)
+        self.file.touch()
+        res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
         self.failUnless(res == False)
-        os.unlink(self.path)
-        res = self.store.isUnchanged(self.path)
+        self.file.remove()
+        res = self.store.isUnchanged(self.file)
         self.failUnless(res == None)
         
     def test_expiry(self):
         self.failUnless(res == None)
         
     def test_expiry(self):
@@ -286,35 +290,34 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res[self.hash]), 1)
         self.failUnlessEqual(res[self.hash][0], self.urlpath)
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res[self.hash]), 1)
         self.failUnlessEqual(res[self.hash][0], self.urlpath)
-        res = self.store.refreshFile(self.path)
+        res = self.store.refreshFile(self.file)
         self.failUnless(res)
         res = self.store.expiredFiles(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
         self.failUnless(res)
         res = self.store.expiredFiles(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
-            path = os.path.join(dir, self.path[1:])
-            os.makedirs(os.path.dirname(path))
-            f = open(path, 'w')
-            f.write(path)
-            f.close()
-            os.utime(path, None)
-            self.store.storeFile(path, self.hash, dir)
+            file = dir.preauthChild(self.testfile)
+            if not file.parent().exists():
+                file.parent().makedirs()
+            file.setContent(file.path)
+            file.touch()
+            self.store.storeFile(file, self.hash, dir)
     
     def test_removeUntracked(self):
         self.build_dirs()
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
     
     def test_removeUntracked(self):
         self.build_dirs()
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
-        self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
-        self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[:1])
         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[:1])
         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
-        self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
-        self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
+        self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
+        self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
         
     def test_reconcileDirectories(self):
         self.build_dirs()
         
     def test_reconcileDirectories(self):
         self.build_dirs()
@@ -338,17 +341,13 @@ class TestDB(unittest.TestCase):
         res = self.store.getAllDirectories()
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 1)
         res = self.store.getAllDirectories()
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 1)
-        res = self.store.removeUntrackedFiles(['/what'])
+        res = self.store.removeUntrackedFiles([FilePath('/what')])
         res = self.store.reconcileDirectories()
         self.failUnlessEqual(res, True)
         res = self.store.getAllDirectories()
         self.failUnlessEqual(len(res.keys()), 0)
         
     def tearDown(self):
         res = self.store.reconcileDirectories()
         self.failUnlessEqual(res, True)
         res = self.store.getAllDirectories()
         self.failUnlessEqual(len(res.keys()), 0)
         
     def tearDown(self):
-        for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
-            for name in files:
-                os.remove(os.path.join(root, name))
-            for name in dirs:
-                os.rmdir(os.path.join(root, name))
+        self.directory.remove()
         self.store.close()
         self.store.close()
-        os.unlink(self.db)
+        self.db.remove()
index 235c8d0c7a5264e10f26283352fe6459e964b752..d16f59837811e0bc561718a130c2fe0ddcaf3832 100644 (file)
@@ -11,6 +11,8 @@ from zope.interface import implements
 from apt_dht.interfaces import IDHT
 from khashmir import Khashmir
 
 from apt_dht.interfaces import IDHT
 from khashmir import Khashmir
 
+khashmir_dir = 'apt-dht-Khashmir'
+
 class DHTError(Exception):
     """Represents errors that occur in the DHT."""
 
 class DHTError(Exception):
     """Represents errors that occur in the DHT."""
 
@@ -36,7 +38,9 @@ class DHT:
         self.config_parser = config
         self.section = section
         self.config = {}
         self.config_parser = config
         self.section = section
         self.config = {}
-        self.cache_dir = self.config_parser.get(section, 'cache_dir')
+        self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
+        if not os.path.exists(self.cache_dir):
+            os.makedirs(self.cache_dir)
         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
         for k in self.config_parser.options(section):
         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
         for k in self.config_parser.options(section):
@@ -182,7 +186,7 @@ class TestSimpleDHT(unittest.TestCase):
                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
-                    'KE_AGE': 3600, 'SPEW': True, }
+                    'KE_AGE': 3600, 'SPEW': False, }
 
     def setUp(self):
         self.a = DHT()
 
     def setUp(self):
         self.a = DHT()
index 5d14c7bd34a85ba1c76c772695e731d38f06e6a8..ae11dd7ef07158642637df47810aebb6a8c8ee6e 100644 (file)
@@ -30,7 +30,7 @@ class KhashmirBase(protocol.Factory):
     def setup(self, config, cache_dir):
         self.config = config
         self.port = config['PORT']
     def setup(self, config, cache_dir):
         self.config = config
         self.port = config['PORT']
-        self.store = DB(os.path.join(cache_dir, '.khashmir.' + str(self.port) + '.db'))
+        self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
         self.node = self._loadSelfNode('', self.port)
         self.table = KTable(self.node, config)
         #self.app = service.Application("krpc")
         self.node = self._loadSelfNode('', self.port)
         self.table = KTable(self.node, config)
         #self.app = service.Application("krpc")