Minor update to the multiple peer downloading (still not working).
[quix0rs-apt-p2p.git] / apt_p2p / PeerManager.py
1
2 """Manage a set of peers and the requests to them."""
3
4 from random import choice
5 from urlparse import urlparse, urlunparse
6 from urllib import quote_plus
7 from binascii import b2a_hex, a2b_hex
8 import sha
9
10 from twisted.internet import reactor, defer
11 from twisted.python import log
12 from twisted.trial import unittest
13 from twisted.web2 import stream
14 from twisted.web2.http import splitHostPort
15
16 from HTTPDownloader import Peer
17 from util import uncompact
18 from hash import PIECE_SIZE
19 from apt_p2p_Khashmir.bencode import bdecode
20
21 class GrowingFileStream(stream.FileStream):
22     """Modified to stream data from a file as it becomes available.
23     
24     @ivar CHUNK_SIZE: the maximum size of chunks of data to send at a time
25     @ivar deferred: waiting for the result of the last read attempt
26     @ivar available: the number of bytes that are currently available to read
27     @ivar position: the current position in the file where the next read will begin
28     @ivar finished: True when no more data will be coming available
29     """
30
31     CHUNK_SIZE = 4*1024
32
33     def __init__(self, f):
34         stream.FileStream.__init__(self, f)
35         self.length = None
36         self.deferred = None
37         self.available = 0L
38         self.position = 0L
39         self.finished = False
40
41     def updateAvaliable(self, newlyAvailable):
42         """Update the number of bytes that are available.
43         
44         Call it with 0 to trigger reading of a fully read file.
45         
46         @param newlyAvailable: the number of bytes that just became available
47         """
48         assert not self.finished
49         self.available += newlyAvailable
50         
51         # If a read is pending, let it go
52         if self.deferred and self.position < self.available:
53             # Try to read some data from the file
54             length = self.available - self.position
55             readSize = min(length, self.CHUNK_SIZE)
56             self.f.seek(self.position)
57             b = self.f.read(readSize)
58             bytesRead = len(b)
59             
60             # Check if end of file was reached
61             if bytesRead:
62                 self.position += bytesRead
63                 deferred = self.deferred
64                 self.deferred = None
65                 deferred.callback(b)
66
67     def allAvailable(self):
68         """Indicate that no more data is coming available."""
69         self.finished = True
70
71         # If a read is pending, let it go
72         if self.deferred:
73             if self.position < self.available:
74                 # Try to read some data from the file
75                 length = self.available - self.position
76                 readSize = min(length, self.CHUNK_SIZE)
77                 self.f.seek(self.position)
78                 b = self.f.read(readSize)
79                 bytesRead = len(b)
80     
81                 # Check if end of file was reached
82                 if bytesRead:
83                     self.position += bytesRead
84                     deferred = self.deferred
85                     self.deferred = None
86                     deferred.callback(b)
87                 else:
88                     # We're done
89                     deferred.callback(None)
90             else:
91                 # We're done
92                 deferred.callback(None)
93         
94     def read(self, sendfile=False):
95         assert not self.deferred, "A previous read is still deferred."
96
97         if self.f is None:
98             return None
99
100         length = self.available - self.position
101         readSize = min(length, self.CHUNK_SIZE)
102
103         # If we don't have any available, we're done or deferred
104         if readSize <= 0:
105             if self.finished:
106                 return None
107             else:
108                 self.deferred = defer.Deferred()
109                 return self.deferred
110
111         # Try to read some data from the file
112         self.f.seek(self.position)
113         b = self.f.read(readSize)
114         bytesRead = len(b)
115         if not bytesRead:
116             # End of file was reached, we're done or deferred
117             if self.finished:
118                 return None
119             else:
120                 self.deferred = defer.Deferred()
121                 return self.deferred
122         else:
123             self.position += bytesRead
124             return b
125
126 class StreamToFile(defer.Deferred):
127     """Saves a stream to a file.
128     
129     @type stream: L{twisted.web2.stream.IByteStream}
130     @ivar stream: the input stream being read
131     @type outFile: L{twisted.python.filepath.FilePath}
132     @ivar outFile: the file being written
133     @type hash: L{Hash.HashObject}
134     @ivar hash: the hash object for the file
135     @type length: C{int}
136     @ivar length: the length of the original (compressed) file
137     @type doneDefer: L{twisted.internet.defer.Deferred}
138     @ivar doneDefer: the deferred that will fire when done streaming
139     """
140     
141     def __init__(self, inputStream, outFile, hash, start, length):
142         """Initializes the file.
143         
144         @type inputStream: L{twisted.web2.stream.IByteStream}
145         @param inputStream: the input stream to read from
146         @type outFile: L{twisted.python.filepath.FilePath}
147         @param outFile: the file to write to
148         @type hash: L{Hash.HashObject}
149         @param hash: the hash object to use for the file
150         """
151         self.stream = inputStream
152         self.outFile = outFile
153         self.hash = hash
154         self.hash.new()
155         self.length = self.stream.length
156         self.doneDefer = None
157         
158     def run(self):
159         """Start the streaming."""
160         self.doneDefer = stream.readStream(self.stream, _gotData)
161         self.doneDefer.addCallbacks(self._done, self._error)
162         return self.doneDefer
163
164     def _done(self):
165         """Close all the output files, return the result."""
166         if not self.outFile.closed:
167             self.outFile.close()
168             self.hash.digest()
169             self.doneDefer.callback(self.hash)
170     
171     def _gotData(self, data):
172         if self.outFile.closed:
173             return
174         
175         if data is None:
176             self._done()
177         
178         # Write and hash the streamed data
179         self.outFile.write(data)
180         self.hash.update(data)
181         
182 class FileDownload:
183     """Manage a download from a list of peers or a mirror.
184     
185     
186     """
187     
188     def __init__(self, manager, hash, mirror, compact_peers, file):
189         """Initialize the instance and check for piece hashes.
190         
191         @type hash: L{Hash.HashObject}
192         @param hash: the hash object containing the expected hash for the file
193         @param mirror: the URI of the file on the mirror
194         @type compact_peers: C{list} of C{string}
195         @param compact_peers: a list of the peer info where the file can be found
196         @type file: L{twisted.python.filepath.FilePath}
197         @param file: the temporary file to use to store the downloaded file
198         """
199         self.manager = manager
200         self.hash = hash
201         self.mirror = mirror
202         self.compact_peers = compact_peers
203         
204         self.path = '/~/' + quote_plus(hash.expected())
205         self.pieces = None
206         self.started = False
207         
208         file.restat(False)
209         if file.exists():
210             file.remove()
211         self.file = file.open('w')
212
213     def run(self):
214         """Start the downloading process."""
215         self.defer = defer.Deferred()
216         self.peers = {}
217         no_pieces = 0
218         pieces_string = {}
219         pieces_hash = {}
220         pieces_dl_hash = {}
221
222         for compact_peer in self.compact_peers:
223             # Build a list of all the peers for this download
224             site = uncompact(compact_peer['c'])
225             peer = manager.getPeer(site)
226             self.peers.setdefault(site, {})['peer'] = peer
227
228             # Extract any piece information from the peers list
229             if 't' in compact_peer:
230                 self.peers[site]['t'] = compact_peer['t']['t']
231                 pieces_string.setdefault(compact_peer['t']['t'], 0)
232                 pieces_string[compact_peer['t']['t']] += 1
233             elif 'h' in compact_peer:
234                 self.peers[site]['h'] = compact_peer['h']
235                 pieces_hash.setdefault(compact_peer['h'], 0)
236                 pieces_hash[compact_peer['h']] += 1
237             elif 'l' in compact_peer:
238                 self.peers[site]['l'] = compact_peer['l']
239                 pieces_dl_hash.setdefault(compact_peer['l'], 0)
240                 pieces_dl_hash[compact_peer['l']] += 1
241             else:
242                 no_pieces += 1
243         
244         # Select the most popular piece info
245         max_found = max(no_pieces, max(pieces_string.values()),
246                         max(pieces_hash.values()), max(pieces_dl_hash.values()))
247
248         if max_found < len(self.peers):
249             log.msg('Misleading piece information found, using most popular %d of %d peers' % 
250                     (max_found, len(self.peers)))
251
252         if max_found == no_pieces:
253             # The file is not split into pieces
254             self.pieces = []
255             self.startDownload()
256         elif max_found == max(pieces_string.values()):
257             # Small number of pieces in a string
258             for pieces, num in pieces_string.items():
259                 # Find the most popular piece string
260                 if num == max_found:
261                     self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
262                     self.startDownload()
263                     break
264         elif max_found == max(pieces_hash.values()):
265             # Medium number of pieces stored in the DHT
266             for pieces, num in pieces_hash.items():
267                 # Find the most popular piece hash to lookup
268                 if num == max_found:
269                     self.getDHTPieces(pieces)
270                     break
271         elif max_found == max(pieces_dl_hash.values()):
272             # Large number of pieces stored in peers
273             for pieces, num in pieces_hash.items():
274                 # Find the most popular piece hash to download
275                 if num == max_found:
276                     self.getPeerPieces(pieces)
277                     break
278         return self.defer
279
280     #{ Downloading the piece hashes
281     def getDHTPieces(self, key):
282         """Retrieve the piece information from the DHT.
283         
284         @param key: the key to lookup in the DHT
285         """
286         # Remove any peers with the wrong piece hash
287         #for site in self.peers.keys():
288         #    if self.peers[site].get('h', '') != key:
289         #        del self.peers[site]
290
291         # Start the DHT lookup
292         lookupDefer = self.manager.dht.getValue(key)
293         lookupDefer.addCallback(self._getDHTPieces, key)
294         
295     def _getDHTPieces(self, results, key):
296         """Check the retrieved values."""
297         for result in results:
298             # Make sure the hash matches the key
299             result_hash = sha.new(result.get('t', '')).digest()
300             if result_hash == key:
301                 pieces = result['t']
302                 self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
303                 log.msg('Retrieved %d piece hashes from the DHT' % len(self.pieces))
304                 self.startDownload()
305                 return
306             
307         # Continue without the piece hashes
308         log.msg('Could not retrieve the piece hashes from the DHT')
309         self.pieces = []
310         self.startDownload()
311
312     def getPeerPieces(self, key, failedSite = None):
313         """Retrieve the piece information from the peers.
314         
315         @param key: the key to request from the peers
316         """
317         if failedSite is None:
318             self.outstanding = 0
319             # Remove any peers with the wrong piece hash
320             #for site in self.peers.keys():
321             #    if self.peers[site].get('l', '') != key:
322             #        del self.peers[site]
323         else:
324             self.peers[failedSite]['failed'] = True
325             self.outstanding -= 1
326
327         if self.pieces is None:
328             # Send a request to one or more peers
329             for site in self.peers:
330                 if self.peers[site].get('failed', False) != True:
331                     path = '/~/' + quote_plus(key)
332                     lookupDefer = self.peers[site]['peer'].get(path)
333                     lookupDefer.addCallbacks(self._getPeerPieces, self._gotPeerError,
334                                              callbackArgs=(key, site), errbackArgs=(key, site))
335                     self.outstanding += 1
336                     if self.outstanding >= 3:
337                         break
338         
339         if self.pieces is None and self.outstanding == 0:
340             # Continue without the piece hashes
341             log.msg('Could not retrieve the piece hashes from the peers')
342             self.pieces = []
343             self.startDownload()
344         
345     def _getPeerPieces(self, response, key, site):
346         """Process the retrieved response from the peer."""
347         if response.code != 200:
348             # Request failed, try a different peer
349             self.getPeerPieces(key, site)
350         else:
351             # Read the response stream to a string
352             self.peers[site]['pieces'] = ''
353             def _gotPeerPiece(data, self = self, site = site):
354                 self.peers[site]['pieces'] += data
355             df = stream.readStream(response.stream, _gotPeerPiece)
356             df.addCallbacks(self._gotPeerPieces, self._gotPeerError,
357                             callbackArgs=(key, site), errbackArgs=(key, site))
358
359     def _gotPeerError(self, err, key, site):
360         """Peer failed, try again."""
361         log.err(err)
362         self.getPeerPieces(key, site)
363
364     def _gotPeerPieces(self, result, key, site):
365         """Check the retrieved pieces from the peer."""
366         if self.pieces is not None:
367             # Already done
368             return
369         
370         try:
371             result = bdecode(self.peers[site]['pieces'])
372         except:
373             log.err()
374             self.getPeerPieces(key, site)
375             return
376             
377         result_hash = sha.new(result.get('t', '')).digest()
378         if result_hash == key:
379             pieces = result['t']
380             self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
381             log.msg('Retrieved %d piece hashes from the peer' % len(self.pieces))
382             self.startDownload()
383         else:
384             log.msg('Peer returned a piece string that did not match')
385             self.getPeerPieces(key, site)
386
387     #{ Downloading the file
388     def sort(self):
389         """Sort the peers by their rank (highest ranked at the end)."""
390         def sort(a, b):
391             """Sort peers by their rank."""
392             if a.rank > b.rank:
393                 return 1
394             elif a.rank < b.rank:
395                 return -1
396             return 0
397         self.peerlist.sort(sort)
398
399     def startDownload(self):
400         """Start the download from the peers."""
401         # Don't start twice
402         if self.started:
403             return
404         
405         self.started = True
406         assert self.pieces is not None, "You must initialize the piece hashes first"
407         self.peerlist = [self.peers[site]['peer'] for site in self.peers]
408         
409         # Special case if there's only one good peer left
410         if len(self.peerlist) == 1:
411             log.msg('Downloading from peer %r' % (self.peerlist[0], ))
412             self.defer.callback(self.peerlist[0].get(self.path))
413             return
414         
415         self.sort()
416         self.outstanding = 0
417         self.next_piece = 0
418         
419         while self.outstanding < 3 and self.peerlist and self.next_piece < len(self.pieces):
420             peer = self.peerlist.pop()
421             piece = self.next_piece
422             self.next_piece += 1
423             
424             self.outstanding += 1
425             df = peer.getRange(self.path, piece*PIECE_SIZE, (piece+1)*PIECE_SIZE - 1)
426             df.addCallbacks(self._gotPiece, self._gotError,
427                             callbackArgs=(piece, peer), errbackArgs=(piece, peer))
428     
429     def _gotPiece(self, response, piece, peer):
430         """Process the retrieved piece from the peer."""
431         if response.code != 206:
432             # Request failed, try a different peer
433             self.getPeerPieces(key, site)
434         else:
435             # Read the response stream to the file
436             df = StreamToFile(response.stream, self.file, self.hash, piece*PIECE_SIZE, PIECE_SIZE).run()
437             df.addCallbacks(self._gotPeerPieces, self._gotPeerError,
438                             callbackArgs=(key, site), errbackArgs=(key, site))
439
440     def _gotError(self, err, piece, peer):
441         """Peer failed, try again."""
442         log.err(err)
443
444         
445 class PeerManager:
446     """Manage a set of peers and the requests to them.
447     
448     @type clients: C{dictionary}
449     @ivar clients: the available peers that have been previously contacted
450     """
451
452     def __init__(self, cache_dir, dht):
453         """Initialize the instance."""
454         self.cache_dir = cache_dir
455         self.cache_dir.restat(False)
456         if not self.cache_dir.exists():
457             self.cache_dir.makedirs()
458         self.dht = dht
459         self.clients = {}
460         
461     def get(self, hash, mirror, peers = [], method="GET", modtime=None):
462         """Download from a list of peers or fallback to a mirror.
463         
464         @type hash: L{Hash.HashObject}
465         @param hash: the hash object containing the expected hash for the file
466         @param mirror: the URI of the file on the mirror
467         @type peers: C{list} of C{string}
468         @param peers: a list of the peer info where the file can be found
469             (optional, defaults to downloading from the mirror)
470         @type method: C{string}
471         @param method: the HTTP method to use, 'GET' or 'HEAD'
472             (optional, defaults to 'GET')
473         @type modtime: C{int}
474         @param modtime: the modification time to use for an 'If-Modified-Since'
475             header, as seconds since the epoch
476             (optional, defaults to not sending that header)
477         """
478         if not peers or method != "GET" or modtime is not None:
479             log.msg('Downloading (%s) from mirror %s' % (method, mirror))
480             parsed = urlparse(mirror)
481             assert parsed[0] == "http", "Only HTTP is supported, not '%s'" % parsed[0]
482             site = splitHostPort(parsed[0], parsed[1])
483             path = urlunparse(('', '') + parsed[2:])
484             peer = self.getPeer(site)
485             return peer.get(path, method, modtime)
486         elif len(peers) == 1:
487             site = uncompact(peers[0]['c'])
488             log.msg('Downloading from peer %r' % (site, ))
489             path = '/~/' + quote_plus(hash.expected())
490             peer = self.getPeer(site)
491             return peer.get(path)
492         else:
493             tmpfile = self.cache_dir.child(hash.hexexpected())
494             return FileDownload(self, hash, mirror, peers, tmpfile).run()
495         
496     def getPeer(self, site):
497         """Create a new peer if necessary and return it.
498         
499         @type site: (C{string}, C{int})
500         @param site: the IP address and port of the peer
501         """
502         if site not in self.clients:
503             self.clients[site] = Peer(site[0], site[1])
504         return self.clients[site]
505     
506     def close(self):
507         """Close all the connections to peers."""
508         for site in self.clients:
509             self.clients[site].close()
510         self.clients = {}
511
512 class TestPeerManager(unittest.TestCase):
513     """Unit tests for the PeerManager."""
514     
515     manager = None
516     pending_calls = []
517     
518     def gotResp(self, resp, num, expect):
519         self.failUnless(resp.code >= 200 and resp.code < 300, "Got a non-200 response: %r" % resp.code)
520         if expect is not None:
521             self.failUnless(resp.stream.length == expect, "Length was incorrect, got %r, expected %r" % (resp.stream.length, expect))
522         def print_(n):
523             pass
524         def printdone(n):
525             pass
526         stream.readStream(resp.stream, print_).addCallback(printdone)
527     
528     def test_download(self):
529         """Tests a normal download."""
530         self.manager = PeerManager()
531         self.timeout = 10
532         
533         host = 'www.ietf.org'
534         d = self.manager.get('', 'http://' + host + '/rfc/rfc0013.txt')
535         d.addCallback(self.gotResp, 1, 1070)
536         return d
537         
538     def test_head(self):
539         """Tests a 'HEAD' request."""
540         self.manager = PeerManager()
541         self.timeout = 10
542         
543         host = 'www.ietf.org'
544         d = self.manager.get('', 'http://' + host + '/rfc/rfc0013.txt', method = "HEAD")
545         d.addCallback(self.gotResp, 1, 0)
546         return d
547         
548     def test_multiple_downloads(self):
549         """Tests multiple downloads with queueing and connection closing."""
550         self.manager = PeerManager()
551         self.timeout = 120
552         lastDefer = defer.Deferred()
553         
554         def newRequest(host, path, num, expect, last=False):
555             d = self.manager.get('', 'http://' + host + ':' + str(80) + path)
556             d.addCallback(self.gotResp, num, expect)
557             if last:
558                 d.addBoth(lastDefer.callback)
559                 
560         newRequest('www.ietf.org', "/rfc/rfc0006.txt", 1, 1776)
561         newRequest('www.ietf.org', "/rfc/rfc2362.txt", 2, 159833)
562         newRequest('www.google.ca', "/", 3, None)
563         self.pending_calls.append(reactor.callLater(1, newRequest, 'www.sfu.ca', '/', 4, None))
564         self.pending_calls.append(reactor.callLater(10, newRequest, 'www.ietf.org', '/rfc/rfc0048.txt', 5, 41696))
565         self.pending_calls.append(reactor.callLater(30, newRequest, 'www.ietf.org', '/rfc/rfc0022.txt', 6, 4606))
566         self.pending_calls.append(reactor.callLater(31, newRequest, 'www.sfu.ca', '/studentcentral/index.html', 7, None))
567         self.pending_calls.append(reactor.callLater(32, newRequest, 'www.ietf.org', '/rfc/rfc0014.txt', 8, 27))
568         self.pending_calls.append(reactor.callLater(32, newRequest, 'www.ietf.org', '/rfc/rfc0001.txt', 9, 21088))
569         self.pending_calls.append(reactor.callLater(62, newRequest, 'www.google.ca', '/intl/en/options/', 0, None, True))
570         return lastDefer
571         
572     def tearDown(self):
573         for p in self.pending_calls:
574             if p.active():
575                 p.cancel()
576         self.pending_calls = []
577         if self.manager:
578             self.manager.close()
579             self.manager = None