Fix an error in the actions that allowed for the result to be sent twice.
[quix0rs-apt-p2p.git] / apt_p2p / apt_p2p.py
1
2 """The main program code.
3
4 @var DHT_PIECES: the maximum number of pieces to store with our contact info
5     in the DHT
6 @var TORRENT_PIECES: the maximum number of pieces to store as a separate entry
7     in the DHT
8 @var download_dir: the name of the directory to use for downloaded files
9 @var peer_dir: the name of the directory to use for peer downloads
10 """
11
12 from binascii import b2a_hex
13 from urlparse import urlunparse
14 from urllib import unquote
15 import os, re, sha
16
17 from twisted.internet import defer, reactor, protocol
18 from twisted.web2 import server, http, http_headers, static
19 from twisted.python import log, failure
20 from twisted.python.filepath import FilePath
21
22 from interfaces import IDHT, IDHTStats
23 from apt_p2p_conf import config
24 from PeerManager import PeerManager
25 from HTTPServer import TopLevel
26 from MirrorManager import MirrorManager
27 from CacheManager import CacheManager
28 from Hash import HashObject
29 from db import DB
30 from stats import StatsLogger
31 from util import findMyIPAddr, compact
32
33 DHT_PIECES = 4
34 TORRENT_PIECES = 70
35
36 download_dir = 'cache'
37 peer_dir = 'peers'
38
39 class AptP2P(protocol.Factory):
40     """The main code object that does all of the work.
41     
42     Contains all of the sub-components that do all the low-level work, and
43     coordinates communication between them.
44     
45     @type dhtClass: L{interfaces.IDHT}
46     @ivar dhtClass: the DHT class to use
47     @type cache_dir: L{twisted.python.filepath.FilePath}
48     @ivar cache_dir: the directory to use for storing all files
49     @type db: L{db.DB}
50     @ivar db: the database to use for tracking files and hashes
51     @type dht: L{interfaces.IDHT}
52     @ivar dht: the DHT instance
53     @type stats: L{stats.StatsLogger}
54     @ivar stats: the statistics logger to record sent data to
55     @type http_server: L{HTTPServer.TopLevel}
56     @ivar http_server: the web server that will handle all requests from apt
57         and from other peers
58     @type peers: L{PeerManager.PeerManager}
59     @ivar peers: the manager of all downloads from mirrors and other peers
60     @type mirrors: L{MirrorManager.MirrorManager}
61     @ivar mirrors: the manager of downloaded information about mirrors which
62         can be queried to get hashes from file names
63     @type cache: L{CacheManager.CacheManager}
64     @ivar cache: the manager of all downloaded files
65     @type my_contact: C{string}
66     @ivar my_contact: the 6-byte compact peer representation of this peer's
67         download information (IP address and port)
68     """
69     
70     def __init__(self, dhtClass):
71         """Initialize all the sub-components.
72         
73         @type dhtClass: L{interfaces.IDHT}
74         @param dhtClass: the DHT class to use
75         """
76         log.msg('Initializing the main apt_p2p application')
77         self.dhtClass = dhtClass
78
79     #{ Factory interface
80     def startFactory(self):
81         reactor.callLater(0, self._startFactory)
82         
83     def _startFactory(self):
84         log.msg('Starting the main apt_p2p application')
85         self.cache_dir = FilePath(config.get('DEFAULT', 'CACHE_DIR'))
86         if not self.cache_dir.child(download_dir).exists():
87             self.cache_dir.child(download_dir).makedirs()
88         if not self.cache_dir.child(peer_dir).exists():
89             self.cache_dir.child(peer_dir).makedirs()
90         self.db = DB(self.cache_dir.child('apt-p2p.db'))
91         self.dht = self.dhtClass()
92         self.dht.loadConfig(config, config.get('DEFAULT', 'DHT'))
93         self.dht.join().addCallbacks(self.joinComplete, self.joinError)
94         self.stats = StatsLogger(self.db)
95         self.http_server = TopLevel(self.cache_dir.child(download_dir), self.db, self)
96         self.http_server.getHTTPFactory().startFactory()
97         self.peers = PeerManager(self.cache_dir.child(peer_dir), self.dht, self.stats)
98         self.mirrors = MirrorManager(self.cache_dir)
99         self.cache = CacheManager(self.cache_dir.child(download_dir), self.db, self)
100         self.my_contact = None
101         
102     def stopFactory(self):
103         log.msg('Stoppping the main apt_p2p application')
104         self.http_server.getHTTPFactory().stopFactory()
105         self.mirrors.cleanup()
106         self.stats.save()
107         self.db.close()
108     
109     def buildProtocol(self, addr):
110         return self.http_server.getHTTPFactory().buildProtocol(addr)
111         
112     #{ DHT Maintenance
113     def joinComplete(self, result):
114         """Complete the DHT join process and determine our download information.
115         
116         Called by the DHT when the join has been completed with information
117         on the external IP address and port of this peer.
118         """
119         my_addr = findMyIPAddr(result,
120                                config.getint(config.get('DEFAULT', 'DHT'), 'PORT'),
121                                config.getboolean('DEFAULT', 'LOCAL_OK'))
122         if not my_addr:
123             raise RuntimeError, "IP address for this machine could not be found"
124         self.my_contact = compact(my_addr, config.getint('DEFAULT', 'PORT'))
125         self.cache.scanDirectories()
126         reactor.callLater(60, self.refreshFiles)
127
128     def joinError(self, failure):
129         """Joining the DHT has failed."""
130         log.msg("joining DHT failed miserably")
131         log.err(failure)
132         raise RuntimeError, "IP address for this machine could not be found"
133     
134     def refreshFiles(self):
135         """Refresh any files in the DHT that are about to expire."""
136         expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
137         hashes = self.db.expiredHashes(expireAfter)
138         if len(hashes.keys()) > 0:
139             log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
140         self._refreshFiles(None, hashes)
141         
142     def _refreshFiles(self, result, hashes):
143         if result is not None:
144             log.msg('Storage resulted in: %r' % result)
145
146         if hashes:
147             raw_hash = hashes.keys()[0]
148             self.db.refreshHash(raw_hash)
149             hash = HashObject(raw_hash, pieces = hashes[raw_hash]['pieces'])
150             del hashes[raw_hash]
151             storeDefer = self.store(hash)
152             storeDefer.addBoth(self._refreshFiles, hashes)
153         else:
154             reactor.callLater(60, self.refreshFiles)
155     
156     def getStats(self):
157         """Retrieve and format the statistics for the program.
158         
159         @rtype: C{string}
160         @return: the formatted HTML page containing the statistics
161         """
162         out = '<html><body>\n\n'
163         out += self.stats.formatHTML(self.my_contact)
164         out += '\n\n'
165         if IDHTStats.implementedBy(self.dhtClass):
166             out += self.dht.getStats()
167         out += '\n</body></html>\n'
168         return out
169
170     #{ Main workflow
171     def get_resp(self, req, url, orig_resp = None):
172         """Lookup a hash for the file in the local mirror info.
173         
174         Starts the process of getting a response to an apt request.
175         
176         @type req: L{twisted.web2.http.Request}
177         @param req: the initial request sent to the HTTP server by apt
178         @param url: the URI of the actual mirror request
179         @type orig_resp: L{twisted.web2.http.Response}
180         @param orig_resp: the response from the cache to be sent to apt
181             (optional, ignored if missing)
182         @rtype: L{twisted.internet.defer.Deferred}
183         @return: a deferred that will be called back with the response
184         """
185         d = defer.Deferred()
186         
187         log.msg('Trying to find hash for %s' % url)
188         findDefer = self.mirrors.findHash(unquote(url))
189         
190         findDefer.addCallbacks(self.findHash_done, self.findHash_error, 
191                                callbackArgs=(req, url, orig_resp, d),
192                                errbackArgs=(req, url, orig_resp, d))
193         findDefer.addErrback(log.err)
194         return d
195     
196     def findHash_error(self, failure, req, url, orig_resp, d):
197         """Process the error in hash lookup by returning an empty L{HashObject}."""
198         log.err(failure)
199         self.findHash_done(HashObject(), req, url, orig_resp, d)
200         
201     def findHash_done(self, hash, req, url, orig_resp, d):
202         """Use the returned hash to lookup the file in the cache.
203         
204         If the hash was not found, the workflow skips down to download from
205         the mirror (L{startDownload}), or checks the freshness of an old
206         response if there is one.
207         
208         @type hash: L{Hash.HashObject}
209         @param hash: the hash object containing the expected hash for the file
210         """
211         if hash.expected() is None:
212             log.msg('Hash for %s was not found' % url)
213             # Send the old response or get a new one
214             if orig_resp:
215                 self.check_freshness(req, url, orig_resp, d)
216             else:
217                 self.startDownload([], req, hash, url, d)
218         else:
219             log.msg('Found hash %s for %s' % (hash.hexexpected(), url))
220             
221             # Lookup hash in cache
222             locations = self.db.lookupHash(hash.expected(), filesOnly = True)
223             self.getCachedFile(hash, req, url, d, locations)
224
225     def check_freshness(self, req, url, orig_resp, d):
226         """Send a HEAD to the mirror to check if the response from the cache is still valid.
227         
228         @type req: L{twisted.web2.http.Request}
229         @param req: the initial request sent to the HTTP server by apt
230         @param url: the URI of the actual mirror request
231         @type orig_resp: L{twisted.web2.http.Response}
232         @param orig_resp: the response from the cache to be sent to apt
233         @rtype: L{twisted.internet.defer.Deferred}
234         @return: a deferred that will be called back with the correct response
235         """
236         log.msg('Checking if %s is still fresh' % url)
237         modtime = orig_resp.headers.getHeader('Last-Modified')
238         headDefer = self.peers.get(HashObject(), url, method = "HEAD",
239                                    modtime = modtime)
240         headDefer.addCallbacks(self.check_freshness_done,
241                                self.check_freshness_error,
242                                callbackArgs = (req, url, orig_resp, d),
243                                errbackArgs = (req, url, d))
244     
245     def check_freshness_done(self, resp, req, url, orig_resp, d):
246         """Return the fresh response, if stale start to redownload.
247         
248         @type resp: L{twisted.web2.http.Response}
249         @param resp: the response from the mirror to the HEAD request
250         @type req: L{twisted.web2.http.Request}
251         @param req: the initial request sent to the HTTP server by apt
252         @param url: the URI of the actual mirror request
253         @type orig_resp: L{twisted.web2.http.Response}
254         @param orig_resp: the response from the cache to be sent to apt
255         """
256         if resp.code == 304:
257             log.msg('Still fresh, returning: %s' % url)
258             d.callback(orig_resp)
259         else:
260             log.msg('Stale, need to redownload: %s' % url)
261             self.startDownload([], req, HashObject(), url, d)
262     
263     def check_freshness_error(self, err, req, url, d):
264         """Mirror request failed, continue with download.
265         
266         @param err: the response from the mirror to the HEAD request
267         @type req: L{twisted.web2.http.Request}
268         @param req: the initial request sent to the HTTP server by apt
269         @param url: the URI of the actual mirror request
270         """
271         log.err(err)
272         self.startDownload([], req, HashObject(), url, d)
273     
274     def getCachedFile(self, hash, req, url, d, locations):
275         """Try to return the file from the cache, otherwise move on to a DHT lookup.
276         
277         @type locations: C{list} of C{dictionary}
278         @param locations: the files in the cache that match the hash,
279             the dictionary contains a key 'path' whose value is a
280             L{twisted.python.filepath.FilePath} object for the file.
281         """
282         if not locations:
283             log.msg('Failed to return file from cache: %s' % url)
284             self.lookupHash(req, hash, url, d)
285             return
286         
287         # Get the first possible location from the list
288         file = locations.pop(0)['path']
289         log.msg('Returning cached file: %s' % file.path)
290         
291         # Get it's response
292         resp = static.File(file.path).renderHTTP(req)
293         if isinstance(resp, defer.Deferred):
294             resp.addBoth(self._getCachedFile, hash, req, url, d, locations)
295         else:
296             self._getCachedFile(resp, hash, req, url, d, locations)
297         
298     def _getCachedFile(self, resp, hash, req, url, d, locations):
299         """Check the returned response to be sure it is valid."""
300         if isinstance(resp, failure.Failure):
301             log.msg('Got error trying to get cached file')
302             log.err()
303             # Try the next possible location
304             self.getCachedFile(hash, req, url, d, locations)
305             return
306             
307         log.msg('Cached response: %r' % resp)
308         
309         if resp.code >= 200 and resp.code < 400:
310             d.callback(resp)
311         else:
312             # Try the next possible location
313             self.getCachedFile(hash, req, url, d, locations)
314
315     def lookupHash(self, req, hash, url, d):
316         """Lookup the hash in the DHT."""
317         log.msg('Looking up hash in DHT for file: %s' % url)
318         key = hash.expected()
319         lookupDefer = self.dht.getValue(key)
320         lookupDefer.addBoth(self.startDownload, req, hash, url, d)
321
322     def startDownload(self, values, req, hash, url, d):
323         """Start the download of the file.
324         
325         The download will be from peers if the DHT lookup succeeded, or
326         from the mirror otherwise.
327         
328         @type values: C{list} of C{dictionary}
329         @param values: the returned values from the DHT containing peer
330             download information
331         """
332         # Remove some headers Apt sets in the request
333         req.headers.removeHeader('If-Modified-Since')
334         req.headers.removeHeader('Range')
335         req.headers.removeHeader('If-Range')
336         
337         if not isinstance(values, list) or not values:
338             if not isinstance(values, list):
339                 log.msg('DHT lookup for %s failed with error %r' % (url, values))
340             else:
341                 log.msg('Peers for %s were not found' % url)
342             getDefer = self.peers.get(hash, url)
343             getDefer.addCallback(self.cache.save_file, hash, url)
344             getDefer.addErrback(self.cache.save_error, url)
345             getDefer.addCallbacks(d.callback, d.errback)
346         else:
347             log.msg('Found peers for %s: %r' % (url, values))
348             # Download from the found peers
349             getDefer = self.peers.get(hash, url, values)
350             getDefer.addCallback(self.check_response, hash, url)
351             getDefer.addCallback(self.cache.save_file, hash, url)
352             getDefer.addErrback(self.cache.save_error, url)
353             getDefer.addCallbacks(d.callback, d.errback)
354             
355     def check_response(self, response, hash, url):
356         """Check the response from peers, and download from the mirror if it is not."""
357         if response.code < 200 or response.code >= 300:
358             log.msg('Download from peers failed, going to direct download: %s' % url)
359             getDefer = self.peers.get(hash, url)
360             return getDefer
361         return response
362         
363     def new_cached_file(self, file_path, hash, new_hash, url = None, forceDHT = False):
364         """Add a newly cached file to the mirror info and/or the DHT.
365         
366         If the file was downloaded, set url to the path it was downloaded for.
367         Doesn't add a file to the DHT unless a hash was found for it
368         (but does add it anyway if forceDHT is True).
369         
370         @type file_path: L{twisted.python.filepath.FilePath}
371         @param file_path: the location of the file in the local cache
372         @type hash: L{Hash.HashObject}
373         @param hash: the original (expected) hash object containing also the
374             hash of the downloaded file
375         @type new_hash: C{boolean}
376         @param new_hash: whether the has was new to this peer, and so should
377             be added to the DHT
378         @type url: C{string}
379         @param url: the URI of the location of the file in the mirror
380             (optional, defaults to not adding the file to the mirror info)
381         @type forceDHT: C{boolean}
382         @param forceDHT: whether to force addition of the file to the DHT
383             even if the hash was not found in a mirror
384             (optional, defaults to False)
385         """
386         if url:
387             self.mirrors.updatedFile(url, file_path)
388         
389         if self.my_contact and hash and new_hash and (hash.expected() is not None or forceDHT):
390             return self.store(hash)
391         return None
392             
393     def store(self, hash):
394         """Add a key/value pair for the file to the DHT.
395         
396         Sets the key and value from the hash information, and tries to add
397         it to the DHT.
398         """
399         key = hash.digest()
400         value = {'c': self.my_contact}
401         pieces = hash.pieceDigests()
402         
403         # Determine how to store any piece data
404         if len(pieces) <= 1:
405             pass
406         elif len(pieces) <= DHT_PIECES:
407             # Short enough to be stored with our peer contact info
408             value['t'] = {'t': ''.join(pieces)}
409         elif len(pieces) <= TORRENT_PIECES:
410             # Short enough to be stored in a separate key in the DHT
411             value['h'] = sha.new(''.join(pieces)).digest()
412         else:
413             # Too long, must be served up by our peer HTTP server
414             value['l'] = sha.new(''.join(pieces)).digest()
415
416         storeDefer = self.dht.storeValue(key, value)
417         storeDefer.addCallbacks(self.store_done, self.store_error,
418                                 callbackArgs = (hash, ), errbackArgs = (hash.digest(), ))
419         return storeDefer
420
421     def store_done(self, result, hash):
422         """Add a key/value pair for the pieces of the file to the DHT (if necessary)."""
423         log.msg('Added %s to the DHT: %r' % (hash.hexdigest(), result))
424         pieces = hash.pieceDigests()
425         if len(pieces) > DHT_PIECES and len(pieces) <= TORRENT_PIECES:
426             # Add the piece data key and value to the DHT
427             key = sha.new(''.join(pieces)).digest()
428             value = {'t': ''.join(pieces)}
429
430             storeDefer = self.dht.storeValue(key, value)
431             storeDefer.addCallbacks(self.store_torrent_done, self.store_error,
432                                     callbackArgs = (key, ), errbackArgs = (key, ))
433             return storeDefer
434         return result
435
436     def store_torrent_done(self, result, key):
437         """Adding the file to the DHT is complete, and so is the workflow."""
438         log.msg('Added torrent string %s to the DHT: %r' % (b2a_hex(key), result))
439         return result
440
441     def store_error(self, err, key):
442         """Adding to the DHT failed."""
443         log.msg('An error occurred adding %s to the DHT: %r' % (b2a_hex(key), err))
444         return err
445