mirror of
https://github.com/pruzko/hakuin
synced 2024-11-08 13:59:15 +01:00
float extraction (text-based binary search)
This commit is contained in:
parent
34e11cf2ee
commit
5983e5fb42
@ -2,8 +2,6 @@ import hakuin
|
|||||||
import hakuin.search_algorithms as alg
|
import hakuin.search_algorithms as alg
|
||||||
import hakuin.collectors as coll
|
import hakuin.collectors as coll
|
||||||
|
|
||||||
from hakuin.utils import CHARSET_DIGITS
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Extractor:
|
class Extractor:
|
||||||
@ -200,12 +198,10 @@ class Extractor:
|
|||||||
list: list of floats in the column
|
list: list of floats in the column
|
||||||
'''
|
'''
|
||||||
ctx = coll.Context(table=table, column=column)
|
ctx = coll.Context(table=table, column=column)
|
||||||
res = coll.BinaryTextCollector(
|
return coll.FloatCollector(
|
||||||
requester=self.requester,
|
requester=self.requester,
|
||||||
dbms=self.dbms,
|
dbms=self.dbms,
|
||||||
charset=CHARSET_DIGITS,
|
|
||||||
).run(ctx)
|
).run(ctx)
|
||||||
return [float(v) if v is not None else None for v in res]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_column_bytes(self, table, column):
|
def extract_column_bytes(self, table, column):
|
||||||
|
@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
import hakuin
|
import hakuin
|
||||||
from hakuin.utils import tokenize, EOS, ASCII_MAX, UNICODE_MAX, BYTE_MAX
|
from hakuin.utils import tokenize, EOS, ASCII_MAX, UNICODE_MAX, BYTE_MAX, CHARSET_DIGITS
|
||||||
from hakuin.utils.huffman import make_tree
|
from hakuin.utils.huffman import make_tree
|
||||||
from hakuin.search_algorithms import BinarySearch, TreeSearch, NumericBinarySearch
|
from hakuin.search_algorithms import BinarySearch, TreeSearch, NumericBinarySearch
|
||||||
|
|
||||||
@ -116,6 +116,7 @@ class Collector(metaclass=ABCMeta):
|
|||||||
return self.requester.request(ctx, query)
|
return self.requester.request(ctx, query)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class IntCollector(Collector):
|
class IntCollector(Collector):
|
||||||
'''Collector for integer columns'''
|
'''Collector for integer columns'''
|
||||||
def collect_row(self, ctx):
|
def collect_row(self, ctx):
|
||||||
@ -129,12 +130,35 @@ class IntCollector(Collector):
|
|||||||
).run(ctx)
|
).run(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FloatCollector(Collector):
|
||||||
|
'''Collector for integer columns'''
|
||||||
|
def collect_row(self, ctx):
|
||||||
|
ctx.s = ''
|
||||||
|
while True:
|
||||||
|
c = self.collect_one(ctx)
|
||||||
|
if c == EOS:
|
||||||
|
return ctx.s
|
||||||
|
ctx.s += c
|
||||||
|
|
||||||
|
return float(ctx.s)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_one(self, ctx):
|
||||||
|
return BinarySearch(
|
||||||
|
requester=self.requester,
|
||||||
|
query_cb=self.dbms.q_float_char_in_set,
|
||||||
|
values=CHARSET_DIGITS,
|
||||||
|
).run(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BytesCollector(Collector):
|
class BytesCollector(Collector):
|
||||||
'''Collector for bytes columns'''
|
'''Collector for bytes columns'''
|
||||||
def collect_row(self, ctx):
|
def collect_row(self, ctx):
|
||||||
ctx.s = b''
|
ctx.s = b''
|
||||||
while True:
|
while True:
|
||||||
b = self.collect_byte(ctx)
|
b = self.collect_one(ctx)
|
||||||
if b == EOS:
|
if b == EOS:
|
||||||
return ctx.s
|
return ctx.s
|
||||||
ctx.s += b
|
ctx.s += b
|
||||||
@ -142,7 +166,7 @@ class BytesCollector(Collector):
|
|||||||
return ctx.s
|
return ctx.s
|
||||||
|
|
||||||
|
|
||||||
def collect_byte(self, ctx):
|
def collect_one(self, ctx):
|
||||||
res = NumericBinarySearch(
|
res = NumericBinarySearch(
|
||||||
requester=self.requester,
|
requester=self.requester,
|
||||||
query_cb=self.dbms.q_byte_lt,
|
query_cb=self.dbms.q_byte_lt,
|
||||||
@ -154,6 +178,7 @@ class BytesCollector(Collector):
|
|||||||
return EOS if res == BYTE_MAX + 1 else res.to_bytes(1, 'big')
|
return EOS if res == BYTE_MAX + 1 else res.to_bytes(1, 'big')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextCollector(Collector):
|
class TextCollector(Collector):
|
||||||
'''Collector for text columns.'''
|
'''Collector for text columns.'''
|
||||||
def __init__(self, requester, dbms, charset=None):
|
def __init__(self, requester, dbms, charset=None):
|
||||||
|
@ -114,8 +114,8 @@ class DBMS(metaclass=ABCMeta):
|
|||||||
query = self.jj.get_template('char_in_set.jinja').render(ctx=ctx, values=values, has_eos=has_eos)
|
query = self.jj.get_template('char_in_set.jinja').render(ctx=ctx, values=values, has_eos=has_eos)
|
||||||
return self.normalize(query)
|
return self.normalize(query)
|
||||||
|
|
||||||
def q_char_lt(self, ctx, n, has_eos):
|
def q_char_lt(self, ctx, n):
|
||||||
query = self.jj.get_template('char_lt.jinja').render(ctx=ctx, n=n, has_eos=has_eos)
|
query = self.jj.get_template('char_lt.jinja').render(ctx=ctx, n=n)
|
||||||
return self.normalize(query)
|
return self.normalize(query)
|
||||||
|
|
||||||
def q_string_in_set(self, ctx, values):
|
def q_string_in_set(self, ctx, values):
|
||||||
@ -126,6 +126,9 @@ class DBMS(metaclass=ABCMeta):
|
|||||||
query = self.jj.get_template('int_lt.jinja').render(ctx=ctx, n=n)
|
query = self.jj.get_template('int_lt.jinja').render(ctx=ctx, n=n)
|
||||||
return self.normalize(query)
|
return self.normalize(query)
|
||||||
|
|
||||||
|
def q_float_char_in_set(self, ctx, values):
|
||||||
|
return self.q_char_in_set(ctx, values)
|
||||||
|
|
||||||
def q_byte_lt(self, ctx, n):
|
def q_byte_lt(self, ctx, n):
|
||||||
query = self.jj.get_template('byte_lt.jinja').render(ctx=ctx, n=n)
|
query = self.jj.get_template('byte_lt.jinja').render(ctx=ctx, n=n)
|
||||||
return self.normalize(query)
|
return self.normalize(query)
|
||||||
|
Loading…
Reference in New Issue
Block a user