Main maigret function refactoring

This commit is contained in:
Soxoj
2021-05-05 18:02:13 +03:00
parent 2fb1f19948
commit 3cbb9df7b3
6 changed files with 115 additions and 85 deletions
+1 -1
View File
@@ -37,7 +37,7 @@ SUPPORTED_IDS = (
"uidme_uguid", "uidme_uguid",
) )
unsupported_characters = "#" BAD_CHARS = "#"
async def get_response(request_future, logger) -> Tuple[str, int, Optional[CheckError]]: async def get_response(request_future, logger) -> Tuple[str, int, Optional[CheckError]]:
+3 -2
View File
@@ -1,6 +1,7 @@
from typing import Dict, List, Any from typing import Dict, List, Any
from .result import QueryResult from .result import QueryResult
from .types import QueryResultWrapper
# error got as a result of completed search query # error got as a result of completed search query
@@ -104,9 +105,9 @@ def solution_of(err_type) -> str:
return ERRORS_TYPES.get(err_type, '') return ERRORS_TYPES.get(err_type, '')
def extract_and_group(search_res: dict) -> List[Dict[str, Any]]: def extract_and_group(search_res: QueryResultWrapper) -> List[Dict[str, Any]]:
errors_counts: Dict[str, int] = {} errors_counts: Dict[str, int] = {}
for r in search_res: for r in search_res.values():
if r and isinstance(r, dict) and r.get('status'): if r and isinstance(r, dict) and r.get('status'):
if not isinstance(r['status'], QueryResult): if not isinstance(r['status'], QueryResult):
continue continue
+72 -61
View File
@@ -8,6 +8,7 @@ import os
import sys import sys
import platform import platform
from argparse import ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentParser, RawDescriptionHelpFormatter
from typing import List, Tuple
import requests import requests
from socid_extractor import extract, parse, __version__ as socid_version from socid_extractor import extract, parse, __version__ as socid_version
@@ -16,7 +17,7 @@ from .checking import (
timeout_check, timeout_check,
SUPPORTED_IDS, SUPPORTED_IDS,
self_check, self_check,
unsupported_characters, BAD_CHARS,
maigret, maigret,
) )
from . import errors from . import errors
@@ -33,13 +34,14 @@ from .report import (
) )
from .sites import MaigretDatabase from .sites import MaigretDatabase
from .submit import submit_dialog from .submit import submit_dialog
from .types import QueryResultWrapper
from .utils import get_dict_ascii_tree from .utils import get_dict_ascii_tree
__version__ = '0.2.1' __version__ = '0.2.1'
def notify_about_errors(search_results, query_notify): def notify_about_errors(search_results: QueryResultWrapper, query_notify):
errs = errors.extract_and_group(search_results.values()) errs = errors.extract_and_group(search_results)
was_errs_displayed = False was_errs_displayed = False
for e in errs: for e in errs:
if not errors.is_important(e): if not errors.is_important(e):
@@ -58,6 +60,58 @@ def notify_about_errors(search_results, query_notify):
) )
def extract_ids_from_page(url, logger, timeout=5) -> dict:
results = {}
# url, headers
reqs: List[Tuple[str, set]] = [(url, set())]
try:
# temporary workaround for URL mutations MVP
from socid_extractor import mutate_url
reqs += list(mutate_url(url))
except Exception as e:
logger.warning(e)
for req in reqs:
url, headers = req
print(f'Scanning webpage by URL {url}...')
page, _ = parse(url, cookies_str='', headers=headers, timeout=timeout)
logger.debug(page)
info = extract(page)
if not info:
print('Nothing extracted')
else:
print(get_dict_ascii_tree(info.items(), new_line=False), ' ')
for k, v in info.items():
if 'username' in k:
results[v] = 'username'
if k in SUPPORTED_IDS:
results[v] = k
return results
def extract_ids_from_results(results: QueryResultWrapper, db: MaigretDatabase) -> dict:
ids_results = {}
for website_name in results:
dictionary = results[website_name]
# TODO: fix no site data issue
if not dictionary:
continue
new_usernames = dictionary.get('ids_usernames')
if new_usernames:
for u, utype in new_usernames.items():
ids_results[u] = utype
for url in dictionary.get('ids_links', []):
for s in db.sites:
u = s.detect_username(url)
if u:
ids_results[u] = 'username'
return ids_results
def setup_arguments_parser(): def setup_arguments_parser():
version_string = '\n'.join( version_string = '\n'.join(
[ [
@@ -392,31 +446,8 @@ async def main():
print("Using the proxy: " + args.proxy) print("Using the proxy: " + args.proxy)
if args.parse_url: if args.parse_url:
# url, headers extracted_ids = extract_ids_from_page(args.parse_url, logger, timeout=args.timeout)
reqs = [(args.parse_url, set())] usernames.update(extracted_ids)
try:
# temporary workaround for URL mutations MVP
from socid_extractor import mutate_url
reqs += list(mutate_url(args.parse_url))
except Exception as e:
logger.warning(e)
pass
for req in reqs:
url, headers = req
print(f'Scanning webpage by URL {url}...')
page, _ = parse(url, cookies_str='', headers=headers)
info = extract(page)
if not info:
print('Nothing extracted')
else:
print(get_dict_ascii_tree(info.items(), new_line=False), ' ')
for k, v in info.items():
if 'username' in k:
usernames[v] = 'username'
if k in SUPPORTED_IDS:
usernames[v] = k
if args.tags: if args.tags:
args.tags = list(set(str(args.tags).split(','))) args.tags = list(set(str(args.tags).split(',')))
@@ -471,6 +502,7 @@ async def main():
print('Updates will be applied only for current search session.') print('Updates will be applied only for current search session.')
print(db.get_scan_stats(site_data)) print(db.get_scan_stats(site_data))
# Database statistics
if args.stats: if args.stats:
print(db.get_db_stats(db.sites_dict)) print(db.get_db_stats(db.sites_dict))
@@ -480,11 +512,6 @@ async def main():
# Define one report filename template # Define one report filename template
report_filepath_tpl = os.path.join(args.folderoutput, 'report_{username}{postfix}') report_filepath_tpl = os.path.join(args.folderoutput, 'report_{username}{postfix}')
# Database stats
# TODO: verbose info about filtered sites
# enabled_count = len(list(filter(lambda x: not x.disabled, site_data.values())))
# print(f'Sites in database, enabled/total: {enabled_count}/{len(site_data)}')
if usernames == {}: if usernames == {}:
# magic params to exit after init # magic params to exit after init
query_notify.warning('No usernames to check, exiting.') query_notify.warning('No usernames to check, exiting.')
@@ -493,14 +520,14 @@ async def main():
if not site_data: if not site_data:
query_notify.warning('No sites to check, exiting!') query_notify.warning('No sites to check, exiting!')
sys.exit(2) sys.exit(2)
else:
query_notify.warning(
f'Starting a search on top {len(site_data)} sites from the Maigret database...'
)
if not args.all_sites:
query_notify.warning( query_notify.warning(
f'Starting a search on top {len(site_data)} sites from the Maigret database...' 'You can run search by full list of sites with flag `-a`', '!'
) )
if not args.all_sites:
query_notify.warning(
'You can run search by full list of sites with flag `-a`', '!'
)
already_checked = set() already_checked = set()
general_results = [] general_results = []
@@ -511,8 +538,8 @@ async def main():
if username.lower() in already_checked: if username.lower() in already_checked:
continue continue
else:
already_checked.add(username.lower()) already_checked.add(username.lower())
if username in args.ignore_ids_list: if username in args.ignore_ids_list:
query_notify.warning( query_notify.warning(
@@ -521,10 +548,7 @@ async def main():
continue continue
# check for characters do not supported by sites generally # check for characters do not supported by sites generally
found_unsupported_chars = set(unsupported_characters).intersection( found_unsupported_chars = set(BAD_CHARS).intersection(set(username))
set(username)
)
if found_unsupported_chars: if found_unsupported_chars:
pretty_chars_str = ','.join( pretty_chars_str = ','.join(
map(lambda s: f'"{s}"', found_unsupported_chars) map(lambda s: f'"{s}"', found_unsupported_chars)
@@ -558,22 +582,9 @@ async def main():
general_results.append((username, id_type, results)) general_results.append((username, id_type, results))
# TODO: tests # TODO: tests
for website_name in results: if recursive_search_enabled:
dictionary = results[website_name] extracted_ids = extract_ids_from_results(results, db)
# TODO: fix no site data issue usernames.update(extracted_ids)
if not dictionary or not recursive_search_enabled:
continue
new_usernames = dictionary.get('ids_usernames')
if new_usernames:
for u, utype in new_usernames.items():
usernames[u] = utype
for url in dictionary.get('ids_links', []):
for s in db.sites:
u = s.detect_username(url)
if u:
usernames[u] = 'username'
# reporting for a one username # reporting for a one username
if args.xmind: if args.xmind:
-1
View File
@@ -3,7 +3,6 @@ import io
import json import json
import logging import logging
import os import os
from argparse import ArgumentTypeError
from datetime import datetime from datetime import datetime
from typing import Dict, Any from typing import Dict, Any
+1 -1
View File
@@ -12,7 +12,7 @@ from maigret.maigret import setup_arguments_parser
CUR_PATH = os.path.dirname(os.path.realpath(__file__)) CUR_PATH = os.path.dirname(os.path.realpath(__file__))
JSON_FILE = os.path.join(CUR_PATH, '../maigret/resources/data.json') JSON_FILE = os.path.join(CUR_PATH, '../maigret/resources/data.json')
TEST_JSON_FILE = os.path.join(CUR_PATH, 'db.json') TEST_JSON_FILE = os.path.join(CUR_PATH, 'db.json')
empty_mark = Mark('', [], {}) empty_mark = Mark('', (), {})
def by_slow_marker(item): def by_slow_marker(item):
+38 -19
View File
@@ -4,11 +4,31 @@ import asyncio
import pytest import pytest
from mock import Mock from mock import Mock
from maigret.maigret import self_check, maigret from maigret.maigret import self_check, maigret, extract_ids_from_page, extract_ids_from_results
from maigret.sites import MaigretSite from maigret.sites import MaigretSite
from maigret.result import QueryResult, QueryStatus from maigret.result import QueryResult, QueryStatus
RESULTS_EXAMPLE = {
'Reddit': {
'cookies': None,
'parsing_enabled': False,
'url_main': 'https://www.reddit.com/',
'username': 'Facebook',
},
'GooglePlayStore': {
'cookies': None,
'http_status': 200,
'is_similar': False,
'parsing_enabled': False,
'rank': 1,
'url_main': 'https://play.google.com/store',
'url_user': 'https://play.google.com/store/apps/developer?id=Facebook',
'username': 'Facebook',
},
}
@pytest.mark.slow @pytest.mark.slow
def test_self_check_db_positive_disable(test_db): def test_self_check_db_positive_disable(test_db):
logger = Mock() logger = Mock()
@@ -113,21 +133,20 @@ def test_maigret_results(test_db):
assert results['Reddit'].get('future') is None assert results['Reddit'].get('future') is None
del results['GooglePlayStore']['future'] del results['GooglePlayStore']['future']
assert results == { assert results == RESULTS_EXAMPLE
'Reddit': {
'cookies': None,
'parsing_enabled': False, @pytest.mark.slow
'url_main': 'https://www.reddit.com/', def test_extract_ids_from_page(test_db):
'username': 'Facebook', logger = Mock()
}, found_ids = extract_ids_from_page('https://www.reddit.com/user/test', logger)
'GooglePlayStore': { assert found_ids == {'test': 'username'}
'cookies': None,
'http_status': 200,
'is_similar': False, def test_extract_ids_from_results(test_db):
'parsing_enabled': False, TEST_EXAMPLE = dict(RESULTS_EXAMPLE)
'rank': 1, TEST_EXAMPLE['Reddit']['ids_usernames'] = {'test1': 'yandex_public_id'}
'url_main': 'https://play.google.com/store', TEST_EXAMPLE['Reddit']['ids_links'] = ['https://www.reddit.com/user/test2']
'url_user': 'https://play.google.com/store/apps/developer?id=Facebook',
'username': 'Facebook', found_ids = extract_ids_from_results(TEST_EXAMPLE, test_db)
}, assert found_ids == {'test1': 'yandex_public_id', 'test2': 'username'}
}