diff --git a/hakuin/Extractor.py b/hakuin/Extractor.py index c1acf03..a6373a3 100644 --- a/hakuin/Extractor.py +++ b/hakuin/Extractor.py @@ -2,8 +2,6 @@ import hakuin import hakuin.search_algorithms as alg import hakuin.collectors as coll -from hakuin.utils import CHARSET_DIGITS - class Extractor: @@ -200,12 +198,10 @@ class Extractor: list: list of floats in the column ''' ctx = coll.Context(table=table, column=column) - res = coll.BinaryTextCollector( + return coll.FloatCollector( requester=self.requester, dbms=self.dbms, - charset=CHARSET_DIGITS, ).run(ctx) - return [float(v) if v is not None else None for v in res] def extract_column_bytes(self, table, column): diff --git a/hakuin/collectors.py b/hakuin/collectors.py index 0e282f2..edd9c93 100644 --- a/hakuin/collectors.py +++ b/hakuin/collectors.py @@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod from collections import Counter import hakuin -from hakuin.utils import tokenize, EOS, ASCII_MAX, UNICODE_MAX, BYTE_MAX +from hakuin.utils import tokenize, EOS, ASCII_MAX, UNICODE_MAX, BYTE_MAX, CHARSET_DIGITS from hakuin.utils.huffman import make_tree from hakuin.search_algorithms import BinarySearch, TreeSearch, NumericBinarySearch @@ -116,6 +116,7 @@ class Collector(metaclass=ABCMeta): return self.requester.request(ctx, query) + class IntCollector(Collector): '''Collector for integer columns''' def collect_row(self, ctx): @@ -129,12 +130,35 @@ class IntCollector(Collector): ).run(ctx) + +class FloatCollector(Collector): + '''Collector for integer columns''' + def collect_row(self, ctx): + ctx.s = '' + while True: + c = self.collect_one(ctx) + if c == EOS: + return ctx.s + ctx.s += c + + return float(ctx.s) + + + def collect_one(self, ctx): + return BinarySearch( + requester=self.requester, + query_cb=self.dbms.q_float_char_in_set, + values=CHARSET_DIGITS, + ).run(ctx) + + + class BytesCollector(Collector): '''Collector for bytes columns''' def collect_row(self, ctx): ctx.s = b'' while True: - b = self.collect_byte(ctx) + b = self.collect_one(ctx) if b == EOS: return ctx.s ctx.s += b @@ -142,7 +166,7 @@ class BytesCollector(Collector): return ctx.s - def collect_byte(self, ctx): + def collect_one(self, ctx): res = NumericBinarySearch( requester=self.requester, query_cb=self.dbms.q_byte_lt, @@ -154,6 +178,7 @@ class BytesCollector(Collector): return EOS if res == BYTE_MAX + 1 else res.to_bytes(1, 'big') + class TextCollector(Collector): '''Collector for text columns.''' def __init__(self, requester, dbms, charset=None): diff --git a/hakuin/dbms/DBMS.py b/hakuin/dbms/DBMS.py index 06a0ca6..353ddbb 100644 --- a/hakuin/dbms/DBMS.py +++ b/hakuin/dbms/DBMS.py @@ -114,8 +114,8 @@ class DBMS(metaclass=ABCMeta): query = self.jj.get_template('char_in_set.jinja').render(ctx=ctx, values=values, has_eos=has_eos) return self.normalize(query) - def q_char_lt(self, ctx, n, has_eos): - query = self.jj.get_template('char_lt.jinja').render(ctx=ctx, n=n, has_eos=has_eos) + def q_char_lt(self, ctx, n): + query = self.jj.get_template('char_lt.jinja').render(ctx=ctx, n=n) return self.normalize(query) def q_string_in_set(self, ctx, values): @@ -126,6 +126,9 @@ class DBMS(metaclass=ABCMeta): query = self.jj.get_template('int_lt.jinja').render(ctx=ctx, n=n) return self.normalize(query) + def q_float_char_in_set(self, ctx, values): + return self.q_char_in_set(ctx, values) + def q_byte_lt(self, ctx, n): query = self.jj.get_template('byte_lt.jinja').render(ctx=ctx, n=n) return self.normalize(query)