from __future__ import annotations
import asyncio
import pickle
from typing import TYPE_CHECKING, Generic, TypeVar
from zarr.storage import WrapperStore
if TYPE_CHECKING:
from typing import Any
from zarr.abc.store import ByteRequest
from zarr.core.buffer.core import BufferPrototype
import pytest
from zarr.abc.store import (
ByteRequest,
OffsetByteRequest,
RangeByteRequest,
Store,
SuffixByteRequest,
)
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.sync import _collect_aiterator
from zarr.storage._utils import _normalize_byte_range_index
from zarr.testing.utils import assert_bytes_equal
__all__ = ["StoreTests"]
S = TypeVar("S", bound=Store)
B = TypeVar("B", bound=Buffer)
[docs]
class StoreTests(Generic[S, B]):
[docs]
async def set(self, store: S, key: str, value: Buffer) -> None:
"""
Insert a value into a storage backend, with a specific key.
This should not not use any store methods. Bypassing the store methods allows them to be
tested.
"""
raise NotImplementedError
[docs]
async def get(self, store: S, key: str) -> Buffer:
"""
Retrieve a value from a storage backend, by key.
This should not not use any store methods. Bypassing the store methods allows them to be
tested.
"""
raise NotImplementedError
@pytest.fixture
[docs]
def store_kwargs(self) -> dict[str, Any]:
return {"read_only": False}
@pytest.fixture
[docs]
async def store(self, store_kwargs: dict[str, Any]) -> Store:
return await self.store_cls.open(**store_kwargs)
[docs]
def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)
[docs]
def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None:
# check self equality
assert store == store
# check store equality with same inputs
# asserting this is important for being able to compare (de)serialized stores
store2 = self.store_cls(**store_kwargs)
assert store == store2
[docs]
def test_serializable_store(self, store: S) -> None:
foo = pickle.dumps(store)
assert pickle.loads(foo) == store
[docs]
def test_store_read_only(self, store: S) -> None:
assert not store.read_only
with pytest.raises(AttributeError):
store.read_only = False # type: ignore[misc]
@pytest.mark.parametrize("read_only", [True, False])
[docs]
async def test_store_open_read_only(
self, store_kwargs: dict[str, Any], read_only: bool
) -> None:
store_kwargs["read_only"] = read_only
store = await self.store_cls.open(**store_kwargs)
assert store._is_open
assert store.read_only == read_only
[docs]
async def test_read_only_store_raises(self, store_kwargs: dict[str, Any]) -> None:
kwargs = {**store_kwargs, "read_only": True}
store = await self.store_cls.open(**kwargs)
assert store.read_only
# set
with pytest.raises(ValueError):
await store.set("foo", self.buffer_cls.from_bytes(b"bar"))
# delete
with pytest.raises(ValueError):
await store.delete("foo")
[docs]
def test_store_repr(self, store: S) -> None:
raise NotImplementedError
[docs]
def test_store_supports_writes(self, store: S) -> None:
raise NotImplementedError
[docs]
def test_store_supports_partial_writes(self, store: S) -> None:
raise NotImplementedError
[docs]
def test_store_supports_listing(self, store: S) -> None:
raise NotImplementedError
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
@pytest.mark.parametrize(
"byte_range", [None, RangeByteRequest(1, 4), OffsetByteRequest(1), SuffixByteRequest(1)]
)
[docs]
async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None:
"""
Ensure that data can be read from the store using the store.get method.
"""
data_buf = self.buffer_cls.from_bytes(data)
await self.set(store, key, data_buf)
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range)
expected = data_buf[start:stop]
assert_bytes_equal(observed, expected)
[docs]
async def test_get_many(self, store: S) -> None:
"""
Ensure that multiple keys can be retrieved at once with the _get_many method.
"""
keys = tuple(map(str, range(10)))
values = tuple(f"{k}".encode() for k in keys)
for k, v in zip(keys, values, strict=False):
await self.set(store, k, self.buffer_cls.from_bytes(v))
observed_buffers = await _collect_aiterator(
store._get_many(
zip(
keys,
(default_buffer_prototype(),) * len(keys),
(None,) * len(keys),
strict=False,
)
)
)
observed_kvs = sorted(((k, b.to_bytes()) for k, b in observed_buffers)) # type: ignore[union-attr]
expected_kvs = sorted(((k, b) for k, b in zip(keys, values, strict=False)))
assert observed_kvs == expected_kvs
@pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
[docs]
async def test_set(self, store: S, key: str, data: bytes) -> None:
"""
Ensure that data can be written to the store using the store.set method.
"""
assert not store.read_only
data_buf = self.buffer_cls.from_bytes(data)
await store.set(key, data_buf)
observed = await self.get(store, key)
assert_bytes_equal(observed, data_buf)
[docs]
async def test_set_many(self, store: S) -> None:
"""
Test that a dict of key : value pairs can be inserted into the store via the
`_set_many` method.
"""
keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]
data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys]
store_dict = dict(zip(keys, data_buf, strict=True))
await store._set_many(store_dict.items())
for k, v in store_dict.items():
assert (await self.get(store, k)).to_bytes() == v.to_bytes()
@pytest.mark.parametrize(
"key_ranges",
[
[],
[("zarr.json", RangeByteRequest(0, 2))],
[("c/0", RangeByteRequest(0, 2)), ("zarr.json", None)],
[
("c/0/0", RangeByteRequest(0, 2)),
("c/0/1", SuffixByteRequest(2)),
("c/0/2", OffsetByteRequest(2)),
],
],
)
[docs]
async def test_get_partial_values(
self, store: S, key_ranges: list[tuple[str, ByteRequest]]
) -> None:
# put all of the data
for key, _ in key_ranges:
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))
# read back just part of it
observed_maybe = await store.get_partial_values(
prototype=default_buffer_prototype(), key_ranges=key_ranges
)
observed: list[Buffer] = []
expected: list[Buffer] = []
for obs in observed_maybe:
assert obs is not None
observed.append(obs)
for idx in range(len(observed)):
key, byte_range = key_ranges[idx]
result = await store.get(
key, prototype=default_buffer_prototype(), byte_range=byte_range
)
assert result is not None
expected.append(result)
assert all(
obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True)
)
[docs]
async def test_exists(self, store: S) -> None:
assert not await store.exists("foo")
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
assert await store.exists("foo/zarr.json")
[docs]
async def test_delete(self, store: S) -> None:
if not store.supports_deletes:
pytest.skip("store does not support deletes")
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
assert await store.exists("foo/zarr.json")
await store.delete("foo/zarr.json")
assert not await store.exists("foo/zarr.json")
[docs]
async def test_delete_dir(self, store: S) -> None:
if not store.supports_deletes:
pytest.skip("store does not support deletes")
await store.set("zarr.json", self.buffer_cls.from_bytes(b"root"))
await store.set("foo-bar/zarr.json", self.buffer_cls.from_bytes(b"root"))
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
await store.set("foo/c/0", self.buffer_cls.from_bytes(b"chunk"))
await store.delete_dir("foo")
assert await store.exists("zarr.json")
assert await store.exists("foo-bar/zarr.json")
assert not await store.exists("foo/zarr.json")
assert not await store.exists("foo/c/0")
[docs]
async def test_is_empty(self, store: S) -> None:
assert await store.is_empty("")
await self.set(
store, "foo/bar", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
)
assert not await store.is_empty("")
assert await store.is_empty("fo")
assert not await store.is_empty("foo/")
assert not await store.is_empty("foo")
assert await store.is_empty("spam/")
[docs]
async def test_clear(self, store: S) -> None:
await self.set(
store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8"))
)
await store.clear()
assert await store.is_empty("")
[docs]
async def test_list(self, store: S) -> None:
assert await _collect_aiterator(store.list()) == ()
prefix = "foo"
data = self.buffer_cls.from_bytes(b"")
store_dict = {
prefix + "/zarr.json": data,
**{prefix + f"/c/{idx}": data for idx in range(10)},
}
await store._set_many(store_dict.items())
expected_sorted = sorted(store_dict.keys())
observed = await _collect_aiterator(store.list())
observed_sorted = sorted(observed)
assert observed_sorted == expected_sorted
[docs]
async def test_list_prefix(self, store: S) -> None:
"""
Test that the `list_prefix` method works as intended. Given a prefix, it should return
all the keys in storage that start with this prefix.
"""
prefixes = ("", "a/", "a/b/", "a/b/c/")
data = self.buffer_cls.from_bytes(b"")
fname = "zarr.json"
store_dict = {p + fname: data for p in prefixes}
await store._set_many(store_dict.items())
for prefix in prefixes:
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix))))
expected: tuple[str, ...] = ()
for key in store_dict:
if key.startswith(prefix):
expected += (key,)
expected = tuple(sorted(expected))
assert observed == expected
[docs]
async def test_list_dir(self, store: S) -> None:
root = "foo"
store_dict = {
root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"),
root + "/c/1": self.buffer_cls.from_bytes(b"\x01"),
}
assert await _collect_aiterator(store.list_dir("")) == ()
assert await _collect_aiterator(store.list_dir(root)) == ()
await store._set_many(store_dict.items())
keys_observed = await _collect_aiterator(store.list_dir(root))
keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict}
assert sorted(keys_observed) == sorted(keys_expected)
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)
[docs]
async def test_set_if_not_exists(self, store: S) -> None:
key = "k"
data_buf = self.buffer_cls.from_bytes(b"0000")
await self.set(store, key, data_buf)
new = self.buffer_cls.from_bytes(b"1111")
await store.set_if_not_exists("k", new) # no error
result = await store.get(key, default_buffer_prototype())
assert result == data_buf
await store.set_if_not_exists("k2", new) # no error
result = await store.get("k2", default_buffer_prototype())
assert result == new
class LatencyStore(WrapperStore[Store]):
"""
A wrapper class that takes any store class in its constructor and
adds latency to the `set` and `get` methods. This can be used for
performance testing.
"""
get_latency: float
set_latency: float
def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None:
self.get_latency = float(get_latency)
self.set_latency = float(set_latency)
self._store = cls
async def set(self, key: str, value: Buffer) -> None:
"""
Add latency to the ``set`` method.
Calls ``asyncio.sleep(self.set_latency)`` before invoking the wrapped ``set`` method.
Parameters
----------
key : str
The key to set
value : Buffer
The value to set
Returns
-------
None
"""
await asyncio.sleep(self.set_latency)
await self._store.set(key, value)
async def get(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
"""
Add latency to the ``get`` method.
Calls ``asyncio.sleep(self.get_latency)`` before invoking the wrapped ``get`` method.
Parameters
----------
key : str
The key to get
prototype : BufferPrototype
The BufferPrototype to use.
byte_range : ByteRequest, optional
An optional byte range.
Returns
-------
buffer : Buffer or None
"""
await asyncio.sleep(self.get_latency)
return await self._store.get(key, prototype=prototype, byte_range=byte_range)