|
1 | 1 | from typing_extensions import Generic, TypeVar |
2 | 2 |
|
3 | 3 | import pydantic_core |
| 4 | +import re |
| 5 | +import inspect |
4 | 6 |
|
5 | 7 | from typechat._internal.model import PromptSection, TypeChatLanguageModel |
6 | 8 | from typechat._internal.result import Failure, Result, Success |
@@ -123,3 +125,41 @@ def _create_repair_prompt(self, validation_error: str) -> str: |
123 | 125 | The following is a revised JSON object: |
124 | 126 | """ |
125 | 127 | return prompt |
| 128 | + |
| 129 | + def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=False): |
| 130 | + |
| 131 | + schema_path = inspect.getfile(schema_class) |
| 132 | + |
| 133 | + with open(schema_path, 'r') as file: |
| 134 | + schema_class_source = file.read() |
| 135 | + |
| 136 | + if debug: |
| 137 | + print("File contents before modification:") |
| 138 | + print("--"*50) |
| 139 | + print(schema_class_source) |
| 140 | + print("--"*50) |
| 141 | + |
| 142 | + pattern = r"(\w+\s*:\s*.*?)(?=\s*#\s*(.+?)(?:\n|\Z))" |
| 143 | + commented_fields = re.findall(pattern, schema_class_source) |
| 144 | + annotated_fields = [] |
| 145 | + |
| 146 | + for field, comment in commented_fields: |
| 147 | + field_separator = field.split(":") |
| 148 | + field_name = field_separator[0].strip() |
| 149 | + field_type = field_separator[1].strip() |
| 150 | + |
| 151 | + annotated_fields.append( |
| 152 | + f"{field_name}: Annotated[{field_type}, Doc(\"{comment}\")]") |
| 153 | + |
| 154 | + for field, annotation in zip(commented_fields, annotated_fields): |
| 155 | + schema_class_source = schema_class_source.replace(field[0], annotation) |
| 156 | + |
| 157 | + if debug: |
| 158 | + print("File contents after modification:") |
| 159 | + print("--"*50) |
| 160 | + print(schema_class_source) |
| 161 | + print("--"*50) |
| 162 | + |
| 163 | + namespace = {} |
| 164 | + exec(schema_class_source, namespace) |
| 165 | + return namespace[schema_class.__name__] |
0 commit comments