Reorder the main application to find cached hashes before checking freshness.
[quix0rs-apt-p2p.git] / apt_p2p / apt_p2p.py
1
2 """The main program code.
3
4 @var DHT_PIECES: the maximum number of pieces to store with our contact info
5     in the DHT
6 @var TORRENT_PIECES: the maximum number of pieces to store as a separate entry
7     in the DHT
8 @var download_dir: the name of the directory to use for downloaded files
9 @var peer_dir: the name of the directory to use for peer downloads
10 """
11
12 from binascii import b2a_hex
13 from urlparse import urlunparse
14 from urllib import unquote
15 import os, re, sha
16
17 from twisted.internet import defer, reactor, protocol
18 from twisted.web2 import server, http, http_headers, static
19 from twisted.python import log, failure
20 from twisted.python.filepath import FilePath
21
22 from interfaces import IDHT, IDHTStats
23 from apt_p2p_conf import config
24 from PeerManager import PeerManager
25 from HTTPServer import TopLevel
26 from MirrorManager import MirrorManager
27 from CacheManager import CacheManager
28 from Hash import HashObject
29 from db import DB
30 from stats import StatsLogger
31 from util import findMyIPAddr, compact
32
33 DHT_PIECES = 4
34 TORRENT_PIECES = 70
35
36 download_dir = 'cache'
37 peer_dir = 'peers'
38
39 class AptP2P(protocol.Factory):
40     """The main code object that does all of the work.
41     
42     Contains all of the sub-components that do all the low-level work, and
43     coordinates communication between them.
44     
45     @type dhtClass: L{interfaces.IDHT}
46     @ivar dhtClass: the DHT class to use
47     @type cache_dir: L{twisted.python.filepath.FilePath}
48     @ivar cache_dir: the directory to use for storing all files
49     @type db: L{db.DB}
50     @ivar db: the database to use for tracking files and hashes
51     @type dht: L{interfaces.IDHT}
52     @ivar dht: the DHT instance
53     @type stats: L{stats.StatsLogger}
54     @ivar stats: the statistics logger to record sent data to
55     @type http_server: L{HTTPServer.TopLevel}
56     @ivar http_server: the web server that will handle all requests from apt
57         and from other peers
58     @type peers: L{PeerManager.PeerManager}
59     @ivar peers: the manager of all downloads from mirrors and other peers
60     @type mirrors: L{MirrorManager.MirrorManager}
61     @ivar mirrors: the manager of downloaded information about mirrors which
62         can be queried to get hashes from file names
63     @type cache: L{CacheManager.CacheManager}
64     @ivar cache: the manager of all downloaded files
65     @type my_contact: C{string}
66     @ivar my_contact: the 6-byte compact peer representation of this peer's
67         download information (IP address and port)
68     """
69     
70     def __init__(self, dhtClass):
71         """Initialize all the sub-components.
72         
73         @type dhtClass: L{interfaces.IDHT}
74         @param dhtClass: the DHT class to use
75         """
76         log.msg('Initializing the main apt_p2p application')
77         self.dhtClass = dhtClass
78
79     #{ Factory interface
80     def startFactory(self):
81         reactor.callLater(0, self._startFactory)
82         
83     def _startFactory(self):
84         log.msg('Starting the main apt_p2p application')
85         self.cache_dir = FilePath(config.get('DEFAULT', 'CACHE_DIR'))
86         if not self.cache_dir.child(download_dir).exists():
87             self.cache_dir.child(download_dir).makedirs()
88         if not self.cache_dir.child(peer_dir).exists():
89             self.cache_dir.child(peer_dir).makedirs()
90         self.db = DB(self.cache_dir.child('apt-p2p.db'))
91         self.dht = self.dhtClass()
92         self.dht.loadConfig(config, config.get('DEFAULT', 'DHT'))
93         self.dht.join().addCallbacks(self.joinComplete, self.joinError)
94         self.stats = StatsLogger(self.db)
95         self.http_server = TopLevel(self.cache_dir.child(download_dir), self.db, self)
96         self.http_server.getHTTPFactory().startFactory()
97         self.peers = PeerManager(self.cache_dir.child(peer_dir), self.dht, self.stats)
98         self.mirrors = MirrorManager(self.cache_dir)
99         self.cache = CacheManager(self.cache_dir.child(download_dir), self.db, self)
100         self.my_contact = None
101         
102     def stopFactory(self):
103         log.msg('Stoppping the main apt_p2p application')
104         self.http_server.getHTTPFactory().stopFactory()
105         self.stats.save()
106         self.db.close()
107     
108     def buildProtocol(self, addr):
109         return self.http_server.getHTTPFactory().buildProtocol(addr)
110         
111     #{ DHT Maintenance
112     def joinComplete(self, result):
113         """Complete the DHT join process and determine our download information.
114         
115         Called by the DHT when the join has been completed with information
116         on the external IP address and port of this peer.
117         """
118         my_addr = findMyIPAddr(result,
119                                config.getint(config.get('DEFAULT', 'DHT'), 'PORT'),
120                                config.getboolean('DEFAULT', 'LOCAL_OK'))
121         if not my_addr:
122             raise RuntimeError, "IP address for this machine could not be found"
123         self.my_contact = compact(my_addr, config.getint('DEFAULT', 'PORT'))
124         self.cache.scanDirectories()
125         reactor.callLater(60, self.refreshFiles)
126
127     def joinError(self, failure):
128         """Joining the DHT has failed."""
129         log.msg("joining DHT failed miserably")
130         log.err(failure)
131         raise RuntimeError, "IP address for this machine could not be found"
132     
133     def refreshFiles(self):
134         """Refresh any files in the DHT that are about to expire."""
135         expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
136         hashes = self.db.expiredHashes(expireAfter)
137         if len(hashes.keys()) > 0:
138             log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
139         self._refreshFiles(None, hashes)
140         
141     def _refreshFiles(self, result, hashes):
142         if result is not None:
143             log.msg('Storage resulted in: %r' % result)
144
145         if hashes:
146             raw_hash = hashes.keys()[0]
147             self.db.refreshHash(raw_hash)
148             hash = HashObject(raw_hash, pieces = hashes[raw_hash]['pieces'])
149             del hashes[raw_hash]
150             storeDefer = self.store(hash)
151             storeDefer.addBoth(self._refreshFiles, hashes)
152         else:
153             reactor.callLater(60, self.refreshFiles)
154     
155     def getStats(self):
156         """Retrieve and format the statistics for the program.
157         
158         @rtype: C{string}
159         @return: the formatted HTML page containing the statistics
160         """
161         out = '<html><body>\n\n'
162         out += self.stats.formatHTML(self.my_contact)
163         out += '\n\n'
164         if IDHTStats.implementedBy(self.dhtClass):
165             out += self.dht.getStats()
166         out += '\n</body></html>\n'
167         return out
168
169     #{ Main workflow
170     def get_resp(self, req, url, orig_resp = None):
171         """Lookup a hash for the file in the local mirror info.
172         
173         Starts the process of getting a response to an apt request.
174         
175         @type req: L{twisted.web2.http.Request}
176         @param req: the initial request sent to the HTTP server by apt
177         @param url: the URI of the actual mirror request
178         @type orig_resp: L{twisted.web2.http.Response}
179         @param orig_resp: the response from the cache to be sent to apt
180             (optional, ignored if missing)
181         @rtype: L{twisted.internet.defer.Deferred}
182         @return: a deferred that will be called back with the response
183         """
184         d = defer.Deferred()
185         
186         log.msg('Trying to find hash for %s' % url)
187         findDefer = self.mirrors.findHash(unquote(url))
188         
189         findDefer.addCallbacks(self.findHash_done, self.findHash_error, 
190                                callbackArgs=(req, url, orig_resp, d),
191                                errbackArgs=(req, url, orig_resp, d))
192         findDefer.addErrback(log.err)
193         return d
194     
195     def findHash_error(self, failure, req, url, orig_resp, d):
196         """Process the error in hash lookup by returning an empty L{HashObject}."""
197         log.err(failure)
198         self.findHash_done(HashObject(), req, url, orig_resp, d)
199         
200     def findHash_done(self, hash, req, url, orig_resp, d):
201         """Use the returned hash to lookup the file in the cache.
202         
203         If the hash was not found, the workflow skips down to download from
204         the mirror (L{startDownload}), or checks the freshness of an old
205         response if there is one.
206         
207         @type hash: L{Hash.HashObject}
208         @param hash: the hash object containing the expected hash for the file
209         """
210         if hash.expected() is None:
211             log.msg('Hash for %s was not found' % url)
212             # Send the old response or get a new one
213             if orig_resp:
214                 self.check_freshness(req, url, orig_resp, d)
215             else:
216                 self.startDownload([], req, hash, url, d)
217         else:
218             log.msg('Found hash %s for %s' % (hash.hexexpected(), url))
219             
220             # Lookup hash in cache
221             locations = self.db.lookupHash(hash.expected(), filesOnly = True)
222             self.getCachedFile(hash, req, url, d, locations)
223
224     def check_freshness(self, req, url, orig_resp, d):
225         """Send a HEAD to the mirror to check if the response from the cache is still valid.
226         
227         @type req: L{twisted.web2.http.Request}
228         @param req: the initial request sent to the HTTP server by apt
229         @param url: the URI of the actual mirror request
230         @type orig_resp: L{twisted.web2.http.Response}
231         @param orig_resp: the response from the cache to be sent to apt
232         @rtype: L{twisted.internet.defer.Deferred}
233         @return: a deferred that will be called back with the correct response
234         """
235         log.msg('Checking if %s is still fresh' % url)
236         modtime = orig_resp.headers.getHeader('Last-Modified')
237         headDefer = self.peers.get(HashObject(), url, method = "HEAD",
238                                    modtime = modtime)
239         headDefer.addCallbacks(self.check_freshness_done,
240                                self.check_freshness_error,
241                                callbackArgs = (req, url, orig_resp, d),
242                                errbackArgs = (req, url, d))
243     
244     def check_freshness_done(self, resp, req, url, orig_resp, d):
245         """Return the fresh response, if stale start to redownload.
246         
247         @type resp: L{twisted.web2.http.Response}
248         @param resp: the response from the mirror to the HEAD request
249         @type req: L{twisted.web2.http.Request}
250         @param req: the initial request sent to the HTTP server by apt
251         @param url: the URI of the actual mirror request
252         @type orig_resp: L{twisted.web2.http.Response}
253         @param orig_resp: the response from the cache to be sent to apt
254         """
255         if resp.code == 304:
256             log.msg('Still fresh, returning: %s' % url)
257             d.callback(orig_resp)
258         else:
259             log.msg('Stale, need to redownload: %s' % url)
260             self.startDownload([], req, HashObject(), url, d)
261     
262     def check_freshness_error(self, err, req, url, d):
263         """Mirror request failed, continue with download.
264         
265         @param err: the response from the mirror to the HEAD request
266         @type req: L{twisted.web2.http.Request}
267         @param req: the initial request sent to the HTTP server by apt
268         @param url: the URI of the actual mirror request
269         """
270         log.err(err)
271         self.startDownload([], req, HashObject(), url, d)
272     
273     def getCachedFile(self, hash, req, url, d, locations):
274         """Try to return the file from the cache, otherwise move on to a DHT lookup.
275         
276         @type locations: C{list} of C{dictionary}
277         @param locations: the files in the cache that match the hash,
278             the dictionary contains a key 'path' whose value is a
279             L{twisted.python.filepath.FilePath} object for the file.
280         """
281         if not locations:
282             log.msg('Failed to return file from cache: %s' % url)
283             self.lookupHash(req, hash, url, d)
284             return
285         
286         # Get the first possible location from the list
287         file = locations.pop(0)['path']
288         log.msg('Returning cached file: %s' % file.path)
289         
290         # Get it's response
291         resp = static.File(file.path).renderHTTP(req)
292         if isinstance(resp, defer.Deferred):
293             resp.addBoth(self._getCachedFile, hash, req, url, d, locations)
294         else:
295             self._getCachedFile(resp, hash, req, url, d, locations)
296         
297     def _getCachedFile(self, resp, hash, req, url, d, locations):
298         """Check the returned response to be sure it is valid."""
299         if isinstance(resp, failure.Failure):
300             log.msg('Got error trying to get cached file')
301             log.err()
302             # Try the next possible location
303             self.getCachedFile(hash, req, url, d, locations)
304             return
305             
306         log.msg('Cached response: %r' % resp)
307         
308         if resp.code >= 200 and resp.code < 400:
309             d.callback(resp)
310         else:
311             # Try the next possible location
312             self.getCachedFile(hash, req, url, d, locations)
313
314     def lookupHash(self, req, hash, url, d):
315         """Lookup the hash in the DHT."""
316         log.msg('Looking up hash in DHT for file: %s' % url)
317         key = hash.expected()
318         lookupDefer = self.dht.getValue(key)
319         lookupDefer.addBoth(self.startDownload, req, hash, url, d)
320
321     def startDownload(self, values, req, hash, url, d):
322         """Start the download of the file.
323         
324         The download will be from peers if the DHT lookup succeeded, or
325         from the mirror otherwise.
326         
327         @type values: C{list} of C{dictionary}
328         @param values: the returned values from the DHT containing peer
329             download information
330         """
331         # Remove some headers Apt sets in the request
332         req.headers.removeHeader('If-Modified-Since')
333         req.headers.removeHeader('Range')
334         req.headers.removeHeader('If-Range')
335         
336         if not isinstance(values, list) or not values:
337             if not isinstance(values, list):
338                 log.msg('DHT lookup for %s failed with error %r' % (url, values))
339             else:
340                 log.msg('Peers for %s were not found' % url)
341             getDefer = self.peers.get(hash, url)
342             getDefer.addCallback(self.cache.save_file, hash, url)
343             getDefer.addErrback(self.cache.save_error, url)
344             getDefer.addCallbacks(d.callback, d.errback)
345         else:
346             log.msg('Found peers for %s: %r' % (url, values))
347             # Download from the found peers
348             getDefer = self.peers.get(hash, url, values)
349             getDefer.addCallback(self.check_response, hash, url)
350             getDefer.addCallback(self.cache.save_file, hash, url)
351             getDefer.addErrback(self.cache.save_error, url)
352             getDefer.addCallbacks(d.callback, d.errback)
353             
354     def check_response(self, response, hash, url):
355         """Check the response from peers, and download from the mirror if it is not."""
356         if response.code < 200 or response.code >= 300:
357             log.msg('Download from peers failed, going to direct download: %s' % url)
358             getDefer = self.peers.get(hash, url)
359             return getDefer
360         return response
361         
362     def new_cached_file(self, file_path, hash, new_hash, url = None, forceDHT = False):
363         """Add a newly cached file to the mirror info and/or the DHT.
364         
365         If the file was downloaded, set url to the path it was downloaded for.
366         Doesn't add a file to the DHT unless a hash was found for it
367         (but does add it anyway if forceDHT is True).
368         
369         @type file_path: L{twisted.python.filepath.FilePath}
370         @param file_path: the location of the file in the local cache
371         @type hash: L{Hash.HashObject}
372         @param hash: the original (expected) hash object containing also the
373             hash of the downloaded file
374         @type new_hash: C{boolean}
375         @param new_hash: whether the has was new to this peer, and so should
376             be added to the DHT
377         @type url: C{string}
378         @param url: the URI of the location of the file in the mirror
379             (optional, defaults to not adding the file to the mirror info)
380         @type forceDHT: C{boolean}
381         @param forceDHT: whether to force addition of the file to the DHT
382             even if the hash was not found in a mirror
383             (optional, defaults to False)
384         """
385         if url:
386             self.mirrors.updatedFile(url, file_path)
387         
388         if self.my_contact and hash and new_hash and (hash.expected() is not None or forceDHT):
389             return self.store(hash)
390         return None
391             
392     def store(self, hash):
393         """Add a key/value pair for the file to the DHT.
394         
395         Sets the key and value from the hash information, and tries to add
396         it to the DHT.
397         """
398         key = hash.digest()
399         value = {'c': self.my_contact}
400         pieces = hash.pieceDigests()
401         
402         # Determine how to store any piece data
403         if len(pieces) <= 1:
404             pass
405         elif len(pieces) <= DHT_PIECES:
406             # Short enough to be stored with our peer contact info
407             value['t'] = {'t': ''.join(pieces)}
408         elif len(pieces) <= TORRENT_PIECES:
409             # Short enough to be stored in a separate key in the DHT
410             value['h'] = sha.new(''.join(pieces)).digest()
411         else:
412             # Too long, must be served up by our peer HTTP server
413             value['l'] = sha.new(''.join(pieces)).digest()
414
415         storeDefer = self.dht.storeValue(key, value)
416         storeDefer.addCallbacks(self.store_done, self.store_error,
417                                 callbackArgs = (hash, ), errbackArgs = (hash.digest(), ))
418         return storeDefer
419
420     def store_done(self, result, hash):
421         """Add a key/value pair for the pieces of the file to the DHT (if necessary)."""
422         log.msg('Added %s to the DHT: %r' % (hash.hexdigest(), result))
423         pieces = hash.pieceDigests()
424         if len(pieces) > DHT_PIECES and len(pieces) <= TORRENT_PIECES:
425             # Add the piece data key and value to the DHT
426             key = sha.new(''.join(pieces)).digest()
427             value = {'t': ''.join(pieces)}
428
429             storeDefer = self.dht.storeValue(key, value)
430             storeDefer.addCallbacks(self.store_torrent_done, self.store_error,
431                                     callbackArgs = (key, ), errbackArgs = (key, ))
432             return storeDefer
433         return result
434
435     def store_torrent_done(self, result, key):
436         """Adding the file to the DHT is complete, and so is the workflow."""
437         log.msg('Added torrent string %s to the DHT: %r' % (b2a_hex(key), result))
438         return result
439
440     def store_error(self, err, key):
441         """Adding to the DHT failed."""
442         log.msg('An error occurred adding %s to the DHT: %r' % (b2a_hex(key), err))
443         return err
444