]> git.mxchange.org Git - fba.git/commitdiff
WIP(?):
authorRoland Häder <roland@mxchange.org>
Sat, 10 Jun 2023 17:32:52 +0000 (19:32 +0200)
committerRoland Häder <roland@mxchange.org>
Sat, 10 Jun 2023 17:32:52 +0000 (19:32 +0200)
- moved csrf.determine() of of central functions, it was causing to much CSRF
  checks and slowed down all

fba/federation.py
fba/instances.py
fba/locking.py
fba/network.py
fba/networks/lemmy.py
fba/networks/mastodon.py
fba/networks/misskey.py
fba/networks/peertube.py

index 77fd00a1df7bb476c311a26f02381ac2f4bc09dc..92669c50cea2356f67e0ec8d0a1ba47889946154 100644 (file)
@@ -18,6 +18,7 @@ import validators
 
 from fba import blacklist
 from fba import config
+from fba import csrf
 from fba import fba
 from fba import instances
 from fba import network
@@ -119,10 +120,14 @@ def fetch_peers(domain: str, software: str) -> list:
         # DEBUG: print(f"DEBUG: Invoking peertube.fetch_peers({domain}) ...")
         return peertube.fetch_peers(domain)
 
+    # DEBUG: print(f"DEBUG: Checking CSRF for domain='{domain}'")
+    headers = csrf.determine(domain, dict())
+
     # DEBUG: print(f"DEBUG: Fetching peers from '{domain}',software='{software}' ...")
     data = network.get_json_api(
         domain,
         "/api/v1/instance/peers",
+        headers,
         (config.get("connection_timeout"), config.get("read_timeout"))
     )
     # DEBUG: print(f"DEBUG: data[]='{type(data)}'")
@@ -175,6 +180,9 @@ def fetch_nodeinfo(domain: str, path: str = None) -> list:
         # DEBUG: print("DEBUG: nodeinfo()={len(nodeinfo))} - EXIT!")
         return nodeinfo
 
