# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+"""Manage a mirror's index files.
+
+@type TRACKED_FILES: C{list} of C{string}
+@var TRACKED_FILES: the file names of files that contain index information
+"""
+
# Disable the FutureWarning from the apt module
import warnings
warnings.simplefilter("ignore", FutureWarning)
TRACKED_FILES = ['release', 'sources', 'packages']
class PackageFileList(DictMixin):
- """Manages a list of package files belonging to a backend.
+ """Manages a list of index files belonging to a mirror.
+ @type cache_dir: L{twisted.python.filepath.FilePath}
+ @ivar cache_dir: the directory to use for storing all files
@type packages: C{shelve dictionary}
- @ivar packages: the files stored for this backend
+ @ivar packages: the files tracked for this mirror
"""
def __init__(self, cache_dir):
+ """Initialize the list by opening the dictionary."""
self.cache_dir = cache_dir
self.cache_dir.restat(False)
if not self.cache_dir.exists():
self.open()
def open(self):
- """Open the persistent dictionary of files in this backend."""
+ """Open the persistent dictionary of files for this mirror."""
if self.packages is None:
self.packages = shelve.open(self.cache_dir.child('packages.db').path)
Called from the mirror manager when files get updated so we can update our
fake lists and sources.list.
+
+ @type cache_path: C{string}
+ @param cache_path: the location of the file within the mirror
+ @type file_path: L{twisted.python.filepath.FilePath}
+ @param file_path: The location of the file in the file system
+ @rtype: C{boolean}
+ @return: whether the file is an index file
"""
filename = cache_path.split('/')[-1]
if filename.lower() in TRACKED_FILES:
return False
def check_files(self):
- """Check all files in the database to make sure they exist."""
+ """Check all files in the database to remove any that don't exist."""
files = self.packages.keys()
for f in files:
self.packages[f].restat(False)
log.msg("File in packages database has been deleted: "+f)
del self.packages[f]
- # Standard dictionary implementation so this class can be used like a dictionary.
+ #{ Dictionary interface details
def __getitem__(self, key): return self.packages[key]
def __setitem__(self, key, item): self.packages[key] = item
def __delitem__(self, key): del self.packages[key]
def keys(self): return self.packages.keys()
class AptPackages:
- """Uses python-apt to answer queries about packages.
-
- Makes a fake configuration for python-apt for each backend.
+ """Answers queries about packages available from a mirror.
+
+ Uses the python-apt tools to parse and provide information about the
+ files that are available on a single mirror.
+
+ @ivar DEFAULT_APT_CONFIG: the default configuration parameters to use for apt
+ @ivar essential_dirs: directories that must be created for apt to work
+ @ivar essential_files: files that must be created for apt to work
+ @type cache_dir: L{twisted.python.filepath.FilePath}
+ @ivar cache_dir: the directory to use for storing all files
+ @type unload_delay: C{int}
+ @ivar unload_delay: the time to wait before unloading the apt cache
+ @ivar apt_config: the configuration parameters to use for apt
+ @type packages: L{PackageFileList}
+ @ivar packages: the persistent storage of tracked apt index files
+ @type loaded: C{boolean}
+ @ivar loaded: whether the apt cache is currently loaded
+ @type loading: L{twisted.internet.defer.Deferred}
+ @ivar loading: if the cache is currently being loaded, this will be
+ called when it is loaded, otherwise it is None
+ @type unload_later: L{twisted.internet.interfaces.IDelayedCall}
+ @ivar unload_later: the delayed call to unload the apt cache
+ @type indexrecords: C{dictionary}
+ @ivar indexrecords: the hashes of index files for the mirror, keys are
+ mirror directories, values are dictionaries with keys the path to the
+ index file in the mirror directory and values are dictionaries with
+ keys the hash type and values the hash
+ @type cache: C{apt_pkg.GetCache()}
+ @ivar cache: the apt cache of the mirror
+ @type records: C{apt_pkg.GetPkgRecords()}
+ @ivar records: the apt package records for all binary packages in a mirror
+ @type srcrecords: C{apt_pkg.GetPkgSrcRecords}
+ @ivar srcrecords: the apt package records for all source packages in a mirror
"""
DEFAULT_APT_CONFIG = {
def __init__(self, cache_dir, unload_delay):
"""Construct a new packages manager.
- @param cache_dir: cache directory from config file
+ @param cache_dir: directory to use to store files for this mirror
"""
self.cache_dir = cache_dir
self.unload_delay = unload_delay
self.apt_config = deepcopy(self.DEFAULT_APT_CONFIG)
+ # Create the necessary files and directories for apt
for dir in self.essential_dirs:
path = self.cache_dir.preauthChild(dir)
if not path.exists():
self.apt_config['Dir'] = self.cache_dir.path
self.apt_config['Dir::State::status'] = self.cache_dir.preauthChild(self.apt_config['Dir::State']).preauthChild(self.apt_config['Dir::State::status']).path
self.packages = PackageFileList(cache_dir)
- self.loaded = 0
+ self.loaded = False
self.loading = None
self.unload_later = None
self.cleanup()
def addRelease(self, cache_path, file_path):
- """Dirty hack until python-apt supports apt-pkg/indexrecords.h
+ """Add a Release file's info to the list of index files.
+
+ Dirty hack until python-apt supports apt-pkg/indexrecords.h
(see Bug #456141)
"""
self.indexrecords[cache_path] = {}
read_packages = False
f = file_path.open('r')
+ # Use python-debian routines to parse the file for hashes
rel = deb822.Release(f, fields = ['MD5Sum', 'SHA1', 'SHA256'])
for hash_type in rel:
for file in rel[hash_type]:
f.close()
def file_updated(self, cache_path, file_path):
- """A file in the backend has changed, manage it.
+ """A file in the mirror has changed or been added.
- If this affects us, unload our apt database
+ If this affects us, unload our apt database.
+ @see: L{PackageFileList.update_file}
"""
if self.packages.update_file(cache_path, file_path):
self.unload()
def load(self):
- """Make sure the package is initialized and loaded."""
+ """Make sure the package cache is initialized and loaded."""
+ # Reset the pending unload call
if self.unload_later and self.unload_later.active():
self.unload_later.reset(self.unload_delay)
else:
self.unload_later = reactor.callLater(self.unload_delay, self.unload)
+
+ # Make sure it's not already being loaded
if self.loading is None:
log.msg('Loading the packages cache')
self.loading = threads.deferToThread(self._load)
return loadResult
def _load(self):
- """Regenerates the fake configuration and load the packages cache."""
+ """Regenerates the fake configuration and loads the packages caches."""
if self.loaded: return True
+
+ # Modify the default configuration to create the fake one.
apt_pkg.InitSystem()
self.cache_dir.preauthChild(self.apt_config['Dir::State']
).preauthChild(self.apt_config['Dir::State::Lists']).remove()
deb_src_added = False
self.packages.check_files()
self.indexrecords = {}
+
+ # Create an entry in sources.list for each needed index file
for f in self.packages:
# we should probably clear old entries from self.packages and
# take into account the recorded mtime as optimization
else:
self.srcrecords = None
- self.loaded = 1
+ self.loaded = True
return True
def unload(self):
self.unload_later = None
if self.loaded:
log.msg('Unloading the packages cache')
+ # This should save memory
del self.cache
del self.records
del self.srcrecords
del self.indexrecords
- self.loaded = 0
+ self.loaded = False
def cleanup(self):
"""Cleanup and close any loaded caches."""
def findHash(self, path):
"""Find the hash for a given path in this mirror.
- Returns a deferred so it can make sure the cache is loaded first.
+ @type path: C{string}
+ @param path: the path within the mirror of the file to lookup
+ @rtype: L{twisted.internet.defer.Deferred}
+ @return: a deferred so it can make sure the cache is loaded first
"""
d = defer.Deferred()
return d
def _findHash_error(self, failure, path, d):
- """An error occurred while trying to find a hash."""
+ """An error occurred, return an empty hash."""
log.msg('An error occurred while looking up a hash for: %s' % path)
log.err(failure)
d.callback(HashObject())
+ return failure
def _findHash(self, loadResult, path, d):
- """Really find the hash for a path.
+ """Search the records for the hash of a path.
- Have to pass the returned loadResult on in case other calls to this
- function are pending.
+ @type loadResult: C{boolean}
+ @param loadResult: whether apt's cache was successfully loaded
+ @type path: C{string}
+ @param path: the path within the mirror of the file to lookup
+ @type d: L{twisted.internet.defer.Deferred}
+ @param d: the deferred to callback with the result
"""
if not loadResult:
d.callback(HashObject())
return loadResult
+ h = HashObject()
+
# First look for the path in the cache of index files
for release in self.indexrecords:
if path.startswith(release[:-7]):
for indexFile in self.indexrecords[release]:
if release[:-7] + indexFile == path:
- h = HashObject()
h.setFromIndexRecord(self.indexrecords[release][indexFile])
d.callback(h)
return loadResult
for verFile in version.FileList:
if self.records.Lookup(verFile):
if '/' + self.records.FileName == path:
- h = HashObject()
h.setFromPkgRecord(self.records, size)
d.callback(h)
return loadResult
if self.srcrecords.Lookup(package):
for f in self.srcrecords.Files:
if path == '/' + f[2]:
- h = HashObject()
h.setFromSrcRecord(f)
d.callback(h)
return loadResult
- d.callback(HashObject())
+ d.callback(h)
+
+ # Have to pass the returned loadResult on in case other calls to this function are pending.
return loadResult
class TestAptPackages(unittest.TestCase):
releaseFile = ''
def setUp(self):
+ """Initializes the cache with files found in the traditional apt location."""
self.client = AptPackages(FilePath('/tmp/.apt-dht'), 300)
+ # Find the largest index files that are for 'main'
self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
+
+ # Find the Release file corresponding to the found Packages file
for f in os.walk('/var/lib/apt/lists').next()[2]:
if f[-7:] == "Release" and self.packagesFile.startswith(f[:-7]):
self.releaseFile = f
break
-
+
+ # Add all the found files to the PackageFileList
self.client.file_updated(self.releaseFile[self.releaseFile.find('_dists_'):].replace('_','/'),
FilePath('/var/lib/apt/lists/' + self.releaseFile))
self.client.file_updated(self.packagesFile[self.packagesFile.find('_dists_'):].replace('_','/'),
FilePath('/var/lib/apt/lists/' + self.sourcesFile))
def test_pkg_hash(self):
+ """Tests loading the binary package records cache."""
self.client._load()
self.client.records.Lookup(self.client.cache['dpkg'].VersionList[0].FileList[0])
"Hashes don't match: %s != %s" % (self.client.records.SHA1Hash, pkg_hash))
def test_src_hash(self):
+ """Tests loading the source package records cache."""
self.client._load()
self.client.srcrecords.Lookup('dpkg')
self.failUnless(f[0] in src_hashes, "Couldn't find %s in: %r" % (f[0], src_hashes))
def test_index_hash(self):
+ """Tests loading the cache of index file information."""
self.client._load()
indexhash = self.client.indexrecords[self.releaseFile[self.releaseFile.find('_dists_'):].replace('_','/')]['main/binary-i386/Packages.bz2']['SHA1'][0]
"%s hashes don't match: %s != %s" % (path, found_hash.hexexpected(), true_hash))
def test_findIndexHash(self):
+ """Tests finding the hash of a single index file."""
lastDefer = defer.Deferred()
idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' +
return lastDefer
def test_findPkgHash(self):
+ """Tests finding the hash of a single binary package."""
lastDefer = defer.Deferred()
pkg_hash = os.popen('grep -A 30 -E "^Package: dpkg$" ' +
return lastDefer
def test_findSrcHash(self):
+ """Tests finding the hash of a single source package."""
lastDefer = defer.Deferred()
src_dir = '/' + os.popen('grep -A 30 -E "^Package: dpkg$" ' +
return lastDefer
def test_multipleFindHash(self):
+ """Tests finding the hash of an index file, binary package, source package, and another index file."""
lastDefer = defer.Deferred()
+ # Lookup a Packages.bz2 file
idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' +
'/var/lib/apt/lists/' + self.releaseFile +
' | grep -E " main/binary-i386/Packages.bz2$"'
d = self.client.findHash(idx_path)
d.addCallback(self.verifyHash, idx_path, idx_hash)
+ # Lookup the binary 'dpkg' package
pkg_hash = os.popen('grep -A 30 -E "^Package: dpkg$" ' +
'/var/lib/apt/lists/' + self.packagesFile +
' | grep -E "^SHA1:" | head -n 1' +
d = self.client.findHash(pkg_path)
d.addCallback(self.verifyHash, pkg_path, pkg_hash)
+ # Lookup the source 'dpkg' package
src_dir = '/' + os.popen('grep -A 30 -E "^Package: dpkg$" ' +
'/var/lib/apt/lists/' + self.sourcesFile +
' | grep -E "^Directory:" | head -n 1' +
d = self.client.findHash(src_dir + '/' + src_paths[i])
d.addCallback(self.verifyHash, src_dir + '/' + src_paths[i], src_hashes[i])
+ # Lookup a Sources.bz2 file
idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' +
'/var/lib/apt/lists/' + self.releaseFile +
' | grep -E " main/source/Sources.bz2$"'
+"""Manage a cache of downloaded files.
+
+@var DECOMPRESS_EXTS: a list of file extensions that need to be decompressed
+@var DECOMPRESS_FILES: a list of file names that need to be decompressed
+"""
+
from bz2 import BZ2Decompressor
from zlib import decompressobj, MAX_WBITS
from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
from Hash import HashObject
-aptpkg_dir='apt-packages'
-
DECOMPRESS_EXTS = ['.gz', '.bz2']
DECOMPRESS_FILES = ['release', 'sources', 'packages']
class ProxyFileStream(stream.SimpleStream):
- """Saves a stream to a file while providing a new stream."""
+ """Saves a stream to a file while providing a new stream.
+
+ Also optionally decompresses the file while it is being downloaded.
+
+ @type stream: L{twisted.web2.stream.IByteStream}
+ @ivar stream: the input stream being read
+ @type outFile: L{twisted.python.filepath.FilePath}
+ @ivar outFile: the file being written
+ @type hash: L{Hash.HashObject}
+ @ivar hash: the hash object for the file
+ @type gzfile: C{file}
+ @ivar gzfile: the open file to write decompressed gzip data to
+ @type gzdec: L{zlib.decompressobj}
+ @ivar gzdec: the decompressor to use for the compressed gzip data
+ @type gzheader: C{boolean}
+ @ivar gzheader: whether the gzip header still needs to be removed from
+ the zlib compressed data
+ @type bz2file: C{file}
+ @ivar bz2file: the open file to write decompressed bz2 data to
+ @type bz2dec: L{bz2.BZ2Decompressor}
+ @ivar bz2dec: the decompressor to use for the compressed bz2 data
+ @type length: C{int}
+ @ivar length: the length of the original (compressed) file
+ @type doneDefer: L{twisted.internet.defer.Deferred}
+ @ivar doneDefer: the deferred that will fire when done streaming
+
+ @group Stream implementation: read, close
+
+ """
def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
"""Initializes the proxy.
- @type stream: C{twisted.web2.stream.IByteStream}
+ @type stream: L{twisted.web2.stream.IByteStream}
@param stream: the input stream to read from
- @type outFile: C{twisted.python.FilePath}
+ @type outFile: L{twisted.python.filepath.FilePath}
@param outFile: the file to write to
@type hash: L{Hash.HashObject}
@param hash: the hash object to use for the file
self.bz2file = decFile.open('w')
self.bz2dec = BZ2Decompressor()
self.length = self.stream.length
- self.start = 0
self.doneDefer = defer.Deferred()
def _done(self):
- """Close the output file."""
+ """Close all the output files, return the result."""
if not self.outFile.closed:
self.outFile.close()
self.hash.digest()
if self.gzfile:
+ # Finish the decompression
data_dec = self.gzdec.flush()
self.gzfile.write(data_dec)
self.gzfile.close()
if self.outFile.closed:
return None
+ # Read data from the stream, deal with the possible deferred
data = self.stream.read()
if isinstance(data, defer.Deferred):
data.addCallbacks(self._write, self._done)
return data
def _write(self, data):
- """Write the stream data to the file and return it for others to use."""
+ """Write the stream data to the file and return it for others to use.
+
+ Also optionally decompresses it.
+ """
if data is None:
self._done()
return data
+ # Write and hash the streamed data
self.outFile.write(data)
self.hash.update(data)
+
if self.gzfile:
+ # Decompress the zlib portion of the file
if self.gzheader:
+ # Remove the gzip header junk
self.gzheader = False
new_data = self._remove_gzip_header(data)
dec_data = self.gzdec.decompress(new_data)
dec_data = self.gzdec.decompress(data)
self.gzfile.write(dec_data)
if self.bz2file:
+ # Decompress the bz2 file
dec_data = self.bz2dec.decompress(data)
self.bz2file.write(dec_data)
+
return data
def _remove_gzip_header(self, data):
+ """Remove the gzip header from the zlib compressed data."""
+ # Read, check & discard the header fields
if data[:2] != '\037\213':
raise IOError, 'Not a gzipped file'
if ord(data[2]) != 8:
skip = 10
if flag & FEXTRA:
- # Read & discard the extra field, if present
+ # Read & discard the extra field
xlen = ord(data[10])
xlen = xlen + 256*ord(data[11])
skip = skip + 2 + xlen
skip += 1
if flag & FHCRC:
skip += 2 # Read & discard the 16-bit header CRC
+
return data[skip:]
def close(self):
self.stream.close()
class CacheManager:
- """Manages all requests for cached objects."""
+ """Manages all downloaded files and requests for cached objects.
+
+ @type cache_dir: L{twisted.python.filepath.FilePath}
+ @ivar cache_dir: the directory to use for storing all files
+ @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
+ @ivar other_dirs: the other directories that have shared files in them
+ @type all_dirs: C{list} of L{twisted.python.filepath.FilePath}
+ @ivar all_dirs: all the directories that have cached files in them
+ @type db: L{db.DB}
+ @ivar db: the database to use for tracking files and hashes
+ @type manager: L{apt_dht.AptDHT}
+ @ivar manager: the main program object to send requests to
+ @type scanning: C{list} of L{twisted.python.filepath.FilePath}
+ @ivar scanning: all the directories that are currectly being scanned or waiting to be scanned
+ """
def __init__(self, cache_dir, db, other_dirs = [], manager = None):
+ """Initialize the instance and remove any untracked files from the DB..
+
+ @type cache_dir: L{twisted.python.filepath.FilePath}
+ @param cache_dir: the directory to use for storing all files
+ @type db: L{db.DB}
+ @param db: the database to use for tracking files and hashes
+ @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
+ @param other_dirs: the other directories that have shared files in them
+ (optional, defaults to only using the cache directory)
+ @type manager: L{apt_dht.AptDHT}
+ @param manager: the main program object to send requests to
+ (optional, defaults to not calling back with cached files)
+ """
self.cache_dir = cache_dir
self.other_dirs = other_dirs
self.all_dirs = self.other_dirs[:]
# Init the database, remove old files
self.db.removeUntrackedFiles(self.all_dirs)
-
+ #{ Scanning directories
def scanDirectories(self):
"""Scan the cache directories, hashing new and rehashing changed files."""
assert not self.scanning, "a directory scan is already under way"
self._scanDirectories()
def _scanDirectories(self, result = None, walker = None):
- # Need to start waling a new directory
+ """Walk each directory looking for cached files.
+
+ @param result: the result of a DHT store request, not used (optional)
+ @param walker: the walker to use to traverse the current directory
+ (optional, defaults to creating a new walker from the first
+ directory in the L{CacheManager.scanning} list)
+ """
+ # Need to start walking a new directory
if walker is None:
# If there are any left, get them
if self.scanning:
df.addErrback(log.err)
def _doneHashing(self, result, file, walker):
-
+ """If successful, add the hashed file to the DB and inform the main program."""
if isinstance(result, HashObject):
log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
+
+ # Only set a URL if this is a downloaded file
url = None
if self.scanning[0] == self.cache_dir:
url = 'http:/' + file.path[len(self.cache_dir.path):]
+
+ # Store the hashed file in the database
new_hash = self.db.storeFile(file, result.digest())
+
+ # Tell the main program to handle the new cache file
df = self.manager.new_cached_file(file, result, new_hash, url, True)
if df is None:
reactor.callLater(0, self._scanDirectories, None, walker)
else:
df.addBoth(self._scanDirectories, walker)
else:
+ # Must have returned an error
log.msg('hash check of %s failed' % file.path)
log.err(result)
reactor.callLater(0, self._scanDirectories, None, walker)
+ #{ Downloading files
def save_file(self, response, hash, url):
- """Save a downloaded file to the cache and stream it."""
+ """Save a downloaded file to the cache and stream it.
+
+ @type response: L{twisted.web2.http.Response}
+ @param response: the response from the download
+ @type hash: L{Hash.HashObject}
+ @param hash: the hash object containing the expected hash for the file
+ @param url: the URI of the actual mirror request
+ @rtype: L{twisted.web2.http.Response}
+ @return: the final response from the download
+ """
if response.code != 200:
log.msg('File was not found (%r): %s' % (response, url))
return response
log.msg('Returning file: %s' % url)
-
+
+ # Set the destination path for the file
parsed = urlparse(url)
destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
+ # Make sure there's a free place for the file
if destFile.exists():
log.msg('File already exists, removing: %s' % destFile.path)
destFile.remove()
elif not destFile.parent().exists():
destFile.parent().makedirs()
-
+
+ # Determine whether it needs to be decompressed and how
root, ext = os.path.splitext(destFile.basename())
if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
ext = ext.lower()
ext = None
decFile = None
+ # Create the new stream from the old one.
orig_stream = response.stream
response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
response.headers.getHeader('Last-Modified'),
- ext, decFile)
+ decFile)
response.stream.doneDefer.addErrback(self.save_error, url)
+
+ # Return the modified response with the new stream
return response
- def _save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
- """Update the modification time and AptPackages."""
+ def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
+ """Update the modification time and inform the main program.
+
+ @type hash: L{Hash.HashObject}
+ @param hash: the hash object containing the expected hash for the file
+ @param url: the URI of the actual mirror request
+ @type destFile: C{twisted.python.FilePath}
+ @param destFile: the file where the download was written to
+ @type modtime: C{int}
+ @param modtime: the modified time of the cached file (seconds since epoch)
+ (optional, defaults to not setting the modification time of the file)
+ @type decFile: C{twisted.python.FilePath}
+ @param decFile: the file where the decompressed download was written to
+ (optional, defaults to the file not having been compressed)
+ """
if modtime:
os.utime(destFile.path, (modtime, modtime))
- if ext:
+ if decFile:
os.utime(decFile.path, (modtime, modtime))
result = hash.verify()
if self.manager:
self.manager.new_cached_file(destFile, hash, new_hash, url)
- if ext:
- self.manager.new_cached_file(decFile, None, False, url[:-len(ext)])
+ if decFile:
+ ext_len = len(destFile.path) - len(decFile.path)
+ self.manager.new_cached_file(decFile, None, False, url[:-ext_len])
else:
log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
destFile.remove()
- if ext:
+ if decFile:
decFile.remove()
def save_error(self, failure, url):
+"""Manage all download requests to a single site."""
+
from math import exp
from datetime import datetime, timedelta
self._lastResponse = None
self._responseTimes = []
+ #{ Manage the request queue
def connect(self):
+ """Connect to the peer."""
assert self.closed and not self.connecting
self.connecting = True
d = protocol.ClientCreator(reactor, HTTPClientProtocol, self).connectTCP(self.host, self.port)
d.addCallback(self.connected)
def connected(self, proto):
+ """Begin processing the queued requests."""
self.closed = False
self.connecting = False
self.proto = proto
self.processQueue()
def close(self):
+ """Close the connection to the peer."""
if not self.closed:
self.proto.transport.loseConnection()
def submitRequest(self, request):
+ """Add a new request to the queue.
+
+ @type request: L{twisted.web2.client.http.ClientRequest}
+ @return: deferred that will fire with the completed request
+ """
request.submissionTime = datetime.now()
request.deferRequest = defer.Deferred()
self.request_queue.append(request)
return request.deferRequest
def processQueue(self):
+ """Check the queue to see if new requests can be sent to the peer."""
if not self.request_queue:
return
if self.connecting:
req.deferResponse.addCallbacks(self.requestComplete, self.requestError)
def requestComplete(self, resp):
+ """Process a completed request."""
self._processLastResponse()
req = self.response_queue.pop(0)
log.msg('%s of %s completed with code %d' % (req.method, req.uri, resp.code))
req.deferRequest.callback(resp)
def requestError(self, error):
+ """Process a request that ended with an error."""
self._processLastResponse()
req = self.response_queue.pop(0)
log.msg('Download of %s generated error %r' % (req.uri, error))
log.msg('Hash error from peer (%s, %d): %r' % (self.host, self.port, error))
self._errors += 1
- # The IHTTPClientManager interface functions
+ #{ IHTTPClientManager interface
def clientBusy(self, proto):
+ """Save the busy state."""
self.busy = True
def clientIdle(self, proto):
+ """Try to send a new request."""
self._processLastResponse()
self.busy = False
self.processQueue()
def clientPipelining(self, proto):
+ """Try to send a new request."""
self.pipeline = True
self.processQueue()
def clientGone(self, proto):
+ """Mark sent requests as errors."""
self._processLastResponse()
for req in self.response_queue:
req.deferRequest.errback(ProtocolError('lost connection'))
if self.request_queue:
self.processQueue()
- # The downloading request interface functions
+ #{ Downloading request interface
def setCommonHeaders(self):
+ """Get the common HTTP headers for all requests."""
headers = http_headers.Headers()
headers.setHeader('Host', self.host)
headers.setHeader('User-Agent', 'apt-dht/%s (twisted/%s twisted.web2/%s)' %
return headers
def get(self, path, method="GET", modtime=None):
+ """Add a new request to the queue.
+
+ @type path: C{string}
+ @param path: the path to request from the peer
+ @type method: C{string}
+ @param method: the HTTP method to use, 'GET' or 'HEAD'
+ (optional, defaults to 'GET')
+ @type modtime: C{int}
+ @param modtime: the modification time to use for an 'If-Modified-Since'
+ header, as seconds since the epoch
+ (optional, defaults to not sending that header)
+ """
headers = self.setCommonHeaders()
if modtime:
headers.setHeader('If-Modified-Since', modtime)
return self.submitRequest(ClientRequest(method, path, headers, None))
def getRange(self, path, rangeStart, rangeEnd, method="GET"):
+ """Add a new request with a Range header to the queue.
+
+ @type path: C{string}
+ @param path: the path to request from the peer
+ @type rangeStart: C{int}
+ @param rangeStart: the byte to begin the request at
+ @type rangeEnd: C{int}
+ @param rangeEnd: the byte to end the request at (inclusive)
+ @type method: C{string}
+ @param method: the HTTP method to use, 'GET' or 'HEAD'
+ (optional, defaults to 'GET')
+ """
headers = self.setCommonHeaders()
headers.setHeader('Range', ('bytes', [(rangeStart, rangeEnd)]))
return self.submitRequest(ClientRequest(method, path, headers, None))
- # Functions that return information about the peer
+ #{ Peer information
def isIdle(self):
+ """Check whether the peer is idle or not."""
return not self.busy and not self.request_queue and not self.response_queue
def _processLastResponse(self):
+ """Save the download time of the last request for speed calculations."""
if self._lastResponse is not None:
now = datetime.now()
self._downloadSpeeds.append((now, now - self._lastResponse[0], self._lastResponse[1]))
stream_mod.readStream(resp.stream, print_).addCallback(printdone)
def test_download(self):
+ """Tests a normal download."""
host = 'www.ietf.org'
self.client = Peer(host, 80)
self.timeout = 10
return d
def test_head(self):
+ """Tests a 'HEAD' request."""
host = 'www.ietf.org'
self.client = Peer(host, 80)
self.timeout = 10
return d
def test_multiple_downloads(self):
+ """Tests multiple downloads with queueing and connection closing."""
host = 'www.ietf.org'
self.client = Peer(host, 80)
self.timeout = 120
d.addCallback(self.gotResp, num, expect)
if last:
d.addBoth(lastDefer.callback)
-
+
+ # 3 quick requests
newRequest("/rfc/rfc0006.txt", 1, 1776)
newRequest("/rfc/rfc2362.txt", 2, 159833)
newRequest("/rfc/rfc0801.txt", 3, 40824)
+
+ # This one will probably be queued
self.pending_calls.append(reactor.callLater(1, newRequest, '/rfc/rfc0013.txt', 4, 1070))
+
+ # Connection should still be open, but idle
self.pending_calls.append(reactor.callLater(10, newRequest, '/rfc/rfc0022.txt', 5, 4606))
+
+ #Connection should be closed
self.pending_calls.append(reactor.callLater(30, newRequest, '/rfc/rfc0048.txt', 6, 41696))
self.pending_calls.append(reactor.callLater(31, newRequest, '/rfc/rfc3261.txt', 7, 647976))
self.pending_calls.append(reactor.callLater(32, newRequest, '/rfc/rfc0014.txt', 8, 27))
self.pending_calls.append(reactor.callLater(32, newRequest, '/rfc/rfc0001.txt', 9, 21088))
+
+ # Now it should definitely be closed
self.pending_calls.append(reactor.callLater(62, newRequest, '/rfc/rfc2801.txt', 0, 598794, True))
return lastDefer
def test_multiple_quick_downloads(self):
+ """Tests lots of multiple downloads with queueing."""
host = 'www.ietf.org'
self.client = Peer(host, 80)
self.timeout = 30
log.msg('Response Time is: %r' % self.client.responseTime())
def test_peer_info(self):
+ """Test retrieving the peer info during a download."""
host = 'www.ietf.org'
self.client = Peer(host, 80)
self.timeout = 120
return lastDefer
def test_range(self):
+ """Test a Range request."""
host = 'www.ietf.org'
self.client = Peer(host, 80)
self.timeout = 10
+"""Serve local requests from apt and remote requests from peers."""
+
from urllib import unquote_plus
from binascii import b2a_hex
from apt_dht_Khashmir.bencode import bencode
class FileDownloader(static.File):
+ """Modified to make it suitable for apt requests.
+
+ Tries to find requests in the cache. Found files are first checked for
+ freshness before being sent. Requests for unfound and stale files are
+ forwarded to the main program for downloading.
+
+ @type manager: L{apt_dht.AptDHT}
+ @ivar manager: the main program to query
+ """
def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
self.manager = manager
self.processors, self.indexNames[:])
class FileUploaderStream(stream.FileStream):
+ """Modified to make it suitable for streaming to peers.
+
+ Streams the file is small chunks to make it easier to throttle the
+ streaming to peers.
+
+ @ivar CHUNK_SIZE: the size of chunks of data to send at a time
+ """
CHUNK_SIZE = 4*1024
if length == 0:
self.f = None
return None
+
+ # Remove the SendFileBuffer and mmap use, just use string reads and writes
readSize = min(length, self.CHUNK_SIZE)
class FileUploader(static.File):
+ """Modified to make it suitable for peer requests.
+
+ Uses the modified L{FileUploaderStream} to stream the file for throttling,
+ and doesn't do any listing of directory contents.
+ """
def render(self, req):
if not self.fp.exists():
return responsecode.NOT_FOUND
if self.fp.isdir():
+ # Don't try to render a directory listing
return responsecode.NOT_FOUND
try:
raise
response = http.Response()
+ # Use the modified FileStream
response.stream = FileUploaderStream(f, 0, self.fp.getsize())
for (header, value) in (
return response
class TopLevel(resource.Resource):
+ """The HTTP server for all requests, both from peers and apt.
+
+ @type directory: L{twisted.python.filepath.FilePath}
+ @ivar directory: the directory to check for cached files
+ @type db: L{db.DB}
+ @ivar db: the database to use for looking up files and hashes
+ @type manager: L{apt_dht.AptDHT}
+ @ivar manager: the main program object to send requests to
+ @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
+ @ivar factory: the factory to use to server HTTP requests
+
+ """
+
addSlash = True
def __init__(self, directory, db, manager):
+ """Initialize the instance.
+
+ @type directory: L{twisted.python.filepath.FilePath}
+ @param directory: the directory to check for cached files
+ @type db: L{db.DB}
+ @param db: the database to use for looking up files and hashes
+ @type manager: L{apt_dht.AptDHT}
+ @param manager: the main program object to send requests to
+ """
self.directory = directory
self.db = db
self.manager = manager
self.factory = None
def getHTTPFactory(self):
+ """Initialize and get the factory for this HTTP server."""
if self.factory is None:
self.factory = channel.HTTPFactory(server.Site(self),
**{'maxPipeline': 10,
return self.factory
def render(self, ctx):
+ """Render a web page with descriptive statistics."""
return http.Response(
200,
{'content-type': http_headers.MimeType('text', 'html')},
<p>TODO: eventually some stats will be shown here.</body></html>""")
def locateChild(self, request, segments):
+ """Process the incoming request."""
log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
name = segments[0]
+
+ # If the request is for a shared file (from a peer)
if name == '~':
if len(segments) != 2:
log.msg('Got a malformed request from %s' % request.remoteAddr)
return None, ()
+
+ # Find the file in the database
hash = unquote_plus(segments[1])
files = self.db.lookupHash(hash)
if files:
+ # If it is a file, return it
if 'path' in files[0]:
log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
return FileUploader(files[0]['path'].path), ()
else:
+ # It's not for a file, but for a piece string, so return that
log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
else:
log.msg('Hash could not be found in database: %s' % hash)
-
+
+ # Only local requests (apt) get past this point
if request.remoteAddr.host != "127.0.0.1":
log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
return None, ()
if len(name) > 1:
+ # It's a request from apt
return FileDownloader(self.directory.path, self.manager), segments[0:]
else:
+ # Will render the statistics page
return self, ()
log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
+"""Hash and store hash information for a file.
+
+@var PIECE_SIZE: the piece size to use for hashing pieces of files
+
+"""
+
from binascii import b2a_hex, a2b_hex
import sys
"""An error has occurred while hashing a file."""
class HashObject:
- """Manages hashes and hashing for a file."""
+ """Manages hashes and hashing for a file.
- """The priority ordering of hashes, and how to extract them."""
+ @ivar ORDER: the priority ordering of hashes, and how to extract them
+
+ """
+
ORDER = [ {'name': 'sha1',
'length': 20,
'AptPkgRecord': 'SHA1Hash',
]
def __init__(self, digest = None, size = None, pieces = ''):
+ """Initialize the hash object."""
self.hashTypeNum = 0 # Use the first if nothing else matters
if sys.version_info < (2, 5):
# sha256 is not available in python before 2.5, remove it
self.done = True
self.result = None
- #### Methods for returning the expected hash
- def expected(self):
- """Get the expected hash."""
- return self.expHash
-
- def hexexpected(self):
- """Get the expected hash in hex format."""
- if self.expHex is None and self.expHash is not None:
- self.expHex = b2a_hex(self.expHash)
- return self.expHex
-
- #### Methods for hashing data
+ #{ Hashing data
def new(self, force = False):
"""Generate a new hashing object suitable for hashing a file.
- @param force: set to True to force creating a new hasher even if
+ @param force: set to True to force creating a new object even if
the hash has been verified already
"""
- if self.result is None or force == True:
+ if self.result is None or force:
self.result = None
self.done = False
self.fileHasher = self._new()
self.fileHasher.update(data)
self.size += len(data)
+ def hashInThread(self, file):
+ """Hashes a file in a separate thread, returning a deferred that will callback with the result."""
+ file.restat(False)
+ if not file.exists():
+ df = defer.Deferred()
+ df.errback(HashError("file not found"))
+ return df
+
+ df = threads.deferToThread(self._hashInThread, file)
+ return df
+
+ def _hashInThread(self, file):
+ """Hashes a file, returning itself as the result."""
+ f = file.open()
+ self.new(force = True)
+ data = f.read(4096)
+ while data:
+ self.update(data)
+ data = f.read(4096)
+ self.digest()
+ return self
+
+ #{ Checking hashes of data
def pieceDigests(self):
"""Get the piece hashes of the added file data."""
self.digest()
self.result = (self.fileHash == self.expHash and self.size == self.expSize)
return self.result
- def hashInThread(self, file):
- """Hashes a file in a separate thread, callback with the result."""
- file.restat(False)
- if not file.exists():
- df = defer.Deferred()
- df.errback(HashError("file not found"))
- return df
-
- df = threads.deferToThread(self._hashInThread, file)
- return df
+ #{ Expected hash
+ def expected(self):
+ """Get the expected hash."""
+ return self.expHash
- def _hashInThread(self, file):
- """Hashes a file, returning itself as the result."""
- f = file.open()
- self.new(force = True)
- data = f.read(4096)
- while data:
- self.update(data)
- data = f.read(4096)
- self.digest()
- return self
-
- #### Methods for setting the expected hash
+ def hexexpected(self):
+ """Get the expected hash in hex format."""
+ if self.expHex is None and self.expHash is not None:
+ self.expHex = b2a_hex(self.expHash)
+ return self.expHex
+
+ #{ Setting the expected hash
def set(self, hashType, hashHex, size):
"""Initialize the hash object.
skip = "skippingme"
def test_failure(self):
+ """Tests that the hash object fails when treated badly."""
h = HashObject()
h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
self.failUnlessRaises(HashError, h.digest)
self.failUnlessRaises(HashError, h.update, 'gfgf')
def test_pieces(self):
+ """Tests the hashing of large files into pieces."""
h = HashObject()
h.new()
h.update('1234567890'*120*1024)
self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
def test_sha1(self):
+ """Test hashing using the SHA1 hash."""
h = HashObject()
found = False
for hashType in h.ORDER:
self.failUnless(h.verify() == True)
def test_md5(self):
+ """Test hashing using the MD5 hash."""
h = HashObject()
found = False
for hashType in h.ORDER:
self.failUnless(h.verify() == True)
def test_sha256(self):
+ """Test hashing using the SHA256 hash."""
h = HashObject()
found = False
for hashType in h.ORDER:
+"""Manage the multiple mirrors that may be requested.
+
+@var aptpkg_dir: the name of the directory to use for mirror files
+"""
+
from urlparse import urlparse
import os
"""Exception raised when there's a problem with the mirror."""
class MirrorManager:
- """Manages all requests for mirror objects."""
+ """Manages all requests for mirror information.
+
+ @type cache_dir: L{twisted.python.filepath.FilePath}
+ @ivar cache_dir: the directory to use for storing all files
+ @type unload_delay: C{int}
+ @ivar unload_delay: the time to wait before unloading the apt cache
+ @type apt_caches: C{dictionary}
+ @ivar apt_caches: the avaliable mirrors
+ """
def __init__(self, cache_dir, unload_delay):
self.cache_dir = cache_dir
self.apt_caches = {}
def extractPath(self, url):
+ """Break the full URI down into the site, base directory and path.
+
+ Site is the host and port of the mirror. Base directory is the
+ directory to the mirror location (usually just '/debian'). Path is
+ the remaining path to get to the file.
+
+ E.g. http://ftp.debian.org/debian/dists/sid/binary-i386/Packages.bz2
+ would return ('ftp.debian.org:80', '/debian',
+ '/dists/sid/binary-i386/Packages.bz2').
+
+ @param url: the URI of the file's location on the mirror
+ @rtype: (C{string}, C{string}, C{string})
+ @return: the site, base directory and path to the file
+ """
+ # Extract the host and port
parsed = urlparse(url)
host, port = splitHostPort(parsed[0], parsed[1])
site = host + ":" + str(port)
path = parsed[2]
-
+
+ # Try to find the base directory (most can be found this way)
i = max(path.rfind('/dists/'), path.rfind('/pool/'))
if i >= 0:
baseDir = path[:i]
else:
# Uh oh, this is not good
log.msg("Couldn't find a good base directory for path: %s" % (site + path))
+
+ # Try to find an existing cache that starts with this one
+ # (fallback to using an empty base directory)
baseDir = ''
if site in self.apt_caches:
longest_match = 0
return site, baseDir, path
def init(self, site, baseDir):
+ """Make sure an L{AptPackages} exists for this mirror."""
if site not in self.apt_caches:
self.apt_caches[site] = {}
self.apt_caches[site][baseDir] = AptPackages(site_cache, self.unload_delay)
def updatedFile(self, url, file_path):
+ """A file in the mirror has changed or been added.
+
+ @see: L{AptPackages.PackageFileList.update_file}
+ """
site, baseDir, path = self.extractPath(url)
self.init(site, baseDir)
self.apt_caches[site][baseDir].file_updated(path, file_path)
def findHash(self, url):
+ """Find the hash for a given url.
+
+ @param url: the URI of the file's location on the mirror
+ @rtype: L{twisted.internet.defer.Deferred}
+ @return: a deferred that will fire with the returned L{Hash.HashObject}
+ """
site, baseDir, path = self.extractPath(url)
if site in self.apt_caches and baseDir in self.apt_caches[site]:
return self.apt_caches[site][baseDir].findHash(path)
self.client = MirrorManager(FilePath('/tmp/.apt-dht'), 300)
def test_extractPath(self):
+ """Test extracting the site and base directory from various mirrors."""
site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org/debian/dists/unstable/Release')
self.failUnless(site == "ftp.us.debian.org:80", "no match: %s" % site)
self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
"%s hashes don't match: %s != %s" % (path, found_hash.hexexpected(), true_hash))
def test_findHash(self):
+ """Tests finding the hash of an index file, binary package, source package, and another index file."""
+ # Find the largest index files that are for 'main'
self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
+
+ # Find the Release file corresponding to the found Packages file
for f in os.walk('/var/lib/apt/lists').next()[2]:
if f[-7:] == "Release" and self.packagesFile.startswith(f[:-7]):
self.releaseFile = f
break
+ # Add all the found files to the mirror
self.client.updatedFile('http://' + self.releaseFile.replace('_','/'),
FilePath('/var/lib/apt/lists/' + self.releaseFile))
self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
lastDefer = defer.Deferred()
+ # Lookup a Packages.bz2 file
idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' +
'/var/lib/apt/lists/' + self.releaseFile +
' | grep -E " main/binary-i386/Packages.bz2$"'
d = self.client.findHash(idx_path)
d.addCallback(self.verifyHash, idx_path, idx_hash)
+ # Lookup the binary 'dpkg' package
pkg_hash = os.popen('grep -A 30 -E "^Package: dpkg$" ' +
'/var/lib/apt/lists/' + self.packagesFile +
' | grep -E "^SHA1:" | head -n 1' +
d = self.client.findHash(pkg_path)
d.addCallback(self.verifyHash, pkg_path, pkg_hash)
+ # Lookup the source 'dpkg' package
src_dir = os.popen('grep -A 30 -E "^Package: dpkg$" ' +
'/var/lib/apt/lists/' + self.sourcesFile +
' | grep -E "^Directory:" | head -n 1' +
d = self.client.findHash(src_path)
d.addCallback(self.verifyHash, src_path, src_hashes[i])
+ # Lookup a Sources.bz2 file
idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' +
'/var/lib/apt/lists/' + self.releaseFile +
' | grep -E " main/source/Sources.bz2$"'
+"""Manage a set of peers and the requests to them."""
+
from random import choice
from urlparse import urlparse, urlunparse
from urllib import quote_plus
from util import uncompact
class PeerManager:
+ """Manage a set of peers and the requests to them.
+
+ @type clients: C{dictionary}
+ @ivar clients: the available peers that have been previously contacted
+ """
+
def __init__(self):
+ """Initialize the instance."""
self.clients = {}
def get(self, hash, mirror, peers = [], method="GET", modtime=None):
"""Download from a list of peers or fallback to a mirror.
+ @type hash: L{Hash.HashObject}
+ @param hash: the hash object containing the expected hash for the file
+ @param mirror: the URI of the file on the mirror
@type peers: C{list} of C{string}
- @param peers: a list of the peers where the file can be found
+ @param peers: a list of the peer info where the file can be found
+ (optional, defaults to downloading from the mirror)
+ @type method: C{string}
+ @param method: the HTTP method to use, 'GET' or 'HEAD'
+ (optional, defaults to 'GET')
+ @type modtime: C{int}
+ @param modtime: the modification time to use for an 'If-Modified-Since'
+ header, as seconds since the epoch
+ (optional, defaults to not sending that header)
"""
if peers:
+ # Choose one of the peers at random
compact_peer = choice(peers)
peer = uncompact(compact_peer['c'])
log.msg('Downloading from peer %r' % (peer, ))
return self.getPeer(site, path, method, modtime)
def getPeer(self, site, path, method="GET", modtime=None):
+ """Create a new peer if necessary and forward the request to it.
+
+ @type site: (C{string}, C{int})
+ @param site: the IP address and port of the peer
+ @type path: C{string}
+ @param path: the path to the file on the peer
+ @type method: C{string}
+ @param method: the HTTP method to use, 'GET' or 'HEAD'
+ (optional, defaults to 'GET')
+ @type modtime: C{int}
+ @param modtime: the modification time to use for an 'If-Modified-Since'
+ header, as seconds since the epoch
+ (optional, defaults to not sending that header)
+ """
if site not in self.clients:
self.clients[site] = Peer(site[0], site[1])
return self.clients[site].get(path, method, modtime)
def close(self):
+ """Close all the connections to peers."""
for site in self.clients:
self.clients[site].close()
self.clients = {}
stream_mod.readStream(resp.stream, print_).addCallback(printdone)
def test_download(self):
+ """Tests a normal download."""
self.manager = PeerManager()
self.timeout = 10
return d
def test_head(self):
+ """Tests a 'HEAD' request."""
self.manager = PeerManager()
self.timeout = 10
return d
def test_multiple_downloads(self):
+ """Tests multiple downloads with queueing and connection closing."""
self.manager = PeerManager()
self.timeout = 120
lastDefer = defer.Deferred()
"""The main apt-dht modules.
+To run apt-dht, you probably want to do something like::
+
+ from apt_dht.apt_dht import AptDHT
+ myapp = AptDHT(myDHT)
+
+where myDHT is a DHT that implements interfaces.IDHT.
+
Diagram of the interaction between the given modules::
+---------------+ +-----------------------------------+ +-------------
| AptDHT | | DHT | | Internet
| |--->|join DHT|----|--\
| |--->|loadConfig | | | Another
- | |--->|getValue | | | Peer
+ | |--->|getValue | | | Node
| |--->|storeValue DHT|<---|--/
| |--->|leave | |
| | +-----------------------------------+ |
| |--->|get |--->|get HTTP|----|---> Mirror
| | | |--->|getRange | |
| |--->|close |--->|close HTTP|----|--\
- | | +-------------+ +----------------+ | |
- | | +-----------------------------------+ | | Another
- | | | HTTPServer | | | Peer
- | |--->|getHTTPFactory HTTP|<---|--/
+ | | +-------------+ +----------------+ | | Another
+ | | +-----------------------------------+ | | Peer
+ | | | HTTPServer HTTP|<---|--/
+ | |--->|getHTTPFactory | +-------------
|check_freshness|<---| | +-------------
- | get_resp|<---| | +-------------
- | /----|--->|setDirectories HTTP|<---|HTTP Request
- | | | +-----------------------------------+ |
- | | | +---------------+ +--------------+ | Local Net
- | | | | CacheManager | | ProxyFile- | | (apt)
- | | |--->|scanDirectories| | Stream* | |
- | setDirectories|<---| |--->|__init__ HTTP|--->|HTTP Response
- | |--->|save_file | | | +-------------
+ | get_resp|<---| HTTP|<---|HTTP Request
+ | | +-----------------------------------+ |
+ | | +---------------+ +--------------+ | Local Net
+ | | | CacheManager | | ProxyFile- | | (apt)
+ | |--->|scanDirectories| | Stream* | |
+ | |--->|save_file |--->|__init__ HTTP|--->|HTTP Response
| |--->|save_error | | | +-------------
+ | | | | | | +-------------
|new_cached_file|<---| | | file|--->|write file
| | +---------------+ +--------------+ |
| | +---------------+ +--------------+ | Filesystem
| | | MirrorManager | | AptPackages* | |
- | |--->|updatedFile |--->|file_updated |--->|write file
- | |--->|findHash |--->|findHash | |
+ | |--->|updatedFile |--->|file_updated | |
+ | |--->|findHash |--->|findHash file|<---|read file
+---------------+ +---------------+ +--------------+ +-------------
"""
+"""The main program code.
+
+@var DHT_PIECES: the maximum number of pieces to store with our contact info
+ in the DHT
+@var TORRENT_PIECES: the maximum number of pieces to store as a separate entry
+ in the DHT
+@var download_dir: the name of the directory to use for downloaded files
+
+"""
+
from binascii import b2a_hex
from urlparse import urlunparse
import os, re, sha
download_dir = 'cache'
class AptDHT:
+ """The main code object that does all of the work.
+
+ Contains all of the sub-components that do all the low-level work, and
+ coordinates communication between them.
+
+ @type cache_dir: L{twisted.python.filepath.FilePath}
+ @ivar cache_dir: the directory to use for storing all files
+ @type db: L{db.DB}
+ @ivar db: the database to use for tracking files and hashes
+ @type dht: L{interfaces.IDHT}
+ @ivar dht: the DHT instance to use
+ @type http_server: L{HTTPServer.TopLevel}
+ @ivar http_server: the web server that will handle all requests from apt
+ and from other peers
+ @type peers: L{PeerManager.PeerManager}
+ @ivar peers: the manager of all downloads from mirrors and other peers
+ @type mirrors: L{MirrorManager.MirrorManager}
+ @ivar mirrors: the manager of downloaded information about mirrors which
+ can be queried to get hashes from file names
+ @type cache: L{CacheManager.CacheManager}
+ @ivar cache: the manager of all downloaded files
+ @type my_contact: C{string}
+ @ivar my_contact: the 6-byte compact peer representation of this peer's
+ download information (IP address and port)
+ """
+
def __init__(self, dht):
+ """Initialize all the sub-components.
+
+ @type dht: L{interfaces.IDHT}
+ @param dht: the DHT instance to use
+ """
log.msg('Initializing the main apt_dht application')
self.cache_dir = FilePath(config.get('DEFAULT', 'cache_dir'))
if not self.cache_dir.child(download_dir).exists():
other_dirs = [FilePath(f) for f in config.getstringlist('DEFAULT', 'OTHER_DIRS')]
self.cache = CacheManager(self.cache_dir.child(download_dir), self.db, other_dirs, self)
self.my_contact = None
-
+
+ #{ DHT maintenance
def joinComplete(self, result):
+ """Complete the DHT join process and determine our download information.
+
+ Called by the DHT when the join has been completed with information
+ on the external IP address and port of this peer.
+ """
my_addr = findMyIPAddr(result,
config.getint(config.get('DEFAULT', 'DHT'), 'PORT'),
config.getboolean('DEFAULT', 'LOCAL_OK'))
reactor.callLater(60, self.refreshFiles)
def joinError(self, failure):
+ """Joining the DHT has failed."""
log.msg("joining DHT failed miserably")
log.err(failure)
raise RuntimeError, "IP address for this machine could not be found"
else:
reactor.callLater(60, self.refreshFiles)
- def check_freshness(self, req, path, modtime, resp):
- log.msg('Checking if %s is still fresh' % path)
- d = self.peers.get('', path, method = "HEAD", modtime = modtime)
- d.addCallback(self.check_freshness_done, req, path, resp)
+ #{ Main workflow
+ def check_freshness(self, req, url, modtime, resp):
+ """Send a HEAD to the mirror to check if the response from the cache is still valid.
+
+ @type req: L{twisted.web2.http.Request}
+ @param req: the initial request sent to the HTTP server by apt
+ @param url: the URI of the actual mirror request
+ @type modtime: C{int}
+ @param modtime: the modified time of the cached file (seconds since epoch)
+ @type resp: L{twisted.web2.http.Response}
+ @param resp: the response from the cache to be sent to apt
+ @rtype: L{twisted.internet.defer.Deferred}
+ @return: a deferred that will be called back with the correct response
+ """
+ log.msg('Checking if %s is still fresh' % url)
+ d = self.peers.get('', url, method = "HEAD", modtime = modtime)
+ d.addCallback(self.check_freshness_done, req, url, resp)
return d
- def check_freshness_done(self, resp, req, path, orig_resp):
+ def check_freshness_done(self, resp, req, url, orig_resp):
+ """Process the returned response from the mirror.
+
+ @type resp: L{twisted.web2.http.Response}
+ @param resp: the response from the mirror to the HEAD request
+ @type req: L{twisted.web2.http.Request}
+ @param req: the initial request sent to the HTTP server by apt
+ @param url: the URI of the actual mirror request
+ @type orig_resp: L{twisted.web2.http.Response}
+ @param orig_resp: the response from the cache to be sent to apt
+ """
if resp.code == 304:
- log.msg('Still fresh, returning: %s' % path)
+ log.msg('Still fresh, returning: %s' % url)
return orig_resp
else:
- log.msg('Stale, need to redownload: %s' % path)
- return self.get_resp(req, path)
+ log.msg('Stale, need to redownload: %s' % url)
+ return self.get_resp(req, url)
- def get_resp(self, req, path):
+ def get_resp(self, req, url):
+ """Lookup a hash for the file in the local mirror info.
+
+ Starts the process of getting a response to an uncached apt request.
+
+ @type req: L{twisted.web2.http.Request}
+ @param req: the initial request sent to the HTTP server by apt
+ @param url: the URI of the actual mirror request
+ @rtype: L{twisted.internet.defer.Deferred}
+ @return: a deferred that will be called back with the response
+ """
d = defer.Deferred()
- log.msg('Trying to find hash for %s' % path)
- findDefer = self.mirrors.findHash(path)
+ log.msg('Trying to find hash for %s' % url)
+ findDefer = self.mirrors.findHash(url)
findDefer.addCallbacks(self.findHash_done, self.findHash_error,
- callbackArgs=(req, path, d), errbackArgs=(req, path, d))
+ callbackArgs=(req, url, d), errbackArgs=(req, url, d))
findDefer.addErrback(log.err)
return d
- def findHash_error(self, failure, req, path, d):
+ def findHash_error(self, failure, req, url, d):
+ """Process the error in hash lookup by returning an empty L{HashObject}."""
log.err(failure)
- self.findHash_done(HashObject(), req, path, d)
+ self.findHash_done(HashObject(), req, url, d)
+
+ def findHash_done(self, hash, req, url, d):
+ """Use the returned hash to lookup the file in the cache.
+
+ If the hash was not found, the workflow skips down to download from
+ the mirror (L{lookupHash_done}).
- def findHash_done(self, hash, req, path, d):
+ @type hash: L{Hash.HashObject}
+ @param hash: the hash object containing the expected hash for the file
+ """
if hash.expected() is None:
- log.msg('Hash for %s was not found' % path)
- self.lookupHash_done([], hash, path, d)
+ log.msg('Hash for %s was not found' % url)
+ self.lookupHash_done([], hash, url, d)
else:
- log.msg('Found hash %s for %s' % (hash.hexexpected(), path))
+ log.msg('Found hash %s for %s' % (hash.hexexpected(), url))
# Lookup hash in cache
locations = self.db.lookupHash(hash.expected(), filesOnly = True)
- self.getCachedFile(hash, req, path, d, locations)
+ self.getCachedFile(hash, req, url, d, locations)
- def getCachedFile(self, hash, req, path, d, locations):
+ def getCachedFile(self, hash, req, url, d, locations):
+ """Try to return the file from the cache, otherwise move on to a DHT lookup.
+
+ @type locations: C{list} of C{dictionary}
+ @param locations: the files in the cache that match the hash,
+ the dictionary contains a key 'path' whose value is a
+ L{twisted.python.filepath.FilePath} object for the file.
+ """
if not locations:
- log.msg('Failed to return file from cache: %s' % path)
- self.lookupHash(hash, path, d)
+ log.msg('Failed to return file from cache: %s' % url)
+ self.lookupHash(hash, url, d)
return
# Get the first possible location from the list
# Get it's response
resp = static.File(file.path).renderHTTP(req)
if isinstance(resp, defer.Deferred):
- resp.addBoth(self._getCachedFile, hash, req, path, d, locations)
+ resp.addBoth(self._getCachedFile, hash, req, url, d, locations)
else:
- self._getCachedFile(resp, hash, req, path, d, locations)
+ self._getCachedFile(resp, hash, req, url, d, locations)
- def _getCachedFile(self, resp, hash, req, path, d, locations):
+ def _getCachedFile(self, resp, hash, req, url, d, locations):
+ """Check the returned response to be sure it is valid."""
if isinstance(resp, failure.Failure):
log.msg('Got error trying to get cached file')
log.err()
# Try the next possible location
- self.getCachedFile(hash, req, path, d, locations)
+ self.getCachedFile(hash, req, url, d, locations)
return
log.msg('Cached response: %r' % resp)
d.callback(resp)
else:
# Try the next possible location
- self.getCachedFile(hash, req, path, d, locations)
+ self.getCachedFile(hash, req, url, d, locations)
- def lookupHash(self, hash, path, d):
- log.msg('Looking up hash in DHT for file: %s' % path)
+ def lookupHash(self, hash, url, d):
+ """Lookup the hash in the DHT."""
+ log.msg('Looking up hash in DHT for file: %s' % url)
key = hash.expected()
lookupDefer = self.dht.getValue(key)
- lookupDefer.addCallback(self.lookupHash_done, hash, path, d)
+ lookupDefer.addCallback(self.lookupHash_done, hash, url, d)
- def lookupHash_done(self, values, hash, path, d):
+ def lookupHash_done(self, values, hash, url, d):
+ """Start the download of the file.
+
+ The download will be from peers if the DHT lookup succeeded, or
+ from the mirror otherwise.
+
+ @type values: C{list} of C{dictionary}
+ @param values: the returned values from the DHT containing peer
+ download information
+ """
if not values:
- log.msg('Peers for %s were not found' % path)
- getDefer = self.peers.get(hash, path)
- getDefer.addCallback(self.cache.save_file, hash, path)
- getDefer.addErrback(self.cache.save_error, path)
+ log.msg('Peers for %s were not found' % url)
+ getDefer = self.peers.get(hash, url)
+ getDefer.addCallback(self.cache.save_file, hash, url)
+ getDefer.addErrback(self.cache.save_error, url)
getDefer.addCallbacks(d.callback, d.errback)
else:
- log.msg('Found peers for %s: %r' % (path, values))
+ log.msg('Found peers for %s: %r' % (url, values))
# Download from the found peers
- getDefer = self.peers.get(hash, path, values)
- getDefer.addCallback(self.check_response, hash, path)
- getDefer.addCallback(self.cache.save_file, hash, path)
- getDefer.addErrback(self.cache.save_error, path)
+ getDefer = self.peers.get(hash, url, values)
+ getDefer.addCallback(self.check_response, hash, url)
+ getDefer.addCallback(self.cache.save_file, hash, url)
+ getDefer.addErrback(self.cache.save_error, url)
getDefer.addCallbacks(d.callback, d.errback)
- def check_response(self, response, hash, path):
+ def check_response(self, response, hash, url):
+ """Check the response from peers, and download from the mirror if it is not."""
if response.code < 200 or response.code >= 300:
- log.msg('Download from peers failed, going to direct download: %s' % path)
- getDefer = self.peers.get(hash, path)
+ log.msg('Download from peers failed, going to direct download: %s' % url)
+ getDefer = self.peers.get(hash, url)
return getDefer
return response
def new_cached_file(self, file_path, hash, new_hash, url = None, forceDHT = False):
- """Add a newly cached file to the appropriate places.
+ """Add a newly cached file to the mirror info and/or the DHT.
If the file was downloaded, set url to the path it was downloaded for.
Doesn't add a file to the DHT unless a hash was found for it
(but does add it anyway if forceDHT is True).
+
+ @type file_path: L{twisted.python.filepath.FilePath}
+ @param file_path: the location of the file in the local cache
+ @type hash: L{Hash.HashObject}
+ @param hash: the original (expected) hash object containing also the
+ hash of the downloaded file
+ @type new_hash: C{boolean}
+ @param new_hash: whether the has was new to this peer, and so should
+ be added to the DHT
+ @type url: C{string}
+ @param url: the URI of the location of the file in the mirror
+ (optional, defaults to not adding the file to the mirror info)
+ @type forceDHT: C{boolean}
+ @param forceDHT: whether to force addition of the file to the DHT
+ even if the hash was not found in a mirror
+ (optional, defaults to False)
"""
if url:
self.mirrors.updatedFile(url, file_path)
return None
def store(self, hash):
- """Add a file to the DHT."""
+ """Add a key/value pair for the file to the DHT.
+
+ Sets the key and value from the hash information, and tries to add
+ it to the DHT.
+ """
key = hash.digest()
value = {'c': self.my_contact}
pieces = hash.pieceDigests()
+
+ # Determine how to store any piece data
if len(pieces) <= 1:
pass
elif len(pieces) <= DHT_PIECES:
+ # Short enough to be stored with our peer contact info
value['t'] = {'t': ''.join(pieces)}
elif len(pieces) <= TORRENT_PIECES:
+ # Short enough to be stored in a separate key in the DHT
s = sha.new().update(''.join(pieces))
value['h'] = s.digest()
else:
+ # Too long, must be served up by our peer HTTP server
s = sha.new().update(''.join(pieces))
value['l'] = s.digest()
+
storeDefer = self.dht.storeValue(key, value)
storeDefer.addCallback(self.store_done, hash)
return storeDefer
def store_done(self, result, hash):
+ """Add a key/value pair for the pieces of the file to the DHT (if necessary)."""
log.msg('Added %s to the DHT: %r' % (hash.hexdigest(), result))
pieces = hash.pieceDigests()
if len(pieces) > DHT_PIECES and len(pieces) <= TORRENT_PIECES:
+ # Add the piece data key and value to the DHT
s = sha.new().update(''.join(pieces))
key = s.digest()
value = {'t': ''.join(pieces)}
+
storeDefer = self.dht.storeValue(key, value)
storeDefer.addCallback(self.store_torrent_done, key)
return storeDefer
return result
def store_torrent_done(self, result, key):
+ """Adding the file to the DHT is complete, and so is the workflow."""
log.msg('Added torrent string %s to the DHT: %r' % (b2ahex(key), result))
return result
\ No newline at end of file
+"""Loading of configuration files and parameters.
+
+@type version: L{twisted.python.versions.Version}
+@var version: the version of this program
+@type DEFAULT_CONFIG_FILES: C{list} of C{string}
+@var DEFAULT_CONFIG_FILES: the default config files to load (in order)
+@var DEFAULTS: the default config parameter values for the main program
+@var DHT_DEFAULTS: the default config parameter values for the default DHT
+
+"""
+
import os, sys
from ConfigParser import SafeConfigParser
from twisted.python import log, versions
class ConfigError(Exception):
+ """Errors that occur in the loading of configuration variables."""
def __init__(self, message):
self.message = message
def __str__(self):
return repr(self.message)
version = versions.Version('apt-dht', 0, 0, 0)
+
+# Set the home parameter
home = os.path.expandvars('${HOME}')
if home == '${HOME}' or not os.path.isdir(home):
home = os.path.expanduser('~')
if not os.path.isdir(home):
home = os.path.abspath(os.path.dirname(sys.argv[0]))
+
DEFAULT_CONFIG_FILES=['/etc/apt-dht/apt-dht.conf',
home + '/.apt-dht/apt-dht.conf']
}
class AptDHTConfigParser(SafeConfigParser):
+ """Adds 'gettime' and 'getstringlist' to ConfigParser objects.
+
+ @ivar time_multipliers: the 'gettime' suffixes and the multipliers needed
+ to convert them to seconds
"""
- Adds 'gettime' to ConfigParser to interpret the suffixes.
- """
+
time_multipliers={
's': 1, #seconds
'm': 60, #minutes
}
def gettime(self, section, option):
+ """Read the config parameter as a time value."""
mult = 1
value = self.get(section, option)
if len(value) == 0:
mult = self.time_multipliers[suffix]
value = value[:-1]
return int(value)*mult
+
def getstring(self, section, option):
+ """Read the config parameter as a string."""
return self.get(section,option)
+
def getstringlist(self, section, option):
+ """Read the multi-line config parameter as a list of strings."""
return self.get(section,option).split()
+
def optionxform(self, option):
+ """Use all uppercase in the config parameters names."""
return option.upper()
+# Initialize the default config parameters
config = AptDHTConfigParser(DEFAULTS)
config.add_section(config.get('DEFAULT', 'DHT'))
for k in DHT_DEFAULTS:
+"""An sqlite database for storing persistent files and hashes."""
+
from datetime import datetime, timedelta
from pysqlite2 import dbapi2 as sqlite
from binascii import a2b_base64, b2a_base64
assert sqlite.version_info >= (2, 1)
class DBExcept(Exception):
+ """An error occurred in accessing the database."""
pass
class khash(str):
"""Dummy class to convert all hashes to base64 for storing in the DB."""
-
+
+# Initialize the database to work with 'khash' objects (binary strings)
sqlite.register_adapter(khash, b2a_base64)
sqlite.register_converter("KHASH", a2b_base64)
sqlite.register_converter("khash", a2b_base64)
sqlite.enable_callback_tracebacks(True)
class DB:
- """Database access for storing persistent data."""
+ """An sqlite database for storing persistent files and hashes.
+
+ @type db: L{twisted.python.filepath.FilePath}
+ @ivar db: the database file to use
+ @type conn: L{pysqlite2.dbapi2.Connection}
+ @ivar conn: an open connection to the sqlite database
+ """
def __init__(self, db):
+ """Load or create the database file.
+
+ @type db: L{twisted.python.filepath.FilePath}
+ @param db: the database file to use
+ """
self.db = db
self.db.restat(False)
if self.db.exists():
self.conn.row_factory = sqlite.Row
def _loadDB(self):
+ """Open a new connection to the existing database file"""
try:
self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
except:
raise DBExcept, "Couldn't open DB", traceback.format_exc()
def _createNewDB(self):
+ """Open a connection to a new database and create the necessary tables."""
if not self.db.parent().exists():
self.db.parent().makedirs()
self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
self.conn.commit()
def _removeChanged(self, file, row):
+ """If the file has changed or is missing, remove it from the DB.
+
+ @type file: L{twisted.python.filepath.FilePath}
+ @param file: the file to check
+ @type row: C{dictionary}-like object
+ @param row: contains the expected 'size' and 'mtime' of the file
+ @rtype: C{boolean}
+ @return: True if the file is unchanged, False if it is changed,
+ and None if it is missing
+ """
res = None
if row:
file.restat(False)
if file.exists():
+ # Compare the current with the expected file properties
res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
if not res:
+ # Remove the file from the database
c = self.conn.cursor()
c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
self.conn.commit()
def storeFile(self, file, hash, pieces = ''):
"""Store or update a file in the database.
+ @type file: L{twisted.python.filepath.FilePath}
+ @param file: the file to check
+ @type hash: C{string}
+ @param hash: the hash of the file
+ @type pieces: C{string}
+ @param pieces: the concatenated list of the hashes of the pieces of
+ the file (optional, defaults to the empty string)
@return: True if the hash was not in the database before
(so it needs to be added to the DHT)
"""
+ # Hash the pieces to get the piecehash
piecehash = ''
if pieces:
s = sha.new().update(pieces)
piecehash = sha.digest()
+
+ # Check the database for the hash
c = self.conn.cursor()
c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
row = c.fetchone()
new_hash = False
hashID = row['hashID']
else:
+ # Add the new hash to the database
c = self.conn.cursor()
c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
(khash(hash), khash(pieces), khash(piecehash), datetime.now()))
self.conn.commit()
new_hash = True
hashID = c.lastrowid
-
+
+ # Add the file to the database
file.restat()
c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
(file.path, hashID, file.getsize(), file.getmtime()))
If it has changed or is missing, it is removed from the database.
+ @type file: L{twisted.python.filepath.FilePath}
+ @param file: the file to check
@return: dictionary of info for the file, False if changed, or
None if not in database or missing
"""
@return: list of dictionaries of info for the found files
"""
+ # Try to find the hash in the files table
c = self.conn.cursor()
c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
row = c.fetchone()
files = []
while row:
+ # Save the file to the list of found files
file = FilePath(row['path'])
res = self._removeChanged(file, row)
if res:
row = c.fetchone()
if not filesOnly and not files:
+ # No files were found, so check the piecehashes as well
c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
row = c.fetchone()
if row:
def isUnchanged(self, file):
"""Check if a file in the file system has changed.
- If it has changed, it is removed from the table.
+ If it has changed, it is removed from the database.
@return: True if unchanged, False if changed, None if not in database
"""
return self._removeChanged(file, row)
def refreshHash(self, hash):
- """Refresh the publishing time all files with a hash."""
+ """Refresh the publishing time of a hash."""
c = self.conn.cursor()
c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
c.close()
"""
t = datetime.now() - timedelta(seconds=expireAfter)
- # First find the hashes that need refreshing
+ # Find all the hashes that need refreshing
c = self.conn.cursor()
c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
row = c.fetchone()
valid = True
row = c.fetchone()
if not valid:
+ # Remove hashes for which no files are still available
del expired[hash['hash']]
c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
return expired
def removeUntrackedFiles(self, dirs):
- """Find files that are no longer tracked and so should be removed.
-
- Also removes the entries from the table.
+ """Remove files that are no longer tracked by the program.
+ @type dirs: C{list} of L{twisted.python.filepath.FilePath}
+ @param dirs: a list of the directories that we are tracking
@return: list of files that were removed
"""
assert len(dirs) >= 1
+
+ # Create a list of globs and an SQL statement for the directories
newdirs = []
sql = "WHERE"
for dir in dirs:
sql += " path NOT GLOB ? AND"
sql = sql[:-4]
+ # Get a listing of all the files that will be removed
c = self.conn.cursor()
c.execute("SELECT path FROM files " + sql, newdirs)
row = c.fetchone()
removed.append(FilePath(row['path']))
row = c.fetchone()
+ # Delete all the removed files from the database
if removed:
c.execute("DELETE FROM files " + sql, newdirs)
self.conn.commit()
+
return removed
def close(self):
+ """Close the database connection."""
self.conn.close()
class TestDB(unittest.TestCase):
self.store.storeFile(self.file, self.hash)
def test_openExistingDB(self):
+ """Tests opening an existing database."""
self.store.close()
self.store = None
sleep(1)
self.failUnless(res)
def test_getFile(self):
+ """Tests retrieving a file from the database."""
res = self.store.getFile(self.file)
self.failUnless(res)
self.failUnlessEqual(res['hash'], self.hash)
def test_lookupHash(self):
+ """Tests looking up a hash in the database."""
res = self.store.lookupHash(self.hash)
self.failUnless(res)
self.failUnlessEqual(len(res), 1)
self.failUnlessEqual(res[0]['path'].path, self.file.path)
def test_isUnchanged(self):
+ """Tests checking if a file in the database is unchanged."""
res = self.store.isUnchanged(self.file)
self.failUnless(res)
sleep(2)
self.failUnless(res is None)
def test_expiry(self):
+ """Tests retrieving the files from the database that have expired."""
res = self.store.expiredHashes(1)
self.failUnlessEqual(len(res.keys()), 0)
sleep(2)
self.store.storeFile(file, self.hash)
def test_multipleHashes(self):
+ """Tests looking up a hash with multiple files in the database."""
self.build_dirs()
res = self.store.expiredHashes(1)
self.failUnlessEqual(len(res.keys()), 0)
self.failUnlessEqual(len(res.keys()), 0)
def test_removeUntracked(self):
+ """Tests removing untracked files from the database."""
self.build_dirs()
res = self.store.removeUntrackedFiles(self.dirs)
self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
-"""
-Some interfaces that are used by the apt-dht classes.
-
-"""
+"""Some interfaces that are used by the apt-dht classes."""
from zope.interface import Interface
-## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
-# see LICENSE.txt for license information
+
+"""Some utitlity functions for use in the apt-dht program.
+
+@var isLocal: a compiled regular expression suitable for testing if an
+ IP address is from a known local or private range
+"""
import os, re
log.msg("got addrs: %r" % (addrs,))
my_addr = None
+ # Try to find an address using the ifconfig function
try:
ifconfig = os.popen("/sbin/ifconfig |/bin/grep inet|"+
"/usr/bin/awk '{print $2}' | "+
except:
ifconfig = []
- # Get counts for all the non-local addresses returned
+ # Get counts for all the non-local addresses returned from ifconfig
addr_count = {}
for addr in ifconfig:
if local_ok or not isLocal.match(addr):
addr_count.setdefault(addr, 0)
addr_count[addr] += 1
+ # If only one was found, use it as a starting point
local_addrs = addr_count.keys()
if len(local_addrs) == 1:
my_addr = local_addrs[0]
log.msg('Found remote address from ifconfig: %r' % (my_addr,))
- # Get counts for all the non-local addresses returned
+ # Get counts for all the non-local addresses returned from the DHT
addr_count = {}
port_count = {}
for addr in addrs:
popular_count = port_count[port]
elif port_count[port] == popular_count:
popular_port.append(port)
-
+
+ # Check to make sure the port isn't being changed
port = intended_port
if len(port_count.keys()) > 1:
log.msg('Problem, multiple ports have been found: %r' % (port_count,))
else:
log.msg('Port was not found')
+ # If one is popular, use that address
if len(popular_addr) == 1:
log.msg('Found popular address: %r' % (popular_addr[0],))
if my_addr and my_addr != popular_addr[0]:
return my_addr
def ipAddrFromChicken():
+ """Retrieve a possible IP address from the ipchecken website."""
import urllib
ip_search = re.compile('\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}')
try:
return []
def uncompact(s):
- """Extract the contatc info from a compact peer representation.
+ """Extract the contact info from a compact peer representation.
@type s: C{string}
@param s: the compact representation
port = 61234
def test_compact(self):
+ """Make sure compacting is reversed correctly by uncompacting."""
d = uncompact(compact(self.ip, self.port))
self.failUnlessEqual(d[0], self.ip)
self.failUnlessEqual(d[1], self.port)
+"""The main interface to the Khashmir DHT.
+
+@var khashmir_dir: the name of the directory to use for DHT files
+"""
+
from datetime import datetime
import os, sha, random
"""Represents errors that occur in the DHT."""
class DHT:
+ """The main interface instance to the Khashmir DHT.
+
+ @type config: C{dictionary}
+ @ivar config: the DHT configuration values
+ @type cache_dir: C{string}
+ @ivar cache_dir: the directory to use for storing files
+ @type bootstrap: C{list} of C{string}
+ @ivar bootstrap: the nodes to contact to bootstrap into the system
+ @type bootstrap_node: C{boolean}
+ @ivar bootstrap_node: whether this node is a bootstrap node
+ @type joining: L{twisted.internet.defer.Deferred}
+ @ivar joining: if a join is underway, the deferred that will signal it's end
+ @type joined: C{boolean}
+ @ivar joined: whether the DHT network has been successfully joined
+ @type outstandingJoins: C{int}
+ @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
+ @type foundAddrs: C{list} of (C{string}, C{int})
+ @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
+ @type storing: C{dictionary}
+ @ivar storing: keys are keys for which store requests are active, values
+ are dictionaries with keys the values being stored and values the
+ deferred to call when complete
+ @type retrieving: C{dictionary}
+ @ivar retrieving: keys are the keys for which getValue requests are active,
+ values are lists of the deferreds waiting for the requests
+ @type retrieved: C{dictionary}
+ @ivar retrieved: keys are the keys for which getValue requests are active,
+ values are list of the values returned so far
+ @type config_parser: L{apt_dht.apt_dht_conf.AptDHTConfigParser}
+ @ivar config_parser: the configuration info for the main program
+ @type section: C{string}
+ @ivar section: the section of the configuration info that applies to the DHT
+ @type khashmir: L{khashmir.Khashmir}
+ @ivar khashmir: the khashmir DHT instance to use
+ """
implements(IDHT)
def __init__(self):
+ """Initialize the DHT."""
self.config = None
self.cache_dir = ''
self.bootstrap = []
self.config_parser = config
self.section = section
self.config = {}
+
+ # Get some initial values
self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
for k in self.config_parser.options(section):
+ # The numbers in the config file
if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY',
'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
self.config[k] = self.config_parser.getint(section, k)
+ # The times in the config file
elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL',
'BUCKET_STALENESS', 'KEY_EXPIRE']:
self.config[k] = self.config_parser.gettime(section, k)
+ # The booleans in the config file
elif k in ['SPEW']:
self.config[k] = self.config_parser.getboolean(section, k)
+ # Everything else is a string
else:
self.config[k] = self.config_parser.get(section, k)
if self.joining:
raise DHTError, "a join is already in progress"
+ # Create the new khashmir instance
self.khashmir = Khashmir(self.config, self.cache_dir)
self.joining = defer.Deferred()
for node in self.bootstrap:
host, port = node.rsplit(':', 1)
port = int(port)
+
+ # Translate host names into IP addresses
if isIPAddress(host):
self._join_gotIP(host, port)
else:
return self.joining
def _join_gotIP(self, ip, port):
- """Called after an IP address has been found for a single bootstrap node."""
+ """Join the DHT using a single bootstrap nodes IP address."""
self.outstandingJoins += 1
self.khashmir.addContact(ip, port, self._join_single, self._join_error)
def _join_single(self, addr):
- """Called when a single bootstrap node has been added."""
+ """Process the response from the bootstrap node.
+
+ Finish the join by contacting close nodes.
+ """
self.outstandingJoins -= 1
if addr:
self.foundAddrs.append(addr)
log.msg('Got back from bootstrap node: %r' % (addr,))
def _join_error(self, failure = None):
- """Called when a single bootstrap node has failed."""
+ """Process an error in contacting the bootstrap node.
+
+ If no bootstrap nodes remain, finish the process by contacting
+ close nodes.
+ """
self.outstandingJoins -= 1
log.msg("bootstrap node could not be reached")
if self.outstandingJoins <= 0:
self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
def _join_complete(self, result):
- """Called when the tables have been initialized with nodes."""
+ """End the joining process and return the addresses found for this node."""
if not self.joined and len(result) > 0:
self.joined = True
if self.joining and self.outstandingJoins <= 0:
df.errback(DHTError('could not find any nodes to bootstrap to'))
def getAddrs(self):
+ """Get the list of addresses returned by bootstrap nodes for this node."""
return self.foundAddrs
def leave(self):
self.khashmir.shutdown()
def _normKey(self, key, bits=None, bytes=None):
+ """Normalize the length of keys used in the DHT."""
bits = self.config["HASH_LENGTH"]
if bits is not None:
bytes = (bits - 1) // 8 + 1
else:
if bytes is None:
raise DHTError, "you must specify one of bits or bytes for normalization"
+
+ # Extend short keys with null bytes
if len(key) < bytes:
key = key + '\000'*(bytes - len(key))
+ # Truncate long keys
elif len(key) > bytes:
key = key[:bytes]
return key
return d
def _getValue(self, key, result):
+ """Process a returned list of values from the DHT."""
+ # Save the list of values to return when it is complete
if result:
self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
else:
+ # Empty list, the get is complete, return the result
final_result = []
if key in self.retrieved:
final_result = self.retrieved[key]
return d
def _storeValue(self, key, bvalue, result):
+ """Process the response from the DHT."""
if key in self.storing and bvalue in self.storing[key]:
+ # Check if the store succeeded
if len(result) > 0:
self.storing[key][bvalue].callback(result)
else:
del self.storing[key]
class TestSimpleDHT(unittest.TestCase):
- """Unit tests for the DHT."""
+ """Simple 2-node unit tests for the DHT."""
timeout = 2
DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
pass
class TestMultiDHT(unittest.TestCase):
+ """More complicated 20-node tests for the DHT."""
timeout = 60
num = 20
+
+"""The apt-dht implementation of the Khashmir DHT.
+
+These modules implement a modified Khashmir, which is a kademlia-like
+Distributed Hash Table available at::
+
+ http://khashmir.sourceforge.net/
+
+The protocol for the implementation's communication is described here::
+
+ http://www.camrdale.org/apt-dht/protocol.html
+
+To run the DHT you probably want to do something like::
+
+ from apt_dht_Khashmir import DHT
+ myDHT = DHT.DHT()
+ myDHT.loadConfig(config, section)
+ myDHT.join()
+
+at which point you should be up and running and connected to others in the DHT.
+
+"""
## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""Details of how to perform actions on remote peers."""
+
from twisted.internet import reactor
from twisted.python import log
from util import uncompact
class ActionBase:
- """ base class for some long running asynchronous proccesses like finding nodes or values """
+ """Base class for some long running asynchronous proccesses like finding nodes or values.
+
+ @type caller: L{khashmir.Khashmir}
+ @ivar caller: the DHT instance that is performing the action
+ @type target: C{string}
+ @ivar target: the target of the action, usually a DHT key
+ @type config: C{dictionary}
+ @ivar config: the configuration variables for the DHT
+ @type action: C{string}
+ @ivar action: the name of the action to call on remote nodes
+ @type num: C{long}
+ @ivar num: the target key in integer form
+ @type queried: C{dictionary}
+ @ivar queried: the nodes that have been queried for this action,
+ keys are node IDs, values are the node itself
+ @type answered: C{dictionary}
+ @ivar answered: the nodes that have answered the queries
+ @type found: C{dictionary}
+ @ivar found: nodes that have been found so far by the action
+ @type sorted_nodes: C{list} of L{node.Node}
+ @ivar sorted_nodes: a sorted list of nodes by there proximity to the key
+ @type results: C{dictionary}
+ @ivar results: keys are the results found so far by the action
+ @type desired_results: C{int}
+ @ivar desired_results: the minimum number of results that are needed
+ before the action should stop
+ @type callback: C{method}
+ @ivar callback: the method to call with the results
+ @type outstanding: C{int}
+ @ivar outstanding: the number of requests currently outstanding
+ @type outstanding_results: C{int}
+ @ivar outstanding_results: the number of results that are expected from
+ the requests that are currently outstanding
+ @type finished: C{boolean}
+ @ivar finished: whether the action is done
+ @type sort: C{method}
+ @ivar sort: used to sort nodes by their proximity to the target
+ """
+
def __init__(self, caller, target, callback, config, action, num_results = None):
- """Initialize the action."""
+ """Initialize the action.
+
+ @type caller: L{khashmir.Khashmir}
+ @param caller: the DHT instance that is performing the action
+ @type target: C{string}
+ @param target: the target of the action, usually a DHT key
+ @type callback: C{method}
+ @param callback: the method to call with the results
+ @type config: C{dictionary}
+ @param config: the configuration variables for the DHT
+ @type action: C{string}
+ @param action: the name of the action to call on remote nodes
+ @type num_results: C{int}
+ @param num_results: the minimum number of results that are needed before
+ the action should stop (optional, defaults to getting all the results)
+
+ """
+
self.caller = caller
self.target = target
self.config = config
self.callback = callback
self.outstanding = 0
self.outstanding_results = 0
- self.finished = 0
+ self.finished = False
def sort(a, b, num=self.num):
"""Sort nodes relative to the ID we are looking for."""
return -1
return 0
self.sort = sort
-
+
+ #{ Main operation
def goWithNodes(self, nodes):
"""Start the action's process with a list of nodes to contact."""
for node in nodes:
if self.desired_results and ((len(self.results) >= abs(self.desired_results)) or
(self.desired_results < 0 and
len(self.answered) >= self.config['STORE_REDUNDANCY'])):
- self.finished=1
+ self.finished = True
result = self.generateResult()
reactor.callLater(0, self.callback, *result)
len(self.results) + self.outstanding_results >= abs(self.desired_results)):
return
+ # Loop for each node that should be processed
for node in self.getNodesToProcess():
+ # Don't send requests twice or to ourself
if node.id not in self.queried and node.id != self.caller.node.id:
self.queried[node.id] = 1
# If no requests are outstanding, then we are done
if self.outstanding == 0:
- self.finished = 1
+ self.finished = True
result = self.generateResult()
reactor.callLater(0, self.callback, *result)
self.schedule()
def handleGotNodes(self, nodes):
- """Process any received node contact info in the response."""
+ """Process any received node contact info in the response.
+
+ Not called by default, but suitable for being called by
+ L{processResponse} in a recursive node search.
+ """
for compact_node in nodes:
node_contact = uncompact(compact_node)
node = self.caller.Node(node_contact)
self.sorted_nodes = self.found.values()
self.sorted_nodes.sort(self.sort)
- # The methods below are meant to be subclassed by actions
+ #{ Subclass for specific actions
def getNodesToProcess(self):
"""Generate a list of nodes to process next.
self.handleGotNodes(dict['nodes'])
def generateResult(self, nodes):
- """Create the result to return to the callback function."""
+ """Create the final result to return to the L{callback} function."""
return []
class FindValue(ActionBase):
- """Find the closest nodes to the key and check their values."""
+ """Find the closest nodes to the key and check for values."""
def __init__(self, caller, target, callback, config, action="findValue"):
ActionBase.__init__(self, caller, target, callback, config, action)
class GetValue(ActionBase):
+ """Retrieve values from a list of nodes."""
+
def __init__(self, caller, target, local_results, num_results, callback, config, action="getValue"):
+ """Initialize the action with the locally available results.
+
+ @type local_results: C{list} of C{string}
+ @param local_results: the values that were available in this node
+ """
ActionBase.__init__(self, caller, target, callback, config, action, num_results)
if local_results:
for result in local_results:
self.results[result] = 1
def getNodesToProcess(self):
- """Nodes are never added, always return the same thing."""
+ """Nodes are never added, always return the same sorted node list."""
return self.sorted_nodes
def generateArgs(self, node):
- """Args include the number of values to request."""
+ """Arguments include the number of values to request."""
if node.num_values > 0:
# Request all desired results from each node, just to be sure.
num_values = abs(self.desired_results) - len(self.results)
raise ValueError, "Don't try and get values from this node because it doesn't have any"
def processResponse(self, dict):
- """Save the returned values, calling the callback each time there are new ones."""
+ """Save the returned values, calling the L{callback} each time there are new ones."""
if dict.has_key('values'):
def x(y, z=self.results):
if not z.has_key(y):
reactor.callLater(0, self.callback, self.target, v)
def generateResult(self):
- """Results have all been returned, now send the empty list to end it."""
+ """Results have all been returned, now send the empty list to end the action."""
return (self.target, [])
class StoreValue(ActionBase):
+ """Store a value in a list of nodes."""
+
def __init__(self, caller, target, value, num_results, callback, config, action="storeValue"):
+ """Initialize the action with the value to store.
+
+ @type value: C{string}
+ @param value: the value to store in the nodes
+ """
ActionBase.__init__(self, caller, target, callback, config, action, num_results)
self.value = value
def getNodesToProcess(self):
- """Nodes are never added, always return the same thing."""
+ """Nodes are never added, always return the same sorted list."""
return self.sorted_nodes
def generateArgs(self, node):
- """Args include the value to request and the node's token."""
+ """Args include the value to store and the node's token."""
if node.token:
return (self.target, self.value, node.token), 1
else:
decode_func['u'] = decode_unicode
decode_func['t'] = decode_datetime
-def bdecode(x, sloppy = 0):
+def bdecode(x, sloppy = False):
"""Bdecode a string of data.
@type x: C{string}
+"""An sqlite database for storing nodes and key/value pairs."""
+
from datetime import datetime, timedelta
from pysqlite2 import dbapi2 as sqlite
from binascii import a2b_base64, b2a_base64
class dht_value(str):
"""Dummy class to convert all DHT values to base64 for storing in the DB."""
-
+
+# Initialize the database to work with 'khash' objects (binary strings)
sqlite.register_adapter(khash, b2a_base64)
sqlite.register_converter("KHASH", a2b_base64)
sqlite.register_converter("khash", a2b_base64)
+
+# Initialize the database to work with DHT values (binary strings)
sqlite.register_adapter(dht_value, b2a_base64)
sqlite.register_converter("DHT_VALUE", a2b_base64)
sqlite.register_converter("dht_value", a2b_base64)
class DB:
- """Database access for storing persistent data."""
+ """An sqlite database for storing persistent node info and key/value pairs.
+
+ @type db: C{string}
+ @ivar db: the database file to use
+ @type conn: L{pysqlite2.dbapi2.Connection}
+ @ivar conn: an open connection to the sqlite database
+ """
def __init__(self, db):
+ """Load or create the database file.
+
+ @type db: C{string}
+ @param db: the database file to use
+ """
self.db = db
try:
os.stat(db)
sqlite.register_converter("text", str)
else:
self.conn.text_factory = str
-
+
+ #{ Loading the DB
def _loadDB(self, db):
+ """Open a new connection to the existing database file"""
try:
self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
except:
raise DBExcept, "Couldn't open DB", traceback.format_exc()
def _createNewDB(self, db):
+ """Open a connection to a new database and create the necessary tables."""
self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
c = self.conn.cursor()
- c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
+ c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, "+
+ "PRIMARY KEY (key, value))")
c.execute("CREATE INDEX kv_key ON kv(key)")
c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
self.conn.commit()
+ def close(self):
+ self.conn.close()
+
+ #{ This node's ID
def getSelfNode(self):
+ """Retrieve this node's ID from a previous run of the program."""
c = self.conn.cursor()
c.execute('SELECT id FROM self WHERE num = 0')
id = c.fetchone()
return None
def saveSelfNode(self, id):
+ """Store this node's ID for a subsequent run of the program."""
c = self.conn.cursor()
c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
self.conn.commit()
+ #{ Routing table
def dumpRoutingTable(self, buckets):
- """
- save routing table nodes to the database
- """
+ """Save routing table nodes to the database."""
c = self.conn.cursor()
c.execute("DELETE FROM nodes WHERE id NOT NULL")
for bucket in buckets:
self.conn.commit()
def getRoutingTable(self):
- """
- load routing table nodes from database
- it's usually a good idea to call refreshTable(force=1) after loading the table
- """
+ """Load routing table nodes from database."""
c = self.conn.cursor()
c.execute("SELECT * FROM nodes")
return c.fetchall()
-
+
+ #{ Key/value pairs
def retrieveValues(self, key):
"""Retrieve values from the database."""
c = self.conn.cursor()
c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
self.conn.commit()
- def close(self):
- self.conn.close()
-
class TestDB(unittest.TestCase):
"""Tests for the khashmir database."""
## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""Functions to deal with hashes (node IDs and keys)."""
+
from sha import sha
from os import urandom
from twisted.trial import unittest
def intify(hstr):
- """20 bit hash, big-endian -> long python integer"""
+ """Convert a hash (big-endian) to a long python integer."""
assert len(hstr) == 20
return long(hstr.encode('hex'), 16)
def stringify(num):
- """long int -> 20-character string"""
+ """Convert a long python integer to a hash."""
str = hex(num)[2:]
if str[-1] == 'L':
str = str[:-1]
return (20 - len(str)) *'\x00' + str
def distance(a, b):
- """distance between two 160-bit hashes expressed as 20-character strings"""
+ """Calculate the distance between two hashes expressed as strings."""
return intify(a) ^ intify(b)
-
def newID():
- """returns a new pseudorandom globally unique ID string"""
+ """Get a new pseudorandom globally unique hash string."""
h = sha()
h.update(urandom(20))
return h.digest()
def newIDInRange(min, max):
+ """Get a new pseudorandom globally unique hash string in the range."""
return stringify(randRange(min,max))
def randRange(min, max):
+ """Get a new pseudorandom globally unique hash number in the range."""
return min + intify(newID()) % (max - min)
def newTID():
+ """Get a new pseudorandom transaction ID number."""
return randRange(-2**30, 2**30)
class TestNewID(unittest.TestCase):
+ """Test the newID function."""
def testLength(self):
self.failUnlessEqual(len(newID()), 20)
def testHundreds(self):
self.testLength
class TestIntify(unittest.TestCase):
+ """Test the intify function."""
known = [('\0' * 20, 0),
('\xff' * 20, 2L**160 - 1),
]
self.testEndianessOnce()
class TestDisantance(unittest.TestCase):
+ """Test the distance function."""
known = [
(("\0" * 20, "\xff" * 20), 2**160L -1),
((sha("foo").digest(), sha("foo").digest()), 0),
self.failUnlessEqual(distance(x,y) ^ distance(y, z), distance(x, z))
class TestRandRange(unittest.TestCase):
+ """Test the randRange function."""
def testOnce(self):
a = intify(newID())
b = intify(newID())
## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""The main Khashmir program."""
+
import warnings
warnings.simplefilter("ignore", DeprecationWarning)
from actions import FindNode, FindValue, GetValue, StoreValue
import krpc
-# this is the base class, has base functionality and find node, no key-value mappings
class KhashmirBase(protocol.Factory):
+ """The base Khashmir class, with base functionality and find node, no key-value mappings.
+
+ @type _Node: L{node.Node}
+ @ivar _Node: the knode implementation to use for this class of DHT
+ @type config: C{dictionary}
+ @ivar config: the configuration parameters for the DHT
+ @type port: C{int}
+ @ivar port: the port to listen on
+ @type store: L{db.DB}
+ @ivar store: the database to store nodes and key/value pairs in
+ @type node: L{node.Node}
+ @ivar node: this node
+ @type table: L{ktable.KTable}
+ @ivar table: the routing table
+ @type token_secrets: C{list} of C{string}
+ @ivar token_secrets: the current secrets to use to create tokens
+ @type udp: L{krpc.hostbroker}
+ @ivar udp: the factory for the KRPC protocol
+ @type listenport: L{twisted.internet.interfaces.IListeningPort}
+ @ivar listenport: the UDP listening port
+ @type next_checkpoint: L{twisted.internet.interfaces.IDelayedCall}
+ @ivar next_checkpoint: the delayed call for the next checkpoint
+ """
+
_Node = KNodeBase
+
def __init__(self, config, cache_dir='/tmp'):
+ """Initialize the Khashmir class and call the L{setup} method.
+
+ @type config: C{dictionary}
+ @param config: the configuration parameters for the DHT
+ @type cache_dir: C{string}
+ @param cache_dir: the directory to store all files in
+ (optional, defaults to the /tmp directory)
+ """
self.config = None
self.setup(config, cache_dir)
def setup(self, config, cache_dir):
+ """Setup all the Khashmir sub-modules.
+
+ @type config: C{dictionary}
+ @param config: the configuration parameters for the DHT
+ @type cache_dir: C{string}
+ @param cache_dir: the directory to store all files in
+ """
self.config = config
self.port = config['PORT']
self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
self.node = self._loadSelfNode('', self.port)
self.table = KTable(self.node, config)
self.token_secrets = [newID()]
- #self.app = service.Application("krpc")
+
+ # Start listening
self.udp = krpc.hostbroker(self, config)
self.udp.protocol = krpc.KRPC
self.listenport = reactor.listenUDP(self.port, self.udp)
+
+ # Load the routing table and begin checkpointing
self._loadRoutingTable()
- self.refreshTable(force=1)
- self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
+ self.refreshTable(force = True)
+ self.next_checkpoint = reactor.callLater(60, self.checkpoint)
def Node(self, id, host = None, port = None):
- """Create a new node."""
+ """Create a new node.
+
+ @see: L{node.Node.__init__}
+ """
n = self._Node(id, host, port)
n.table = self.table
n.conn = self.udp.connectionForAddr((n.host, n.port))
return n
def __del__(self):
+ """Stop listening for packets."""
self.listenport.stopListening()
def _loadSelfNode(self, host, port):
+ """Create this node, loading any previously saved one."""
id = self.store.getSelfNode()
if not id:
id = newID()
return self._Node(id, host, port)
- def checkpoint(self, auto=0):
+ def checkpoint(self):
+ """Perform some periodic maintenance operations."""
+ # Create a new token secret
self.token_secrets.insert(0, newID())
if len(self.token_secrets) > 3:
self.token_secrets.pop()
+
+ # Save some parameters for reloading
self.store.saveSelfNode(self.node.id)
self.store.dumpRoutingTable(self.table.buckets)
+
+ # DHT maintenance
self.store.expireValues(self.config['KEY_EXPIRE'])
self.refreshTable()
- if auto:
- self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9),
- int(self.config['CHECKPOINT_INTERVAL'] * 1.1)),
- self.checkpoint, (1,))
+
+ self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9),
+ int(self.config['CHECKPOINT_INTERVAL'] * 1.1)),
+ self.checkpoint)
def _loadRoutingTable(self):
- """
- load routing table nodes from database
- it's usually a good idea to call refreshTable(force=1) after loading the table
+ """Load the previous routing table nodes from the database.
+
+ It's usually a good idea to call refreshTable(force = True) after
+ loading the table.
"""
nodes = self.store.getRoutingTable()
for rec in nodes:
n = self.Node(rec[0], rec[1], int(rec[2]))
- self.table.insertNode(n, contacted=0)
+ self.table.insertNode(n, contacted = False)
-
- #######
- ####### LOCAL INTERFACE - use these methods!
+ #{ Local interface
def addContact(self, host, port, callback=None, errback=None):
- """
- ping this node and add the contact info to the table on pong!
+ """Ping this node and add the contact info to the table on pong.
+
+ @type host: C{string}
+ @param host: the IP address of the node to contact
+ @type port: C{int}
+ @param port:the port of the node to contact
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 1
+ parameter, the contact info returned by the node
+ (optional, defaults to doing nothing with the results)
+ @type errback: C{method}
+ @param errback: the method to call if an error occurs
+ (optional, defaults to calling the callback with None)
"""
n = self.Node(NULL_ID, host, port)
self.sendJoin(n, callback=callback, errback=errback)
- ## this call is async!
def findNode(self, id, callback, errback=None):
- """ returns the contact info for node, or the k closest nodes, from the global table """
- # get K nodes out of local table/cache, or the node we want
+ """Find the contact info for the K closest nodes in the global table.
+
+ @type id: C{string}
+ @param id: the target ID to find the K closest nodes of
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 1
+ parameter, the list of K closest nodes
+ @type errback: C{method}
+ @param errback: the method to call if an error occurs
+ (optional, defaults to doing nothing when an error occurs)
+ """
+ # Get K nodes out of local table/cache
nodes = self.table.findNodes(id)
d = Deferred()
if errback:
d.addCallbacks(callback, errback)
else:
d.addCallback(callback)
- if len(nodes) == 1 and nodes[0].id == id :
+
+ # If the target ID was found
+ if len(nodes) == 1 and nodes[0].id == id:
d.callback(nodes)
else:
- # create our search state
+ # Start the finding nodes action
state = FindNode(self, id, d.callback, self.config)
reactor.callLater(0, state.goWithNodes, nodes)
- def insertNode(self, n, contacted=1):
- """
- insert a node in our local table, pinging oldest contact in bucket, if necessary
+ def insertNode(self, node, contacted = True):
+ """Try to insert a node in our local table, pinging oldest contact if necessary.
- If all you have is a host/port, then use addContact, which calls this method after
- receiving the PONG from the remote node. The reason for the seperation is we can't insert
- a node into the table without it's peer-ID. That means of course the node passed into this
- method needs to be a properly formed Node object with a valid ID.
+ If all you have is a host/port, then use L{addContact}, which calls this
+ method after receiving the PONG from the remote node. The reason for
+ the seperation is we can't insert a node into the table without its
+ node ID. That means of course the node passed into this method needs
+ to be a properly formed Node object with a valid ID.
+
+ @type node: L{node.Node}
+ @param node: the new node to try and insert
+ @type contacted: C{boolean}
+ @param contacted: whether the new node is known to be good, i.e.
+ responded to a request (optional, defaults to True)
"""
- old = self.table.insertNode(n, contacted=contacted)
+ old = self.table.insertNode(node, contacted=contacted)
if (old and old.id != self.node.id and
(datetime.now() - old.lastSeen) >
timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
- # the bucket is full, check to see if old node is still around and if so, replace it
- ## these are the callbacks used when we ping the oldest node in a bucket
- def _staleNodeHandler(oldnode=old, newnode = n):
- """ called if the pinged node never responds """
- self.table.replaceStaleNode(old, newnode)
+ def _staleNodeHandler(oldnode = old, newnode = node):
+ """The pinged node never responded, so replace it."""
+ self.table.replaceStaleNode(oldnode, newnode)
def _notStaleNodeHandler(dict, old=old):
- """ called when we get a pong from the old node """
+ """Got a pong from the old node, so update it."""
dict = dict['rsp']
if dict['id'] == old.id:
self.table.justSeenNode(old.id)
+ # Bucket is full, check to see if old node is still available
df = old.ping(self.node.id)
df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
def sendJoin(self, node, callback=None, errback=None):
+ """Join the DHT by pinging a bootstrap node.
+
+ @type node: L{node.Node}
+ @param node: the node to send the join to
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 1
+ parameter, the contact info returned by the node
+ (optional, defaults to doing nothing with the results)
+ @type errback: C{method}
+ @param errback: the method to call if an error occurs
+ (optional, defaults to calling the callback with None)
"""
- ping a node
- """
- df = node.join(self.node.id)
- ## these are the callbacks we use when we issue a PING
+
def _pongHandler(dict, node=node, self=self, callback=callback):
+ """Node responded properly, callback with response."""
n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
self.insertNode(n)
if callback:
callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
+
def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
+ """Error occurred, fail node and errback or callback with error."""
table.nodeFailed(node)
if errback:
errback()
- else:
+ elif callback:
callback(None)
- df.addCallbacks(_pongHandler,_defaultPong)
+ df = node.join(self.node.id)
+ df.addCallbacks(_pongHandler, _defaultPong)
def findCloseNodes(self, callback=lambda a: None, errback = None):
- """
- This does a findNode on the ID one away from our own.
- This will allow us to populate our table with nodes on our network closest to our own.
- This is called as soon as we start up with an empty table
+ """Perform a findNode on the ID one away from our own.
+
+ This will allow us to populate our table with nodes on our network
+ closest to our own. This is called as soon as we start up with an
+ empty table.
+
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 1
+ parameter, the list of K closest nodes
+ (optional, defaults to doing nothing with the results)
+ @type errback: C{method}
+ @param errback: the method to call if an error occurs
+ (optional, defaults to doing nothing when an error occurs)
"""
id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
self.findNode(id, callback, errback)
- def refreshTable(self, force=0):
- """
- force=1 will refresh table regardless of last bucket access time
+ def refreshTable(self, force = False):
+ """Check all the buckets for those that need refreshing.
+
+ @param force: refresh all buckets regardless of last bucket access time
+ (optional, defaults to False)
"""
def callback(nodes):
pass
for bucket in self.table.buckets:
if force or (datetime.now() - bucket.lastAccessed >
timedelta(seconds=self.config['BUCKET_STALENESS'])):
+ # Choose a random ID in the bucket and try and find it
id = newIDInRange(bucket.min, bucket.max)
self.findNode(id, callback)
def stats(self):
- """
- Returns (num_contacts, num_nodes)
- num_contacts: number contacts in our routing table
- num_nodes: number of nodes estimated in the entire dht
+ """Collect some statistics about the DHT.
+
+ @rtype: (C{int}, C{int})
+ @return: the number contacts in our routing table, and the estimated
+ number of nodes in the entire DHT
"""
num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
pass
self.store.close()
- #### Remote Interface - called by remote nodes
+ #{ Remote interface
def krpc_ping(self, id, _krpc_sender):
+ """Pong with our ID.
+
+ @type id: C{string}
+ @param id: the node ID of the sender node
+ @type _krpc_sender: (C{string}, C{int})
+ @param _krpc_sender: the sender node's IP address and port
+ """
n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
- self.insertNode(n, contacted=0)
+ self.insertNode(n, contacted = False)
+
return {"id" : self.node.id}
def krpc_join(self, id, _krpc_sender):
+ """Add the node by responding with its address and port.
+
+ @type id: C{string}
+ @param id: the node ID of the sender node
+ @type _krpc_sender: (C{string}, C{int})
+ @param _krpc_sender: the sender node's IP address and port
+ """
n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
- self.insertNode(n, contacted=0)
+ self.insertNode(n, contacted = False)
+
return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
def krpc_find_node(self, target, id, _krpc_sender):
+ """Find the K closest nodes to the target in the local routing table.
+
+ @type target: C{string}
+ @param target: the target ID to find nodes for
+ @type id: C{string}
+ @param id: the node ID of the sender node
+ @type _krpc_sender: (C{string}, C{int})
+ @param _krpc_sender: the sender node's IP address and port
+ """
n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
- self.insertNode(n, contacted=0)
+ self.insertNode(n, contacted = False)
+
nodes = self.table.findNodes(target)
nodes = map(lambda node: node.contactInfo(), nodes)
token = sha(self.token_secrets[0] + _krpc_sender[0]).digest()
return {"nodes" : nodes, "token" : token, "id" : self.node.id}
-## This class provides read-only access to the DHT, valueForKey
-## you probably want to use this mixin and provide your own write methods
class KhashmirRead(KhashmirBase):
+ """The read-only Khashmir class, which can only retrieve (not store) key/value mappings."""
+
_Node = KNodeRead
- ## also async
+ #{ Local interface
def findValue(self, key, callback, errback=None):
- """ returns the contact info for nodes that have values for the key, from the global table """
- # get K nodes out of local table/cache
+ """Get the nodes that have values for the key from the global table.
+
+ @type key: C{string}
+ @param key: the target key to find the values for
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 1
+ parameter, the list of nodes with values
+ @type errback: C{method}
+ @param errback: the method to call if an error occurs
+ (optional, defaults to doing nothing when an error occurs)
+ """
+ # Get K nodes out of local table/cache
nodes = self.table.findNodes(key)
d = Deferred()
if errback:
else:
d.addCallback(callback)
- # create our search state
+ # Search for others starting with the locally found ones
state = FindValue(self, key, d.callback, self.config)
reactor.callLater(0, state.goWithNodes, nodes)
- def valueForKey(self, key, callback, searchlocal = 1):
- """ returns the values found for key in global table
- callback will be called with a list of values for each peer that returns unique values
- final callback will be an empty list - probably should change to 'more coming' arg
+ def valueForKey(self, key, callback, searchlocal = True):
+ """Get the values found for key in global table.
+
+ Callback will be called with a list of values for each peer that
+ returns unique values. The final callback will be an empty list.
+
+ @type key: C{string}
+ @param key: the target key to get the values for
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 2
+ parameters: the key, and the values found
+ @type searchlocal: C{boolean}
+ @param searchlocal: whether to also look for any local values
"""
- # get locals
+ # Get any local values
if searchlocal:
l = self.store.retrieveValues(key)
if len(l) > 0:
l = []
def _getValueForKey(nodes, key=key, local_values=l, response=callback, self=self):
- # create our search state
+ """Use the found nodes to send requests for values to."""
state = GetValue(self, key, local_values, self.config['RETRIEVE_VALUES'], response, self.config)
reactor.callLater(0, state.goWithNodes, nodes)
- # this call is asynch
+ # First lookup nodes that have values for the key
self.findValue(key, _getValueForKey)
- #### Remote Interface - called by remote nodes
+ #{ Remote interface
def krpc_find_value(self, key, id, _krpc_sender):
+ """Find the number of values stored locally for the key, and the K closest nodes.
+
+ @type key: C{string}
+ @param key: the target key to find the values and nodes for
+ @type id: C{string}
+ @param id: the node ID of the sender node
+ @type _krpc_sender: (C{string}, C{int})
+ @param _krpc_sender: the sender node's IP address and port
+ """
n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
- self.insertNode(n, contacted=0)
+ self.insertNode(n, contacted = False)
nodes = self.table.findNodes(key)
nodes = map(lambda node: node.contactInfo(), nodes)
return {'nodes' : nodes, 'num' : num_values, "id": self.node.id}
def krpc_get_value(self, key, num, id, _krpc_sender):
+ """Retrieve the values stored locally for the key.
+
+ @type key: C{string}
+ @param key: the target key to retrieve the values for
+ @type num: C{int}
+ @param num: the maximum number of values to retrieve, or 0 to
+ retrieve all of them
+ @type id: C{string}
+ @param id: the node ID of the sender node
+ @type _krpc_sender: (C{string}, C{int})
+ @param _krpc_sender: the sender node's IP address and port
+ """
n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
- self.insertNode(n, contacted=0)
+ self.insertNode(n, contacted = False)
l = self.store.retrieveValues(key)
if num == 0 or num >= len(l):
shuffle(l)
return {'values' : l[:num], "id": self.node.id}
-### provides a generic write method, you probably don't want to deploy something that allows
-### arbitrary value storage
+
class KhashmirWrite(KhashmirRead):
+ """The read-write Khashmir class, which can store and retrieve key/value mappings."""
+
_Node = KNodeWrite
- ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
+
+ #{ Local interface
def storeValueForKey(self, key, value, callback=None):
- """ stores the value and origination time for key in the global table, returns immediately, no status
- in this implementation, peers respond but don't indicate status to storing values
- a key can have many values
+ """Stores the value for the key in the global table.
+
+ No status in this implementation, peers respond but don't indicate
+ status of storing values.
+
+ @type key: C{string}
+ @param key: the target key to store the value for
+ @type value: C{string}
+ @param value: the value to store with the key
+ @type callback: C{method}
+ @param callback: the method to call with the results, it must take 3
+ parameters: the key, the value stored, and the result of the store
+ (optional, defaults to doing nothing with the results)
"""
def _storeValueForKey(nodes, key=key, value=value, response=callback, self=self):
+ """Use the returned K closest nodes to store the key at."""
if not response:
- # default callback
def _storedValueHandler(key, value, sender):
+ """Default callback that does nothing."""
pass
- response=_storedValueHandler
+ response = _storedValueHandler
action = StoreValue(self, key, value, self.config['STORE_REDUNDANCY'], response, self.config)
reactor.callLater(0, action.goWithNodes, nodes)
- # this call is asynch
+ # First find the K closest nodes to operate on.
self.findNode(key, _storeValueForKey)
- #### Remote Interface - called by remote nodes
+ #{ Remote interface
def krpc_store_value(self, key, value, token, id, _krpc_sender):
+ """Store the value locally with the key.
+
+ @type key: C{string}
+ @param key: the target key to store the value for
+ @type value: C{string}
+ @param value: the value to store with the key
+ @param token: the token to confirm that this peer contacted us previously
+ @type id: C{string}
+ @param id: the node ID of the sender node
+ @type _krpc_sender: (C{string}, C{int})
+ @param _krpc_sender: the sender node's IP address and port
+ """
n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
- self.insertNode(n, contacted=0)
+ self.insertNode(n, contacted = False)
for secret in self.token_secrets:
this_token = sha(secret + _krpc_sender[0]).digest()
if token == this_token:
return {"id" : self.node.id}
raise krpc.KrpcError, (krpc.KRPC_ERROR_INVALID_TOKEN, 'token is invalid, do a find_nodes to get a fresh one')
-# the whole shebang, for testing
+
class Khashmir(KhashmirWrite):
+ """The default Khashmir class (currently the read-write L{KhashmirWrite})."""
_Node = KNodeWrite
+
class SimpleTests(unittest.TestCase):
timeout = 10
'KEY_EXPIRE': 3600, 'SPEW': False, }
def setUp(self):
- krpc.KRPC.noisy = 0
d = self.DHT_DEFAULTS.copy()
d['PORT'] = 4044
self.a = Khashmir(d)
## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""Represents a khashmir node in the DHT."""
+
from twisted.python import log
from node import Node, NULL_ID
class KNodeBase(Node):
+ """A basic node that can only be pinged and help find other nodes."""
+
def checkSender(self, dict):
+ """Check the sender's info to make sure it meets expectations."""
try:
senderid = dict['rsp']['id']
except KeyError:
return dict
def errBack(self, err):
+ """Log an error that has occurred."""
log.err(err)
return err
def ping(self, id):
+ """Ping the node."""
df = self.conn.sendRequest('ping', {"id":id})
df.addErrback(self.errBack)
df.addCallback(self.checkSender)
return df
def join(self, id):
+ """Use the node to bootstrap into the system."""
df = self.conn.sendRequest('join', {"id":id})
df.addErrback(self.errBack)
df.addCallback(self.checkSender)
return df
def findNode(self, id, target):
+ """Request the nearest nodes to the target that the node knows about."""
df = self.conn.sendRequest('find_node', {"target" : target, "id": id})
df.addErrback(self.errBack)
df.addCallback(self.checkSender)
return df
class KNodeRead(KNodeBase):
+ """More advanced node that can also find and send values."""
+
def findValue(self, id, key):
+ """Request the nearest nodes to the key that the node knows about."""
df = self.conn.sendRequest('find_value', {"key" : key, "id" : id})
df.addErrback(self.errBack)
df.addCallback(self.checkSender)
return df
def getValue(self, id, key, num):
+ """Request the values that the node has for the key."""
df = self.conn.sendRequest('get_value', {"key" : key, "num": num, "id" : id})
df.addErrback(self.errBack)
df.addCallback(self.checkSender)
return df
class KNodeWrite(KNodeRead):
+ """Most advanced node that can also store values."""
+
def storeValue(self, id, key, value, token):
+ """Store a value in the node."""
df = self.conn.sendRequest('store_value', {"key" : key, "value" : value, "token" : token, "id": id})
df.addErrback(self.errBack)
df.addCallback(self.checkSender)
## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""The KRPC communication protocol implementation.
+
+@var KRPC_TIMEOUT: the number of seconds after which requests timeout
+@var UDP_PACKET_LIMIT: the maximum number of bytes that can be sent in a
+ UDP packet without fragmentation
+
+@var KRPC_ERROR: the code for a generic error
+@var KRPC_ERROR_SERVER_ERROR: the code for a server error
+@var KRPC_ERROR_MALFORMED_PACKET: the code for a malformed packet error
+@var KRPC_ERROR_METHOD_UNKNOWN: the code for a method unknown error
+@var KRPC_ERROR_MALFORMED_REQUEST: the code for a malformed request error
+@var KRPC_ERROR_INVALID_TOKEN: the code for an invalid token error
+@var KRPC_ERROR_RESPONSE_TOO_LONG: the code for a response too long error
+
+@var KRPC_ERROR_INTERNAL: the code for an internal error
+@var KRPC_ERROR_RECEIVED_UNKNOWN: the code for an unknown message type error
+@var KRPC_ERROR_TIMEOUT: the code for a timeout error
+@var KRPC_ERROR_PROTOCOL_STOPPED: the code for a stopped protocol error
+
+@var TID: the identifier for the transaction ID
+@var REQ: the identifier for a request packet
+@var RSP: the identifier for a response packet
+@var TYP: the identifier for the type of packet
+@var ARG: the identifier for the argument to the request
+@var ERR: the identifier for an error packet
+
+@group Remote node error codes: KRPC_ERROR, KRPC_ERROR_SERVER_ERROR,
+ KRPC_ERROR_MALFORMED_PACKET, KRPC_ERROR_METHOD_UNKNOWN,
+ KRPC_ERROR_MALFORMED_REQUEST, KRPC_ERROR_INVALID_TOKEN,
+ KRPC_ERROR_RESPONSE_TOO_LONG
+@group Local node error codes: KRPC_ERROR_INTERNAL, KRPC_ERROR_RECEIVED_UNKNOWN,
+ KRPC_ERROR_TIMEOUT, KRPC_ERROR_PROTOCOL_STOPPED
+@group Command identifiers: TID, REQ, RSP, TYP, ARG, ERR
+
+"""
+
from bencode import bencode, bdecode
from time import asctime
from math import ceil
ERR = 'e'
class KrpcError(Exception):
+ """An error occurred in the KRPC protocol."""
pass
def verifyMessage(msg):
if type(msg[TID]) != str:
raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "transaction id is not a string")
-class hostbroker(protocol.DatagramProtocol):
+class hostbroker(protocol.DatagramProtocol):
+ """The factory for the KRPC protocol.
+
+ @type server: L{khashmir.Khashmir}
+ @ivar server: the main Khashmir program
+ @type config: C{dictionary}
+ @ivar config: the configuration parameters for the DHT
+ @type connections: C{dictionary}
+ @ivar connections: all the connections that have ever been made to the
+ protocol, keys are IP address and port pairs, values are L{KRPC}
+ protocols for the addresses
+ @ivar protocol: the protocol to use to handle incoming connections
+ (added externally)
+ @type addr: (C{string}, C{int})
+ @ivar addr: the IP address and port of this node
+ """
+
def __init__(self, server, config):
+ """Initialize the factory.
+
+ @type server: L{khashmir.Khashmir}
+ @param server: the main DHT program
+ @type config: C{dictionary}
+ @param config: the configuration parameters for the DHT
+ """
self.server = server
self.config = config
# this should be changed to storage that drops old entries
self.connections = {}
def datagramReceived(self, datagram, addr):
- #print `addr`, `datagram`
- #if addr != self.addr:
+ """Optionally create a new protocol object, and handle the new datagram.
+
+ @type datagram: C{string}
+ @param datagram: the data received from the transport.
+ @type addr: (C{string}, C{int})
+ @param addr: source IP address and port of datagram.
+ """
c = self.connectionForAddr(addr)
c.datagramReceived(datagram, addr)
#if c.idle():
# del self.connections[addr]
def connectionForAddr(self, addr):
+ """Get a protocol object for the source.
+
+ @type addr: (C{string}, C{int})
+ @param addr: source IP address and port of datagram.
+ """
+ # Don't connect to ourself
if addr == self.addr:
- raise Exception
+ raise KrcpError
+
+ # Create a new protocol object if necessary
if not self.connections.has_key(addr):
conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
self.connections[addr] = conn
return conn
def makeConnection(self, transport):
+ """Make a connection to a transport and save our address."""
protocol.DatagramProtocol.makeConnection(self, transport)
tup = transport.getHost()
self.addr = (tup.host, tup.port)
def stopProtocol(self):
+ """Stop all the open connections."""
for conn in self.connections.values():
conn.stop()
protocol.DatagramProtocol.stopProtocol(self)
-## connection
class KRPC:
+ """The KRPC protocol implementation.
+
+ @ivar transport: the transport to use for the protocol
+ @type factory: L{khashmir.Khashmir}
+ @ivar factory: the main Khashmir program
+ @type addr: (C{string}, C{int})
+ @ivar addr: the IP address and port of the source node
+ @type noisy: C{boolean}
+ @ivar noisy: whether to log additional details of the protocol
+ @type tids: C{dictionary}
+ @ivar tids: the transaction IDs outstanding for requests, keys are the
+ transaction ID of the request, values are the deferreds to call with
+ the results
+ @type stopped: C{boolean}
+ @ivar stopped: whether the protocol has been stopped
+ """
+
def __init__(self, addr, server, transport, spew = False):
+ """Initialize the protocol.
+
+ @type addr: (C{string}, C{int})
+ @param addr: the IP address and port of the source node
+ @type server: L{khashmir.Khashmir}
+ @param server: the main Khashmir program
+ @param transport: the transport to use for the protocol
+ @type spew: C{boolean}
+ @param spew: whether to log additional details of the protocol
+ (optional, defaults to False)
+ """
self.transport = transport
self.factory = server
self.addr = addr
self.stopped = False
def datagramReceived(self, data, addr):
+ """Process the new datagram.
+
+ @type data: C{string}
+ @param data: the data received from the transport.
+ @type addr: (C{string}, C{int})
+ @param addr: source IP address and port of datagram.
+ """
if self.stopped:
if self.noisy:
log.msg("stopped, dropping message from %r: %s" % (addr, data))
- # bdecode
+
+ # Bdecode the message
try:
msg = bdecode(data)
except Exception, e:
log.err(e)
return
+ # Make sure the remote node isn't trying anything funny
try:
verifyMessage(msg)
except Exception, e:
if self.noisy:
log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
- # look at msg type
+
+ # Process it based on its type
if msg[TYP] == REQ:
ilen = len(data)
- # if request
- # tell factory to handle
+
+ # Requests are handled by the factory
f = getattr(self.factory ,"krpc_" + msg[REQ], None)
msg[ARG]['_krpc_sender'] = self.addr
if f and callable(f):
else:
olen = self._sendResponse(addr, msg[TID], RSP, ret)
else:
- # unknown method
+ # Request for unknown method
log.msg("ERROR: don't know about method %s" % msg[REQ])
olen = self._sendResponse(addr, msg[TID], ERR,
[KRPC_ERROR_METHOD_UNKNOWN, "unknown method "+str(msg[REQ])])
log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
ilen, msg[REQ], olen))
elif msg[TYP] == RSP:
- # if response
- # lookup tid
+ # Responses get processed by their TID's deferred
if self.tids.has_key(msg[TID]):
df = self.tids[msg[TID]]
# callback
if self.noisy:
log.msg('timeout: %r' % msg[RSP]['id'])
elif msg[TYP] == ERR:
- # if error
- # lookup tid
+ # Errors get processed by their TID's deferred's errback
if self.tids.has_key(msg[TID]):
df = self.tids[msg[TID]]
del(self.tids[msg[TID]])
log.msg("Got an error for an unknown request: %r" % (msg[ERR], ))
pass
else:
+ # Received an unknown message type
if self.noisy:
log.msg("unknown message type: %r" % msg)
- # unknown message type
if msg[TID] in self.tids:
df = self.tids[msg[TID]]
del(self.tids[msg[TID]])
"Received an unknown message type: %r" % msg[TYP]))
def _sendResponse(self, addr, tid, msgType, response):
+ """Helper function for sending responses to nodes.
+
+ @type addr: (C{string}, C{int})
+ @param addr: source IP address and port of datagram.
+ @param tid: the transaction ID of the request
+ @param msgType: the type of message to respond with
+ @param response: the arguments for the response
+ """
if not response:
response = {}
try:
+ # Create the response message
msg = {TID : tid, TYP : msgType, msgType : response}
if self.noisy:
out = bencode(msg)
+ # Make sure its not too long
if len(out) > UDP_PACKET_LIMIT:
+ # Can we remove some values to shorten it?
if 'values' in response:
# Save the original list of values
orig_values = response['values']
return len(out)
def sendRequest(self, method, args):
+ """Send a request to the remote node.
+
+ @type method: C{string}
+ @param method: the methiod name to call on the remote node
+ @param args: the arguments to send to the remote node's method
+ """
if self.stopped:
raise KrpcError, (KRPC_ERROR_PROTOCOL_STOPPED, "cannot send, connection has been stopped")
- # make message
- # send it
+
+ # Create the request message
msg = {TID : newID(), TYP : REQ, REQ : method, ARG : args}
if self.noisy:
log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
data = bencode(msg)
+
+ # Create the deferred and save it with the TID
d = Deferred()
self.tids[msg[TID]] = d
+
+ # Schedule a later timeout call
def timeOut(tids = self.tids, id = msg[TID], method = method, addr = self.addr):
+ """Call the deferred's errback if a timeout occurs."""
if tids.has_key(id):
df = tids[id]
del(tids[id])
df.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % (method, addr)))
later = reactor.callLater(KRPC_TIMEOUT, timeOut)
+
+ # Cancel the timeout call if a response is received
def dropTimeOut(dict, later_call = later):
+ """Cancel the timeout call when a response is received."""
if later_call.active():
later_call.cancel()
return dict
d.addBoth(dropTimeOut)
+
self.transport.write(data, self.addr)
return d
df.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED, 'connection has been stopped while waiting for response'))
self.tids = {}
self.stopped = True
-
+
+#{ For testing the KRPC protocol
def connectionForAddr(host, port):
return host
def gotLongRsp(self, dict):
# Not quite accurate, but good enough
self.failUnless(len(bencode(dict))-10 < UDP_PACKET_LIMIT)
-
\ No newline at end of file
## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""The routing table and buckets for a kademlia-like DHT."""
+
from datetime import datetime
from bisect import bisect_left
from node import Node, NULL_ID
class KTable:
- """local routing table for a kademlia like distributed hash table"""
+ """Local routing table for a kademlia-like distributed hash table.
+
+ @type node: L{node.Node}
+ @ivar node: the local node
+ @type config: C{dictionary}
+ @ivar config: the configuration parameters for the DHT
+ @type buckets: C{list} of L{KBucket}
+ @ivar buckets: the buckets of nodes in the routing table
+ """
+
def __init__(self, node, config):
+ """Initialize the first empty bucket of everything.
+
+ @type node: L{node.Node}
+ @param node: the local node
+ @type config: C{dictionary}
+ @param config: the configuration parameters for the DHT
+ """
# this is the root node, a.k.a. US!
assert node.id != NULL_ID
self.node = node
self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
def _bucketIndexForInt(self, num):
- """the index of the bucket that should hold int"""
+ """Find the index of the bucket that should hold the node's ID number."""
return bisect_left(self.buckets, num)
def findNodes(self, id):
+ """Find the K nodes in our own local table closest to the ID.
+
+ @type id: C{string} of C{int} or L{node.Node}
+ @param id: the ID to find nodes that are close to
+ @raise TypeError: if id does not properly identify an ID
"""
- return K nodes in our own local table closest to the ID.
- """
-
+
+ # Get the ID number from the input
if isinstance(id, str):
num = khash.intify(id)
elif isinstance(id, Node):
nodes = []
i = self._bucketIndexForInt(num)
- # if this node is already in our table then return it
+ # If this node is already in our table then return it
try:
index = self.buckets[i].l.index(num)
except ValueError:
else:
return [self.buckets[i].l[index]]
- # don't have the node, get the K closest nodes
+ # Don't have the node, get the K closest nodes from the appropriate bucket
nodes = nodes + self.buckets[i].l
+
+ # Make sure we have enough
if len(nodes) < self.config['K']:
- # need more nodes
+ # Look in adjoining buckets for nodes
min = i - 1
max = i + 1
while len(nodes) < self.config['K'] and (min >= 0 or max < len(self.buckets)):
- #ASw: note that this requires K be even
+ # Add the adjoining buckets' nodes to the list
if min >= 0:
nodes = nodes + self.buckets[min].l
if max < len(self.buckets):
min = min - 1
max = max + 1
+ # Sort the found nodes by proximity to the id and return the closest K
nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
return nodes[:self.config['K']]
def _splitBucket(self, a):
+ """Split a bucket in two.
+
+ @type a: L{KBucket}
+ @param a: the bucket to split
+ """
+ # Create a new bucket with half the (upper) range of the current bucket
diff = (a.max - a.min) / 2
b = KBucket([], a.max - diff, a.max)
self.buckets.insert(self.buckets.index(a.min) + 1, b)
+
+ # Reduce the input bucket's (upper) range
a.max = a.max - diff
- # transfer nodes to new bucket
+
+ # Transfer nodes to the new bucket
for anode in a.l[:]:
if anode.num >= a.max:
a.l.remove(anode)
b.l.append(anode)
- def replaceStaleNode(self, stale, new):
- """this is used by clients to replace a node returned by insertNode after
- it fails to respond to a Pong message"""
+ def replaceStaleNode(self, stale, new = None):
+ """Replace a stale node in a bucket with a new one.
+
+ This is used by clients to replace a node returned by insertNode after
+ it fails to respond to a ping.
+
+ @type stale: L{node.Node}
+ @param stale: the stale node to remove from the bucket
+ @type new: L{node.Node}
+ @param new: the new node to add in it's place (optional, defaults to
+ not adding any node in the old node's place)
+ """
+ # Find the stale node's bucket
i = self._bucketIndexForInt(stale.num)
try:
it = self.buckets[i].l.index(stale.num)
except ValueError:
return
+ # Remove the stale node and insert the new one
del(self.buckets[i].l[it])
if new:
self.buckets[i].l.append(new)
- def insertNode(self, node, contacted=1):
- """
- this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
- the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
- contacted means that yes, we contacted THEM and we know the node is reachable
+ def insertNode(self, node, contacted = True):
+ """Try to insert a node in the routing table.
+
+ This inserts the node, returning None if successful, otherwise returns
+ the oldest node in the bucket if it's full. The caller is then
+ responsible for pinging the returned node and calling replaceStaleNode
+ if it doesn't respond. contacted means that yes, we contacted THEM and
+ we know the node is reachable.
+
+ @type node: L{node.Node}
+ @param node: the new node to try and insert
+ @type contacted: C{boolean}
+ @param contacted: whether the new node is known to be good, i.e.
+ responded to a request (optional, defaults to True)
+ @rtype: L{node.Node}
+ @return: None if successful (the bucket wasn't full), otherwise returns the oldest node in the bucket
"""
assert node.id != NULL_ID
if node.id == self.node.id: return
- # get the bucket for this node
+
+ # Get the bucket for this node
i = self. _bucketIndexForInt(node.num)
- # check to see if node is in the bucket already
+
+ # Check to see if node is in the bucket already
try:
it = self.buckets[i].l.index(node.num)
except ValueError:
- # no
pass
else:
+ # The node is already in the bucket
if contacted:
+ # It responded, so update it
node.updateLastSeen()
# move node to end of bucket
xnode = self.buckets[i].l[it]
self.buckets[i].touch()
return
- # we don't have this node, check to see if the bucket is full
+ # We don't have this node, check to see if the bucket is full
if len(self.buckets[i].l) < self.config['K']:
- # no, append this node and return
+ # Not full, append this node and return
if contacted:
node.updateLastSeen()
self.buckets[i].l.append(node)
self.buckets[i].touch()
return
- # bucket is full, check to see if self.node is in the bucket
+ # Bucket is full, check to see if the local node is not in the bucket
if not (self.buckets[i].min <= self.node < self.buckets[i].max):
+ # Local node not in the bucket, can't split it, return the oldest node
return self.buckets[i].l[0]
- # this bucket is full and contains our node, split the bucket
+ # Make sure our table isn't FULL, this is really unlikely
if len(self.buckets) >= self.config['HASH_LENGTH']:
- # our table is FULL, this is really unlikely
log.err("Hash Table is FULL! Increase K!")
return
+ # This bucket is full and contains our node, split the bucket
self._splitBucket(self.buckets[i])
- # now that the bucket is split and balanced, try to insert the node again
+ # Now that the bucket is split and balanced, try to insert the node again
return self.insertNode(node)
def justSeenNode(self, id):
- """call this any time you get a message from a node
- it will update it in the table if it's there """
+ """Mark a node as just having been seen.
+
+ Call this any time you get a message from a node, it will update it
+ in the table if it's there.
+
+ @type id: C{string} of C{int} or L{node.Node}
+ @param id: the node ID to mark as just having been seen
+ @rtype: C{datetime.datetime}
+ @return: the old lastSeen time of the node, or None if it's not in the table
+ """
try:
n = self.findNodes(id)[0]
except IndexError:
return tstamp
def invalidateNode(self, n):
+ """Remove the node from the routing table.
+
+ Forget about node n. Use this when you know that a node is invalid.
"""
- forget about node n - use when you know that node is invalid
- """
- self.replaceStaleNode(n, None)
+ self.replaceStaleNode(n)
def nodeFailed(self, node):
- """ call this when a node fails to respond to a message, to invalidate that node """
+ """Mark a node as having failed once, and remove it if it has failed too much."""
try:
n = self.findNodes(node.num)[0]
except IndexError:
self.invalidateNode(n)
class KBucket:
+ """Single bucket of nodes in a kademlia-like routing table.
+
+ @type l: C{list} of L{node.Node}
+ @ivar l: the nodes that are in this bucket
+ @type min: C{long}
+ @ivar min: the minimum node ID that can be in this bucket
+ @type max: C{long}
+ @ivar max: the maximum node ID that can be in this bucket
+ @type lastAccessed: C{datetime.datetime}
+ @ivar lastAccessed: the last time a node in this bucket was successfully contacted
+ """
+
def __init__(self, contents, min, max):
+ """Initialize the bucket with nodes.
+
+ @type contents: C{list} of L{node.Node}
+ @param contents: the nodes to store in the bucket
+ @type min: C{long}
+ @param min: the minimum node ID that can be in this bucket
+ @type max: C{long}
+ @param max: the maximum node ID that can be in this bucket
+ """
self.l = contents
self.min = min
self.max = max
self.lastAccessed = datetime.now()
def touch(self):
+ """Update the L{lastAccessed} time."""
self.lastAccessed = datetime.now()
def getNodeWithInt(self, num):
+ """Get the node in the bucket with that number.
+
+ @type num: C{long}
+ @param num: the node ID to look for
+ @raise ValueError: if the node ID is not in the bucket
+ @rtype: L{node.Node}
+ @return: the node
+ """
if num in self.l: return num
else: raise ValueError
def __repr__(self):
return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
- ## Comparators
- # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
- # compares integer or node object with the bucket's range
+ #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
def __lt__(self, a):
if isinstance(a, Node): a = a.num
return self.max <= a
return self.min >= a or self.max < a
class TestKTable(unittest.TestCase):
+ """Unit tests for the routing table."""
+
def setUp(self):
self.a = Node(khash.newID(), '127.0.0.1', 2002)
self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""Represents a node in the DHT.
+
+@type NULL_ID: C{string}
+@var NULL_ID: the node ID to use until one is known
+"""
+
from datetime import datetime, MINYEAR
from types import InstanceType
NULL_ID = 20 * '\0'
class Node:
- """encapsulate contact info"""
+ """Encapsulate a node's contact info.
+
+ @ivar conn: the connection to the remote node (added externally)
+ @ivar table: the routing table (added externally)
+ @type fails: C{int}
+ @ivar fails: number of times this node has failed in a row
+ @type lastSeen: C{datetime.datetime}
+ @ivar lastSeen: the last time a response was received from this node
+ @type id: C{string}
+ @ivar id: the node's ID in the DHT
+ @type num: C{long}
+ @ivar num: the node's ID in number form
+ @type host: C{string}
+ @ivar host: the IP address of the node
+ @type port: C{int}
+ @ivar port: the port of the node
+ @type token: C{string}
+ @ivar token: the last received token from the node
+ @type num_values: C{int}
+ @ivar num_values: the number of values the node has for the key in the
+ currently executing action
+ """
+
def __init__(self, id, host = None, port = None):
+ """Initialize the node.
+
+ @type id: C{string} or C{dictionary}
+ @param id: the node's ID in the DHT, or a dictionary containing the
+ node's id, host and port
+ @type host: C{string}
+ @param host: the IP address of the node
+ (optional, but must be specified if id is not a dictionary)
+ @type port: C{int}
+ @param port: the port of the node
+ (optional, but must be specified if id is not a dictionary)
+ """
self.fails = 0
self.lastSeen = datetime(MINYEAR, 1, 1)
self._contactInfo = None
def updateLastSeen(self):
+ """Updates the last contact time of the node and resets the number of failures."""
self.lastSeen = datetime.now()
self.fails = 0
def updateToken(self, token):
+ """Update the token for the node."""
self.token = token
def updateNumValues(self, num_values):
+ """Update how many values the node has in the current search for a value."""
self.num_values = num_values
def msgFailed(self):
+ """Log a failed attempt to contact this node.
+
+ @rtype: C{int}
+ @return: the number of consecutive failures this node has
+ """
self.fails = self.fails + 1
return self.fails
def contactInfo(self):
+ """Get the compact contact info for the node."""
if self._contactInfo is None:
self._contactInfo = compact(self.id, self.host, self.port)
return self._contactInfo
def __repr__(self):
return `(self.id, self.host, self.port)`
- ## these comparators let us bisect/index a list full of nodes with either a node or an int/long
+ #{ Comparators to bisect/index a list of nodes with either a node or a long
def __lt__(self, a):
if type(a) == InstanceType:
a = a.num
class TestNode(unittest.TestCase):
+ """Unit tests for the node implementation."""
def setUp(self):
self.node = Node(khash.newID(), '127.0.0.1', 2002)
def testUpdateLastSeen(self):
## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
# see LICENSE.txt for license information
+"""Some utitlity functions for use in apt-dht's khashmir DHT."""
+
from twisted.trial import unittest
def bucket_stats(l):
- """given a list of khashmir instances, finds min, max, and average number of nodes in tables"""
+ """Given a list of khashmir instances, finds min, max, and average number of nodes in tables."""
max = avg = 0
min = None
def count(buckets):