@@ -10,22 +10,47 @@ class Adapter(AsyncAdapter):
1010
1111 def __init__ (
1212 self ,
13- uri ,
14- dbname ,
13+ uri = None ,
14+ dbname = None ,
1515 collection = "casbin_rule" ,
1616 filtered = False ,
17+ client = None ,
18+ db_name = None ,
1719 ):
1820 """Create an adapter for Mongodb
1921
2022 Args:
21- uri (str): This should be the same requiement as pymongo Client's 'uri' parameter.
23+ uri (str, optional ): This should be the same requiement as pymongo Client's 'uri' parameter.
2224 See https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient.
23- dbname (str): Database to store policy.
25+ Required if client is not provided.
26+ dbname (str, optional): Database to store policy. Required if client is not provided.
2427 collection (str, optional): Collection of the choosen database. Defaults to "casbin_rule".
2528 filtered (bool, optional): Whether to use filtered query. Defaults to False.
29+ client (AsyncMongoClient, optional): An existing AsyncMongoClient instance to reuse. If provided, uri is ignored.
30+ db_name (str, optional): Database name to use with the provided client. Takes precedence over dbname.
31+
32+ Note:
33+ When both client and uri are provided, client takes precedence and uri is ignored.
2634 """
27- client = AsyncMongoClient (uri )
28- db = client [dbname ]
35+ # Support both db_name and dbname for backward compatibility
36+ database_name = db_name if db_name is not None else dbname
37+
38+ if client is not None :
39+ # Use the provided client
40+ if database_name is None :
41+ raise ValueError (
42+ "db_name or dbname must be provided when using an existing client"
43+ )
44+ mongo_client = client
45+ else :
46+ # Create a new client from URI
47+ if uri is None :
48+ raise ValueError ("uri must be provided when client is not specified" )
49+ if database_name is None :
50+ raise ValueError ("dbname must be provided when client is not specified" )
51+ mongo_client = AsyncMongoClient (uri )
52+
53+ db = mongo_client [database_name ]
2954 self ._collection = db [collection ]
3055 self ._filtered = filtered
3156
0 commit comments