+"""The main interface to the Khashmir DHT.
+
+@var khashmir_dir: the name of the directory to use for DHT files
+"""
+
+from datetime import datetime
import os, sha, random
from twisted.internet import defer, reactor
from twisted.internet.abstract import isIPAddress
+from twisted.python import log
from twisted.trial import unittest
from zope.interface import implements
from apt_dht.interfaces import IDHT
from khashmir import Khashmir
+from bencode import bencode, bdecode
+
+khashmir_dir = 'apt-dht-Khashmir'
class DHTError(Exception):
"""Represents errors that occur in the DHT."""
class DHT:
+ """The main interface instance to the Khashmir DHT.
+
+ @type config: C{dictionary}
+ @ivar config: the DHT configuration values
+ @type cache_dir: C{string}
+ @ivar cache_dir: the directory to use for storing files
+ @type bootstrap: C{list} of C{string}
+ @ivar bootstrap: the nodes to contact to bootstrap into the system
+ @type bootstrap_node: C{boolean}
+ @ivar bootstrap_node: whether this node is a bootstrap node
+ @type joining: L{twisted.internet.defer.Deferred}
+ @ivar joining: if a join is underway, the deferred that will signal it's end
+ @type joined: C{boolean}
+ @ivar joined: whether the DHT network has been successfully joined
+ @type outstandingJoins: C{int}
+ @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
+ @type foundAddrs: C{list} of (C{string}, C{int})
+ @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
+ @type storing: C{dictionary}
+ @ivar storing: keys are keys for which store requests are active, values
+ are dictionaries with keys the values being stored and values the
+ deferred to call when complete
+ @type retrieving: C{dictionary}
+ @ivar retrieving: keys are the keys for which getValue requests are active,
+ values are lists of the deferreds waiting for the requests
+ @type retrieved: C{dictionary}
+ @ivar retrieved: keys are the keys for which getValue requests are active,
+ values are list of the values returned so far
+ @type config_parser: L{apt_dht.apt_dht_conf.AptDHTConfigParser}
+ @ivar config_parser: the configuration info for the main program
+ @type section: C{string}
+ @ivar section: the section of the configuration info that applies to the DHT
+ @type khashmir: L{khashmir.Khashmir}
+ @ivar khashmir: the khashmir DHT instance to use
+ """
implements(IDHT)
def __init__(self):
+ """Initialize the DHT."""
self.config = None
self.cache_dir = ''
self.bootstrap = []
self.bootstrap_node = False
self.joining = None
self.joined = False
+ self.outstandingJoins = 0
+ self.foundAddrs = []
self.storing = {}
self.retrieving = {}
self.retrieved = {}
self.config_parser = config
self.section = section
self.config = {}
- self.cache_dir = self.config_parser.get('DEFAULT', 'cache_dir')
+
+ # Get some initial values
+ self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
+ if not os.path.exists(self.cache_dir):
+ os.makedirs(self.cache_dir)
self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
for k in self.config_parser.options(section):
+ # The numbers in the config file
if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY',
- 'MAX_FAILURES', 'PORT']:
+ 'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
self.config[k] = self.config_parser.getint(section, k)
+ # The times in the config file
elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL',
- 'BUCKET_STALENESS', 'KEINITIAL_DELAY', 'KE_DELAY', 'KE_AGE']:
+ 'BUCKET_STALENESS', 'KEY_EXPIRE']:
self.config[k] = self.config_parser.gettime(section, k)
+ # The booleans in the config file
+ elif k in ['SPEW']:
+ self.config[k] = self.config_parser.getboolean(section, k)
+ # Everything else is a string
else:
self.config[k] = self.config_parser.get(section, k)
- if 'PORT' not in self.config:
- self.config['PORT'] = self.config_parser.getint('DEFAULT', 'PORT')
def join(self):
"""See L{apt_dht.interfaces.IDHT}."""
if self.joining:
raise DHTError, "a join is already in progress"
+ # Create the new khashmir instance
self.khashmir = Khashmir(self.config, self.cache_dir)
self.joining = defer.Deferred()
for node in self.bootstrap:
host, port = node.rsplit(':', 1)
port = int(port)
+
+ # Translate host names into IP addresses
if isIPAddress(host):
self._join_gotIP(host, port)
else:
return self.joining
def _join_gotIP(self, ip, port):
- """Called after an IP address has been found for a single bootstrap node."""
- self.khashmir.addContact(ip, port, self._join_single)
+ """Join the DHT using a single bootstrap nodes IP address."""
+ self.outstandingJoins += 1
+ self.khashmir.addContact(ip, port, self._join_single, self._join_error)
- def _join_single(self):
- """Called when a single bootstrap node has been added."""
- self.khashmir.findCloseNodes(self._join_complete)
+ def _join_single(self, addr):
+ """Process the response from the bootstrap node.
+
+ Finish the join by contacting close nodes.
+ """
+ self.outstandingJoins -= 1
+ if addr:
+ self.foundAddrs.append(addr)
+ if addr or self.outstandingJoins <= 0:
+ self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
+ log.msg('Got back from bootstrap node: %r' % (addr,))
+ def _join_error(self, failure = None):
+ """Process an error in contacting the bootstrap node.
+
+ If no bootstrap nodes remain, finish the process by contacting
+ close nodes.
+ """
+ self.outstandingJoins -= 1
+ log.msg("bootstrap node could not be reached")
+ if self.outstandingJoins <= 0:
+ self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
+
def _join_complete(self, result):
- """Called when the tables have been initialized with nodes."""
- if not self.joined:
+ """End the joining process and return the addresses found for this node."""
+ if not self.joined and len(result) > 0:
self.joined = True
- if len(result) > 0 or self.bootstrap_node:
- df = self.joining
- self.joining = None
- df.callback(result)
+ if self.joining and self.outstandingJoins <= 0:
+ df = self.joining
+ self.joining = None
+ if self.joined or self.bootstrap_node:
+ self.joined = True
+ df.callback(self.foundAddrs)
else:
- df = self.joining
- self.joining = None
df.errback(DHTError('could not find any nodes to bootstrap to'))
+ def getAddrs(self):
+ """Get the list of addresses returned by bootstrap nodes for this node."""
+ return self.foundAddrs
+
def leave(self):
"""See L{apt_dht.interfaces.IDHT}."""
if self.config is None:
self.joined = False
self.khashmir.shutdown()
+ def _normKey(self, key, bits=None, bytes=None):
+ """Normalize the length of keys used in the DHT."""
+ bits = self.config["HASH_LENGTH"]
+ if bits is not None:
+ bytes = (bits - 1) // 8 + 1
+ else:
+ if bytes is None:
+ raise DHTError, "you must specify one of bits or bytes for normalization"
+
+ # Extend short keys with null bytes
+ if len(key) < bytes:
+ key = key + '\000'*(bytes - len(key))
+ # Truncate long keys
+ elif len(key) > bytes:
+ key = key[:bytes]
+ return key
+
def getValue(self, key):
"""See L{apt_dht.interfaces.IDHT}."""
if self.config is None:
raise DHTError, "configuration not loaded"
if not self.joined:
raise DHTError, "have not joined a network yet"
+
+ key = self._normKey(key)
d = defer.Deferred()
if key not in self.retrieving:
return d
def _getValue(self, key, result):
+ """Process a returned list of values from the DHT."""
+ # Save the list of values to return when it is complete
if result:
- self.retrieved.setdefault(key, []).extend(result)
+ self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
else:
+ # Empty list, the get is complete, return the result
final_result = []
if key in self.retrieved:
final_result = self.retrieved[key]
if not self.joined:
raise DHTError, "have not joined a network yet"
- if key in self.storing and value in self.storing[key]:
+ key = self._normKey(key)
+ bvalue = bencode(value)
+
+ if key in self.storing and bvalue in self.storing[key]:
raise DHTError, "already storing that key with the same value"
d = defer.Deferred()
- self.khashmir.storeValueForKey(key, value, self._storeValue)
- self.storing.setdefault(key, {})[value] = d
+ self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
+ self.storing.setdefault(key, {})[bvalue] = d
return d
- def _storeValue(self, key, value, result):
- if key in self.storing and value in self.storing[key]:
+ def _storeValue(self, key, bvalue, result):
+ """Process the response from the DHT."""
+ if key in self.storing and bvalue in self.storing[key]:
+ # Check if the store succeeded
if len(result) > 0:
- self.storing[key][value].callback(result)
+ self.storing[key][bvalue].callback(result)
else:
- self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key)))
- del self.storing[key][value]
+ self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
+ del self.storing[key][bvalue]
if len(self.storing[key].keys()) == 0:
del self.storing[key]
class TestSimpleDHT(unittest.TestCase):
- """Unit tests for the DHT."""
+ """Simple 2-node unit tests for the DHT."""
timeout = 2
DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
- 'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
- 'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
+ 'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
+ 'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
+ 'MAX_FAILURES': 3,
'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
- 'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
- 'KE_AGE': 3600, }
+ 'KEY_EXPIRE': 3600, 'SPEW': False, }
def setUp(self):
self.a = DHT()
d.addCallback(self.lastDefer.callback)
return self.lastDefer
+ def test_normKey(self):
+ h = self.a._normKey('12345678901234567890')
+ self.failUnless(h == '12345678901234567890')
+ h = self.a._normKey('12345678901234567')
+ self.failUnless(h == '12345678901234567\000\000\000')
+ h = self.a._normKey('1234567890123456789012345')
+ self.failUnless(h == '12345678901234567890')
+ h = self.a._normKey('1234567890123456789')
+ self.failUnless(h == '1234567890123456789\000')
+ h = self.a._normKey('123456789012345678901')
+ self.failUnless(h == '12345678901234567890')
+
def value_stored(self, result, value):
self.stored -= 1
if self.stored == 0:
def tearDown(self):
self.a.leave()
try:
- os.unlink(self.a.khashmir.db)
+ os.unlink(self.a.khashmir.store.db)
except:
pass
self.b.leave()
try:
- os.unlink(self.b.khashmir.db)
+ os.unlink(self.b.khashmir.store.db)
except:
pass
class TestMultiDHT(unittest.TestCase):
+ """More complicated 20-node tests for the DHT."""
timeout = 60
num = 20
DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
- 'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
- 'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
+ 'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
+ 'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
+ 'MAX_FAILURES': 3,
'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
- 'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
- 'KE_AGE': 3600, }
+ 'KEY_EXPIRE': 3600, 'SPEW': False, }
def setUp(self):
self.l = []
d.addCallback(self.node_join, 1)
return self.lastDefer
- def value_stored(self, result, value):
- self.stored -= 1
- if self.stored == 0:
+ def store_values(self, result, i = 0, j = 0):
+ if j > i:
+ j -= i+1
+ i += 1
+ if i == len(self.l):
self.get_values()
-
- def store_values(self, result):
- self.stored = 0
- for i in range(len(self.l)):
- for j in range(0, i+1):
- self.stored += 1
- d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
- d.addCallback(self.value_stored, self.startport+i)
+ else:
+ d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
+ d.addCallback(self.store_values, i, j+1)
- def check_values(self, result, values):
- self.checked -= 1
- self.failUnless(len(result) == len(values))
- for v in result:
- self.failUnless(v in values)
- if self.checked == 0:
+ def get_values(self, result = None, check = None, i = 0, j = 0):
+ if result is not None:
+ self.failUnless(len(result) == len(check))
+ for v in result:
+ self.failUnless(v in check)
+ if j >= len(self.l):
+ j -= len(self.l)
+ i += 1
+ if i == len(self.l):
self.lastDefer.callback(1)
-
- def get_values(self):
- self.checked = 0
- for i in range(len(self.l)):
- for j in random.sample(xrange(len(self.l)), 4):
- self.checked += 1
- d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
- check = []
- for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
- check.append(str(k))
- d.addCallback(self.check_values, check)
+ else:
+ d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
+ check = []
+ for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
+ check.append(str(k))
+ d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
def store_join(self, result, next_node):
d = self.l[next_node].join()
for i in self.l:
try:
i.leave()
- os.unlink(i.khashmir.db)
+ os.unlink(i.khashmir.store.db)
except:
pass