diff --git a/examples/main.py b/examples/main.py index 4b1a9cc..3f53a02 100644 --- a/examples/main.py +++ b/examples/main.py @@ -12,7 +12,7 @@ app = FastAPI() ret = 0 -@cache(expire=1) +@cache(namespace="test", expire=1) async def get_ret(): global ret ret = ret + 1 @@ -20,11 +20,16 @@ async def get_ret(): @app.get("/") -@cache(expire=2) +@cache(namespace="test", expire=2) async def index(request: Request, response: Response): return dict(ret=await get_ret()) +@app.get("/clear") +async def clear(): + return await FastAPICache.clear(namespace="test") + + @app.on_event("startup") async def startup(): FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache") diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index 1f050de..079aeb4 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -50,3 +50,8 @@ class FastAPICache: @classmethod def get_key_builder(cls): return cls._key_builder + + @classmethod + async def clear(cls, namespace: str = None, key: str = None): + namespace = cls._prefix + ":" + namespace if namespace else None + return await cls._backend.clear(namespace, key) diff --git a/fastapi_cache/backends/__init__.py b/fastapi_cache/backends/__init__.py index 0586110..8adaa31 100644 --- a/fastapi_cache/backends/__init__.py +++ b/fastapi_cache/backends/__init__.py @@ -14,3 +14,7 @@ class Backend: @abc.abstractmethod async def set(self, key: str, value: str, expire: int = None): raise NotImplementedError + + @abc.abstractmethod + async def clear(self, namespace: str = None, key: str = None) -> int: + raise NotImplementedError diff --git a/fastapi_cache/backends/inmemory.py b/fastapi_cache/backends/inmemory.py index 3cac3d4..c906609 100644 --- a/fastapi_cache/backends/inmemory.py +++ b/fastapi_cache/backends/inmemory.py @@ -1,4 +1,5 @@ import time +from copy import copy from dataclasses import dataclass from threading import Lock from typing import Dict, Optional, Tuple @@ -44,3 +45,16 @@ class InMemoryBackend(Backend): async def set(self, key: str, value: str, expire: int = None): with self._lock: self._store[key] = Value(value, self._now + expire) + + async def clear(self, namespace: str = None, key: str = None) -> int: + count = 0 + if namespace: + keys = list(self._store.keys()) + for key in keys: + if key.startswith(namespace): + del self._store[key] + count += 1 + elif key: + del self._store[key] + count += 1 + return count diff --git a/fastapi_cache/backends/mencache.py b/fastapi_cache/backends/mencache.py index fe79027..2e8d78d 100644 --- a/fastapi_cache/backends/mencache.py +++ b/fastapi_cache/backends/mencache.py @@ -17,3 +17,6 @@ class MemcacheBackend(Backend): async def set(self, key: str, value: str, expire: int = None): return await self.mcache.set(key.encode(), value.encode(), exptime=expire) + + async def clear(self, namespace: str = None, key: str = None): + raise NotImplementedError diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index 7188500..be67d3f 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -20,3 +20,10 @@ class RedisBackend(Backend): async def set(self, key: str, value: str, expire: int = None): return await self.redis.set(key, value, expire=expire) + + async def clear(self, namespace: str = None, key: str = None) -> int: + if namespace: + lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end" + return await self.redis.eval(lua) + elif key: + return await self.redis.delete(key)