229 lines
7.9 KiB
Python
229 lines
7.9 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass, asdict, field
|
|
from enum import IntEnum
|
|
from io import UnsupportedOperation
|
|
from logging import getLogger
|
|
from mmap import mmap, PROT_READ, PROT_WRITE
|
|
from typing import TypeVar, Generic, Callable, List, Optional
|
|
|
|
import mmh3
|
|
from zstandard import ZstdDecompressor, ZstdCompressor, ZstdError
|
|
|
|
VERSION = 1
|
|
METADATA_CONSTANT = b'mwmbl-tiny-search'
|
|
METADATA_SIZE = 4096
|
|
|
|
PAGE_SIZE = 4096
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
def astuple(dc):
|
|
"""
|
|
Convert a type to a tuple - values at the end that are None can be truncated.
|
|
"""
|
|
value = tuple(dc.__dict__.values())
|
|
while value[-1] is None:
|
|
value = value[:-1]
|
|
return value
|
|
|
|
|
|
class DocumentState(IntEnum):
|
|
CURATED = 0
|
|
VALIDATED = 1
|
|
|
|
|
|
@dataclass
|
|
class Document:
|
|
title: str
|
|
url: str
|
|
extract: str
|
|
score: float
|
|
term: Optional[str] = None
|
|
state: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class TokenizedDocument(Document):
|
|
tokens: List[str] = field(default_factory=list)
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class PageError(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class TinyIndexMetadata:
|
|
version: int
|
|
page_size: int
|
|
num_pages: int
|
|
item_factory: str
|
|
|
|
def to_bytes(self) -> bytes:
|
|
metadata_bytes = METADATA_CONSTANT + json.dumps(asdict(self)).encode('utf8')
|
|
assert len(metadata_bytes) <= METADATA_SIZE
|
|
return metadata_bytes
|
|
|
|
@staticmethod
|
|
def from_bytes(data: bytes):
|
|
constant_length = len(METADATA_CONSTANT)
|
|
metadata_constant = data[:constant_length]
|
|
if metadata_constant != METADATA_CONSTANT:
|
|
raise ValueError("This doesn't seem to be an index file")
|
|
|
|
values = json.loads(data[constant_length:].decode('utf8'))
|
|
return TinyIndexMetadata(**values)
|
|
|
|
|
|
# Find the optimal amount of data that fits onto a page
|
|
# We do this by leveraging binary search to quickly find the index where:
|
|
# - index+1 cannot fit onto a page
|
|
# - <=index can fit on a page
|
|
def _binary_search_fitting_size(compressor: ZstdCompressor, page_size: int, items:list[T], lo:int, hi:int):
|
|
# Base case: our binary search has gone too far
|
|
if lo > hi:
|
|
return -1, None
|
|
# Check the midpoint to see if it will fit onto a page
|
|
mid = (lo+hi)//2
|
|
compressed_data = compressor.compress(json.dumps(items[:mid]).encode('utf8'))
|
|
size = len(compressed_data)
|
|
if size > page_size:
|
|
# We cannot fit this much data into a page
|
|
# Reduce the hi boundary, and try again
|
|
return _binary_search_fitting_size(compressor, page_size, items, lo, mid-1)
|
|
else:
|
|
# We can fit this data into a page, but maybe we can fit more data
|
|
# Try to see if we have a better match
|
|
potential_target, potential_data = _binary_search_fitting_size(compressor, page_size, items, mid+1, hi)
|
|
if potential_target != -1:
|
|
# We found a larger index that can still fit onto a page, so use that
|
|
return potential_target, potential_data
|
|
else:
|
|
# No better match, use our index
|
|
return mid, compressed_data
|
|
|
|
|
|
def _trim_items_to_page(compressor: ZstdCompressor, page_size: int, items:list[T]):
|
|
# Find max number of items that fit on a page
|
|
return _binary_search_fitting_size(compressor, page_size, items, 0, len(items))
|
|
|
|
|
|
def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
|
|
num_fitting, serialised_data = _trim_items_to_page(compressor, page_size, items)
|
|
|
|
compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
|
|
assert len(compressed_data) <= page_size, "The data shouldn't get bigger"
|
|
return _pad_to_page_size(compressed_data, page_size)
|
|
|
|
|
|
def _pad_to_page_size(data: bytes, page_size: int):
|
|
page_length = len(data)
|
|
if page_length > page_size:
|
|
raise PageError(f"Data is too big ({page_length}) for page size ({page_size})")
|
|
padding = b'\x00' * (page_size - page_length)
|
|
page_data = data + padding
|
|
return page_data
|
|
|
|
|
|
class TinyIndex(Generic[T]):
|
|
def __init__(self, item_factory: Callable[..., T], index_path, mode='r'):
|
|
if mode not in {'r', 'w'}:
|
|
raise ValueError(f"Mode should be one of 'r' or 'w', got {mode}")
|
|
|
|
with open(index_path, 'rb') as index_file:
|
|
metadata_page = index_file.read(METADATA_SIZE)
|
|
|
|
metadata_bytes = metadata_page.rstrip(b'\x00')
|
|
metadata = TinyIndexMetadata.from_bytes(metadata_bytes)
|
|
if metadata.item_factory != item_factory.__name__:
|
|
raise ValueError(f"Metadata item factory '{metadata.item_factory}' in the index "
|
|
f"does not match the passed item factory: '{item_factory.__name__}'")
|
|
|
|
self.item_factory = item_factory
|
|
self.index_path = index_path
|
|
self.mode = mode
|
|
|
|
self.num_pages = metadata.num_pages
|
|
self.page_size = metadata.page_size
|
|
self.compressor = ZstdCompressor()
|
|
self.decompressor = ZstdDecompressor()
|
|
logger.info(f"Loaded index with {self.num_pages} pages and {self.page_size} page size")
|
|
self.index_file = None
|
|
self.mmap = None
|
|
|
|
def __enter__(self):
|
|
self.index_file = open(self.index_path, 'r+b')
|
|
prot = PROT_READ if self.mode == 'r' else PROT_READ | PROT_WRITE
|
|
self.mmap = mmap(self.index_file.fileno(), 0, prot=prot)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.mmap.close()
|
|
self.index_file.close()
|
|
|
|
def retrieve(self, key: str) -> List[T]:
|
|
index = self.get_key_page_index(key)
|
|
logger.debug(f"Retrieving index {index}")
|
|
return self.get_page(index)
|
|
|
|
def get_key_page_index(self, key) -> int:
|
|
key_hash = mmh3.hash(key, signed=False)
|
|
return key_hash % self.num_pages
|
|
|
|
def get_page(self, i) -> list[T]:
|
|
"""
|
|
Get the page at index i, decompress and deserialise it using JSON
|
|
"""
|
|
results = self._get_page_tuples(i)
|
|
return [self.item_factory(*item) for item in results]
|
|
|
|
def _get_page_tuples(self, i):
|
|
page_data = self.mmap[i * self.page_size + METADATA_SIZE:(i + 1) * self.page_size + METADATA_SIZE]
|
|
try:
|
|
decompressed_data = self.decompressor.decompress(page_data)
|
|
except ZstdError:
|
|
logger.exception(f"Error decompressing page data, content: {page_data}")
|
|
return []
|
|
return json.loads(decompressed_data.decode('utf8'))
|
|
|
|
def store_in_page(self, page_index: int, values: list[T]):
|
|
value_tuples = [astuple(value) for value in values]
|
|
self._write_page(value_tuples, page_index)
|
|
|
|
def _write_page(self, data, i: int):
|
|
"""
|
|
Serialise the data using JSON, compress it and store it at index i.
|
|
If the data is too big, it will store the first items in the list and discard the rest.
|
|
"""
|
|
if self.mode != 'w':
|
|
raise UnsupportedOperation("The file is open in read mode, you cannot write")
|
|
|
|
page_data = _get_page_data(self.compressor, self.page_size, data)
|
|
logger.debug(f"Got page data of length {len(page_data)}")
|
|
self.mmap[i * self.page_size + METADATA_SIZE:(i+1) * self.page_size + METADATA_SIZE] = page_data
|
|
|
|
@staticmethod
|
|
def create(item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
|
|
if os.path.isfile(index_path):
|
|
raise FileExistsError(f"Index file '{index_path}' already exists")
|
|
|
|
metadata = TinyIndexMetadata(VERSION, page_size, num_pages, item_factory.__name__)
|
|
metadata_bytes = metadata.to_bytes()
|
|
metadata_padded = _pad_to_page_size(metadata_bytes, METADATA_SIZE)
|
|
|
|
compressor = ZstdCompressor()
|
|
page_bytes = _get_page_data(compressor, page_size, [])
|
|
|
|
with open(index_path, 'wb') as index_file:
|
|
index_file.write(metadata_padded)
|
|
for i in range(num_pages):
|
|
index_file.write(page_bytes)
|
|
|
|
return TinyIndex(item_factory, index_path=index_path)
|
|
|