+    # DEBUG: print(f"DEBUG: Checking CSRF for domain='{domain}'")
+    headers = csrf.determine(domain, dict())
+
     request_paths = [
        "/nodeinfo/2.1.json",
        "/nodeinfo/2.1",
@@ -193,6 +201,7 @@ def fetch_nodeinfo(domain: str, path: str = None) -> list:
         data = network.get_json_api(
             domain,
             request,
+            headers,
             (config.get("nodeinfo_connection_timeout"), config.get("nodeinfo_read_timeout"))
         )
 
@@ -215,10 +224,14 @@ def fetch_wellknown_nodeinfo(domain: str) -> list:
     elif domain == "":
         raise ValueError("Parameter 'domain' is empty")
 
+    # DEBUG: print(f"DEBUG: Checking CSRF for domain='{domain}'")
+    headers = csrf.determine(domain, dict())
+
     # DEBUG: print("DEBUG: Fetching .well-known info for domain:", domain)
     data = network.get_json_api(
         domain,
         "/.well-known/nodeinfo",
+        headers,
         (config.get("nodeinfo_connection_timeout"), config.get("nodeinfo_read_timeout"))
     )
 
index 27ea9d39c3d0c5886f04df41ce22aa9bb2941212..7e3d2012949f18fbef5c0c5da28842693493fd71 100644 (file)
@@ -49,8 +49,6 @@ _pending = {
     "last_status_code"   : {},
     # Last error details
     "last_error_details" : {},
-    # Whether CSRF tokens are present
-    "has_csrf"           : {},
 }
 
 def set_data(key: str, domain: str, value: any):
@@ -258,37 +256,37 @@ def update_last_nodeinfo(domain: str):
     # DEBUG: print("DEBUG: EXIT!")
 
 def update_last_error(domain: str, error: dict):
-    print("DEBUG: domain,error[]:", domain, type(error))
+    # DEBUG: print("DEBUG: domain,error[]:", domain, type(error))
     if not isinstance(domain, str):
         raise ValueError(f"Parameter domain[]={type(domain)} is not 'str'")
     elif domain == "":
         raise ValueError("Parameter 'domain' is empty")
 
-    print("DEBUG: BEFORE error[]:", type(error))
+    # DEBUG: print("DEBUG: BEFORE error[]:", type(error))
     if isinstance(error, BaseException) or isinstance(error, json.decoder.JSONDecodeError):
         error = f"error[{type(error)}]='{str(error)}'"
-    print("DEBUG: AFTER error[]:", type(error))
+    # DEBUG: print("DEBUG: AFTER error[]:", type(error))
 
     if isinstance(error, str):
-        print(f"DEBUG: Setting last_error_details='{error}'")
+        # DEBUG: print(f"DEBUG: Setting last_error_details='{error}'")
         set_data("last_status_code"  , domain, 999)
         set_data("last_error_details", domain, error)
     elif isinstance(error, requests.models.Response):
-        print(f"DEBUG: Setting last_error_details='{error.reason}'")
+        # DEBUG: print(f"DEBUG: Setting last_error_details='{error.reason}'")
         set_data("last_status_code"  , domain, error.status_code)
         set_data("last_error_details", domain, error.reason)
     else:
-        print(f"DEBUG: Setting last_error_details='{error['error_message']}'")
+        # DEBUG: print(f"DEBUG: Setting last_error_details='{error['error_message']}'")
         set_data("last_status_code"  , domain, error["status_code"])
         set_data("last_error_details", domain, error["error_message"])
 
     # Running pending updated
-    print(f"DEBUG: Invoking update_data({domain}) ...")
+    # DEBUG: print(f"DEBUG: Invoking update_data({domain}) ...")
     update_data(domain)
 
     fba.log_error(domain, error)
 
-    print("DEBUG: EXIT!")
+    # DEBUG: print("DEBUG: EXIT!")
 
 def is_registered(domain: str) -> bool:
     # DEBUG: print(f"DEBUG: domain='{domain}' - CALLED!")
index ddcbf7c2c88b289d7339bd130824780d916fc590..27f5eac10232046e8b8e168521423ede06f5ea52 100644 (file)
@@ -25,25 +25,25 @@ LOCK = None
 
 def acquire():
     global LOCK
-    print("DEBUG: CALLED!")
+    # DEBUG: print("DEBUG: CALLED!")
 
     try:
-        print(f"DEBUG: Acquiring lock: '{lockfile}'")
+        # DEBUG: print(f"DEBUG: Acquiring lock: '{lockfile}'")
         LOCK = zc.lockfile.LockFile(lockfile)
-        print("DEBUG: Lock obtained.")
+        # DEBUG: print("DEBUG: Lock obtained.")
 
     except zc.lockfile.LockError:
         print(f"ERROR: Cannot aquire lock: '{lockfile}'")
         sys.exit(100)
 
-    print("DEBUG: EXIT!")
+    # DEBUG: print("DEBUG: EXIT!")
 
 def release():
-    print("DEBUG: CALLED!")
+    # DEBUG: print("DEBUG: CALLED!")
     if LOCK is not None:
-        print("DEBUG: Releasing lock ...")
+        # DEBUG: print("DEBUG: Releasing lock ...")
         LOCK.close()
-        print(f"DEBUG: Deleting lockfile='{lockfile}' ...")
+        # DEBUG: print(f"DEBUG: Deleting lockfile='{lockfile}' ...")
         os.remove(lockfile)
 
-    print("DEBUG: EXIT!")
+    # DEBUG: print("DEBUG: EXIT!")
index e9fd9c956b81d6409a0a168d4756fd96466a2a8e..ec7e4bce4ab26510cf518976d4471f7c55a090a5 100644 (file)
@@ -19,7 +19,6 @@ import reqto
 import requests
 
 from fba import config
-from fba import csrf
 from fba import fba
 from fba import instances
 
@@ -49,9 +48,6 @@ def post_json_api(domain: str, path: str, data: str, headers: dict = {}) -> dict
     elif not isinstance(headers, dict):
         raise ValueError(f"headers[]={type(headers)} is not 'list'")
 
-    # DEBUG: print(f"DEBUG: Determining if CSRF header needs to be sent for domain='{domain}' ...")
-    headers = csrf.determine(domain, {**api_headers, **headers})
-
     json_reply = {
         "status_code": 200,
     }
@@ -61,7 +57,7 @@ def post_json_api(domain: str, path: str, data: str, headers: dict = {}) -> dict
         response = reqto.post(
             f"https://{domain}{path}",
             data=data,
-            headers=headers,
+            headers={**api_headers, **headers},
             timeout=(config.get("connection_timeout"), config.get("read_timeout"))
         )
 
@@ -116,8 +112,8 @@ def fetch_api_url(url: str, timeout: tuple) -> dict:
     # DEBUG: print(f"DEBUG: Returning json_reply({len(json_reply)})=[]:{type(json_reply)}")
     return json_reply
 
-def get_json_api(domain: str, path: str, timeout: tuple) -> dict:
-    # DEBUG: print(f"DEBUG: domain='{domain}',path='{path}',data='{data}',timeout()={len(timeout)} - CALLED!")
+def get_json_api(domain: str, path: str, headers: dict, timeout: tuple) -> dict:
+    # DEBUG: print(f"DEBUG: domain='{domain}',path='{path}',timeout()={len(timeout)} - CALLED!")
     if not isinstance(domain, str):
         raise ValueError(f"Parameter domain[]={type(domain)} is not 'str'")
     elif domain == "":
@@ -126,12 +122,11 @@ def get_json_api(domain: str, path: str, timeout: tuple) -> dict:
         raise ValueError(f"path[]={type(path)} is not 'str'")
     elif path == "":
         raise ValueError("Parameter 'path' cannot be empty")
+    elif not isinstance(headers, dict):
+        raise ValueError(f"headers[]={type(headers)} is not 'list'")
     elif not isinstance(timeout, tuple):
         raise ValueError(f"timeout[]={type(timeout)} is not 'tuple'")
 
-    # DEBUG: print(f"DEBUG: Determining if CSRF header needs to be sent for domain='{domain}' ...")
-    headers = csrf.determine(domain, api_headers)
-
     json_reply = {
         "status_code": 200,
     }
@@ -140,7 +135,7 @@ def get_json_api(domain: str, path: str, timeout: tuple) -> dict:
         # DEBUG: print(f"DEBUG: Sending GET to domain='{domain}',path='{path}',timeout({len(timeout)})={timeout}")
         response = reqto.get(
             f"https://{domain}{path}",
-            headers=headers,
+            headers={**api_headers, **headers},
             timeout=timeout
         )
 
index 5963227fbb80e1c16cbf86822899826e6edf7c1f..df63ae0e6d0765bac4f1ac601fb127df2f03a313 100644 (file)
@@ -15,6 +15,7 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
 from fba import config
+from fba import csrf
 from fba import federation
 from fba import instances
 from fba import network
@@ -26,12 +27,16 @@ def fetch_peers(domain: str) -> list:
     elif domain == "":
         raise ValueError("Parameter 'domain' is empty")
 
+    print(f"DEBUG: Checking CSRF for domain='{domain}'")
+    headers = csrf.determine(domain, dict())
+
     peers = list()
     try:
         # DEBUG: print(f"DEBUG: domain='{domain}' is Lemmy, fetching JSON ...")
         data = network.get_json_api(
             domain,
             "/api/v3/site",
+            headers,
             (config.get("connection_timeout"), config.get("read_timeout"))
         )
 
index 7a1f3eb6b07b9c16e3867ffa6d87e2cd95a0a7c8..6b1b967f7418a97380faa2953d23b69324943b8c 100644 (file)
@@ -23,6 +23,7 @@ import validators
 from fba import blacklist
 from fba import blocks
 from fba import config
+from fba import csrf
 from fba import fba
 from fba import instances
 from fba import network
@@ -129,6 +130,9 @@ def fetch_blocks(domain: str, origin: str, nodeinfo_url: str):
     elif nodeinfo_url == "":
         raise ValueError("Parameter 'nodeinfo_url' is empty")
 
+    print(f"DEBUG: Checking CSRF for domain='{domain}'")
+    headers = csrf.determine(domain, dict())
+
     try:
         # json endpoint for newer mastodongs
         blockdict = list()
@@ -143,6 +147,7 @@ def fetch_blocks(domain: str, origin: str, nodeinfo_url: str):
         data = network.get_json_api(
             domain,
             "/api/v1/instance/domain_blocks",
+            headers,
             (config.get("connection_timeout"), config.get("read_timeout"))
         )
 
index 62aee29757ea38e8cc5bf9d5be4d5697300401a0..52c32e3e53aa15985c11f13365b7828e9bde7685 100644 (file)
@@ -20,6 +20,7 @@ import requests
 
 from fba import blacklist
 from fba import config
+from fba import csrf
 from fba import instances
 from fba import network
 
@@ -34,9 +35,10 @@ def fetch_peers(domain: str) -> list:
         raise ValueError("Parameter 'domain' is empty")
 
     # DEBUG: print(f"DEBUG: domain='{domain}' is misskey, sending API POST request ...")
-    peers = list()
-    offset = 0
-    step = config.get("misskey_limit")
+    peers   = list()
+    offset  = 0
+    step    = config.get("misskey_limit")
+    headers = csrf.determine(domain, {"Origin": domain})
 
     # iterating through all "suspended" (follow-only in its terminology)
     # instances page-by-page, since that troonware doesn't support
@@ -48,18 +50,14 @@ def fetch_peers(domain: str) -> list:
                 "sort" : "+pubAt",
                 "host" : None,
                 "limit": step
-            }), {
-                "Origin": domain
-            })
+            }), headers)
         else:
             fetched = network.post_json_api(domain, "/api/federation/instances", json.dumps({
                 "sort"  : "+pubAt",
                 "host"  : None,
                 "limit" : step,
                 "offset": offset - 1
-            }), {
-                "Origin": domain
-            })
+            }), headers)
 
         # DEBUG: print(f"DEBUG: fetched()={len(fetched)}")
         if len(fetched) == 0:
