mwmbl/mwmbl/indexer/batch_cache.py
2023-10-10 13:51:06 +01:00

91 lines
2.9 KiB
Python

"""
Store for local batches.
We store them in a directory on the local machine.
"""
import gzip
import json
import os
from logging import getLogger
from multiprocessing.pool import ThreadPool
from pathlib import Path
from urllib.parse import urlparse
from pydantic import ValidationError
from mwmbl.crawler.batch import HashedBatch
from mwmbl.database import Database
from mwmbl.indexer.indexdb import IndexDatabase, BatchStatus
from mwmbl.retry import retry_requests
logger = getLogger(__name__)
class BatchCache:
num_threads = 20
def __init__(self, repo_path):
os.makedirs(repo_path, exist_ok=True)
self.path = repo_path
def get_cached(self, batch_urls: list[str]) -> dict[str, HashedBatch]:
batches = {}
for url in batch_urls:
path = self.get_path_from_url(url)
try:
data = gzip.GzipFile(path).read()
except FileNotFoundError:
logger.exception(f"Missing batch file: {path}")
continue
try:
batch = HashedBatch.parse_raw(data)
except ValidationError:
logger.exception(f"Unable to parse batch, skipping: '{data}'")
continue
batches[url] = batch
return batches
def retrieve_batches(self, num_batches):
with Database() as db:
index_db = IndexDatabase(db.connection)
index_db.create_tables()
with Database() as db:
index_db = IndexDatabase(db.connection)
batches = index_db.get_batches_by_status(BatchStatus.REMOTE, num_batches)
logger.info(f"Found {len(batches)} remote batches")
if len(batches) == 0:
return
urls = [batch.url for batch in batches]
pool = ThreadPool(self.num_threads)
results = pool.imap_unordered(self.retrieve_batch, urls)
total_processed = 0
for result in results:
total_processed += result
logger.info(f"Processed batches with {total_processed} items")
index_db.update_batch_status(urls, BatchStatus.LOCAL)
def retrieve_batch(self, url):
data = json.loads(gzip.decompress(retry_requests.get(url).content))
try:
batch = HashedBatch.parse_obj(data)
except ValidationError:
logger.info(f"Failed to validate batch {data}")
return 0
if len(batch.items) > 0:
self.store(batch, url)
return len(batch.items)
def store(self, batch, url):
path = self.get_path_from_url(url)
logger.debug(f"Storing local batch at {path}")
os.makedirs(path.parent, exist_ok=True)
with open(path, 'wb') as output_file:
data = gzip.compress(batch.json().encode('utf8'))
output_file.write(data)
def get_path_from_url(self, url) -> Path:
url_path = urlparse(url).path
return Path(self.path) / url_path.lstrip('/')