Use the new DB in the main code.
[quix0rs-apt-p2p.git] / apt_dht / apt_dht.py
1
2 from binascii import b2a_hex
3 from urlparse import urlunparse
4 import os, re
5
6 from twisted.internet import defer
7 from twisted.web2 import server, http, http_headers
8 from twisted.python import log
9
10 from apt_dht_conf import config
11 from PeerManager import PeerManager
12 from HTTPServer import TopLevel
13 from MirrorManager import MirrorManager
14 from Hash import HashObject
15 from db import DB
16
17 class AptDHT:
18     def __init__(self, dht):
19         log.msg('Initializing the main apt_dht application')
20         self.db = DB(config.get('DEFAULT', 'cache_dir') + '/.apt-dht.db')
21         self.dht = dht
22         self.dht.loadConfig(config, config.get('DEFAULT', 'DHT'))
23         self.dht.join().addCallbacks(self.joinComplete, self.joinError)
24         self.http_server = TopLevel(config.get('DEFAULT', 'cache_dir'), self)
25         self.http_site = server.Site(self.http_server)
26         self.peers = PeerManager()
27         self.mirrors = MirrorManager(config.get('DEFAULT', 'cache_dir'), self)
28         self.my_addr = None
29         self.isLocal = re.compile('^(192\.168\.[0-9]{1,3}\.[0-9]{1,3})|'+
30                                   '(10\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})|'+
31                                   '(172\.0?([1][6-9])|([2][0-9])|([3][0-1])\.[0-9]{1,3}\.[0-9]{1,3})|'+
32                                   '(127\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})$')
33     
34     def getSite(self):
35         return self.http_site
36
37     def joinComplete(self, addrs):
38         log.msg("got addrs: %r" % (addrs,))
39         
40         try:
41             ifconfig = os.popen("/sbin/ifconfig |/bin/grep inet|"+
42                                 "/usr/bin/awk '{print $2}' | "+
43                                 "sed -e s/.*://", "r").read().strip().split('\n')
44         except:
45             ifconfig = []
46
47         # Get counts for all the non-local addresses returned
48         addr_count = {}
49         for addr in ifconfig:
50             if not self.isLocal.match(addr):
51                 addr_count.setdefault(addr, 0)
52                 addr_count[addr] += 1
53         
54         local_addrs = addr_count.keys()    
55         if len(local_addrs) == 1:
56             self.my_addr = local_addrs[0]
57             log.msg('Found remote address from ifconfig: %r' % (self.my_addr,))
58         
59         # Get counts for all the non-local addresses returned
60         addr_count = {}
61         port_count = {}
62         for addr in addrs:
63             if not self.isLocal.match(addr[0]):
64                 addr_count.setdefault(addr[0], 0)
65                 addr_count[addr[0]] += 1
66                 port_count.setdefault(addr[1], 0)
67                 port_count[addr[1]] += 1
68         
69         # Find the most popular address
70         popular_addr = []
71         popular_count = 0
72         for addr in addr_count:
73             if addr_count[addr] > popular_count:
74                 popular_addr = [addr]
75                 popular_count = addr_count[addr]
76             elif addr_count[addr] == popular_count:
77                 popular_addr.append(addr)
78         
79         # Find the most popular port
80         popular_port = []
81         popular_count = 0
82         for port in port_count:
83             if port_count[port] > popular_count:
84                 popular_port = [port]
85                 popular_count = port_count[port]
86             elif port_count[port] == popular_count:
87                 popular_port.append(port)
88                 
89         port = config.getint(config.get('DEFAULT', 'DHT'), 'PORT')
90         if len(port_count.keys()) > 1:
91             log.msg('Problem, multiple ports have been found: %r' % (port_count,))
92             if port not in port_count.keys():
93                 log.msg('And none of the ports found match the intended one')
94         elif len(port_count.keys()) == 1:
95             port = port_count.keys()[0]
96         else:
97             log.msg('Port was not found')
98
99         if len(popular_addr) == 1:
100             log.msg('Found popular address: %r' % (popular_addr[0],))
101             if self.my_addr and self.my_addr != popular_addr[0]:
102                 log.msg('But the popular address does not match: %s != %s' % (popular_addr[0], self.my_addr))
103             self.my_addr = popular_addr[0]
104         elif len(popular_addr) > 1:
105             log.msg('Found multiple popular addresses: %r' % (popular_addr,))
106             if self.my_addr and self.my_addr not in popular_addr:
107                 log.msg('And none of the addresses found match the ifconfig one')
108         else:
109             log.msg('No non-local addresses found: %r' % (popular_addr,))
110             
111         if not self.my_addr:
112             log.err(RuntimeError("Remote IP Address could not be found for this machine"))
113
114     def ipAddrFromChicken(self):
115         import urllib
116         ip_search = re.compile('\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}')
117         try:
118              f = urllib.urlopen("http://www.ipchicken.com")
119              data = f.read()
120              f.close()
121              current_ip = ip_search.findall(data)
122              return current_ip
123         except Exception:
124              return []
125
126     def joinError(self, failure):
127         log.msg("joining DHT failed miserably")
128         log.err(failure)
129     
130     def check_freshness(self, path, modtime, resp):
131         log.msg('Checking if %s is still fresh' % path)
132         d = self.peers.get([path], "HEAD", modtime)
133         d.addCallback(self.check_freshness_done, path, resp)
134         return d
135     
136     def check_freshness_done(self, resp, path, orig_resp):
137         if resp.code == 304:
138             log.msg('Still fresh, returning: %s' % path)
139             return orig_resp
140         else:
141             log.msg('Stale, need to redownload: %s' % path)
142             return self.get_resp(path)
143     
144     def get_resp(self, path):
145         d = defer.Deferred()
146         
147         log.msg('Trying to find hash for %s' % path)
148         findDefer = self.mirrors.findHash(path)
149         
150         findDefer.addCallbacks(self.findHash_done, self.findHash_error, 
151                                callbackArgs=(path, d), errbackArgs=(path, d))
152         findDefer.addErrback(log.err)
153         return d
154     
155     def findHash_error(self, failure, path, d):
156         log.err(failure)
157         self.findHash_done(HashObject(), path, d)
158         
159     def findHash_done(self, hash, path, d):
160         if hash.expected() is None:
161             log.msg('Hash for %s was not found' % path)
162             self.lookupHash_done([], hash, path, d)
163         else:
164             log.msg('Found hash %s for %s' % (hash.hexexpected(), path))
165             # Lookup hash from DHT
166             key = hash.normexpected(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
167             lookupDefer = self.dht.getValue(key)
168             lookupDefer.addCallback(self.lookupHash_done, hash, path, d)
169             
170     def lookupHash_done(self, locations, hash, path, d):
171         if not locations:
172             log.msg('Peers for %s were not found' % path)
173             getDefer = self.peers.get([path])
174             getDefer.addCallback(self.mirrors.save_file, hash, path)
175             getDefer.addErrback(self.mirrors.save_error, path)
176             getDefer.addCallbacks(d.callback, d.errback)
177         else:
178             log.msg('Found peers for %s: %r' % (path, locations))
179             # Download from the found peers
180             getDefer = self.peers.get(locations)
181             getDefer.addCallback(self.check_response, hash, path)
182             getDefer.addCallback(self.mirrors.save_file, hash, path)
183             getDefer.addErrback(self.mirrors.save_error, path)
184             getDefer.addCallbacks(d.callback, d.errback)
185             
186     def check_response(self, response, hash, path):
187         if response.code < 200 or response.code >= 300:
188             log.msg('Download from peers failed, going to direct download: %s' % path)
189             getDefer = self.peers.get([path])
190             return getDefer
191         return response
192         
193     def cached_file(self, hash, url, file_path):
194         assert file_path.startswith(config.get('DEFAULT', 'cache_dir'))
195         urlpath, newdir = self.db.storeFile(file_path, hash.digest(), config.get('DEFAULT', 'cache_dir'))
196         log.msg('now avaliable at %s: %s' % (urlpath, url))
197
198         if self.my_addr:
199             site = self.my_addr + ':' + str(config.getint('DEFAULT', 'PORT'))
200             full_path = urlunparse(('http', site, urlpath, None, None, None))
201             key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
202             storeDefer = self.dht.storeValue(key, full_path)
203             storeDefer.addCallback(self.store_done, full_path)
204             storeDefer.addErrback(log.err)
205
206     def store_done(self, result, path):
207         log.msg('Added %s to the DHT: %r' % (path, result))
208