@@ -123,14 +121,16 @@ def fetch_blocks(domain: str) -> dict:
     elif domain == "":
         raise ValueError("Parameter 'domain' is empty")
 
-    # DEBUG: print("DEBUG: Fetching misskey blocks from domain:", domain)
+    # DEBUG: print(f"DEBUG: Fetching misskey blocks from domain={domain}")
     blocklist = {
         "suspended": [],
         "blocked"  : []
     }
 
-    offset = 0
-    step = config.get("misskey_limit")
+    offset  = 0
+    step    = config.get("misskey_limit")
+    headers = csrf.determine(domain, {"Origin": domain})
+
     while True:
         # iterating through all "suspended" (follow-only in its terminology)
         # instances page-by-page, since that troonware doesn't support
@@ -144,9 +144,7 @@ def fetch_blocks(domain: str) -> dict:
                     "host"     : None,
                     "suspended": True,
                     "limit"    : step
-                }), {
-                    "Origin": domain
-                })
+                }), headers)
             else:
                 # DEBUG: print("DEBUG: Sending JSON API request to domain,step,offset:", domain, step, offset)
                 fetched = network.post_json_api(domain, "/api/federation/instances", json.dumps({
@@ -155,9 +153,7 @@ def fetch_blocks(domain: str) -> dict:
                     "suspended": True,
                     "limit"    : step,
                     "offset"   : offset - 1
-                }), {
-                    "Origin": domain
-                })
+                }), headers)
 
             # DEBUG: print("DEBUG: fetched():", len(fetched))
             if len(fetched) == 0:
