diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..a58571825 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +# .coveragerc to control coverage.py +[run] +omit = + env/* + test_* + \ No newline at end of file diff --git a/howdoi/howdoi.py b/howdoi/howdoi.py index 09074c695..7eb295f08 100755 --- a/howdoi/howdoi.py +++ b/howdoi/howdoi.py @@ -32,6 +32,7 @@ from requests.exceptions import ConnectionError from requests.exceptions import SSLError +from .stats import Stats # Handle imports for Python 2 and 3 if sys.version < '3': @@ -124,6 +125,14 @@ def _print_dbg(x): print("[DEBUG] " + x) # noqa: E302 else: cache = FileSystemCache(CACHE_DIR, CACHE_ENTRY_MAX, default_timeout=0) +DEFAULT_STORE_DIR = appdirs.user_cache_dir('howdoi-stats') + +if os.getenv('HOWDOI_DISABLE_STATS_COLLECTIONS'): + stats_cache = NullCache() +else: + stats_cache = FileSystemCache(DEFAULT_STORE_DIR, default_timeout=0) + +stats_obj = Stats(stats_cache) howdoi_session = requests.session() @@ -353,6 +362,7 @@ def _get_links_with_cache(query): question_links = _get_questions(links) cache.set(cache_key, question_links or CACHE_EMPTY_VAL) + stats_obj.process_discovered_links(question_links) return question_links @@ -447,7 +457,7 @@ def _get_cache_key(args): return str(args) + __version__ -def format_stash_item(fields, index = -1): +def format_stash_item(fields, index=-1): title = fields['alias'] description = fields['desc'] item_num = index + 1 @@ -467,12 +477,13 @@ def format_stash_item(fields, index = -1): description=description) -def print_stash(stash_list = []): +def print_stash(stash_list=[]): if len(stash_list) == 0: stash_list = ['\nSTASH LIST:'] commands = keep_utils.read_commands() if commands is None or len(commands.items()) == 0: - print('No commands found in stash. Add a command with "howdoi --{stash_save} ".'.format(stash_save=STASH_SAVE)) + print( + 'No commands found in stash. Add a command with "howdoi --{stash_save} ".'.format(stash_save=STASH_SAVE)) return for cmd, fields in commands.items(): stash_list.append(format_stash_item(fields)) @@ -483,7 +494,7 @@ def print_stash(stash_list = []): def _get_stash_key(args): stash_args = {} - ignore_keys = [STASH_SAVE, STASH_VIEW, STASH_REMOVE, STASH_EMPTY, 'tags'] # ignore these for stash key. + ignore_keys = [STASH_SAVE, STASH_VIEW, STASH_REMOVE, STASH_EMPTY, 'tags'] # ignore these for stash key. for key in args: if not (key in ignore_keys): stash_args[key] = args[key] @@ -543,9 +554,13 @@ def howdoi(raw_query): if _is_help_query(args['query']): return _get_help_instructions() + '\n' + stats_obj.process_args(args) + res = cache.get(cache_key) if res: + stats_obj.record_cache_hit() + stats_obj.process_response(res) return _parse_cmd(args, res) try: @@ -557,6 +572,7 @@ def howdoi(raw_query): res = {'error': 'Unable to reach {search_engine}. Do you need to use a proxy?\n'.format( search_engine=args['search_engine'])} + stats_obj.process_response(res) return _parse_cmd(args, res) @@ -584,10 +600,11 @@ def get_parser(): action='store_true'), parser.add_argument('--empty', help='empty your stash', action='store_true') + parser.add_argument('--stats', help='view your howdoi usage statistics', action='store_true') return parser -def prompt_stash_remove(args, stash_list, view_stash = True): +def prompt_stash_remove(args, stash_list, view_stash=True): if view_stash: print_stash(stash_list) @@ -642,12 +659,17 @@ def command_line_runner(): if args[STASH_REMOVE] and len(args['query']) == 0: commands = keep_utils.read_commands() if commands is None or len(commands.items()) == 0: - print('No commands found in stash. Add a command with "howdoi --{stash_save} ".'.format(stash_save=STASH_SAVE)) + print( + 'No commands found in stash. Add a command with "howdoi --{stash_save} ".'.format(stash_save=STASH_SAVE)) return stash_list = [{'command': cmd, 'fields': field} for cmd, field in commands.items()] prompt_stash_remove(args, stash_list) return + if args['stats']: + stats_obj.render_stats() + return + if not args['query']: parser.print_help() return diff --git a/howdoi/stats.py b/howdoi/stats.py new file mode 100644 index 000000000..b7dc2ce80 --- /dev/null +++ b/howdoi/stats.py @@ -0,0 +1,307 @@ +import collections +import sys +from datetime import datetime, timedelta +from time import time + +import appdirs +from cachelib import FileSystemCache + +from .utils import get_top_n_key_val_pairs_from_dict, safe_divide + +FIRST_INSTALL_DATE_KEY = 'FIRST_INSTALL_DATE_KEY' +CACHE_HIT_KEY = 'CACHE_HIT_KEY' +TOTAL_REQUESTS_KEY = 'TOTAL_REQUESTS_KEY' +DISCOVERED_LINKS_KEY = 'DISCOVERED_LINKS_KEY' +ERROR_RESULT_KEY = 'ERROR_RESULT_KEY' +SUCCESS_RESULT_KEY = 'SUCCESS_RESULT_KEY' +DATE_KEY = 'DATE_KEY' +HOUR_OF_DAY_KEY = 'HOUR_OF_DAY_KEY' +QUERY_KEY = 'QUERY_KEY' +QUERY_WORD_KEY = 'QUERY_WORD' +DATESTRING_FORMAT = "%Y-%m-%d" +TIMESTRING_FORMAT = "%H:%M:%S" +SEARCH_ENGINE_KEY = 'SEARCH_ENGINE_KEY' +DATETIME_STRING_FORMAT = " ".join((DATESTRING_FORMAT, TIMESTRING_FORMAT)) + +TERMGRAPH_DEFAULT_ARGS = {'filename': '-', 'title': None, 'width': 50, 'format': '{:<5.1f}', 'suffix': '', 'no_labels': False, 'no_values': False, 'color': None, 'vertical': False, 'stacked': False, + 'histogram': False, 'bins': 5, 'different_scale': False, 'calendar': False, 'start_dt': None, 'custom_tick': '', 'delim': '', 'verbose': False, 'label_before': False, 'version': False} + +Report = collections.namedtuple('Report', ['group', 'content']) + + +def can_use_termgraph(): + return sys.version >= '3.6' + + +if can_use_termgraph(): + from termgraph import termgraph + + +def draw_horizontal_graph(data, labels, custom_args=None): + if can_use_termgraph(): + assert len(data) == len(labels) + if custom_args is None: + custom_args = {} + args = {} + args.update(TERMGRAPH_DEFAULT_ARGS) + args.update(custom_args) + termgraph.chart([], [[datapoint] for datapoint in data], args, [str(label) for label in labels]) + + +class StatsReporter: + def __init__(self, args, colors=[]): + self.termgraph_args = args + self.COLORS = colors + self._report_group_map = collections.OrderedDict() + + def add(self, report): + assert isinstance(report, Report) + if report.group not in self._report_group_map: + self._report_group_map[report.group] = [] + + self._report_group_map[report.group].append(report) + + def render_report(self, report): + if callable(report.content): + report.content() + elif isinstance(report.content, str): + print(report.content) + + def render_report_separator(self, length, separator_char="*"): + separation_string = separator_char*length + print(separation_string) + + def report(self): + for key in self._report_group_map: + self.render_report_separator(70) + for report in self._report_group_map[key]: + self.render_report(report) + + +class Stats: + def __init__(self, cache): + self.DISALLOWED_WORDS = set(['in', 'a', 'an', 'for', 'on']) + self.cache = cache + self.sr = StatsReporter(TERMGRAPH_DEFAULT_ARGS) + if not self.cache.has(FIRST_INSTALL_DATE_KEY): + self.cache.clear() + self.cache.set(FIRST_INSTALL_DATE_KEY, datetime.today().strftime(DATESTRING_FORMAT)) + + def load_time_stats(self): + # TODO - Add heatmap. + sr = self.sr + days_since_first_install = self.get_days_since_first_install() or 0 + total_request_count = self[TOTAL_REQUESTS_KEY] or 0 + + sr.add(Report( + 'time-related-stats', 'You have been using howdoi for {} days.'.format(days_since_first_install))) + + sr.add( + Report( + 'time-related-stats', 'You have made an average of {} queries per day.'.format( + safe_divide(total_request_count, days_since_first_install)) + ) + ) + hour_of_day_map = self[HOUR_OF_DAY_KEY] + + if total_request_count > 0 and hour_of_day_map: + most_active_hour_of_day = max(hour_of_day_map, key=lambda hour: hour_of_day_map[hour]) + + sr.add( + Report( + 'time-related-stats', 'You are most active between {}:00 and {}:00.'.format( + most_active_hour_of_day, most_active_hour_of_day+1 + ) + ) + ) + + keys, values = [], [] + for k in hour_of_day_map: + lower_time_bound = str(k) + ":00" + upper_time_bound = str(k+1) + ":00" if k+1 < 24 else "00:00" + keys.append(lower_time_bound + "-" + upper_time_bound) + values.append(hour_of_day_map[k]) + + sr.add( + Report( + 'time-related-stats', lambda: draw_horizontal_graph(data=values, labels=keys, custom_args={ + 'suffix': ' uses', 'format': '{:<1d}'}) + ) + ) + + def load_request_stats(self): + sr = self.sr + total_request_count = self[TOTAL_REQUESTS_KEY] or 0 + cached_request_count = self[CACHE_HIT_KEY] or 0 + total_request_count = self[TOTAL_REQUESTS_KEY] or 0 + outbound_request_count = total_request_count - cached_request_count + successful_requests = self[SUCCESS_RESULT_KEY] or 0 + failed_requests = self[ERROR_RESULT_KEY] or 0 + + sr.add( + Report('network-request-stats', 'Of the {} requests you have made using howdoi, {} have been saved by howdoi\'s cache'.format( + total_request_count, cached_request_count)) + ) + + sr.add( + Report('network-request-stats', 'Also, {} requests have succeeded, while {} have failed due to connection issues, or some other problem.'.format( + successful_requests, failed_requests)) + ) + + if total_request_count > 0: + sr.add( + Report( + 'network-request-stats', lambda: draw_horizontal_graph( + data=[safe_divide(outbound_request_count*100, total_request_count), + safe_divide(cached_request_count*100, total_request_count)], + labels=['Outbound Requests', 'Cache Saved Requests'], + custom_args={'suffix': '%', } + ) + ) + ) + + if successful_requests+failed_requests > 0: + sr.add( + Report('network-request-stats', lambda: draw_horizontal_graph( + data=[safe_divide(successful_requests*100, successful_requests+failed_requests), + safe_divide(failed_requests*100, successful_requests+failed_requests)], + labels=['Succesful Requests', 'Failed Requests'], + custom_args={'suffix': '%', } + ) + ) + ) + + def load_search_engine_stats(self): + sr = self.sr + search_engine_frequency_map = self[SEARCH_ENGINE_KEY] + if search_engine_frequency_map is not None: + max_search_engine_key = max(search_engine_frequency_map, + key=lambda engine: search_engine_frequency_map[engine]) + sr.add( + Report( + 'search-engine-stats', 'Your most used search engine is {}'.format( + max_search_engine_key.title() + ) + ) + ) + + search_engine_keys = [] + search_engine_values = [] + for k in search_engine_frequency_map: + search_engine_keys.append(k) + search_engine_values.append(search_engine_frequency_map[k]) + + sr.add( + Report( + 'search-engine-stats', lambda: draw_horizontal_graph( + data=search_engine_values, labels=search_engine_keys, custom_args={'suffix': ' uses', 'format': '{:<1d}'}) + ) + ) + + def load_query_related_stats(self): + sr = self.sr + query_map = self[QUERY_KEY] + query_words_map = self[QUERY_WORD_KEY] + top_5_query_key_vals = get_top_n_key_val_pairs_from_dict(query_map, 5) + + top_5_query_words_key_vals = get_top_n_key_val_pairs_from_dict(query_words_map, 5) + + if len(top_5_query_key_vals) > 0: + most_common_query = top_5_query_key_vals[0][0] + sr.add( + Report( + 'query-stats', 'The query you\'ve made the most times is {}'.format( + most_common_query + ) + ) + ) + if len(top_5_query_words_key_vals) > 0: + most_common_query_word = top_5_query_words_key_vals[0][0] + sr.add( + Report( + 'query-stats', 'The most common word in your queries is {}'.format( + most_common_query_word + ) + ) + ) + + data = [val for _, val in top_5_query_words_key_vals] + labels = [key for key, _ in top_5_query_words_key_vals] + + sr.add( + Report('query-stats', lambda: draw_horizontal_graph(data=data, labels=labels, + custom_args={'suffix': ' uses', 'format': '{:<1d}'}) + )) + + def render_stats(self): + + self.load_time_stats() + self.load_request_stats() + self.load_search_engine_stats() + self.load_query_related_stats() + self.sr.report() + + def get_days_since_first_install(self): + first_install_date = self.cache.get(FIRST_INSTALL_DATE_KEY) + delta = datetime.today() - datetime.strptime(first_install_date, DATESTRING_FORMAT) + return delta.days + + def record_cache_hit(self): + self.cache.inc(CACHE_HIT_KEY) + + def increment_total_requests(self): + self.cache.inc(TOTAL_REQUESTS_KEY) + + def __getitem__(self, key): + return self.cache.get(key) + + def add_value_to_stats_count_map(self, key, value): + stats_map = self.cache.get(key) + if stats_map is None: + stats_map = collections.Counter() + stats_map[value] += 1 + self.cache.set(key, stats_map) + + def process_query_string(self, querystring): + if not querystring: + return + querystring = querystring.strip() + self.add_value_to_stats_count_map(QUERY_KEY, querystring) + + words = querystring.split(" ") + for word in words: + word = word.lower() + if word not in self.DISALLOWED_WORDS: + self.add_value_to_stats_count_map(QUERY_WORD_KEY, word) + + def increment_current_date_count(self): + curr_date_string = datetime.today().strftime(DATESTRING_FORMAT) + self.add_value_to_stats_count_map(DATE_KEY, curr_date_string) + + def increment_current_hour_of_day_count(self): + curr_hour_of_day = datetime.now().hour + self.add_value_to_stats_count_map(HOUR_OF_DAY_KEY, curr_hour_of_day) + + def process_search_engine(self, search_engine): + if search_engine: + self.add_value_to_stats_count_map(SEARCH_ENGINE_KEY, search_engine) + + def process_discovered_links(self, links): + if not links: + return + for link in links: + self.add_value_to_stats_count_map(DISCOVERED_LINKS_KEY, link) + + def process_args(self, args): + self.increment_total_requests() + self.process_search_engine(args.get('search_engine')) + self.increment_current_date_count() + self.increment_current_hour_of_day_count() + self.process_query_string(args.get('query')) + + def process_response(self, res): + key = ERROR_RESULT_KEY if self._is_error_response(res) else SUCCESS_RESULT_KEY + self.cache.inc(key) + + def _is_error_response(self, res): + return not res or (type(res) == dict and res.get('error')) diff --git a/howdoi/utils.py b/howdoi/utils.py new file mode 100644 index 000000000..f0371b332 --- /dev/null +++ b/howdoi/utils.py @@ -0,0 +1,17 @@ +import heapq + + +def get_top_n_key_val_pairs_from_dict(dict_, N): + top_n_key_value_pairs = [] + if isinstance(dict_, dict): + for key in dict_: + heapq.heappush(top_n_key_value_pairs, (dict_[key], key)) + if len(top_n_key_value_pairs) > N: + heapq.heappop(top_n_key_value_pairs) + + top_n_key_value_pairs.sort(reverse=True) + return [(k, v) for v, k in top_n_key_value_pairs] + + +def safe_divide(numerator, denominator): + return numerator/denominator if denominator != 0 else 0 diff --git a/page_cache/httpswwwgooglecomsearchqsitestackoverflowcom20p20r20i20n20t202020s20t20a20c20k202020t20r20a20c20e202020p20y20t20h20o20nhlen.html.gz b/page_cache/httpswwwgooglecomsearchqsitestackoverflowcom20p20r20i20n20t202020s20t20a20c20k202020t20r20a20c20e202020p20y20t20h20o20nhlen.html.gz new file mode 100644 index 000000000..fec39b1b4 Binary files /dev/null and b/page_cache/httpswwwgooglecomsearchqsitestackoverflowcom20p20r20i20n20t202020s20t20a20c20k202020t20r20a20c20e202020p20y20t20h20o20nhlen.html.gz differ diff --git a/requirements.txt b/requirements.txt index 8a66c4551..638080df9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ pyquery==1.4.1 requests==2.24.0 cachelib==0.1 appdirs==1.4.4 -keep==2.9 \ No newline at end of file +keep==2.9 +termgraph==0.4.2 \ No newline at end of file diff --git a/setup.py b/setup.py index 05563c6a2..b4b564c8a 100644 --- a/setup.py +++ b/setup.py @@ -75,5 +75,6 @@ def read(*names): 'cachelib==0.1', 'appdirs', 'keep', + 'termgraph' ] + extra_dependencies(), ) diff --git a/test_howdoi.py b/test_howdoi.py index e1bd72a2b..fa6749bc8 100644 --- a/test_howdoi.py +++ b/test_howdoi.py @@ -37,7 +37,7 @@ def setUp(self): # ensure no cache is used during testing. howdoi.cache = NullCache() - + howdoi.stats_obj.cache = NullCache() self.queries = ['format date bash', 'print stack trace python', 'convert mp4 to animated gif', diff --git a/test_stats.py b/test_stats.py new file mode 100644 index 000000000..1cbb566b7 --- /dev/null +++ b/test_stats.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python + +"""Tests for Howdoi.""" +import os +import io +import shutil +import unittest +import unittest.mock +from datetime import datetime +from tempfile import mkdtemp, mkstemp + +from cachelib import FileSystemCache, NullCache + +from howdoi import howdoi + +from howdoi.stats import (DATE_KEY, DATESTRING_FORMAT, DISCOVERED_LINKS_KEY, + HOUR_OF_DAY_KEY, QUERY_KEY, QUERY_WORD_KEY, + SEARCH_ENGINE_KEY, Stats, Report, StatsReporter, Report, TERMGRAPH_DEFAULT_ARGS, CACHE_HIT_KEY, TOTAL_REQUESTS_KEY, ERROR_RESULT_KEY, SUCCESS_RESULT_KEY) + + +class StatsTest(unittest.TestCase): + def setUp(self): + self.cache_dir = mkdtemp(prefix='howdoi_test') + cache = FileSystemCache(self.cache_dir, default_timeout=0) + self.stats_obj = Stats(cache) + howdoi.stats_obj = self.stats_obj + + self.howdoi_args = [{'query': 'print stack trace python', 'search_engine': 'google'}, { + 'query': 'convert mp4 to animated gif', 'search_engine': 'bing'}, {'query': 'create tar archive', 'search_engine': 'google'}] + + self.result_links = ['https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n', 'https://stackoverflow.com/questions/13427890/how-can-i-find-all-prime-numbers-in-a-given-range', + 'https://stackoverflow.com/questions/453793/which-is-the-fastest-algorithm-to-find-prime-numbers', 'https://stackoverflow.com/questions/18928095/fastest-way-to-find-all-primes-under-4-billion', ] + + self.error_howdoi_results = [{"error": "Sorry, couldn\'t find any help with that topic\n"}, { + "error": "Failed to establish network connection\n"}] + self.success_howdoi_results = [{'answer': 'https://github.com//.git\n', + 'link': 'https://stackoverflow.com/questions/14762034/push-to-github-without-a-password-using-ssh-key', 'position': 1}] + + def tearDown(self): + shutil.rmtree(self.cache_dir) + + def test_days_since_first_install_is_correct(self): + self.assertEqual(self.stats_obj.get_days_since_first_install(), 0) + + def test_querystring_processing(self): + for args in self.howdoi_args: + self.stats_obj.process_query_string(args['query']) + + self.assertIsNotNone(self.stats_obj[QUERY_KEY]) + self.assertIsNotNone(self.stats_obj[QUERY_WORD_KEY]) + for args in self.howdoi_args: + query = args['query'] + self.assertEquals(self.stats_obj[QUERY_KEY][query], 1) + + self.assertEquals(self.stats_obj[QUERY_WORD_KEY]['python'], 1) + self.assertEquals(self.stats_obj[QUERY_WORD_KEY]['archive'], 1) + self.assertEquals(self.stats_obj[QUERY_WORD_KEY]['on'], 0) + + def test_increment_current_date(self): + self.stats_obj.increment_current_date_count() + self.stats_obj.increment_current_date_count() + self.stats_obj.increment_current_date_count() + + curr_date_string = datetime.today().strftime(DATESTRING_FORMAT) + + self.assertIsNotNone(self.stats_obj[DATE_KEY]) + self.assertIs(self.stats_obj[DATE_KEY][curr_date_string], 3) + + def test_increment_current_hour_of_day(self): + self.stats_obj.increment_current_hour_of_day_count() + self.stats_obj.increment_current_hour_of_day_count() + self.stats_obj.increment_current_hour_of_day_count() + + curr_hour_of_day = datetime.now().hour + self.assertIsNotNone(self.stats_obj[HOUR_OF_DAY_KEY]) + self.assertEquals(self.stats_obj[HOUR_OF_DAY_KEY][curr_hour_of_day], 3) + + def test_increment_queries_cache_hits(self): + self.stats_obj.record_cache_hit() + self.stats_obj.record_cache_hit() + self.stats_obj.record_cache_hit() + + self.assertEquals(self.stats_obj[CACHE_HIT_KEY], 3) + + def test_total_request_count(self): + for args in self.howdoi_args: + self.stats_obj.process_args(args) + + self.assertEquals(self.stats_obj[TOTAL_REQUESTS_KEY], len(self.howdoi_args)) + + def test_process_search_engine(self): + self.stats_obj.process_search_engine('google') + self.stats_obj.process_search_engine('google') + self.stats_obj.process_search_engine('bing') + self.stats_obj.process_search_engine('bing') + + stored_search_engine_map = self.stats_obj[SEARCH_ENGINE_KEY] + + self.assertEquals(stored_search_engine_map['google'], 2) + self.assertEquals(stored_search_engine_map['bing'], 2) + self.assertEquals(stored_search_engine_map['duckduckgo'], 0) + + def test_processes_discovered_links(self): + self.stats_obj.process_discovered_links(self.result_links) + + stored_links_map = self.stats_obj[DISCOVERED_LINKS_KEY] + + for link in self.result_links: + self.assertEquals(stored_links_map[link], self.result_links.count(link)) + + def test_counts_valid_responses(self): + for response in self.success_howdoi_results: + self.stats_obj.process_response(response) + + self.assertEquals(self.stats_obj[SUCCESS_RESULT_KEY], len(self.success_howdoi_results)) + + def test_counts_error_responses(self): + for response in self.error_howdoi_results: + self.stats_obj.process_response(response) + + self.assertEquals(self.stats_obj[ERROR_RESULT_KEY], len(self.error_howdoi_results)) + + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def test_render_stats(self, mock_stdout): + for args in self.howdoi_args: + self.stats_obj.process_args(args) + + self.stats_obj.render_stats() + + stdout_value = mock_stdout.getvalue() + self.assertIn('0 days', stdout_value) + self.assertIn('print stack trace python', stdout_value) + self.assertIn("You are most active between", stdout_value) + self.assertIn("queries per day", stdout_value) + self.assertIn("*****", stdout_value) + + +class StatsReporterTest(unittest.TestCase): + def setUp(self): + self.sr = StatsReporter(TERMGRAPH_DEFAULT_ARGS) + + def test_add_report(self): + report_group_name = 'time-stats-group' + + self.sr.add(Report(report_group_name, 'sample stat report')) + self.assertIn(report_group_name, self.sr._report_group_map) + self.assertEquals(len(self.sr._report_group_map[report_group_name]), 1) + + def test_add_invalid_report_throws_exception(self): + with self.assertRaises(AssertionError): + self.sr.add('sample stat report') + + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def test_text_reports_are_rendered_correctly(self, mock_stdout): + sample_text_report = Report('report-group-1', 'this is a sample stat') + self.sr.render_report(sample_text_report) + self.assertEqual(mock_stdout.getvalue(), 'this is a sample stat\n') + + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def test_callable_reports_are_rendered_correctly(self, mock_stdout): + sample_callable_report = Report('report-group-1', lambda: print('this is a callable stat')) + self.sr.render_report(sample_callable_report) + self.assertEqual(mock_stdout.getvalue(), 'this is a callable stat\n') + + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def test_report_separator_render_valid_separator(self, mock_stdout): + self.sr.render_report_separator(20, '*') + self.assertEqual(mock_stdout.getvalue(), "*"*20 + "\n") + + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def test_overall_report(self, mock_stdout): + sample_callable_report = Report('report-group-1', lambda: print('this is a callable stat')) + sample_text_report = Report('report-group-1', 'this is a sample stat') + + self.sr.add(sample_callable_report) + self.sr.add(sample_text_report) + + self.sr.report() + + self.assertIn("callable", mock_stdout.getvalue()) + self.assertIn("sample", mock_stdout.getvalue()) + + +if __name__ == '__main__': + unittest.main() diff --git a/test_utils.py b/test_utils.py new file mode 100644 index 000000000..99c939ad1 --- /dev/null +++ b/test_utils.py @@ -0,0 +1,24 @@ +import unittest +import math + +from howdoi import utils + + +class TestUtils(unittest.TestCase): + def test_get_top_n_key_val_pairs_from_dict_returns_correct_result(self): + dictionary1 = {'a': 1, 'b': 2, 'd': 2, 'c': 11} + top_2 = utils.get_top_n_key_val_pairs_from_dict(dictionary1, 2) + + self.assertEqual(len(top_2), 2) + self.assertEqual(top_2[0], ('c', 11)) + self.assertEqual(top_2[1], ('d', 2)) + + def test_safe_division(self): + ans1 = utils.safe_divide(10, 3) + self.assertTrue(math.isclose(ans1, 3.333, abs_tol=0.001)) + ans2 = utils.safe_divide(10, 0) + self.assertEqual(ans2, 0) + + +if __name__ == '__main__': + unittest.main()