Skip to content

Commit 5c61fb5

Browse files
authored
Enhance function call handling to support nested functions (#245)
* Enhance function call handling to support nested functions, improving modularity and reusability in QASM programs. Update tests to validate new functionality. * update changelog * restructure changelog
1 parent e79c20f commit 5c61fb5

5 files changed

Lines changed: 117 additions & 66 deletions

File tree

CHANGELOG.md

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,27 @@ Types of changes:
1818
- A new discussion template for issues in pyqasm ([#213](https://github.com/qBraid/pyqasm/pull/213))
1919
- A github workflow for validating `CHANGELOG` updates in a PR ([#214](https://github.com/qBraid/pyqasm/pull/214))
2020
- Added `unroll` command support in PYQASM CLI with options skipping files, overwriting originals files, and specifying output paths.([#224](https://github.com/qBraid/pyqasm/pull/224))
21+
- Added `Duration`,`Stretch` type, `Delay` and `Box` support for `OPENQASM3` code in pyqasm. ([#231](https://github.com/qBraid/pyqasm/pull/231))
22+
###### Example:
23+
```qasm
24+
OPENQASM 3.0;
25+
include "stdgates.inc";
26+
qubit[3] q;
27+
duration t1 = 200dt;
28+
duration t2 = 300ns;
29+
stretch s1;
30+
delay[t1] q[0];
31+
delay[t2] q[1];
32+
delay[s1] q[0], q[2];
33+
box [t2] {
34+
h q[0];
35+
cx q[0], q[1];
36+
delay[100ns] q[2];
37+
}
38+
```
39+
- Added a new `QasmModule.compare` method to compare two QASM modules, providing a detailed report of differences in gates, qubits, and measurements. This method is useful for comparing two identifying differences in QASM programs, their structure and operations. ([#233](https://github.com/qBraid/pyqasm/pull/233))
2140
- Added `.github/copilot-instructions.md` to the repository to document coding standards and design principles for pyqasm. This file provides detailed guidance on documentation, static typing, formatting, error handling, and adherence to the QASM specification for all code contributions. ([#234](https://github.com/qBraid/pyqasm/pull/234))
41+
- Added support for custom include statements in `OPENQASM3` code in pyqasm. This allows users to include custom files or libraries in their QASM programs, enhancing modularity and reusability of code. ([#236](https://github.com/qBraid/pyqasm/pull/236))
2242
- Added support for `Angle`,`extern` and `Complex` type in `OPENQASM3` code in pyqasm. ([#239](https://github.com/qBraid/pyqasm/pull/239))
2343
###### Example:
2444
```qasm
@@ -43,31 +63,13 @@ Types of changes:
4363
extern func6(bit[4]) -> bit[4];
4464
bit[4] be1 = func6(bd);
4565
```
46-
- Added a new `QasmModule.compare` method to compare two QASM modules, providing a detailed report of differences in gates, qubits, and measurements. This method is useful for comparing two identifying differences in QASM programs, their structure and operations. ([#233](https://github.com/qBraid/pyqasm/pull/233))
4766

4867
### Improved / Modified
4968
- Added `slots=True` parameter to the data classes in `elements.py` to improve memory efficiency ([#218](https://github.com/qBraid/pyqasm/pull/218))
5069
- Updated the documentation to include core features in the `README` ([#219](https://github.com/qBraid/pyqasm/pull/219))
5170
- Added support to `device qubit` resgister consolidation.([#222](https://github.com/qBraid/pyqasm/pull/222))
5271
- Updated the scoping of variables in `QasmVisitor` using a `ScopeManager`. This change is introduced to ensure that the `QasmVisitor` and the `PulseVisitor` can share the same `ScopeManager` instance, allowing for consistent variable scoping across different visitors. No change in the user API is expected. ([#232](https://github.com/qBraid/pyqasm/pull/232))
53-
- Added `Duration`,`Stretch` type, `Delay` and `Box` support for `OPENQASM3` code in pyqasm. ([#231](https://github.com/qBraid/pyqasm/pull/231))
54-
###### Example:
55-
```qasm
56-
OPENQASM 3.0;
57-
include "stdgates.inc";
58-
qubit[3] q;
59-
duration t1 = 200dt;
60-
duration t2 = 300ns;
61-
stretch s1;
62-
delay[t1] q[0];
63-
delay[t2] q[1];
64-
delay[s1] q[0], q[2];
65-
box [t2] {
66-
h q[0];
67-
cx q[0], q[1];
68-
delay[100ns] q[2];
69-
}
70-
```
72+
- Enhance function call handling by adding support for nested functions. This change allows for more complex function definitions and calls, enabling better modularity and reusability of code within QASM programs. ([#245](https://github.com/qBraid/pyqasm/pull/245))
7173

7274
### Deprecated
7375

@@ -78,10 +80,12 @@ Types of changes:
7880
- Fixed depth calculation for decomposable gates by computing depth of each constituent quantum gate.([#211](https://github.com/qBraid/pyqasm/pull/211))
7981
- Optimized statement copying in `_visit_function_call` with shallow-copy fallback to deepcopy and added `max_loop_iters` loop‐limit check in for loops.([#223](https://github.com/qBraid/pyqasm/pull/223))
8082

83+
8184
### Dependencies
8285
- Add `pillow<11.3.0` dependency for test and visualization to avoid CI errors in Linux builds ([#226](https://github.com/qBraid/pyqasm/pull/226))
8386
- Added `tabulate` to the testing dependencies to support new comparison table tests. ([#216](https://github.com/qBraid/pyqasm/pull/216))
84-
87+
- Update `docutils` requirement from <0.22 to <0.23 ([#241](https://github.com/qBraid/pyqasm/pull/241))
88+
- Bumps `actions/download-artifact` version from 4 to 5 ([#243](https://github.com/qBraid/pyqasm/pull/243))
8589
### Other
8690

8791
## Past Release Notes

src/pyqasm/subroutines.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def process_quantum_arg( # pylint: disable=too-many-locals
456456
cls,
457457
formal_arg,
458458
actual_arg,
459+
actual_qreg_size_map,
459460
formal_qreg_size_map,
460461
duplicate_qubit_map,
461462
qubit_transform_map,
@@ -468,6 +469,7 @@ def process_quantum_arg( # pylint: disable=too-many-locals
468469
Args:
469470
formal_arg (Qasm3Expression): The formal argument in the function signature.
470471
actual_arg (Qasm3Expression): The actual argument passed to the function.
472+
actual_qreg_size_map (dict): The map of actual quantum register sizes.
471473
formal_qreg_size_map (dict): The map of formal quantum register sizes.
472474
duplicate_qubit_map (dict): The map of duplicate qubit registers.
473475
qubit_transform_map (dict): The map of qubit register transformations.
@@ -504,7 +506,9 @@ def process_quantum_arg( # pylint: disable=too-many-locals
504506
# we expect that actual arg is qubit type only
505507
# note that we ONLY check in global scope as
506508
# we always map the qubit arguments to the global scope
507-
if actual_arg_name not in cls.visitor_obj._global_qreg_size_map:
509+
actual_arg_var = cls.visitor_obj._scope_manager.get_from_visible_scope(actual_arg_name)
510+
511+
if actual_arg_var is None or not actual_arg_var.is_qubit:
508512
# Check if the actual argument is a qubit register
509513
is_literal = actual_arg_name is None
510514
arg_desc = (
@@ -522,7 +526,7 @@ def process_quantum_arg( # pylint: disable=too-many-locals
522526
)
523527

524528
actual_qids, actual_qubits_size = Qasm3Transformer.get_target_qubits(
525-
actual_arg, cls.visitor_obj._global_qreg_size_map, actual_arg_name
529+
actual_arg, actual_qreg_size_map, actual_arg_name
526530
)
527531

528532
if formal_qubit_size != actual_qubits_size:

src/pyqasm/transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ def get_branch_params(
343343
def transform_function_qubits(
344344
cls,
345345
q_op: QuantumGate | QuantumBarrier | QuantumReset | QuantumPhase,
346-
qubit_map: dict[tuple, tuple],
346+
qubit_transform_map: dict[tuple, tuple],
347+
qubit_sizes: dict[str, int],
347348
) -> list[IndexedIdentifier]:
348349
"""Transform the qubits of a function call to the actual qubits.
349350
@@ -356,15 +357,17 @@ def transform_function_qubits(
356357
Returns:
357358
None
358359
"""
359-
expanded_op_qubits = cls.visitor_obj._get_op_bits(q_op)
360+
expanded_op_qubits = cls.visitor_obj._get_op_bits(q_op, function_qubit_sizes=qubit_sizes)
360361

361362
transformed_qubits = []
362363
for qubit in expanded_op_qubits:
363364
formal_qreg_name = qubit.name.name
364365
formal_qreg_idx = qubit.indices[0][0].value
365366

366367
# replace the formal qubit with the actual qubit
367-
actual_qreg_name, actual_qreg_idx = qubit_map[(formal_qreg_name, formal_qreg_idx)]
368+
actual_qreg_name, actual_qreg_idx = qubit_transform_map[
369+
(formal_qreg_name, formal_qreg_idx)
370+
]
368371
transformed_qubits.append(
369372
IndexedIdentifier(
370373
Identifier(actual_qreg_name),

src/pyqasm/visitor.py

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def _visit_quantum_register(
218218

219219
# pylint: disable-next=too-many-locals,too-many-branches
220220
def _get_op_bits(
221-
self, operation: Any, qubits: bool = True
221+
self,
222+
operation: Any,
223+
qubits: bool = True,
224+
function_qubit_sizes: Optional[dict[str, int]] = None,
222225
) -> list[qasm3_ast.IndexedIdentifier]:
223226
"""Get the quantum / classical bits for the operation.
224227
@@ -258,19 +261,28 @@ def _get_op_bits(
258261
else:
259262
reg_name = bit.name
260263

264+
max_register_size = 0
261265
reg_var = self._scope_manager.get_from_visible_scope(reg_name)
262266
if reg_var is None:
263-
err_msg = (
264-
f"Missing {'qubit' if qubits else 'clbit'} register declaration "
265-
f"for '{reg_name}' in {type(operation).__name__}"
266-
)
267-
raise_qasm3_error(
268-
err_msg,
269-
error_node=operation,
270-
span=operation.span,
271-
)
272-
assert isinstance(reg_var, Variable)
273-
max_register_size = reg_var.base_size
267+
if function_qubit_sizes is None:
268+
err_msg = (
269+
f"Missing {'qubit' if qubits else 'clbit'} register declaration "
270+
f"for '{reg_name}' in {type(operation).__name__}"
271+
)
272+
raise_qasm3_error(
273+
err_msg,
274+
error_node=operation,
275+
span=operation.span,
276+
)
277+
# we are trying to replace the qubits inside a nested function
278+
assert function_qubit_sizes is not None
279+
reg_size = function_qubit_sizes.get(reg_name, None)
280+
if reg_size is not None:
281+
max_register_size = reg_size
282+
283+
if reg_var:
284+
assert isinstance(reg_var, Variable)
285+
max_register_size = reg_var.base_size
274286

275287
if isinstance(bit, qasm3_ast.IndexedIdentifier):
276288
if isinstance(bit.indices[0], qasm3_ast.DiscreteSet):
@@ -293,7 +305,7 @@ def _get_op_bits(
293305
else:
294306
bit_ids = list(range(max_register_size))
295307

296-
if reg_var.is_alias:
308+
if reg_var and reg_var.is_alias:
297309
original_reg_name, _ = self._alias_qubit_labels[(reg_name, bit_ids[0])]
298310
bit_ids = [
299311
self._alias_qubit_labels[(reg_name, bit_id)][1] # gives (original_reg, index)
@@ -596,14 +608,19 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> list[qasm3_ast.Quan
596608
"""
597609
logger.debug("Visiting reset statement '%s'", str(statement))
598610
if len(self._function_qreg_size_map) > 0: # atleast in SOME function scope
599-
# transform qubits to use the global qreg identifiers
600-
statement.qubits = (
601-
Qasm3Transformer.transform_function_qubits( # type: ignore[assignment]
602-
statement,
603-
self._function_qreg_transform_map[-1],
611+
# since we may have multiple function scopes, we need to transform the qubits
612+
# to use the global qreg identifiers
613+
for transform_map, size_map in zip(
614+
reversed(self._function_qreg_transform_map), reversed(self._function_qreg_size_map)
615+
):
616+
statement.qubits = (
617+
Qasm3Transformer.transform_function_qubits( # type: ignore[assignment]
618+
statement,
619+
transform_map,
620+
size_map,
621+
)
604622
)
605-
)
606-
qubit_ids = self._get_op_bits(statement, True)
623+
qubit_ids = self._get_op_bits(statement, qubits=True)
607624

608625
unrolled_resets = []
609626
for qid in qubit_ids:
@@ -648,17 +665,22 @@ def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches
648665
"""
649666
# if barrier is applied to ALL qubits at once, we are fine
650667
if len(self._function_qreg_size_map) > 0: # atleast in SOME function scope
651-
# transform qubits to use the global qreg identifiers
668+
# we have multiple function scopes, so we need to transform the qubits
669+
# to use the global qreg identifiers
652670

653671
# since we are changing the qubits to IndexedIdentifiers, we need to supress the
654672
# error for the type checker
655-
barrier.qubits = (
656-
Qasm3Transformer.transform_function_qubits( # type: ignore [assignment]
657-
barrier,
658-
self._function_qreg_transform_map[-1],
673+
for transform_map, size_map in zip(
674+
reversed(self._function_qreg_transform_map), reversed(self._function_qreg_size_map)
675+
):
676+
barrier.qubits = (
677+
Qasm3Transformer.transform_function_qubits( # type: ignore [assignment]
678+
barrier,
679+
transform_map,
680+
size_map,
681+
)
659682
)
660-
)
661-
barrier_qubits = self._get_op_bits(barrier)
683+
barrier_qubits = self._get_op_bits(barrier, qubits=True)
662684
unrolled_barriers = []
663685
max_involved_depth = 0
664686
for qubit in barrier_qubits:
@@ -763,7 +785,7 @@ def _unroll_multiple_target_qubits(
763785
Returns:
764786
The list of all targets that the unrolled gate should act on.
765787
"""
766-
op_qubits = self._get_op_bits(operation)
788+
op_qubits = self._get_op_bits(operation, qubits=True)
767789
if len(op_qubits) <= 0 or len(op_qubits) % gate_qubit_count != 0:
768790
raise_qasm3_error(
769791
f"Invalid number of qubits {len(op_qubits)} for operation {operation.name.name}",
@@ -981,7 +1003,7 @@ def _visit_custom_gate_operation(
9811003
gate_name: str = operation.name.name
9821004
gate_definition: qasm3_ast.QuantumGateDefinition = self._custom_gates[gate_name]
9831005
op_qubits: list[qasm3_ast.IndexedIdentifier] = self._get_op_bits(
984-
operation
1006+
operation, qubits=True
9851007
) # type: ignore [assignment]
9861008

9871009
Qasm3Validator.validate_gate_call(operation, gate_definition, len(op_qubits))
@@ -1223,12 +1245,14 @@ def _visit_generic_gate_operation( # pylint: disable=too-many-branches, too-man
12231245
):
12241246
# we are in SOME function scope
12251247
# transform qubits to use the global qreg identifiers
1226-
operation.qubits = (
1227-
Qasm3Transformer.transform_function_qubits( # type: ignore [assignment]
1228-
operation,
1229-
self._function_qreg_transform_map[-1],
1248+
for transform_map, size_map in zip(
1249+
reversed(self._function_qreg_transform_map), reversed(self._function_qreg_size_map)
1250+
):
1251+
operation.qubits = (
1252+
Qasm3Transformer.transform_function_qubits( # type: ignore [assignment]
1253+
operation, transform_map, size_map
1254+
)
12301255
)
1231-
)
12321256

12331257
operation.qubits = self._get_op_bits(operation, qubits=True) # type: ignore
12341258

@@ -2176,6 +2200,11 @@ def _visit_function_call(
21762200
duplicate_qubit_detect_map: dict = {}
21772201
qubit_transform_map: dict = {} # {(formal arg, idx) : (actual arg, idx)}
21782202
formal_qreg_size_map: dict = {}
2203+
actual_qreg_size_map: dict = (
2204+
self._function_qreg_size_map[-1]
2205+
if self._function_qreg_size_map
2206+
else self._global_qreg_size_map
2207+
)
21792208

21802209
quantum_vars, classical_vars = [], []
21812210
for actual_arg, formal_arg in zip(statement.arguments, subroutine_def.arguments):
@@ -2190,6 +2219,7 @@ def _visit_function_call(
21902219
Qasm3SubroutineProcessor.process_quantum_arg(
21912220
formal_arg,
21922221
actual_arg,
2222+
actual_qreg_size_map,
21932223
formal_qreg_size_map,
21942224
duplicate_qubit_detect_map,
21952225
qubit_transform_map,

tests/qasm3/subroutines/test_subroutines.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -277,31 +277,41 @@ def my_function(qubit a, float[32] b) {
277277
check_single_qubit_rotation_op(result.unrolled_ast, 2, [0, 0], [3.14, 6.28], "rx")
278278

279279

280-
@pytest.mark.skip(reason="Not implemented nested functions yet")
281280
def test_function_call_from_within_fn():
282281
"""Test that a function call from within another function is correctly converted."""
283282
qasm_str = """OPENQASM 3.0;
284283
include "stdgates.inc";
285-
286-
def my_function(qubit q1) {
284+
def my_function(qubit q1, float[32] a) {
287285
h q1;
286+
rx(a) q1;
288287
return;
289288
}
290289
291-
def my_function_2(qubit[2] q2) {
292-
my_function(q2[1]);
290+
def my_function_2(qubit[2] q2, float[32] param) {
291+
my_function(q2[1], param);
293292
return;
294293
}
294+
295+
def my_function_3(qubit[2] q3) {
296+
float[32] a = 3.14;
297+
my_function_2(q3, a);
298+
my_function(q3[1], a);
299+
return;
300+
}
301+
295302
qubit[2] q;
296-
my_function_2(q);
303+
float[32] r = 3.14;
304+
my_function_2(q, r);
305+
my_function_3(q);
297306
"""
298307

299308
result = loads(qasm_str)
300309
result.unroll()
301310
assert result.num_clbits == 0
302311
assert result.num_qubits == 2
303312

304-
check_single_qubit_gate_op(result.unrolled_ast, 1, [1], "h")
313+
check_single_qubit_gate_op(result.unrolled_ast, 3, [1, 1, 1], "h")
314+
check_single_qubit_rotation_op(result.unrolled_ast, 3, [1] * 3, [3.14] * 3, "rx")
305315

306316

307317
@pytest.mark.skip(reason="Bug: qubit in function scope conflicts with global scope")

0 commit comments

Comments
 (0)