Fix trimming page size logic while adding to a page

This commit is contained in:
Rishabh Singh Ahluwalia 2023-03-26 10:04:05 -07:00
parent 38a5dbbf3c
commit 8e197a09f9
2 changed files with 134 additions and 13 deletions

View file

@ -65,19 +65,39 @@ class TinyIndexMetadata:
values = json.loads(data[constant_length:].decode('utf8'))
return TinyIndexMetadata(**values)
# Find the optimal amount of data that fits onto a page
# We do this by leveraging binary search to quickly find the index where:
# - index+1 cannot fit onto a page
# - <=index can fit on a page
def _binary_search_fitting_size(compressor: ZstdCompressor, page_size: int, items:list[T], lo:int, hi:int):
# Base case: our binary search has gone too far
if lo > hi:
return -1, None
# Check the midpoint to see if it will fit onto a page
mid = (lo+hi)//2
compressed_data = compressor.compress(json.dumps(items[:mid]).encode('utf8'))
size = len(compressed_data)
if size > page_size:
# We cannot fit this much data into a page
# Reduce the hi boundary, and try again
return _binary_search_fitting_size(compressor,page_size,items,lo,mid-1)
else:
# We can fit this data into a page, but maybe we can fit more data
# Try to see if we have a better match
potential_target, potential_data = _binary_search_fitting_size(compressor,page_size,items,mid+1,hi)
if potential_target != -1:
# We found a larger index that can still fit onto a page, so use that
return potential_target, potential_data
else:
# No better match, use our index
return mid, compressed_data
def _trim_items_to_page(compressor: ZstdCompressor, page_size: int, items:list[T]):
# Find max number of items that fit on a page
return _binary_search_fitting_size(compressor,page_size,items,0,len(items))
def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
bytes_io = BytesIO()
stream_writer = compressor.stream_writer(bytes_io, write_size=128)
num_fitting = 0
for i, item in enumerate(items):
serialised_data = json.dumps(item) + '\n'
stream_writer.write(serialised_data.encode('utf8'))
stream_writer.flush()
if len(bytes_io.getvalue()) > page_size:
break
num_fitting = i + 1
num_fitting, serialised_data = _trim_items_to_page(compressor, page_size, items)
compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
assert len(compressed_data) < page_size, "The data shouldn't get bigger"

View file

@ -1,8 +1,9 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
from mwmbl.tinysearchengine.indexer import Document, TinyIndex, _binary_search_fitting_size, astuple, _trim_items_to_page, _get_page_data, _pad_to_page_size
from zstandard import ZstdDecompressor, ZstdCompressor, ZstdError
import json
def test_create_index():
num_pages = 10
@ -14,3 +15,103 @@ def test_create_index():
for i in range(num_pages):
page = indexer.get_page(i)
assert page == []
def test_binary_search_fitting_size_all_fit():
items = [1,2,3,4,5,6,7,8,9]
compressor = ZstdCompressor()
page_size = 4096
count_fit, data = _binary_search_fitting_size(compressor,page_size,items,0,len(items))
# We should fit everything
assert count_fit == len(items)
def test_binary_search_fitting_size_subset_fit():
items = [1,2,3,4,5,6,7,8,9]
compressor = ZstdCompressor()
page_size = 15
count_fit, data = _binary_search_fitting_size(compressor,page_size,items,0,len(items))
# We should not fit everything
assert count_fit < len(items)
def test_binary_search_fitting_size_none_fit():
items = [1,2,3,4,5,6,7,8,9]
compressor = ZstdCompressor()
page_size = 5
count_fit, data = _binary_search_fitting_size(compressor,page_size,items,0,len(items))
# We should not fit anything
assert count_fit == -1
assert data is None
def test_get_page_data_single_doc():
document1 = Document(title='title1',url='url1',extract='extract1',score=1.0)
documents = [document1]
items = [astuple(value) for value in documents]
compressor = ZstdCompressor()
page_size = 4096
# Trim data
num_fitting,trimmed_data = _trim_items_to_page(compressor,4096,items)
# We should be able to fit the 1 item into a page
assert num_fitting == 1
# Compare the trimmed data to the actual data we're persisting
# We need to pad the trimmmed data, then it should be equal to the data we persist
padded_trimmed_data = _pad_to_page_size(trimmed_data, page_size)
serialized_data = _get_page_data(compressor,page_size,items)
assert serialized_data == padded_trimmed_data
def test_get_page_data_many_docs_all_fit():
# Build giant documents item
documents = []
documents_len = 500
page_size = 4096
for x in range(documents_len):
txt = 'text{}'.format(x)
document = Document(title=txt,url=txt,extract=txt,score=x)
documents.append(document)
items = [astuple(value) for value in documents]
# Trim the items
compressor = ZstdCompressor()
num_fitting,trimmed_data = _trim_items_to_page(compressor,page_size,items)
# We should be able to fit all items
assert num_fitting == documents_len
# Compare the trimmed data to the actual data we're persisting
# We need to pad the trimmed data, then it should be equal to the data we persist
serialized_data = _get_page_data(compressor,page_size,items)
padded_trimmed_data = _pad_to_page_size(trimmed_data, page_size)
assert serialized_data == padded_trimmed_data
def test_get_page_data_many_docs_subset_fit():
# Build giant documents item
documents = []
documents_len = 5000
page_size = 4096
for x in range(documents_len):
txt = 'text{}'.format(x)
document = Document(title=txt,url=txt,extract=txt,score=x)
documents.append(document)
items = [astuple(value) for value in documents]
# Trim the items
compressor = ZstdCompressor()
num_fitting,trimmed_data = _trim_items_to_page(compressor,page_size,items)
# We should be able to fit a subset of the items onto the page
assert num_fitting > 1
assert num_fitting < documents_len
# Compare the trimmed data to the actual data we're persisting
# We need to pad the trimmed data, then it should be equal to the data we persist
serialized_data = _get_page_data(compressor,page_size,items)
padded_trimmed_data = _pad_to_page_size(trimmed_data, page_size)
assert serialized_data == padded_trimmed_data