Skip to content

Commit 5a00fbf

Browse files
authored
Merge pull request #51 from Zxilly/master
fix: fix db_class not vaild
2 parents d834662 + 9de8561 commit 5a00fbf

2 files changed

Lines changed: 37 additions & 5 deletions

File tree

casbin_sqlalchemy_adapter/adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def __init__(self, engine, db_class=None, filtered=False):
5454

5555
if db_class is None:
5656
db_class = CasbinRule
57+
else:
58+
for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"):
59+
if not hasattr(db_class, attr):
60+
raise Exception(f"{attr} not found in custom DatabaseClass.")
61+
Base.metadata = db_class.metadata
62+
5763
self._db_class = db_class
5864
self.session_local = sessionmaker(bind=self._engine)
5965

tests/test_adapter.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import os
2+
from unittest import TestCase
3+
4+
import casbin
5+
from sqlalchemy import create_engine, Column, Integer, String
6+
from sqlalchemy.orm import sessionmaker
7+
18
from casbin_sqlalchemy_adapter import Adapter
29
from casbin_sqlalchemy_adapter import Base
310
from casbin_sqlalchemy_adapter import CasbinRule
411
from casbin_sqlalchemy_adapter.adapter import Filter
5-
from sqlalchemy import create_engine
6-
from sqlalchemy.orm import sessionmaker
7-
from unittest import TestCase
8-
import casbin
9-
import os
1012

1113

1214
def get_fixture(path):
@@ -35,6 +37,30 @@ def get_enforcer():
3537

3638

3739
class TestConfig(TestCase):
40+
def test_custom_db_class(self):
41+
class CustomRule(Base):
42+
__tablename__ = "casbin_rule2"
43+
44+
id = Column(Integer, primary_key=True)
45+
ptype = Column(String(255))
46+
v0 = Column(String(255))
47+
v1 = Column(String(255))
48+
v2 = Column(String(255))
49+
v3 = Column(String(255))
50+
v4 = Column(String(255))
51+
v5 = Column(String(255))
52+
not_exist = Column(String(255))
53+
54+
engine = create_engine("sqlite://")
55+
adapter = Adapter(engine, CustomRule)
56+
57+
session = sessionmaker(bind=engine)
58+
Base.metadata.create_all(engine)
59+
s = session()
60+
s.add(CustomRule(not_exist="NotNone"))
61+
s.commit()
62+
self.assertEqual(s.query(CustomRule).all()[0].not_exist, "NotNone")
63+
3864
def test_enforcer_basic(self):
3965
e = get_enforcer()
4066

0 commit comments

Comments
 (0)