df003993cf80676c50df08f3ea57e36a424917f3
[quix0rs-apt-p2p.git] / apt_p2p / PeerManager.py
1
2 """Manage a set of peers and the requests to them."""
3
4 from random import choice
5 from urlparse import urlparse, urlunparse
6 from urllib import quote_plus
7 from binascii import b2a_hex, a2b_hex
8 import sha
9
10 from twisted.internet import reactor, defer
11 from twisted.python import log
12 from twisted.trial import unittest
13 from twisted.web2 import stream
14 from twisted.web2.http import splitHostPort
15
16 from HTTPDownloader import Peer
17 from util import uncompact
18 from hash import PIECE_SIZE
19 from apt_p2p_Khashmir.bencode import bdecode
20
21 class GrowingFileStream(stream.FileStream):
22     """Modified to stream data from a file as it becomes available.
23     
24     @ivar CHUNK_SIZE: the maximum size of chunks of data to send at a time
25     @ivar deferred: waiting for the result of the last read attempt
26     @ivar available: the number of bytes that are currently available to read
27     @ivar position: the current position in the file where the next read will begin
28     @ivar finished: True when no more data will be coming available
29     """
30
31     CHUNK_SIZE = 4*1024
32
33     def __init__(self, f):
34         stream.FileStream.__init__(self, f)
35         self.length = None
36         self.deferred = None
37         self.available = 0L
38         self.position = 0L
39         self.finished = False
40
41     def updateAvaliable(self, newlyAvailable):
42         """Update the number of bytes that are available.
43         
44         Call it with 0 to trigger reading of a fully read file.
45         
46         @param newlyAvailable: the number of bytes that just became available
47         """
48         assert not self.finished
49         self.available += newlyAvailable
50         
51         # If a read is pending, let it go
52         if self.deferred and self.position < self.available:
53             # Try to read some data from the file
54             length = self.available - self.position
55             readSize = min(length, self.CHUNK_SIZE)
56             self.f.seek(self.position)
57             b = self.f.read(readSize)
58             bytesRead = len(b)
59             
60             # Check if end of file was reached
61             if bytesRead:
62                 self.position += bytesRead
63                 deferred = self.deferred
64                 self.deferred = None
65                 deferred.callback(b)
66
67     def allAvailable(self):
68         """Indicate that no more data is coming available."""
69         self.finished = True
70
71         # If a read is pending, let it go
72         if self.deferred:
73             if self.position < self.available:
74                 # Try to read some data from the file
75                 length = self.available - self.position
76                 readSize = min(length, self.CHUNK_SIZE)
77                 self.f.seek(self.position)
78                 b = self.f.read(readSize)
79                 bytesRead = len(b)
80     
81                 # Check if end of file was reached
82                 if bytesRead:
83                     self.position += bytesRead
84                     deferred = self.deferred
85                     self.deferred = None
86                     deferred.callback(b)
87                 else:
88                     # We're done
89                     deferred.callback(None)
90             else:
91                 # We're done
92                 deferred.callback(None)
93         
94     def read(self, sendfile=False):
95         assert not self.deferred, "A previous read is still deferred."
96
97         if self.f is None:
98             return None
99
100         length = self.available - self.position
101         readSize = min(length, self.CHUNK_SIZE)
102
103         # If we don't have any available, we're done or deferred
104         if readSize <= 0:
105             if self.finished:
106                 return None
107             else:
108                 self.deferred = defer.Deferred()
109                 return self.deferred
110
111         # Try to read some data from the file
112         self.f.seek(self.position)
113         b = self.f.read(readSize)
114         bytesRead = len(b)
115         if not bytesRead:
116             # End of file was reached, we're done or deferred
117             if self.finished:
118                 return None
119             else:
120                 self.deferred = defer.Deferred()
121                 return self.deferred
122         else:
123             self.position += bytesRead
124             return b
125
126 class StreamToFile(defer.Deferred):
127     """Saves a stream to a file.
128     
129     @type stream: L{twisted.web2.stream.IByteStream}
130     @ivar stream: the input stream being read
131     @type outFile: L{twisted.python.filepath.FilePath}
132     @ivar outFile: the file being written
133     @type hash: L{Hash.HashObject}
134     @ivar hash: the hash object for the file
135     @type length: C{int}
136     @ivar length: the length of the original (compressed) file
137     @type doneDefer: L{twisted.internet.defer.Deferred}
138     @ivar doneDefer: the deferred that will fire when done streaming
139     """
140     
141     def __init__(self, inputStream, outFile, hash, start, length):
142         """Initializes the file.
143         
144         @type inputStream: L{twisted.web2.stream.IByteStream}
145         @param inputStream: the input stream to read from
146         @type outFile: L{twisted.python.filepath.FilePath}
147         @param outFile: the file to write to
148         @type hash: L{Hash.HashObject}
149         @param hash: the hash object to use for the file
150         """
151         self.stream = inputStream
152         self.outFile = outFile.open('w')
153         self.hash = hash
154         self.hash.new()
155         self.length = self.stream.length
156         
157     def run(self):
158         """Start the streaming."""
159         self.doneDefer = stream.readStream(self.stream, _gotData)
160         self.doneDefer.addCallbacks(self._done, self._error)
161         return self.doneDefer
162
163     def _done(self):
164         """Close all the output files, return the result."""
165         if not self.outFile.closed:
166             self.outFile.close()
167             self.hash.digest()
168             self.doneDefer.callback(self.hash)
169     
170     def _gotData(self, data):
171         self.peers[site]['pieces'] += data
172
173     def read(self):
174         """Read some data from the stream."""
175         if self.outFile.closed:
176             return None
177         
178         # Read data from the stream, deal with the possible deferred
179         data = self.stream.read()
180         if isinstance(data, defer.Deferred):
181             data.addCallbacks(self._write, self._done)
182             return data
183         
184         self._write(data)
185         return data
186     
187     def _write(self, data):
188         """Write the stream data to the file and return it for others to use.
189         
190         Also optionally decompresses it.
191         """
192         if data is None:
193             self._done()
194             return data
195         
196         # Write and hash the streamed data
197         self.outFile.write(data)
198         self.hash.update(data)
199         
200         return data
201     
202     def close(self):
203         """Clean everything up and return None to future reads."""
204         self.length = 0
205         self._done()
206         self.stream.close()
207
208
209 class FileDownload:
210     """Manage a download from a list of peers or a mirror.
211     
212     
213     """
214     
215     def __init__(self, manager, hash, mirror, compact_peers, file):
216         """Initialize the instance and check for piece hashes.
217         
218         @type hash: L{Hash.HashObject}
219         @param hash: the hash object containing the expected hash for the file
220         @param mirror: the URI of the file on the mirror
221         @type compact_peers: C{list} of C{string}
222         @param compact_peers: a list of the peer info where the file can be found
223         @type file: L{twisted.python.filepath.FilePath}
224         @param file: the temporary file to use to store the downloaded file
225         """
226         self.manager = manager
227         self.hash = hash
228         self.mirror = mirror
229         self.compact_peers = compact_peers
230         
231         self.path = '/~/' + quote_plus(hash.expected())
232         self.pieces = None
233         self.started = False
234         
235         file.restat(False)
236         if file.exists():
237             file.remove()
238         self.file = file.open('w')
239
240     def run(self):
241         """Start the downloading process."""
242         self.defer = defer.Deferred()
243         self.peers = {}
244         no_pieces = 0
245         pieces_string = {}
246         pieces_hash = {}
247         pieces_dl_hash = {}
248
249         for compact_peer in self.compact_peers:
250             # Build a list of all the peers for this download
251             site = uncompact(compact_peer['c'])
252             peer = manager.getPeer(site)
253             self.peers.setdefault(site, {})['peer'] = peer
254
255             # Extract any piece information from the peers list
256             if 't' in compact_peer:
257                 self.peers[site]['t'] = compact_peer['t']['t']
258                 pieces_string.setdefault(compact_peer['t']['t'], 0)
259                 pieces_string[compact_peer['t']['t']] += 1
260             elif 'h' in compact_peer:
261                 self.peers[site]['h'] = compact_peer['h']
262                 pieces_hash.setdefault(compact_peer['h'], 0)
263                 pieces_hash[compact_peer['h']] += 1
264             elif 'l' in compact_peer:
265                 self.peers[site]['l'] = compact_peer['l']
266                 pieces_dl_hash.setdefault(compact_peer['l'], 0)
267                 pieces_dl_hash[compact_peer['l']] += 1
268             else:
269                 no_pieces += 1
270         
271         # Select the most popular piece info
272         max_found = max(no_pieces, max(pieces_string.values()),
273                         max(pieces_hash.values()), max(pieces_dl_hash.values()))
274
275         if max_found < len(self.peers):
276             log.msg('Misleading piece information found, using most popular %d of %d peers' % 
277                     (max_found, len(self.peers)))
278
279         if max_found == no_pieces:
280             # The file is not split into pieces
281             self.pieces = []
282             self.startDownload()
283         elif max_found == max(pieces_string.values()):
284             # Small number of pieces in a string
285             for pieces, num in pieces_string.items():
286                 # Find the most popular piece string
287                 if num == max_found:
288                     self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
289                     self.startDownload()
290                     break
291         elif max_found == max(pieces_hash.values()):
292             # Medium number of pieces stored in the DHT
293             for pieces, num in pieces_hash.items():
294                 # Find the most popular piece hash to lookup
295                 if num == max_found:
296                     self.getDHTPieces(pieces)
297                     break
298         elif max_found == max(pieces_dl_hash.values()):
299             # Large number of pieces stored in peers
300             for pieces, num in pieces_hash.items():
301                 # Find the most popular piece hash to download
302                 if num == max_found:
303                     self.getPeerPieces(pieces)
304                     break
305         return self.defer
306
307     #{ Downloading the piece hashes
308     def getDHTPieces(self, key):
309         """Retrieve the piece information from the DHT.
310         
311         @param key: the key to lookup in the DHT
312         """
313         # Remove any peers with the wrong piece hash
314         #for site in self.peers.keys():
315         #    if self.peers[site].get('h', '') != key:
316         #        del self.peers[site]
317
318         # Start the DHT lookup
319         lookupDefer = self.manager.dht.getValue(key)
320         lookupDefer.addCallback(self._getDHTPieces, key)
321         
322     def _getDHTPieces(self, results, key):
323         """Check the retrieved values."""
324         for result in results:
325             # Make sure the hash matches the key
326             result_hash = sha.new(result.get('t', '')).digest()
327             if result_hash == key:
328                 pieces = result['t']
329                 self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
330                 log.msg('Retrieved %d piece hashes from the DHT' % len(self.pieces))
331                 self.startDownload()
332                 return
333             
334         # Continue without the piece hashes
335         log.msg('Could not retrieve the piece hashes from the DHT')
336         self.pieces = []
337         self.startDownload()
338
339     def getPeerPieces(self, key, failedSite = None):
340         """Retrieve the piece information from the peers.
341         
342         @param key: the key to request from the peers
343         """
344         if failedSite is None:
345             self.outstanding = 0
346             # Remove any peers with the wrong piece hash
347             #for site in self.peers.keys():
348             #    if self.peers[site].get('l', '') != key:
349             #        del self.peers[site]
350         else:
351             self.peers[failedSite]['failed'] = True
352             self.outstanding -= 1
353
354         if self.pieces is None:
355             # Send a request to one or more peers
356             for site in self.peers:
357                 if self.peers[site].get('failed', False) != True:
358                     path = '/~/' + quote_plus(key)
359                     lookupDefer = self.peers[site]['peer'].get(path)
360                     lookupDefer.addCallbacks(self._getPeerPieces, self._gotPeerError,
361                                              callbackArgs=(key, site), errbackArgs=(key, site))
362                     self.outstanding += 1
363                     if self.outstanding >= 3:
364                         break
365         
366         if self.pieces is None and self.outstanding == 0:
367             # Continue without the piece hashes
368             log.msg('Could not retrieve the piece hashes from the peers')
369             self.pieces = []
370             self.startDownload()
371         
372     def _getPeerPieces(self, response, key, site):
373         """Process the retrieved response from the peer."""
374         if response.code != 200:
375             # Request failed, try a different peer
376             self.getPeerPieces(key, site)
377         else:
378             # Read the response stream to a string
379             self.peers[site]['pieces'] = ''
380             def _gotPeerPiece(data, self = self, site = site):
381                 self.peers[site]['pieces'] += data
382             df = stream.readStream(response.stream, _gotPeerPiece)
383             df.addCallbacks(self._gotPeerPieces, self._gotPeerError,
384                             callbackArgs=(key, site), errbackArgs=(key, site))
385
386     def _gotPeerError(self, err, key, site):
387         """Peer failed, try again."""
388         log.err(err)
389         self.getPeerPieces(key, site)
390
391     def _gotPeerPieces(self, result, key, site):
392         """Check the retrieved pieces from the peer."""
393         if self.pieces is not None:
394             # Already done
395             return
396         
397         try:
398             result = bdecode(self.peers[site]['pieces'])
399         except:
400             log.err()
401             self.getPeerPieces(key, site)
402             return
403             
404         result_hash = sha.new(result.get('t', '')).digest()
405         if result_hash == key:
406             pieces = result['t']
407             self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
408             log.msg('Retrieved %d piece hashes from the peer' % len(self.pieces))
409             self.startDownload()
410         else:
411             log.msg('Peer returned a piece string that did not match')
412             self.getPeerPieces(key, site)
413
414     #{ Downloading the file
415     def sort(self):
416         """Sort the peers by their rank (highest ranked at the end)."""
417         def sort(a, b):
418             """Sort peers by their rank."""
419             if a.rank > b.rank:
420                 return 1
421             elif a.rank < b.rank:
422                 return -1
423             return 0
424         self.peerlist.sort(sort)
425
426     def startDownload(self):
427         """Start the download from the peers."""
428         # Don't start twice
429         if self.started:
430             return
431         
432         self.started = True
433         assert self.pieces is not None, "You must initialize the piece hashes first"
434         self.peerlist = [self.peers[site]['peer'] for site in self.peers]
435         
436         # Special case if there's only one good peer left
437         if len(self.peerlist) == 1:
438             log.msg('Downloading from peer %r' % (self.peerlist[0], ))
439             self.defer.callback(self.peerlist[0].get(self.path))
440             return
441         
442         self.sort()
443         self.outstanding = 0
444         self.next_piece = 0
445         
446         while self.outstanding < 3 and self.peerlist and self.next_piece < len(self.pieces):
447             peer = self.peerlist.pop()
448             piece = self.next_piece
449             self.next_piece += 1
450             
451             self.outstanding += 1
452             df = peer.getRange(self.path, piece*PIECE_SIZE, (piece+1)*PIECE_SIZE - 1)
453             df.addCallbacks(self._gotPiece, self._gotError,
454                             callbackArgs=(piece, peer), errbackArgs=(piece, peer))
455     
456     def _gotPiece(self, response, piece, peer):
457         """Process the retrieved piece from the peer."""
458         if response.code != 206:
459             # Request failed, try a different peer
460             self.getPeerPieces(key, site)
461         else:
462             # Read the response stream to the file
463             df = StreamToFile(response.stream, self.file, self.hash, piece*PIECE_SIZE, PIECE_SIZE).run()
464             df.addCallbacks(self._gotPeerPieces, self._gotPeerError,
465                             callbackArgs=(key, site), errbackArgs=(key, site))
466
467     def _gotError(self, err, piece, peer):
468         """Peer failed, try again."""
469         log.err(err)
470
471         
472 class PeerManager:
473     """Manage a set of peers and the requests to them.
474     
475     @type clients: C{dictionary}
476     @ivar clients: the available peers that have been previously contacted
477     """
478
479     def __init__(self, cache_dir, dht):
480         """Initialize the instance."""
481         self.cache_dir = cache_dir
482         self.cache_dir.restat(False)
483         if not self.cache_dir.exists():
484             self.cache_dir.makedirs()
485         self.dht = dht
486         self.clients = {}
487         
488     def get(self, hash, mirror, peers = [], method="GET", modtime=None):
489         """Download from a list of peers or fallback to a mirror.
490         
491         @type hash: L{Hash.HashObject}
492         @param hash: the hash object containing the expected hash for the file
493         @param mirror: the URI of the file on the mirror
494         @type peers: C{list} of C{string}
495         @param peers: a list of the peer info where the file can be found
496             (optional, defaults to downloading from the mirror)
497         @type method: C{string}
498         @param method: the HTTP method to use, 'GET' or 'HEAD'
499             (optional, defaults to 'GET')
500         @type modtime: C{int}
501         @param modtime: the modification time to use for an 'If-Modified-Since'
502             header, as seconds since the epoch
503             (optional, defaults to not sending that header)
504         """
505         if not peers or method != "GET" or modtime is not None:
506             log.msg('Downloading (%s) from mirror %s' % (method, mirror))
507             parsed = urlparse(mirror)
508             assert parsed[0] == "http", "Only HTTP is supported, not '%s'" % parsed[0]
509             site = splitHostPort(parsed[0], parsed[1])
510             path = urlunparse(('', '') + parsed[2:])
511             peer = self.getPeer(site)
512             return peer.get(path, method, modtime)
513         elif len(peers) == 1:
514             site = uncompact(peers[0]['c'])
515             log.msg('Downloading from peer %r' % (site, ))
516             path = '/~/' + quote_plus(hash.expected())
517             peer = self.getPeer(site)
518             return peer.get(path)
519         else:
520             tmpfile = self.cache_dir.child(hash.hexexpected())
521             return FileDownload(self, hash, mirror, peers, tmpfile).run()
522         
523     def getPeer(self, site):
524         """Create a new peer if necessary and return it.
525         
526         @type site: (C{string}, C{int})
527         @param site: the IP address and port of the peer
528         """
529         if site not in self.clients:
530             self.clients[site] = Peer(site[0], site[1])
531         return self.clients[site]
532     
533     def close(self):
534         """Close all the connections to peers."""
535         for site in self.clients:
536             self.clients[site].close()
537         self.clients = {}
538
539 class TestPeerManager(unittest.TestCase):
540     """Unit tests for the PeerManager."""
541     
542     manager = None
543     pending_calls = []
544     
545     def gotResp(self, resp, num, expect):
546         self.failUnless(resp.code >= 200 and resp.code < 300, "Got a non-200 response: %r" % resp.code)
547         if expect is not None:
548             self.failUnless(resp.stream.length == expect, "Length was incorrect, got %r, expected %r" % (resp.stream.length, expect))
549         def print_(n):
550             pass
551         def printdone(n):
552             pass
553         stream.readStream(resp.stream, print_).addCallback(printdone)
554     
555     def test_download(self):
556         """Tests a normal download."""
557         self.manager = PeerManager()
558         self.timeout = 10
559         
560         host = 'www.ietf.org'
561         d = self.manager.get('', 'http://' + host + '/rfc/rfc0013.txt')
562         d.addCallback(self.gotResp, 1, 1070)
563         return d
564         
565     def test_head(self):
566         """Tests a 'HEAD' request."""
567         self.manager = PeerManager()
568         self.timeout = 10
569         
570         host = 'www.ietf.org'
571         d = self.manager.get('', 'http://' + host + '/rfc/rfc0013.txt', method = "HEAD")
572         d.addCallback(self.gotResp, 1, 0)
573         return d
574         
575     def test_multiple_downloads(self):
576         """Tests multiple downloads with queueing and connection closing."""
577         self.manager = PeerManager()
578         self.timeout = 120
579         lastDefer = defer.Deferred()
580         
581         def newRequest(host, path, num, expect, last=False):
582             d = self.manager.get('', 'http://' + host + ':' + str(80) + path)
583             d.addCallback(self.gotResp, num, expect)
584             if last:
585                 d.addBoth(lastDefer.callback)
586                 
587         newRequest('www.ietf.org', "/rfc/rfc0006.txt", 1, 1776)
588         newRequest('www.ietf.org', "/rfc/rfc2362.txt", 2, 159833)
589         newRequest('www.google.ca', "/", 3, None)
590         self.pending_calls.append(reactor.callLater(1, newRequest, 'www.sfu.ca', '/', 4, None))
591         self.pending_calls.append(reactor.callLater(10, newRequest, 'www.ietf.org', '/rfc/rfc0048.txt', 5, 41696))
592         self.pending_calls.append(reactor.callLater(30, newRequest, 'www.ietf.org', '/rfc/rfc0022.txt', 6, 4606))
593         self.pending_calls.append(reactor.callLater(31, newRequest, 'www.sfu.ca', '/studentcentral/index.html', 7, None))
594         self.pending_calls.append(reactor.callLater(32, newRequest, 'www.ietf.org', '/rfc/rfc0014.txt', 8, 27))
595         self.pending_calls.append(reactor.callLater(32, newRequest, 'www.ietf.org', '/rfc/rfc0001.txt', 9, 21088))
596         self.pending_calls.append(reactor.callLater(62, newRequest, 'www.google.ca', '/intl/en/options/', 0, None, True))
597         return lastDefer
598         
599     def tearDown(self):
600         for p in self.pending_calls:
601             if p.active():
602                 p.cancel()
603         self.pending_calls = []
604         if self.manager:
605             self.manager.close()
606             self.manager = None