Skip to content

Commit 9e8da7d

Browse files
authored
Handle Class Constructor with arguments using __init__ method (#2775)
1 parent f4c4b94 commit 9e8da7d

4 files changed

Lines changed: 282 additions & 61 deletions

File tree

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ RUN(NAME lambda_01 LABELS cpython llvm llvm_jit)
834834

835835
RUN(NAME c_mangling LABELS cpython llvm llvm_jit c)
836836
RUN(NAME class_01 LABELS cpython llvm llvm_jit)
837+
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
837838

838839
# callback_04 is to test emulation. So just run with cpython
839840
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)

integration_tests/class_02.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from lpython import i32
2+
class Character:
3+
def __init__(self:"Character", name:str, health:i32, attack_power:i32):
4+
self.name :str = name
5+
self.health :i32 = health
6+
self.attack_power : i32 = attack_power
7+
self.is_immortal : bool = False
8+
9+
def attack(self:"Character", other:"Character") -> str:
10+
other.health -= self.attack_power
11+
return self.name+" attacks "+ other.name+" for "+str(self.attack_power)+" damage."
12+
13+
def is_alive(self:"Character")->bool:
14+
if self.is_immortal:
15+
return True
16+
else:
17+
return self.health > 0
18+
19+
def main():
20+
hero : Character = Character("Hero", 10, 20)
21+
monster : Character = Character("Monster", 50, 15)
22+
print(hero.attack(monster))
23+
print(monster.health)
24+
assert monster.health == 30
25+
print(monster.is_alive())
26+
assert monster.is_alive() == True
27+
print("Hero gains temporary immortality")
28+
hero.is_immortal = True
29+
print(monster.attack(hero))
30+
print(hero.health)
31+
assert hero. health == -5
32+
print(hero.is_alive())
33+
assert hero.is_alive() == True
34+
print("Hero's immortality runs out")
35+
hero.is_immortal = False
36+
print(hero.is_alive())
37+
assert hero.is_alive() == False
38+
print("Restarting")
39+
hero = Character("Hero", 10, 20)
40+
print(hero.is_alive())
41+
assert hero.is_alive() == True
42+
43+
main()
44+

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3087,6 +3087,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
30873087
}
30883088
}
30893089

3090+
void instantiate_methods(const ASR::Struct_t &x) {
3091+
SymbolTable *current_scope_copy = current_scope;
3092+
current_scope = x.m_symtab;
3093+
for ( auto &item : x.m_symtab->get_scope() ) {
3094+
if ( is_a<ASR::Function_t>(*item.second) ) {
3095+
ASR::Function_t *v = down_cast<ASR::Function_t>(item.second);
3096+
instantiate_function(*v);
3097+
}
3098+
}
3099+
current_scope = current_scope_copy;
3100+
}
3101+
3102+
void visit_methods (const ASR::Struct_t &x) {
3103+
SymbolTable *current_scope_copy = current_scope;
3104+
current_scope = x.m_symtab;
3105+
for ( auto &item : x.m_symtab->get_scope() ) {
3106+
if ( is_a<ASR::Function_t>(*item.second) ) {
3107+
ASR::Function_t *v = down_cast<ASR::Function_t>(item.second);
3108+
visit_Function(*v);
3109+
}
3110+
}
3111+
current_scope = current_scope_copy;
3112+
}
3113+
30903114
void start_module_init_function_prototype(const ASR::Module_t &x) {
30913115
uint32_t h = get_hash((ASR::asr_t*)&x);
30923116
llvm::FunctionType *function_type = llvm::FunctionType::get(
@@ -3128,6 +3152,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
31283152
} else if (is_a<ASR::EnumType_t>(*item.second)) {
31293153
ASR::EnumType_t *et = down_cast<ASR::EnumType_t>(item.second);
31303154
visit_EnumType(*et);
3155+
} else if (is_a<ASR::Struct_t>(*item.second)) {
3156+
ASR::Struct_t *st = down_cast<ASR::Struct_t>(item.second);
3157+
instantiate_methods(*st);
31313158
}
31323159
}
31333160
finish_module_init_function_prototype(x);
@@ -4179,6 +4206,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
41794206
if (is_a<ASR::Function_t>(*item.second)) {
41804207
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(item.second);
41814208
visit_Function(*s);
4209+
} else if ( is_a<ASR::Struct_t>(*item.second) ) {
4210+
ASR::Struct_t *st = down_cast<ASR::Struct_t>(item.second);
4211+
visit_methods(*st);
41824212
}
41834213
}
41844214
}

0 commit comments

Comments
 (0)