@@ -199,25 +195,21 @@ def fetch_blocks(domain: str) -> dict:
         try:
             if offset == 0:
                 # DEBUG: print("DEBUG: Sending JSON API request to domain,step,offset:", domain, step, offset)
-                fetched = network.post_json_api(domain, "/api/federation/instances", json.dumps({
+                fetched = network.post_json_api(domain, "/api/federation/instances", headers, json.dumps({
                     "sort"   : "+pubAt",
                     "host"   : None,
                     "blocked": True,
                     "limit"  : step
-                }), {
-                    "Origin": domain
-                })
+                }))
             else:
                 # DEBUG: print("DEBUG: Sending JSON API request to domain,step,offset:", domain, step, offset)
-                fetched = network.post_json_api(domain, "/api/federation/instances", json.dumps({
+                fetched = network.post_json_api(domain, "/api/federation/instances", headers, json.dumps({
                     "sort"   : "+pubAt",
                     "host"   : None,
                     "blocked": True,
                     "limit"  : step,
                     "offset" : offset - 1
-                }), {
-                    "Origin": domain
-                })
+                }))
 
             # DEBUG: print("DEBUG: fetched():", len(fetched))
             if len(fetched) == 0:
index 83640da5eebc24d8714627008e432d9e7cad7098..96fc880b0234ad303ae8e5a2636e439747c7b2fa 100644 (file)
@@ -15,6 +15,7 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
 from fba import config
+from fba import csrf
 from fba import instances
 from fba import network
 
@@ -28,12 +29,17 @@ def fetch_peers(domain: str) -> list:
     print(f"DEBUG: domain='{domain}' is a PeerTube, fetching JSON ...")
     peers = list()
     start = 0
+
+    print(f"DEBUG: Checking CSRF for domain='{domain}'")
+    headers = csrf.determine(domain, dict())
+
     for mode in ["followers", "following"]:
         print(f"DEBUG: domain='{domain}',mode='{mode}'")
         while True:
             data = network.get_json_api(
                 domain,
                 "/api/v1/server/{mode}?start={start}&count=100",
+                headers,
                 (config.get("connection_timeout"), config.get("read_timeout"))
             )