add clear method

This commit is contained in:
long2ice
2020-11-03 18:08:06 +08:00
parent dc2ac9cc90
commit e483e0dc55
6 changed files with 40 additions and 2 deletions

View File

@@ -12,7 +12,7 @@ app = FastAPI()
ret = 0 ret = 0
@cache(expire=1) @cache(namespace="test", expire=1)
async def get_ret(): async def get_ret():
global ret global ret
ret = ret + 1 ret = ret + 1
@@ -20,11 +20,16 @@ async def get_ret():
@app.get("/") @app.get("/")
@cache(expire=2) @cache(namespace="test", expire=2)
async def index(request: Request, response: Response): async def index(request: Request, response: Response):
return dict(ret=await get_ret()) return dict(ret=await get_ret())
@app.get("/clear")
async def clear():
return await FastAPICache.clear(namespace="test")
@app.on_event("startup") @app.on_event("startup")
async def startup(): async def startup():
FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache") FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache")

View File

@@ -50,3 +50,8 @@ class FastAPICache:
@classmethod @classmethod
def get_key_builder(cls): def get_key_builder(cls):
return cls._key_builder return cls._key_builder
@classmethod
async def clear(cls, namespace: str = None, key: str = None):
namespace = cls._prefix + ":" + namespace if namespace else None
return await cls._backend.clear(namespace, key)

View File

@@ -14,3 +14,7 @@ class Backend:
@abc.abstractmethod @abc.abstractmethod
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: int = None):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def clear(self, namespace: str = None, key: str = None) -> int:
raise NotImplementedError

View File

@@ -1,4 +1,5 @@
import time import time
from copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@@ -44,3 +45,16 @@ class InMemoryBackend(Backend):
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: int = None):
with self._lock: with self._lock:
self._store[key] = Value(value, self._now + expire) self._store[key] = Value(value, self._now + expire)
async def clear(self, namespace: str = None, key: str = None) -> int:
count = 0
if namespace:
keys = list(self._store.keys())
for key in keys:
if key.startswith(namespace):
del self._store[key]
count += 1
elif key:
del self._store[key]
count += 1
return count

View File

@@ -17,3 +17,6 @@ class MemcacheBackend(Backend):
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: int = None):
return await self.mcache.set(key.encode(), value.encode(), exptime=expire) return await self.mcache.set(key.encode(), value.encode(), exptime=expire)
async def clear(self, namespace: str = None, key: str = None):
raise NotImplementedError

View File

@@ -20,3 +20,10 @@ class RedisBackend(Backend):
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: int = None):
return await self.redis.set(key, value, expire=expire) return await self.redis.set(key, value, expire=expire)
async def clear(self, namespace: str = None, key: str = None) -> int:
if namespace:
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
return await self.redis.eval(lua)
elif key:
return await self.redis.delete(key)