Skip to content

Commit 81557ff

Browse files
authored
feat: code refactoring and bug fixing (#29)
* fix: typo in docstring * refactor: deletion of an unnecessary variable * refactor: DRY principle and correction E712 * refactor: Adding a check for the presence of created policies and a small code reduction * fix: correcting the variable - rows_created
1 parent 395e750 commit 81557ff

3 files changed

Lines changed: 16 additions & 26 deletions

File tree

casbin_adapter/adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def save_policy(self, model):
5353
for ptype, ast in model.model[sec].items():
5454
for rule in ast.policy:
5555
lines.append(self._create_policy_line(ptype, rule))
56-
CasbinRule.objects.using(self.db_alias).bulk_create(lines)
57-
return True
56+
rows_created = CasbinRule.objects.using(self.db_alias).bulk_create(lines)
57+
return len(rows_created) > 0
5858

5959
def add_policy(self, sec, ptype, rule):
6060
"""adds a policy rule to the storage."""
@@ -67,7 +67,7 @@ def remove_policy(self, sec, ptype, rule):
6767
for i, v in enumerate(rule):
6868
query_params["v{}".format(i)] = v
6969
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
70-
return True if rows_deleted > 0 else False
70+
return rows_deleted > 0
7171

7272
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
7373
"""removes policy rules that match the filter from the storage.
@@ -81,4 +81,4 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
8181
for i, v in enumerate(field_values):
8282
query_params["v{}".format(i + field_index)] = v
8383
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
84-
return True if rows_deleted > 0 else False
84+
return rows_deleted > 0

casbin_adapter/enforcer.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, *args, **kwargs):
2121
logger.info("Deferring casbin enforcer initialisation until django is ready")
2222

2323
def _load(self):
24-
if self._initialized == False:
24+
if self._initialized is False:
2525
logger.info("Performing deferred casbin enforcer initialisation")
2626
self._initialized = True
2727
model = getattr(settings, "CASBIN_MODEL")
@@ -63,24 +63,15 @@ def __getattribute__(self, name):
6363
def initialize_enforcer(db_alias=None):
6464
try:
6565
row = None
66-
if db_alias:
67-
with connections[db_alias].cursor() as cursor:
68-
cursor.execute(
69-
"""
70-
SELECT app, name applied FROM django_migrations
71-
WHERE app = 'casbin_adapter' AND name = '0001_initial';
72-
"""
73-
)
74-
row = cursor.fetchone()
75-
else:
76-
with connection.cursor() as cursor:
77-
cursor.execute(
78-
"""
79-
SELECT app, name applied FROM django_migrations
80-
WHERE app = 'casbin_adapter' AND name = '0001_initial';
81-
"""
82-
)
83-
row = cursor.fetchone()
66+
connect = connections[db_alias] if db_alias else connection
67+
with connect.cursor() as cursor:
68+
cursor.execute(
69+
"""
70+
SELECT app, name applied FROM django_migrations
71+
WHERE app = 'casbin_adapter' AND name = '0001_initial';
72+
"""
73+
)
74+
row = cursor.fetchone()
8475

8576
if row:
8677
enforcer._load()

casbin_adapter/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
def import_class(name):
55
"""Import class from string
6-
e.g. `package.module.ClassToImport` returns the `ClasToImport` class"""
6+
e.g. `package.module.ClassToImport` returns the `ClassToImport` class"""
77
components = name.split(".")
88
module_name = ".".join(components[:-1])
99
class_name = components[-1]
1010
module = importlib.import_module(module_name)
11-
class_ = getattr(module, class_name)
12-
return class_
11+
return getattr(module, class_name)

0 commit comments

Comments
 (0)