diff --git a/cecli/args.py b/cecli/args.py index d51421d9012..eb05b51231e 100644 --- a/cecli/args.py +++ b/cecli/args.py @@ -121,6 +121,7 @@ def get_parser(default_config_files, git_root): ) group.add_argument( "--model-overrides", + "--model-settings", metavar="MODEL_OVERRIDES_JSON", help=( 'Specify model tag overrides directly as JSON/YAML string (e.g., \'{"gpt-4o": {"high":' diff --git a/cecli/commands/agent_model.py b/cecli/commands/agent_model.py index 04f9c5ad8f6..ceaa02ebf5f 100644 --- a/cecli/commands/agent_model.py +++ b/cecli/commands/agent_model.py @@ -108,7 +108,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for agent-model command.""" - return models.get_chat_model_names() + return models.get_chat_model_names(query=args) @classmethod def get_help(cls) -> str: diff --git a/cecli/commands/editor_model.py b/cecli/commands/editor_model.py index beb97a86ae9..b76190a56d5 100644 --- a/cecli/commands/editor_model.py +++ b/cecli/commands/editor_model.py @@ -107,7 +107,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for editor-model command.""" - return models.get_chat_model_names() + return models.get_chat_model_names(query=args) @classmethod def get_help(cls) -> str: diff --git a/cecli/commands/model.py b/cecli/commands/model.py index 8d4212b2f27..83bc18e6e5f 100644 --- a/cecli/commands/model.py +++ b/cecli/commands/model.py @@ -111,7 +111,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for model command.""" - return models.get_chat_model_names() + return models.get_chat_model_names(query=args) @classmethod def get_help(cls) -> str: diff --git a/cecli/commands/models.py b/cecli/commands/models.py index 57beb1b4f24..a1022b3e18b 100644 --- a/cecli/commands/models.py +++ b/cecli/commands/models.py @@ -24,7 +24,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for models command.""" - return models.get_chat_model_names() + return models.get_chat_model_names(query=args) @classmethod def get_help(cls) -> str: diff --git a/cecli/commands/weak_model.py b/cecli/commands/weak_model.py index 97f75c0c2f3..acff8a48e30 100644 --- a/cecli/commands/weak_model.py +++ b/cecli/commands/weak_model.py @@ -107,7 +107,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for weak-model command.""" - return models.get_chat_model_names() + return models.get_chat_model_names(query=args) @classmethod def get_help(cls) -> str: diff --git a/cecli/helpers/hashline.py b/cecli/helpers/hashline.py index 9368ace5be9..110c8ebab25 100644 --- a/cecli/helpers/hashline.py +++ b/cecli/helpers/hashline.py @@ -327,18 +327,46 @@ def get_hashline_diff( elif operation == "insert": find_text = "" # For insert operations, we need to calculate hashlines for the text to insert - # The text should be hashed starting at the line after the end line + # with surrounding context for proper neighborhood-based hashing if text: - # Insert after the end line, so start hashline at found_end + 2 (1-indexed) - replace_text = hashline(text, start_line=found_end + 2) + original_lines = original_content.splitlines() + text_lines = text.splitlines() + # Get up to 3 lines of context before (ending at found_end) and after the insertion point + ctx_before = original_lines[max(0, found_end - 2) : found_end + 1] + ctx_after = original_lines[found_end + 1 : min(len(original_lines), found_end + 4)] + # Build a mini document with context so HashPos computes correct neighborhood hashes + mini_lines = ctx_before + text_lines + ctx_after + mini_text = "\n".join(mini_lines) + hashed_mini = hashline(mini_text) + hashed_mini_lines = hashed_mini.splitlines(keepends=True) + # Extract only the replacement text portion's hashlines + replace_lines_hashed = hashed_mini_lines[ + len(ctx_before) : len(ctx_before) + len(text_lines) + ] + replace_text = "".join(replace_lines_hashed) else: replace_text = "" # For replace operation, we're replacing the range elif operation == "replace": find_text = original_range_content - # For replace operations, the replacement text should be hashed starting at the start line + # For replace operations, the replacement text should be hashed + # with surrounding context for proper neighborhood-based hashing if text: - replace_text = hashline(text, start_line=found_start + 1) + original_lines = original_content.splitlines() + text_lines = text.splitlines() + # Get up to 3 lines of context before and after the range + ctx_before = original_lines[max(0, found_start - 3) : found_start] + ctx_after = original_lines[found_end + 1 : min(len(original_lines), found_end + 4)] + # Build a mini document with context so HashPos computes correct neighborhood hashes + mini_lines = ctx_before + text_lines + ctx_after + mini_text = "\n".join(mini_lines) + hashed_mini = hashline(mini_text) + hashed_mini_lines = hashed_mini.splitlines(keepends=True) + # Extract only the replacement text portion's hashlines + replace_lines_hashed = hashed_mini_lines[ + len(ctx_before) : len(ctx_before) + len(text_lines) + ] + replace_text = "".join(replace_lines_hashed) else: replace_text = "" else: diff --git a/cecli/helpers/hashpos/hashpos.py b/cecli/helpers/hashpos/hashpos.py index 63897932ee6..dc26801ce26 100644 --- a/cecli/helpers/hashpos/hashpos.py +++ b/cecli/helpers/hashpos/hashpos.py @@ -5,8 +5,6 @@ class HashPos: B64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789~_" - # The actual coprime period (64 * 63) - PERIOD = 4032 # Regex pattern for HashPos format: {4-char-hash}:: HASH_PREFIX_RE = re.compile(r"^([0-9a-zA-Z\~_@]{4})::") # Regex for normalization: 4 hash chars optionally followed by '::' @@ -18,70 +16,53 @@ def __init__(self, source_text: str = ""): self.lines = source_text.splitlines() self.total = len(self.lines) - def _get_content_bits(self, text: str) -> int: - return xxhash.xxh3_64_intdigest(text.encode("utf-8")) & 0xFFF - - def _get_anchor_bits(self, line_idx: int) -> int: - a1 = (line_idx * 53 + 13) % 64 - a2 = (line_idx * 59 + 31) % 63 - return (a1 << 6) | a2 - - def _spread_bits(self, x: int) -> int: + def _get_region_bits(self, line_idx: int) -> tuple[int, int]: """ - Spreads 12 bits of x into 24 bits by inserting a 0 between each bit. - Input: 000000000000abcdefghijkl (12 bits) - Output: 0a0b0c0d0e0f0g0h0i0j0k0l (24 bits) + Uses line_idx modulo 16 (4 bits) to get two 2-bit flags (b1, b2). + This guarantees up to 16 consecutive repeating lines get unique spatial anchors. """ - x &= 0xFFF # Ensure we only have 12 bits - # Shift bits by 8, mask keeps the blocks separated - # x starts: 000000000000 abcdefgh ijkl - x = (x | (x << 8)) & 0x00FF00FF # 0000abcd efgh0000 00000000 ijkl... - # Shift by 4, then 2, then 1 to create 1-bit gaps - x = (x | (x << 4)) & 0x0F0F0F0F - x = (x | (x << 2)) & 0x33333333 - x = (x | (x << 1)) & 0x55555555 # Result: 0a0b0c0d0e0f0g0h0i0j0k0l - return x + mod_val = line_idx % 16 + + # Split the 4-bit modulo value into two separate 2-bit flags + b1 = (mod_val >> 2) & 3 # Top 2 bits (mask with 0b11) + b2 = mod_val & 3 # Bottom 2 bits + return b1, b2 - def _compact_bits(self, x: int) -> int: + def _get_neighborhood_hash(self, line_idx: int) -> int: """ - The inverse of spread: pulls every other bit back together. - Input: 0a0b0c0d0e0f0g0h0i0j0k0l (24 bits) - Output: 000000000000abcdefghijkl (12 bits) + Creates a 20-bit digest using the current line and the 3 lines + before and after it. """ - x &= 0x55555555 # Mask to ensure we only look at the "active" bits - x = (x | (x >> 1)) & 0x33333333 - x = (x | (x >> 2)) & 0x0F0F0F0F - x = (x | (x >> 4)) & 0x00FF00FF - x = (x | (x >> 8)) & 0x0000FFFF # Result: abcdefghijkl - return x + start = max(0, line_idx - 3) + end = min(self.total, line_idx + 4) + + context_window = "\n".join(self.lines[start:end]) + full_hash = xxhash.xxh3_64_intdigest(context_window.encode("utf-8")) + + # Isolate exactly 20 bits + return full_hash & 0xFFFFF - def _interleave(self, content: int, anchor: int) -> int: + def generate_private_id(self, text: str) -> str: """ - Weaves content and anchor bits together. - Content bits occupy the 'odd' positions, Anchor bits occupy the 'even'. + Generates a fast 12-bit (3 hex chars) hash based purely on the line text. """ - # Spread content bits and shift by 1 to put them in positions 1, 3, 5... - # Spread anchor bits and leave them in positions 0, 2, 4... - return (self._spread_bits(content) << 1) | self._spread_bits(anchor) + bits = xxhash.xxh3_64_intdigest(text.encode("utf-8")) & 0xFFF + return f"{bits:03x}" - def _deinterleave(self, mixed: int) -> tuple[int, int]: + def generate_public_id(self, text: str, line_idx: int) -> str: """ - Extracts content and anchor bits from a 24-bit interleaved integer. + Generates a 4-char Base64 ID combining modulo buckets and context hash. + Layout: [2-bit b1] [10-bit Hash A] [2-bit b2] [10-bit Hash B] """ - # To get content: shift right by 1, then compact - content = self._compact_bits(mixed >> 1) - # To get anchor: just compact (the mask inside _compact_bits handles the rest) - anchor = self._compact_bits(mixed) - return content, anchor + b1, b2 = self._get_region_bits(line_idx) + neighborhood_hash = self._get_neighborhood_hash(line_idx) - def generate_private_id(self, text: str) -> str: - bits = self._get_content_bits(text) - return f"{bits:03x}" + # Split the 20-bit hash into two 10-bit halves + hash_a = (neighborhood_hash >> 10) & 0x3FF + hash_b = neighborhood_hash & 0x3FF - def generate_public_id(self, text: str, line_idx: int) -> str: - content_bits = self._get_content_bits(text) - anchor_bits = self._get_anchor_bits(line_idx) - packed = self._interleave(content_bits, anchor_bits) + # Construct the mixed 24-bit integer + packed = (b1 << 22) | (hash_a << 12) | (b2 << 10) | hash_b res = "" for _ in range(4): @@ -90,11 +71,22 @@ def generate_public_id(self, text: str, line_idx: int) -> str: return res def unpack_public_id(self, public_id: str) -> tuple[int, int]: + """ + Reverses the Public ID back into its (Modulo 16, Neighborhood Hash) values. + """ packed = 0 for i, char in enumerate(public_id): packed |= self.B64.index(char) << (6 * i) - return self._deinterleave(packed) + b1 = (packed >> 22) & 3 + hash_a = (packed >> 12) & 0x3FF + b2 = (packed >> 10) & 3 + hash_b = packed & 0x3FF + + mod_val = (b1 << 2) | b2 + neighborhood_hash = (hash_a << 10) | hash_b + + return mod_val, neighborhood_hash def format_content(self, use_private_ids: bool = False, start_line: int = 1) -> str: formatted_lines = [] @@ -102,44 +94,46 @@ def format_content(self, use_private_ids: bool = False, start_line: int = 1) -> prefix = ( self.generate_private_id(line) if use_private_ids - else self.generate_public_id(line, i + start_line) + else self.generate_public_id(line, i) ) formatted_lines.append(f"{prefix}::{line}") return "\n".join(formatted_lines) def resolve_to_lines(self, public_id: str, start_line: int = 1) -> list[int]: - target_content, target_anchor = self.unpack_public_id(public_id) - content_matches = [] - perfect_matches = [] + target_mod, target_hash = self.unpack_public_id(public_id) + matches = [] + # Find all lines whose neighborhood hash matches our target for i, line in enumerate(self.lines): - if self._get_content_bits(line) == target_content: - current_anchor = self._get_anchor_bits(i + start_line) - if current_anchor == target_anchor: - perfect_matches.append(i) - else: - dist = abs(current_anchor - target_anchor) - # Use the actual coprime period for the circular logic - dist = min(dist, self.PERIOD - dist) + if self._get_neighborhood_hash(i) == target_hash: + matches.append(i) + + if not matches: + return [] + + # If perfectly unique, return it immediately + if len(matches) == 1: + return matches - # ~1% chance of collision around 10 items - if dist <= 1: - content_matches.append((dist, i)) + # Distance Heuristic: If multiple matches exist (e.g. repeated code blocks), + # prioritize the one whose modulo is closest to the target modulo. + # We use circular distance since mod 16 wraps around (0 is adjacent to 15). + def modulo_distance(idx: int) -> int: + current_mod = idx % 16 + dist = abs(current_mod - target_mod) + return min(dist, 16 - dist) - if perfect_matches: - return perfect_matches + matches.sort(key=modulo_distance) - content_matches.sort(key=lambda x: x[0]) - return [match[1] for match in content_matches] + return matches def resolve_range(self, start_id: str, end_id: str) -> tuple[int, int]: """ Resolves a block range from two Public IDs. Logic: - 1. Resolve all candidates for both IDs. - 2. Find the pair of (start, end) that are logically ordered and - have the lowest combined distance score. + 1. Resolve all candidates for both IDs (sorted by best match). + 2. Find the pair of (start, end) that are logically ordered. 3. Returns (start_index, end_index) """ starts = self.resolve_to_lines(start_id) @@ -148,13 +142,9 @@ def resolve_range(self, start_id: str, end_id: str) -> tuple[int, int]: if not starts or not ends: raise ValueError(f"Could not resolve IDs: {start_id}..{end_id}") - # If both have 'perfect' matches that are logically ordered, use them immediately - # Note: resolve_to_lines returns perfect matches first. for s in starts: for e in ends: if s <= e: - # Return the first logical pair found - # (This prioritizes perfect matches or closest heuristics) return s, e raise ValueError( diff --git a/cecli/helpers/hashpos/hashpos.v1.bak b/cecli/helpers/hashpos/hashpos.v1.bak new file mode 100644 index 00000000000..63897932ee6 --- /dev/null +++ b/cecli/helpers/hashpos/hashpos.v1.bak @@ -0,0 +1,238 @@ +import re + +import xxhash + + +class HashPos: + B64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789~_" + # The actual coprime period (64 * 63) + PERIOD = 4032 + # Regex pattern for HashPos format: {4-char-hash}:: + HASH_PREFIX_RE = re.compile(r"^([0-9a-zA-Z\~_@]{4})::") + # Regex for normalization: 4 hash chars optionally followed by '::' + NORMALIZE_RE = re.compile(r"^([0-9a-zA-Z\~_@]{4})(?:)?::") + # Regex for a raw 4-character fragment + FRAGMENT_RE = re.compile(r"^[0-9a-zA-Z\~_@]{4}$") + + def __init__(self, source_text: str = ""): + self.lines = source_text.splitlines() + self.total = len(self.lines) + + def _get_content_bits(self, text: str) -> int: + return xxhash.xxh3_64_intdigest(text.encode("utf-8")) & 0xFFF + + def _get_anchor_bits(self, line_idx: int) -> int: + a1 = (line_idx * 53 + 13) % 64 + a2 = (line_idx * 59 + 31) % 63 + return (a1 << 6) | a2 + + def _spread_bits(self, x: int) -> int: + """ + Spreads 12 bits of x into 24 bits by inserting a 0 between each bit. + Input: 000000000000abcdefghijkl (12 bits) + Output: 0a0b0c0d0e0f0g0h0i0j0k0l (24 bits) + """ + x &= 0xFFF # Ensure we only have 12 bits + # Shift bits by 8, mask keeps the blocks separated + # x starts: 000000000000 abcdefgh ijkl + x = (x | (x << 8)) & 0x00FF00FF # 0000abcd efgh0000 00000000 ijkl... + # Shift by 4, then 2, then 1 to create 1-bit gaps + x = (x | (x << 4)) & 0x0F0F0F0F + x = (x | (x << 2)) & 0x33333333 + x = (x | (x << 1)) & 0x55555555 # Result: 0a0b0c0d0e0f0g0h0i0j0k0l + return x + + def _compact_bits(self, x: int) -> int: + """ + The inverse of spread: pulls every other bit back together. + Input: 0a0b0c0d0e0f0g0h0i0j0k0l (24 bits) + Output: 000000000000abcdefghijkl (12 bits) + """ + x &= 0x55555555 # Mask to ensure we only look at the "active" bits + x = (x | (x >> 1)) & 0x33333333 + x = (x | (x >> 2)) & 0x0F0F0F0F + x = (x | (x >> 4)) & 0x00FF00FF + x = (x | (x >> 8)) & 0x0000FFFF # Result: abcdefghijkl + return x + + def _interleave(self, content: int, anchor: int) -> int: + """ + Weaves content and anchor bits together. + Content bits occupy the 'odd' positions, Anchor bits occupy the 'even'. + """ + # Spread content bits and shift by 1 to put them in positions 1, 3, 5... + # Spread anchor bits and leave them in positions 0, 2, 4... + return (self._spread_bits(content) << 1) | self._spread_bits(anchor) + + def _deinterleave(self, mixed: int) -> tuple[int, int]: + """ + Extracts content and anchor bits from a 24-bit interleaved integer. + """ + # To get content: shift right by 1, then compact + content = self._compact_bits(mixed >> 1) + # To get anchor: just compact (the mask inside _compact_bits handles the rest) + anchor = self._compact_bits(mixed) + return content, anchor + + def generate_private_id(self, text: str) -> str: + bits = self._get_content_bits(text) + return f"{bits:03x}" + + def generate_public_id(self, text: str, line_idx: int) -> str: + content_bits = self._get_content_bits(text) + anchor_bits = self._get_anchor_bits(line_idx) + packed = self._interleave(content_bits, anchor_bits) + + res = "" + for _ in range(4): + res += self.B64[packed % 64] + packed //= 64 + return res + + def unpack_public_id(self, public_id: str) -> tuple[int, int]: + packed = 0 + for i, char in enumerate(public_id): + packed |= self.B64.index(char) << (6 * i) + + return self._deinterleave(packed) + + def format_content(self, use_private_ids: bool = False, start_line: int = 1) -> str: + formatted_lines = [] + for i, line in enumerate(self.lines): + prefix = ( + self.generate_private_id(line) + if use_private_ids + else self.generate_public_id(line, i + start_line) + ) + formatted_lines.append(f"{prefix}::{line}") + return "\n".join(formatted_lines) + + def resolve_to_lines(self, public_id: str, start_line: int = 1) -> list[int]: + target_content, target_anchor = self.unpack_public_id(public_id) + content_matches = [] + perfect_matches = [] + + for i, line in enumerate(self.lines): + if self._get_content_bits(line) == target_content: + current_anchor = self._get_anchor_bits(i + start_line) + if current_anchor == target_anchor: + perfect_matches.append(i) + else: + dist = abs(current_anchor - target_anchor) + # Use the actual coprime period for the circular logic + dist = min(dist, self.PERIOD - dist) + + # ~1% chance of collision around 10 items + if dist <= 1: + content_matches.append((dist, i)) + + if perfect_matches: + return perfect_matches + + content_matches.sort(key=lambda x: x[0]) + return [match[1] for match in content_matches] + + def resolve_range(self, start_id: str, end_id: str) -> tuple[int, int]: + """ + Resolves a block range from two Public IDs. + + Logic: + 1. Resolve all candidates for both IDs. + 2. Find the pair of (start, end) that are logically ordered and + have the lowest combined distance score. + 3. Returns (start_index, end_index) + """ + starts = self.resolve_to_lines(start_id) + ends = self.resolve_to_lines(end_id) + + if not starts or not ends: + raise ValueError(f"Could not resolve IDs: {start_id}..{end_id}") + + # If both have 'perfect' matches that are logically ordered, use them immediately + # Note: resolve_to_lines returns perfect matches first. + for s in starts: + for e in ends: + if s <= e: + # Return the first logical pair found + # (This prioritizes perfect matches or closest heuristics) + return s, e + + raise ValueError( + f"Found matches for {start_id} and {end_id}, but no logically ordered range or unique" + " matches." + ) + + @staticmethod + def strip_prefix(text: str) -> str: + r""" + Remove HashPos prefixes from the start of every line. + + Removes prefixes that match the pattern: "{4-char-hash}" + where the hash is exactly 4 characters from the set [0-9a-zA-Z\~_@] followed by '::'. + + Args: + text: Input text with HashPos prefixes + + Returns: + String with HashPos prefixes removed from each line + """ + lines = text.splitlines(keepends=True) + result_lines = [] + for line in lines: + # Remove the HashPos prefix if present + stripped_line = HashPos.HASH_PREFIX_RE.sub("", line, count=1) + result_lines.append(stripped_line) + + return "".join(result_lines) + + @staticmethod + def extract_prefix(line: str) -> str: + """ + Extract the hash prefix from a line if it has a HashPos prefix. + + Args: + line: A line of text that may contain a HashPos prefix + + Returns: + The hash prefix (4 characters) if found, otherwise empty string + """ + match = HashPos.HASH_PREFIX_RE.match(line) + if match: + return match.group(1) + return "" + + @staticmethod + def normalize(hashpos_str: str) -> str: + """ + Normalize a HashPos string to the 4-character hash fragment. + + Accepts HashPos strings in "{hash_prefix}::" format or a raw "{hash_prefix}" fragment. + Also extracts HashPos from strings that contain content after the HashPos, + e.g., "H7M5::Line 1" + + Args: + hashpos_str: HashPos string in various formats + + Returns: + str: The 4-character hash fragment + + Raises: + ValueError: If format is invalid + """ + if hashpos_str is None: + raise ValueError("HashPos string cannot be None") + + # Check if it's already a raw fragment + if HashPos.FRAGMENT_RE.match(hashpos_str): + return hashpos_str + + match = HashPos.NORMALIZE_RE.match(hashpos_str) + if match: + return match.group(1) + + # If no pattern matches, raise error + raise ValueError( + f"Invalid HashPos format '{hashpos_str}'. " + r"Expected \"{hash_prefix}\" " + r"where hash_prefix is exactly 4 characters from the set [0-9a-zA-Z\~_@]." + ) diff --git a/cecli/helpers/monorepo/config.py b/cecli/helpers/monorepo/config.py index af8f4b39ad6..281fe33c1b8 100644 --- a/cecli/helpers/monorepo/config.py +++ b/cecli/helpers/monorepo/config.py @@ -70,6 +70,15 @@ def resolve_workspace_config(config_arg: Optional[str] = None) -> Optional[Any]: return workspace_conf +def load_workspace_config_file(path: Path) -> Dict[str, Any]: + """Load and validate a repo-local ``.cecli.workspaces.yml`` file.""" + from cecli.helpers.monorepo.local_workspace import load_workspace_file + + config = load_workspace_file(path) + validate_config(config) + return config + + def load_workspace_config( config_arg: Optional[str] = None, name: Optional[str] = None ) -> Dict[str, Any]: @@ -108,7 +117,10 @@ def load_workspace_config( def validate_config(config: Dict[str, Any]) -> None: """ - Minimal validation of required fields. + Validate workspace config shape. + + Each project must have a ``name`` and exactly one of ``path`` (local git + root) or ``repo`` (clone URL). At most one project may set ``primary: true``. """ if not config: return @@ -120,12 +132,23 @@ def validate_config(config: Dict[str, Any]) -> None: config["projects"] = [] project_names = set() + primary_count = 0 for project in config["projects"]: - if "name" not in project or "repo" not in project: - raise ValueError("Each project must have a 'name' and 'repo' URL") + if "name" not in project: + raise ValueError("Each project must have a 'name'") + has_path = bool(project.get("path")) + has_repo = bool(project.get("repo")) + if has_path == has_repo: + raise ValueError( + f"Project '{project['name']}' must have exactly one of 'path' or 'repo'" + ) + if project.get("primary"): + primary_count += 1 if project["name"] in project_names: raise ValueError(f"Duplicate project name: {project['name']}") project_names.add(project["name"]) + if primary_count > 1: + raise ValueError("Only one project may be marked primary: true") def find_active_workspace_name(config_arg: Optional[str] = None) -> Optional[str]: diff --git a/cecli/helpers/monorepo/local_workspace.py b/cecli/helpers/monorepo/local_workspace.py new file mode 100644 index 00000000000..51df9b486c9 --- /dev/null +++ b/cecli/helpers/monorepo/local_workspace.py @@ -0,0 +1,233 @@ +""" +Repo-local multi-project workspaces (``path:`` git roots). + +Cecli already supports **clone** workspaces under ``~/.cecli/workspaces/`` with +``repo:`` URLs and paths like ``{project}/main/{file}``. This module adds +**local** layout: projects point at existing directories on disk, and tracked +paths are prefixed as ``{project}/{file}`` (no ``/main/`` segment). + +Config file names (at the workspace root — usually the primary project directory): + +- ``.cecli.workspaces.yml`` +- ``.cecli.workspaces.yaml`` + +Each project must have exactly one of ``path`` (absolute local git root) or +``repo`` (clone URL; handled by existing clone workspace code). +""" + +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from typing import Any + +import yaml + +WORKSPACE_FILENAMES = (".cecli.workspaces.yml", ".cecli.workspaces.yaml") +METADATA_NAME = ".cecli/.workspace-meta.json" + + +def find_workspace_config_file(start: Path) -> Path | None: + """Return the nearest ``.cecli.workspaces.yml`` walking up from *start*.""" + current = Path(start).resolve() + if current.is_file(): + current = current.parent + while True: + for name in WORKSPACE_FILENAMES: + candidate = current / name + if candidate.is_file(): + return candidate + parent = current.parent + if parent == current: + break + current = parent + return None + + +def load_workspace_file(path: Path) -> dict[str, Any]: + raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + if not isinstance(raw, dict): + raise ValueError("Workspace file must be a mapping") + if "name" not in raw: + raw["name"] = path.parent.name or "workspace" + if "projects" not in raw: + raw["projects"] = [] + return raw + + +def primary_project(config: dict[str, Any]) -> dict[str, Any] | None: + projects = config.get("projects") or [] + for proj in projects: + if proj.get("primary"): + return proj + if len(projects) == 1: + return projects[0] + return projects[0] if projects else None + + +def project_git_root(workspace_root: Path, project: dict[str, Any], *, layout: str) -> Path | None: + name = project.get("name") + if not name: + return None + path_val = project.get("path") + if path_val: + root = Path(str(path_val)).expanduser().resolve() + if not root.is_dir(): + return None + try: + subprocess.check_output( + ["git", "-C", str(root), "rev-parse", "--show-toplevel"], + stderr=subprocess.DEVNULL, + ) + return root + except Exception: + return None + if layout != "clone": + return None + clone_root = workspace_root / name / "main" + return clone_root if clone_root.is_dir() else None + + +def project_path_prefix(project: dict[str, Any], *, layout: str) -> str: + name = str(project.get("name") or "") + if layout == "clone": + return f"{name}/main" + return name + + +def resolve_workspace_file_path( + workspace_root: Path, + workspace_rel: str, + config: dict[str, Any], + *, + layout: str, +) -> tuple[Path, Path, str] | None: + """ + Map a workspace-relative path to ``(project_git_root, absolute_file, path_in_project_repo)``. + """ + rel = workspace_rel.replace("\\", "/").lstrip("/") + if not rel: + return None + parts = Path(rel).parts + if not parts: + return None + projects = config.get("projects") or [] + by_name = {str(p.get("name")): p for p in projects if p.get("name")} + + # Clone layout: name/main/rest + if layout == "clone" and len(parts) >= 2 and parts[1] == "main": + proj = by_name.get(parts[0]) + if not proj: + return None + git_root = project_git_root(workspace_root, proj, layout=layout) + if not git_root: + return None + in_repo = "/".join(parts[2:]) if len(parts) > 2 else "" + abs_path = git_root / in_repo if in_repo else git_root + return git_root, abs_path, in_repo + + # Local layout: name/rest or bare path under primary-only tree + if parts[0] in by_name: + proj = by_name[parts[0]] + git_root = project_git_root(workspace_root, proj, layout=layout) + if not git_root: + return None + in_repo = "/".join(parts[1:]) if len(parts) > 1 else "" + abs_path = git_root / in_repo if in_repo else git_root + return git_root, abs_path, in_repo + + primary = primary_project(config) + if primary: + git_root = project_git_root(workspace_root, primary, layout=layout) + if git_root: + in_repo = rel + return git_root, git_root / in_repo, in_repo + return None + + +def union_tracked_files( + workspace_root: Path, + config: dict[str, Any], + *, + layout: str, + ignored_file=None, +) -> list[str]: + """All tracked files as workspace-relative paths.""" + out: list[str] = [] + for proj in config.get("projects") or []: + name = proj.get("name") + if not name: + continue + git_root = project_git_root(workspace_root, proj, layout=layout) + if not git_root: + continue + prefix = project_path_prefix(proj, layout=layout) + try: + lines = subprocess.check_output( + ["git", "-C", str(git_root), "ls-files"], + stderr=subprocess.DEVNULL, + encoding="utf-8", + ).splitlines() + except Exception: + continue + for line in lines: + if not line.strip(): + continue + rel = f"{prefix}/{line}" if prefix else line + rel = rel.replace("\\", "/") + if ignored_file and ignored_file(rel): + continue + out.append(rel) + return out + + +def project_head_shas( + workspace_root: Path, + config: dict[str, Any], + *, + layout: str, +) -> list[str]: + shas: list[str] = [] + for proj in config.get("projects") or []: + name = proj.get("name") + if not name: + continue + git_root = project_git_root(workspace_root, proj, layout=layout) + if not git_root: + shas.append(f"{name}:unknown") + continue + try: + sha = subprocess.check_output( + ["git", "-C", str(git_root), "rev-parse", "HEAD"], + stderr=subprocess.DEVNULL, + encoding="utf-8", + ).strip() + shas.append(f"{name}:{sha}") + except Exception: + shas.append(f"{name}:unknown") + return shas + + +def write_workspace_metadata(workspace_root: Path, config: dict[str, Any], *, layout: str) -> None: + meta_dir = workspace_root / ".cecli" + meta_dir.mkdir(parents=True, exist_ok=True) + payload = {**config, "_layout": layout} + (meta_dir / ".workspace-meta.json").write_text( + json.dumps(payload, indent=2), + encoding="utf-8", + ) + + +def read_workspace_metadata(workspace_root: Path) -> tuple[dict[str, Any], str] | None: + legacy = workspace_root / ".cecli-workspace.json" + modern = workspace_root / METADATA_NAME + path = modern if modern.is_file() else legacy if legacy.is_file() else None + if not path: + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + layout = data.pop("_layout", "clone") + return data, layout + except Exception: + return None diff --git a/cecli/helpers/responses.py b/cecli/helpers/responses.py index 36f8f68633e..0bc9cecf909 100644 --- a/cecli/helpers/responses.py +++ b/cecli/helpers/responses.py @@ -3,6 +3,7 @@ import time from typing import List, Optional +import json_repair from litellm.types.utils import ChatCompletionMessageToolCall, Function from cecli import utils @@ -368,7 +369,8 @@ def parse_tool_arguments(args_string: str) -> dict: if isinstance(lone, dict): return lone try: - single = json.loads(chunks[0]) + json_string = json_repair.repair_json(chunks[0], ensure_ascii=False) + single = json.loads(json_string) except json.JSONDecodeError as err: return {"@error": f"Malformed JSON arguments: {err}"} return single if isinstance(single, dict) else {} diff --git a/cecli/main.py b/cecli/main.py index 17b96a8f8af..85893163897 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -719,6 +719,19 @@ def get_io(pretty): for fname in loaded_dotenvs: io.tool_output(f"Loaded {fname}") all_files = args.files + (args.file or []) + + # Check for arguments starting with '--' that are likely + # unrecognized or misspelled parameters, not file arguments + filtered_files = [] + for f in all_files: + if f.startswith("--"): + # Extract the parameter name: everything between '--' and '=' or end + param = f[2:].split("=")[0].split()[0] + io.tool_warning(f"The parameter --{param} does not exist.") + else: + filtered_files.append(f) + + all_files = filtered_files all_files = utils.expand_glob_patterns(all_files) fnames = [str(Path(fn).resolve()) for fn in all_files] read_patterns = args.read or [] diff --git a/cecli/models.py b/cecli/models.py index 34d75410a78..8ace100ef39 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -592,11 +592,25 @@ def configure_model_settings(self, model): valid_model_settings_fields = {f.name for f in fields(ModelSettings)} + # Detect structured keys: api_settings, api, llm_settings, or llm keys indicate the new format + has_structured_keys = any( + k in self.override_kwargs + for k in ( + "api_settings", + "api-settings", + "llm_settings", + "llm-settings", + "api", + "llm", + "agent", + ) + ) + for key, value in self.override_kwargs.items(): - if key == "model_settings" or key == "model-settings": + if key in ("agent", "model_settings", "model-settings"): if not isinstance(value, dict): raise ValueError( - f"override_kwargs 'model_settings' must be a dict, got {type(value)}" + f"override_kwargs '{key}' must be a dict, got {type(value)}" ) for setting_key, setting_value in value.items(): if setting_key not in valid_model_settings_fields: @@ -605,6 +619,26 @@ def configure_model_settings(self, model): f"Must be one of: {sorted(valid_model_settings_fields)}" ) setattr(self, setting_key, setting_value) + elif has_structured_keys and key in ("api", "api_settings", "api-settings"): + # api_settings: merge each sub-key into extra_params + if not isinstance(value, dict): + raise ValueError( + f"override_kwargs '{key}' must be a dict, got {type(value)}" + ) + for api_key, api_value in value.items(): + if isinstance(api_value, dict) and isinstance( + self.extra_params.get(api_key), dict + ): + self.extra_params[api_key] = {**self.extra_params[api_key], **api_value} + else: + self.extra_params[api_key] = api_value + elif has_structured_keys and key in ("llm", "llm_settings", "llm-settings"): + # llm_settings: merge into self.info + if not isinstance(value, dict): + raise ValueError( + f"override_kwargs '{key}' must be a dict, got {type(value)}" + ) + self.info = {**self.info, **value} elif isinstance(value, dict) and isinstance(self.extra_params.get(key), dict): self.extra_params[key] = {**self.extra_params[key], **value} else: @@ -1557,7 +1591,7 @@ async def check_for_dependencies(io, model_name): ) -def get_chat_model_names(): +def get_chat_model_names(query: str = "") -> list: chat_models = set() model_metadata = list(litellm.model_cost.items()) model_metadata += list(model_info_manager.local_model_metadata.items()) @@ -1575,7 +1609,38 @@ def get_chat_model_names(): fq_model = f"{provider}/{orig_model}" chat_models.add(fq_model) chat_models.add(orig_model) - return sorted(chat_models) + + sorted_models = sorted(chat_models) + + # Fuzzy match against the query when one is provided + if query: + try: + from ngram import NGram + from rapidfuzz import fuzz, process + + score_cutoff = int(0.3 * 100) + results = process.extract( + query, + sorted_models, + scorer=fuzz.partial_ratio, + limit=20, + score_cutoff=score_cutoff, + ) + match_names = [match for match, score, _ in results] + + # Re-rank with ngram trigram similarity when result set is small + if len(match_names) < 100: + ng = NGram(match_names, N=3) + reranked = ng.search(query, threshold=0.0) + match_names = [item for item, score in reranked] + + return match_names + except ImportError: + # Fall back to simple substring matching if fuzzy libs unavailable + query_lower = query.lower() + return [m for m in sorted_models if query_lower in m.lower()] + + return sorted_models def fuzzy_match_models(name): diff --git a/cecli/repo.py b/cecli/repo.py index 889e91e3043..9a55146c111 100644 --- a/cecli/repo.py +++ b/cecli/repo.py @@ -98,6 +98,7 @@ def __init__( self.is_workspace = False self.workspace_path = None self.workspace_config = {} + self.workspace_layout = "clone" self.workspace_ignore_specs = {} self.workspace_ignore_ts = {} # Workspace detection and config loading occurs later in __init__ @@ -131,23 +132,35 @@ def __init__( if num_repos == 0: raise FileNotFoundError if num_repos > 1: - self.io.tool_error("Files are in different git repos.") - raise FileNotFoundError + from cecli.helpers.monorepo.config import load_workspace_config_file + from cecli.helpers.monorepo.local_workspace import ( + find_workspace_config_file, + ) - self._init_repo_path = repo_paths.pop() + ws_file = find_workspace_config_file(Path(repo_paths[0])) + if not ws_file: + self.io.tool_error( + "Files are in different git repos. Add a .cecli.workspaces.yml at a" + " common ancestor with path: entries for each project." + ) + raise FileNotFoundError + self.workspace_config = load_workspace_config_file(ws_file) + primary = next( + (p for p in self.workspace_config.get("projects", []) if p.get("primary")), + None, + ) + if primary and primary.get("path"): + self._init_repo_path = str(Path(str(primary["path"])).expanduser().resolve()) + else: + self._init_repo_path = str(Path(repo_paths[0]).resolve()) + else: + self._init_repo_path = repo_paths.pop() # Detect if we're in a workspace self.workspace_path = self._detect_workspace_path(self._init_repo_path) if self.workspace_path: self.is_workspace = True - - try: - from cecli.helpers.monorepo.config import load_workspace_config - - self.workspace_config = load_workspace_config(name=self.workspace_path.name) - except Exception: - self.workspace_config = {} - + self._load_workspace_config() self.refresh_cecli_ignore() self.init_repo() @@ -170,9 +183,44 @@ def init_repo(self): self.repo = git.Repo(self._init_repo_path, odbt=git.GitCmdObjectDB) self.root = utils.safe_abs_path(self.repo.working_tree_dir) + def _load_workspace_config(self) -> None: + from cecli.helpers.monorepo.config import ( + load_workspace_config, + load_workspace_config_file, + ) + from cecli.helpers.monorepo.local_workspace import ( + find_workspace_config_file, + read_workspace_metadata, + write_workspace_metadata, + ) + + ws_file = find_workspace_config_file(Path(self.workspace_path)) + if ws_file: + self.workspace_layout = "local" + self.workspace_config = load_workspace_config_file(ws_file) + write_workspace_metadata( + Path(self.workspace_path), self.workspace_config, layout="local" + ) + return + meta = read_workspace_metadata(Path(self.workspace_path)) + if meta: + self.workspace_config, self.workspace_layout = meta + return + self.workspace_layout = "clone" + try: + self.workspace_config = load_workspace_config(name=Path(self.workspace_path).name) + except Exception: + self.workspace_config = {} + def _detect_workspace_path(self, start_path: str): """Check if current directory is within a workspace""" + from cecli.helpers.monorepo.local_workspace import find_workspace_config_file + current = Path(start_path).resolve() + ws_file = find_workspace_config_file(current) + if ws_file: + return ws_file.parent.resolve() + workspace_root = Path("~/.cecli/workspaces").expanduser() # Walk up directory tree looking for workspace root @@ -267,6 +315,9 @@ async def commit(self, fnames=None, context=None, message=None, coder_edits=Fals - User commit with explicit no-committer: coder_edits=False, --no-attribute-committer -> Author=You, Committer=You """ + if self.is_workspace and getattr(self, "workspace_layout", "clone") == "local": + return await self._commit_local_workspace(fnames, context, message, coder_edits, coder) + if not fnames and not self.repo.is_dirty(): return @@ -592,6 +643,57 @@ def get_tracked_files(self): return res + async def _commit_local_workspace( + self, fnames=None, context=None, message=None, coder_edits=False, coder=None + ): + from collections import defaultdict + + from cecli.helpers.monorepo.local_workspace import resolve_workspace_file_path + + layout = getattr(self, "workspace_layout", "local") + config = self.workspace_config or {} + readonly = { + str(p.get("name")) + for p in config.get("projects", []) + if p.get("readonly") and p.get("name") + } + + by_root: dict[str, list[str]] = defaultdict(list) + if fnames: + for fname in fnames: + resolved = resolve_workspace_file_path( + Path(self.workspace_path), str(fname), config, layout=layout + ) + if not resolved: + continue + git_root, _abs_path, in_repo = resolved + parts = Path(str(fname)).parts + if parts and parts[0] in readonly: + continue + if in_repo: + by_root[str(git_root)].append(in_repo) + else: + for proj in config.get("projects", []): + name = proj.get("name") + if not name or name in readonly: + continue + from cecli.helpers.monorepo.local_workspace import project_git_root + + git_root = project_git_root(Path(self.workspace_path), proj, layout=layout) + if not git_root: + continue + sub = GitRepo(self.io, [str(git_root)], None) + for rel in sub.get_dirty_files() or []: + by_root[str(git_root)].append(rel) + + last = None + for root, rels in by_root.items(): + sub = GitRepo(self.io, [root], None) + last = await sub.commit( + rels, context=context, message=message, coder_edits=coder_edits, coder=coder + ) + return last + def get_workspace_files(self): """ If in a workspace, return all tracked files from all projects. @@ -601,6 +703,35 @@ def get_workspace_files(self): return self.get_tracked_files() import hashlib + + layout = getattr(self, "workspace_layout", "clone") + config = self.workspace_config or {} + if not config.get("projects"): + return self.get_tracked_files() + + from cecli.helpers.monorepo.local_workspace import ( + project_head_shas, + union_tracked_files, + ) + + project_shas = project_head_shas(Path(self.workspace_path), config, layout=layout) + cache_key = hashlib.sha1(",".join(project_shas).encode()).hexdigest() + + if hasattr(self, "_workspace_files_cache"): + cached_key, cached_files = self._workspace_files_cache + if cached_key == cache_key: + return cached_files + + if layout == "local": + all_files = union_tracked_files( + Path(self.workspace_path), + config, + layout=layout, + ignored_file=self.ignored_file, + ) + self._workspace_files_cache = (cache_key, all_files) + return all_files + import json import subprocess @@ -614,36 +745,8 @@ def get_workspace_files(self): except Exception: return self.get_tracked_files() - # Generate a cache key based on the SHAs of all project HEADs - # This is similar to how base_coder uses staged files hash - projects = config.get("projects", []) - project_shas = [] - for proj in projects: - proj_name = proj.get("name") - if not proj_name: - continue - proj_root = self.workspace_path / proj_name / "main" - if not proj_root.exists(): - continue - try: - sha = subprocess.check_output( - ["git", "-C", str(proj_root), "rev-parse", "HEAD"], - stderr=subprocess.DEVNULL, - encoding="utf-8", - ).strip() - project_shas.append(f"{proj_name}:{sha}") - except Exception: - project_shas.append(f"{proj_name}:unknown") - - cache_key = hashlib.sha1(",".join(project_shas).encode()).hexdigest() - - if hasattr(self, "_workspace_files_cache"): - cached_key, cached_files = self._workspace_files_cache - if cached_key == cache_key: - return cached_files - all_files = [] - for proj in projects: + for proj in config.get("projects", []): proj_name = proj.get("name") if not proj_name: continue @@ -778,37 +881,45 @@ def _get_gitignore_spec(self, dir_path): self.gitignore_spec_cache[dir_path] = spec return spec + def _resolve_path_in_repo(self, path): + """Resolve *path* under this repo root (not process cwd).""" + file_path = Path(path) + if not file_path.is_absolute(): + file_path = (Path(self.root) / file_path).resolve() + else: + file_path = file_path.resolve() + return file_path + def _is_gitignored_by_pathspec(self, path): """Check if a file is ignored by any .gitignore file using pathspec.""" if not self.repo: return False try: - file_path = Path(path).resolve() - if not file_path.is_relative_to(self.root): + file_path = self._resolve_path_in_repo(path) + root = Path(self.root).resolve() + if not file_path.is_relative_to(root): return False # Walk up from file's directory to root current_dir = file_path.parent - relative_path = file_path.relative_to(self.root) + relative_path = file_path.relative_to(root) # Check each directory level - while current_dir.is_relative_to(self.root): + while current_dir.is_relative_to(root): spec = self._get_gitignore_spec(current_dir) # Get path relative to the directory containing the .gitignore - if current_dir == Path(self.root).resolve(): + if current_dir == root: path_to_check = str(relative_path) else: - path_to_check = str( - relative_path.relative_to(current_dir.relative_to(self.root)) - ) + path_to_check = str(relative_path.relative_to(current_dir.relative_to(root))) if spec.match_file(path_to_check): return True # Move up one directory - if current_dir == Path(self.root).resolve(): + if current_dir == root: break current_dir = current_dir.parent @@ -865,7 +976,8 @@ def ignored_file_raw(self, fname): ): # Check against project-specific spec # The spec expects paths relative to the project root (usually proj/main/) - if len(parts) > 2 and parts[1] == "main": + layout = getattr(self, "workspace_layout", "clone") + if layout == "clone" and len(parts) > 2 and parts[1] == "main": proj_rel_path = str(Path(*parts[2:])) else: proj_rel_path = str(Path(*parts[1:])) @@ -974,6 +1086,19 @@ def path_in_repo(self, path): return self.normalize_path(path) in tracked_files def abs_root_path(self, path): + if self.is_workspace and getattr(self, "workspace_layout", "clone") == "local": + from cecli.helpers.monorepo.local_workspace import ( + resolve_workspace_file_path, + ) + + resolved = resolve_workspace_file_path( + Path(self.workspace_path), + str(path), + self.workspace_config or {}, + layout="local", + ) + if resolved: + return utils.safe_abs_path(resolved[1]) res = Path(self.root) / path return utils.safe_abs_path(res) diff --git a/cecli/tools/_yield.py b/cecli/tools/_yield.py index a95b7343118..b575cfa9efd 100644 --- a/cecli/tools/_yield.py +++ b/cecli/tools/_yield.py @@ -1,9 +1,10 @@ import asyncio -import json import logging from cecli.tools.utils.base_tool import BaseTool +from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations logger = logging.getLogger(__name__) @@ -150,10 +151,18 @@ async def execute(cls, coder, **kwargs): @classmethod def format_output(cls, coder, mcp_server, tool_response): color_start, color_end = color_markers(coder) - params = json.loads(tool_response.function.arguments) + # Output header tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: + coder.io.tool_error("Invalid Tool JSON") + return + summary = params.get("summary") if summary: coder.io.tool_output("") @@ -161,4 +170,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output(summary) coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/command.py b/cecli/tools/command.py index 127b6deca3b..09ba627bc94 100644 --- a/cecli/tools/command.py +++ b/cecli/tools/command.py @@ -1,12 +1,13 @@ # Import necessary functions -import json import os import platform from cecli.helpers.background_commands import BackgroundCommandManager from cecli.run_cmd import run_cmd_subprocess from cecli.tools.utils.base_tool import BaseTool +from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): @@ -392,15 +393,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for Command tool.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - # Output header - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - command = params.get("command", "") background = params.get("background", False) stop = params.get("stop", False) @@ -430,4 +433,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output("") # Output footer - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/context_manager.py b/cecli/tools/context_manager.py index 0a18bf969bc..e426a46714d 100644 --- a/cecli/tools/context_manager.py +++ b/cecli/tools/context_manager.py @@ -1,4 +1,3 @@ -import json import os import re import time @@ -7,6 +6,7 @@ from cecli.tools.utils.base_tool import BaseTool from cecli.tools.utils.helpers import ToolError, parse_arg_as_list from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): @@ -121,14 +121,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for ContextManager tool.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - # Define action display names action_names = { "create": "create", @@ -145,7 +148,7 @@ def format_output(cls, coder, mcp_server, tool_response): file_list = ", ".join(files) coder.io.tool_output(f"{color_start}{display_name}:{color_end} {file_list}") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) @classmethod def _remove(cls, coder, file_path): diff --git a/cecli/tools/delegate.py b/cecli/tools/delegate.py index 1fa6a5313ff..244fe0e44d9 100644 --- a/cecli/tools/delegate.py +++ b/cecli/tools/delegate.py @@ -1,16 +1,20 @@ """Delegate tool - allows the primary agent to spawn sub-agents.""" import asyncio -import json from cecli.tools.utils.base_tool import BaseTool +from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): NORM_NAME = "delegate" TRACK_INVOCATIONS = True - LIST_PARAMS = ["delegations"] + VALIDATIONS = { + "delegations": ["coerce_list"], + "delegations[]": ["coerce_dict"], + } SCHEMA = { "type": "function", "function": { @@ -100,14 +104,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for Delegate tool - show each delegation's agent and task.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - delegations = params.get("delegations", []) if delegations: coder.io.tool_output("") @@ -120,4 +127,4 @@ def format_output(cls, coder, mcp_server, tool_response): if i < len(delegations) - 1: coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/edit_text.py b/cecli/tools/edit_text.py index 8f5ed549322..a8eeabca75f 100644 --- a/cecli/tools/edit_text.py +++ b/cecli/tools/edit_text.py @@ -1,5 +1,3 @@ -import json - from cecli.helpers.hashline import ( ContentHashError, apply_hashline_operations, @@ -15,6 +13,7 @@ validate_file_for_edit, ) from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations VALID_OPERATIONS = {"replace", "delete", "insert"} OPERATION_NOUNS = { @@ -27,7 +26,10 @@ class Tool(BaseTool): NORM_NAME = "edittext" TRACK_INVOCATIONS = False - LIST_PARAMS = ["edits"] + VALIDATIONS = { + "edits": ["coerce_list"], + "edits[]": ["coerce_dict"], + } SCHEMA = { "type": "function", "function": { @@ -111,8 +113,22 @@ def execute( Each edit object must include its own file_path. """ if not coder.edit_allowed: - raise ToolError( - "Please call `ReadRange` first to make sure edits are appropriately scoped" + from cecli.helpers.conversation import ConversationService, MessageTag + + ConversationService.get_manager(coder).add_message( + message_dict=dict( + role="user", + content=( + "Please call `ReadRange` on files you intend to edit to" + " make sure edits are appropriately targeted." + ), + ), + tag=MessageTag.CUR, + hash_key=("edit_text", "reminder"), + promotion=ConversationService.get_manager(coder).DEFAULT_TAG_PROMOTION_VALUE, + mark_for_delete=0, + mark_for_demotion=1, + force=True, ) tool_name = "EditText" @@ -232,7 +248,7 @@ def execute( if new_content != original_content: file_successful_edits += len(successful_ops) else: - raise ToolError("Invalid Edit - Edit Results In Same Content") + raise ToolError("Invalid Edit - Update content hash bounds") if len(failed_ops): for failed_op in failed_ops: @@ -370,12 +386,16 @@ def execute( def format_output(cls, coder, mcp_server, tool_response): color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") - - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + return # Group edits by file_path for display edits_by_file = {} @@ -450,4 +470,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output(range_info) coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/explore_code.py b/cecli/tools/explore_code.py index c95f8df8acd..2f9d9621197 100644 --- a/cecli/tools/explore_code.py +++ b/cecli/tools/explore_code.py @@ -1,9 +1,9 @@ -import json import os from cecli.tools.utils.base_tool import BaseTool from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations cwd = os.getcwd() @@ -19,7 +19,10 @@ class Tool(BaseTool): NORM_NAME = "explorecode" - LIST_PARAMS = ["queries"] + VALIDATIONS = { + "queries": ["coerce_list"], + "queries[]": ["coerce_dict"], + } SCHEMA = { "type": "function", "function": { @@ -297,14 +300,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for ExploreCode tool.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - # Output header - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) queries = params.get("queries", []) if queries: coder.io.tool_output("") @@ -321,4 +327,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output("") # Output footer - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/grep.py b/cecli/tools/grep.py index deb9db27d60..a925bffc606 100644 --- a/cecli/tools/grep.py +++ b/cecli/tools/grep.py @@ -1,4 +1,3 @@ -import json import shutil from pathlib import Path @@ -7,12 +6,17 @@ from cecli.helpers.hashline import strip_hashline from cecli.run_cmd import run_cmd_subprocess from cecli.tools.utils.base_tool import BaseTool +from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): NORM_NAME = "grep" - LIST_PARAMS = ["searches"] + VALIDATIONS = { + "searches": ["coerce_list"], + "searches[]": ["coerce_dict"], + } SCHEMA = { "type": "function", "function": { @@ -236,14 +240,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for Grep tool.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - # Output each search operation with the requested format searches = params.get("searches", []) if searches: @@ -278,4 +285,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output(formatted_query) coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/ls.py b/cecli/tools/ls.py index 06236289aef..0f6905ade8a 100644 --- a/cecli/tools/ls.py +++ b/cecli/tools/ls.py @@ -1,8 +1,9 @@ -import json import os from cecli.tools.utils.base_tool import BaseTool +from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): @@ -95,14 +96,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for Ls tool.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - # Output the directory parameter with the requested format directory = params.get("path", "") if directory: @@ -111,4 +115,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output(formatted_query) coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/read_range.py b/cecli/tools/read_range.py index 52a95fe7bbc..413a111b61f 100644 --- a/cecli/tools/read_range.py +++ b/cecli/tools/read_range.py @@ -1,4 +1,3 @@ -import json import os from typing import Dict, List @@ -11,12 +10,16 @@ resolve_paths, ) from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): NORM_NAME = "readrange" TRACK_INVOCATIONS = False - LIST_PARAMS = ["show"] + VALIDATIONS = { + "show": ["coerce_list"], + "show[]": ["coerce_dict"], + } SCHEMA = { "type": "function", "function": { @@ -191,16 +194,25 @@ def execute(cls, coder, show, **kwargs): num_lines = len(lines) if num_lines == 0: - # Handle empty file case - output_lines = [f"File {rel_path} is empty."] - if show_index > 0: - all_outputs.append("") - all_outputs.extend(output_lines) + new_context_details.append( + "\n".join( + [ + f"File {rel_path} is empty.", + ( + "Next: use EditText with start_line @000 and end_line @000 to" + " write content, or ContextManager to scaffold — do not call" + " ReadRange again on this empty file." + ), + ] + ) + ) + new_context_retrieved.append(rel_path) + cls._last_read_turn[abs_path] = coder.turn_count continue # 4. Determine line range start_line_idx = -1 end_line_idx = -1 - found_by = "" + # found_by = "" if start_text is not None and end_text is not None: if start_text.isdigit() and end_text.isdigit(): @@ -360,7 +372,7 @@ def execute(cls, coder, show, **kwargs): # Store the found indices for future disambiguation cls._last_invocation[abs_path] = {"start_idx": s_idx, "end_idx": e_idx} - found_by = f"range '{start_text}' to '{end_text}'" + # found_by = f"range '{start_text}' to '{end_text}'" try: padding_int = int(padding) @@ -387,23 +399,23 @@ def execute(cls, coder, show, **kwargs): # 6. Format output for this operation # Use rel_path for user-facing messages - output_lines = [f"Displaying context around {found_by} in {rel_path}:"] + # output_lines = [f"Displaying context around {found_by} in {rel_path}:"] # Generate hashline for the entire file hashed_content = hashline(content) hashed_lines = hashed_content.splitlines() # Extract the context window from hashed lines - context_hashed_lines = hashed_lines[start_line_idx : end_line_idx + 1] + # context_hashed_lines = hashed_lines[start_line_idx : end_line_idx + 1] - for i in range(start_line_idx, end_line_idx + 1): - hashed_line = context_hashed_lines[i - start_line_idx] - output_lines.append(hashed_line) + # for i in range(start_line_idx, end_line_idx + 1): + # hashed_line = context_hashed_lines[i - start_line_idx] + # output_lines.append(hashed_line) # Add separator between multiple show operations - if show_index > 0: - all_outputs.append("") - all_outputs.extend(output_lines) + # if show_index > 0: + # all_outputs.append("") + # all_outputs.extend(output_lines) # Update the conversation cache with the displayed range # Note: start_line_idx and end_line_idx are 0-based, convert to 1-based for hashline @@ -513,6 +525,9 @@ def execute(cls, coder, show, **kwargs): " the relevant files." ) + if all_outputs: + result_parts.append("\n".join(all_outputs)) + if error_outputs: coder.io.tool_error(f"Errors encountered for {len(error_outputs)} operation(s)") @@ -596,14 +611,17 @@ def format_output(cls, coder, mcp_server, tool_response): """Format output for ReadRange tool.""" color_start, color_end = color_markers(coder) + # Output header + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: - params = json.loads(tool_response.function.arguments) - except json.JSONDecodeError: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: coder.io.tool_error("Invalid Tool JSON") return - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - show_ops = params.get("show", []) if show_ops: coder.io.tool_output("") @@ -620,7 +638,7 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output(formatted_query) coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) @classmethod def format_error(cls, coder, error_text, file_path, start_text, end_text, operation_index): diff --git a/cecli/tools/thinking.py b/cecli/tools/thinking.py index 05a2ffa239b..a52d6e735e5 100644 --- a/cecli/tools/thinking.py +++ b/cecli/tools/thinking.py @@ -1,7 +1,7 @@ -import json - from cecli.tools.utils.base_tool import BaseTool +from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): @@ -40,13 +40,21 @@ def execute(cls, coder, content, **kwargs): @classmethod def format_output(cls, coder, mcp_server, tool_response): color_start, color_end = color_markers(coder) - params = json.loads(tool_response.function.arguments) + # Output header tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + try: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: + coder.io.tool_error("Invalid Tool JSON") + return + coder.io.tool_output("") coder.io.tool_output(f"{color_start}Thoughts:{color_end}") coder.io.tool_output(params["content"]) coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/update_todo_list.py b/cecli/tools/update_todo_list.py index 8d9395b22e7..e7864bc0ad4 100644 --- a/cecli/tools/update_todo_list.py +++ b/cecli/tools/update_todo_list.py @@ -1,11 +1,15 @@ from cecli.tools.utils.base_tool import BaseTool from cecli.tools.utils.helpers import ToolError, format_tool_result, handle_tool_error from cecli.tools.utils.output import tool_footer, tool_header +from cecli.tools.validations import ToolValidations class Tool(BaseTool): NORM_NAME = "updatetodolist" - LIST_PARAMS = ["tasks"] + VALIDATIONS = { + "tasks": ["coerce_list"], + "tasks[]": ["coerce_dict"], + } SCHEMA = { "type": "function", "function": { @@ -43,17 +47,6 @@ class Tool(BaseTool): " Defaults to False." ), }, - "change_id": { - "type": "string", - "description": "Optional change ID for tracking.", - }, - "dry_run": { - "type": "boolean", - "description": ( - "Whether to perform a dry run without actually updating the file." - " Defaults to False." - ), - }, }, "required": ["tasks"], }, @@ -187,16 +180,21 @@ def execute(cls, coder, tasks, append=False, change_id=None, dry_run=False, **kw @classmethod def format_output(cls, coder, mcp_server, tool_response): - import json - from cecli.tools.utils.output import color_markers color_start, color_end = color_markers(coder) + # Output header tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - # Parse the parameters to display formatted todo list - params = json.loads(tool_response.function.arguments) + try: + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + except ToolError: + coder.io.tool_error("Invalid Tool JSON") + return + tasks = params.get("tasks", []) if tasks: @@ -235,4 +233,4 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_output("") - tool_footer(coder=coder, tool_response=tool_response) + tool_footer(coder=coder, tool_response=tool_response, params=params) diff --git a/cecli/tools/utils/base_tool.py b/cecli/tools/utils/base_tool.py index 2ed174594d2..a9fb39c709a 100644 --- a/cecli/tools/utils/base_tool.py +++ b/cecli/tools/utils/base_tool.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from cecli.tools.utils.helpers import handle_tool_error, normalize_json_array +from cecli.tools.utils.helpers import handle_tool_error from cecli.tools.utils.output import print_tool_response +from cecli.tools.validations import ToolValidations class BaseTool(ABC): @@ -12,8 +13,8 @@ class BaseTool(ABC): NORM_NAME = None SCHEMA = None - # Parameters to run normalization checks on - LIST_PARAMS = [] + # Declarative validations (maps param paths to lists of validation method names) + VALIDATIONS = {} # Invocation tracking for detecting repeated tool calls _invocations = {} # Dict to store last 3 invocations per tool @@ -122,16 +123,15 @@ def process_response(cls, coder, params): coder, tool_name, ValueError(error_msg), add_traceback=False ) - for param in cls.LIST_PARAMS: - if param in params: - params[param] = normalize_json_array(params[param], param_name=param) - # Add current invocation to history (keeping only last 3) if params: cls._invocations[tool_name].append((current_params_tuple, params)) if len(cls._invocations[tool_name]) > 3: cls._invocations[tool_name] = cls._invocations[tool_name][-3:] + # Apply declarative validations from VALIDATIONS dict + params = ToolValidations.validate_params(params, cls.VALIDATIONS, cls.SCHEMA) + try: return cls.execute(coder, **params) except Exception as e: @@ -139,7 +139,12 @@ def process_response(cls, coder, params): @classmethod def format_output(cls, coder, mcp_server, tool_response): - print_tool_response(coder=coder, mcp_server=mcp_server, tool_response=tool_response) + params = ToolValidations.validate_params( + tool_response.function.arguments, cls.VALIDATIONS, cls.SCHEMA + ) + print_tool_response( + coder=coder, mcp_server=mcp_server, tool_response=tool_response, params=params + ) @classmethod def on_duplicate_request(cls, coder, **kwargs): diff --git a/cecli/tools/utils/output.py b/cecli/tools/utils/output.py index 4c8f84dce4f..67d51011466 100644 --- a/cecli/tools/utils/output.py +++ b/cecli/tools/utils/output.py @@ -2,7 +2,7 @@ import re -def print_tool_response(coder, mcp_server, tool_response): +def print_tool_response(coder, mcp_server, tool_response, params=None): """ Format the output for display. Prints a Header to identify the tool, a body for the relevant information @@ -13,12 +13,12 @@ def print_tool_response(coder, mcp_server, tool_response): mcp_server: An mcp server instance tool_response: a tool_response dictionary """ - tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response) - tool_body(coder=coder, tool_response=tool_response) - tool_footer(coder=coder, tool_response=tool_response) + tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response, params=params) + tool_body(coder=coder, tool_response=tool_response, params=params) + tool_footer(coder=coder, tool_response=tool_response, params=params) -def tool_header(coder, mcp_server, tool_response): +def tool_header(coder, mcp_server, tool_response, params=None): """ Prints the header for the tool call output @@ -35,7 +35,7 @@ def tool_header(coder, mcp_server, tool_response): ) -def tool_body(coder, tool_response): +def tool_body(coder, tool_response, params=None): """ Prints the output body of a tool call as the raw json returned from the model @@ -52,7 +52,7 @@ def tool_body(coder, tool_response): coder.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}") -def tool_body_unwrapped(coder, tool_response): +def tool_body_unwrapped(coder, tool_response, params=None): """ Prints the output body of a tool call with the argument and content sections separated @@ -65,7 +65,7 @@ def tool_body_unwrapped(coder, tool_response): color_start, color_end = color_markers(coder) try: - args_dict = json.loads(tool_response.function.arguments) + args_dict = params if params else json.loads(tool_response.function.arguments) first_key = True for key, value in args_dict.items(): # Convert explicit \\n sequences to actual newlines using regex @@ -90,7 +90,7 @@ def tool_body_unwrapped(coder, tool_response): coder.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}") -def tool_footer(coder, tool_response): +def tool_footer(coder, tool_response, params=None): """ Prints the output footer of a tool call, generally a new line But can include id's if ran in verbose mode diff --git a/cecli/tools/validations/__init__.py b/cecli/tools/validations/__init__.py new file mode 100644 index 00000000000..618fb9d4c49 --- /dev/null +++ b/cecli/tools/validations/__init__.py @@ -0,0 +1,5 @@ +"""Tool parameter validation module for BaseTool subclasses.""" + +from cecli.tools.validations.validations import ToolValidations + +__all__ = ["ToolValidations"] diff --git a/cecli/tools/validations/validations.py b/cecli/tools/validations/validations.py new file mode 100644 index 00000000000..531d295d59b --- /dev/null +++ b/cecli/tools/validations/validations.py @@ -0,0 +1,330 @@ +""" +Tool parameter validation module for BaseTool subclasses. + +Provides a framework for declarative parameter validation via VALIDATIONS dicts +on tool classes, along with built-in validation methods. + +The VALIDATIONS dict maps parameter paths (dot-separated, optionally with [] +for list iteration) to lists of validation method names that are executed +sequentially on the parameter value. + +Example:: + + VALIDATIONS = { + "delegations": ["coerce_list"], + "delegations[]": ["coerce_dict"], + "edits": ["coerce_list"], + "edits[].file_path": ["coerce_str"], + } +""" + +from __future__ import annotations + +import json + +import json_repair + +from cecli.helpers import responses +from cecli.tools.utils.helpers import ToolError + + +class ToolValidations: + """ + Registry of validation methods for tool parameters. + + Each classmethod in this class can be referenced by name in a tool's + VALIDATIONS dict. The ``validate_params`` classmethod orchestrates + the application of validations based on the dict. + """ + + @classmethod + def validate_params(cls, params: dict, validations: dict, schema: dict | None = None) -> dict: + """ + Apply validations to *params* according to the *validations* dict. + + Parameters are modified in place and also returned for convenience. + + Args: + params: The raw tool parameters dict. + validations: A VALIDATIONS dict mapping parameter paths to + lists of validation method names. + schema: The tool's SCHEMA dict (used for context, currently + reserved for future use). + + Returns: + The (possibly mutated) *params* dict. + """ + if isinstance(params, str): + params = json_repair.loads(params) + + if not isinstance(params, (dict, list)): + raise ToolError("Invalid Tool Input - Unparsable JSON") + + # Apply basic structural corrections before declarative validations + params = cls._basic_validations(params, schema) + + if not validations: + return params + + for raw_key, method_names in validations.items(): + # Determine whether the key targets list items (trailing "[]") + iterate_over_list = raw_key.endswith("[]") + clean_key = raw_key.rstrip("[]") + + # Split on dots to get the navigation path into params + path = clean_key.split(".") if clean_key else [] + + if not path: + continue + + if iterate_over_list: + cls._apply_validations_to_list_items(params, path, method_names) + else: + cls._apply_validations_to_value(params, path, method_names) + + return params + + @classmethod + def _basic_validations(cls, params: object, schema: dict | None = None) -> dict: + """ + Apply basic structural corrections to *params* based on *schema*. + + If the schema declares exactly one property of type ``array``: + - If *params* is a bare list, wrap it as ``{param_name: [...]}``. + - If *params* is a dict that doesn't contain the expected key, + wrap the dict in a list under that key: + ``{param_name: [{key: val, ...}]}``. + + Returns the (possibly corrected) *params* dict. + """ + if not schema or "function" not in schema: + return params + + function_schema = schema["function"] + if "parameters" not in function_schema: + return params + + parameters = function_schema["parameters"] + properties = parameters.get("properties", {}) + + # Only auto-correct when there is exactly one property and it is an array + if len(properties) == 1: + single_param_name = next(iter(properties.keys())) + param_schema = properties[single_param_name] + if param_schema.get("type") == "array": + # Case 1: LLM emitted the array directly (bare list) + if isinstance(params, list): + return {single_param_name: params} + # Case 2: LLM emitted a dict missing the expected key → wrap it + if isinstance(params, dict) and single_param_name not in params: + return {single_param_name: [params]} + + return params + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_nested(params: dict, path: list[str]) -> tuple[dict | list, str] | None: + """ + Navigate *params* along *path* and return ``(container, last_key)``. + + Returns ``None`` if any intermediate key is missing or isn't a dict. + """ + current: dict | list = params + for key in path[:-1]: + if isinstance(current, dict) and key in current: + current = current[key] + else: + return None + if isinstance(current, dict): + return current, path[-1] + return None + + @staticmethod + def _set_nested(params: dict, path: list[str], value: object) -> None: + """Set *value* at the nested location described by *path* in *params*.""" + current = params + for key in path[:-1]: + if key not in current: + current[key] = {} + current = current[key] + current[path[-1]] = value + + @classmethod + def _apply_validations_to_value( + cls, params: dict, path: list[str], method_names: list[str] + ) -> None: + """Apply the named validations sequentially to the value at *path*.""" + result = cls._get_nested(params, path) + if result is None: + return + container, last_key = result + if last_key not in container: + return + value = container[last_key] + + for method_name in method_names: + method = getattr(cls, method_name, None) + if method is None: + raise ToolError(f"Unknown validation method: {method_name}") + value = method(value) + if value is None: + # Validation chose to drop the value entirely + container[last_key] = value + return + + container[last_key] = value + + @classmethod + def _apply_validations_to_list_items( + cls, params: dict, path: list[str], method_names: list[str] + ) -> None: + """Apply validations to each item of the list found at *path*.""" + result = cls._get_nested(params, path) + if result is None: + return + container, last_key = result + if last_key not in container: + return + items = container[last_key] + + if not isinstance(items, list): + return + + new_items: list = [] + for item in items: + for method_name in method_names: + method = getattr(cls, method_name, None) + if method is None: + raise ToolError(f"Unknown validation method: {method_name}") + item = method(item) + if item is None: + break + if item is not None: + new_items.append(item) + + container[last_key] = new_items + + # ------------------------------------------------------------------ + # Built-in validation methods + # ------------------------------------------------------------------ + + @classmethod + def coerce_list(cls, item: object) -> list: + """ + Coerce *item* into a list. + + * If *item* is already a list it is returned as-is (after checking + for char-split JSON arrays). + * If *item* is a string it is parsed as JSON. A JSON array is + returned directly; a JSON object is wrapped in a list. + * If *item* is a dict it is wrapped in a list. + * Otherwise an empty list is returned. + """ + if isinstance(item, list): + # Check for per-character-split JSON arrays first + coerced = responses.try_join_char_split_json_array(item) + if coerced is not None: + return coerced + # Single-element wrapping a JSON string of an array/object + if len(item) == 1 and isinstance(item[0], str): + if item[0].strip().startswith(("[", "{", '"')): + item = item[0] + else: + return item + else: + return item + + if isinstance(item, str): + text = item.strip() + if not text: + return [] + parsed = responses.try_parse_json_value(text) + if isinstance(parsed, list): + return parsed + if isinstance(parsed, dict): + return [parsed] + + parsed = json_repair.loads(text, skip_json_loads=True) + + if isinstance(parsed, list): + return parsed + + return [] + + if isinstance(item, dict): + return [item] + + return [] + + @classmethod + def coerce_dict(cls, item: object) -> dict | None: + """ + Coerce *item* into a dict. + + * If *item* is already a dict it is returned as-is. + * If *item* is a string it is parsed as JSON; returns the dict if + successful, otherwise ``None``. + * All other types return ``None``. + """ + if isinstance(item, dict): + return item + if isinstance(item, str): + text = item.strip() + if not text: + return None + parsed = responses.try_parse_json_value(text) + if isinstance(parsed, dict): + return parsed + # Fallback: try json repaid json.loads + try: + parsed = json_repair.loads(text, skip_json_loads=True) + except (json.JSONDecodeError, ValueError): + return None + + if isinstance(parsed, dict): + return parsed + + return None + + @classmethod + def coerce_str(cls, item: object) -> str | None: + """Coerce *item* to a string, returning ``None`` if not possible.""" + if isinstance(item, str): + return item + if item is None: + return None + try: + return str(item) + except (ValueError, TypeError): + return None + + @classmethod + def coerce_int(cls, item: object) -> int | None: + """Coerce *item* to an int, returning ``None`` if not possible.""" + if isinstance(item, int) and not isinstance(item, bool): + return item + if isinstance(item, (float, str)): + try: + return int(item) + except (ValueError, TypeError): + return None + return None + + @classmethod + def coerce_bool(cls, item: object) -> bool | None: + """Coerce *item* to a bool, returning ``None`` if not possible.""" + if isinstance(item, bool): + return item + if isinstance(item, str): + low = item.strip().lower() + if low in ("true", "1", "yes"): + return True + if low in ("false", "0", "no"): + return False + return None + if isinstance(item, int): + return bool(item) + return None diff --git a/cecli/tui/app.py b/cecli/tui/app.py index a53cf6ad339..6d7a9ea320c 100644 --- a/cecli/tui/app.py +++ b/cecli/tui/app.py @@ -1535,9 +1535,18 @@ def _get_suggestions(self, text: str) -> list[str]: else: # Use standard command completions (no file fallback) try: - cmd_completions = commands.get_completions(cmd_name, coder=active_coder) + cmd_completions = commands.get_completions( + cmd_name, args=arg_prefix, coder=active_coder + ) if cmd_completions: - if arg_prefix: + exempt_from_substring_matching = { + "/model", + "/models", + "/agent-model", + "/editor-model", + "/weak-model", + } + if arg_prefix and cmd_name not in exempt_from_substring_matching: suggestions = [ c for c in cmd_completions if arg_prefix_lower in str(c).lower() ] diff --git a/cecli/website/docs/config/model-aliases.md b/cecli/website/docs/config/model-aliases.md index 4d46098cf4d..a4e137b3209 100644 --- a/cecli/website/docs/config/model-aliases.md +++ b/cecli/website/docs/config/model-aliases.md @@ -95,109 +95,6 @@ for alias, model in sorted(MODEL_ALIASES.items()): - `sonnet`: anthropic/claude-sonnet-4-20250514 -## Advanced Model Settings - -CECLI/Cecli supports model names with colon-separated suffixes (e.g., `gpt-5:high`) that map to additional configuration parameters defined in the relevant config.yml file. This allows you to create named configurations for different use cases. These configurations map precisely to the LiteLLM `completion()` method parameters [here](https://docs.litellm.ai/docs/completion/input), though more are supported for specific models and providers. Any key under the `model_settings` key will override the model parameters defined in files like `.cecli.model.settings.yml` (more information [here](https://cecli.dev/docs/config/adv-model-settings.html)) - -### Configuration File - -Add a structure like the following to your config.yml file or create a `.cecli.model.overrides.yml` file (or specify a different file with `--model-overrides-file` if there are global defaults you want): - -```yaml -model-overrides: - gpt-5: - high: # Use with: --model gpt-5:high - temperature: 0.8 - top_p: 0.9 - extra_body: - reasoning_effort: high - low: # Use with: --model gpt-5:low - temperature: 0.2 - top_p: 0.5 - creative: # Use with: --model gpt-5:creative - temperature: 0.9 - top_p: 0.95 - frequency_penalty: 0.5 - - claude-4-5-sonnet: - fast: # Use with: --model claude-3-5-sonnet:fast - temperature: 0.3 - detailed: # Use with: --model claude-3-5-sonnet:detailed - temperature: 0.7 - thinking_tokens: 4096 -``` - -### Usage - -You can use these suffixes with any model argument: - -```bash -# Main model with high reasoning effort (using file) -cecli --model gpt-5:high --model-overrides-file .cecli.model.overrides.yml - -# Main model with high reasoning effort (using direct JSON/YAML) -cecli --model gpt-5:high --model-overrides '{"gpt-5": {"high": {"temperature": 0.8, "top_p": 0.9, "extra_body": {"reasoning_effort": "high"}}}}' - -# Different configurations for main and weak models -cecli --model claude-3-5-sonnet:detailed --weak-model claude-3-5-sonnet:fast - -# Editor model with creative settings -cecli --model gpt-5 --editor-model gpt-5:creative -``` - -### How It Works - -1. When you specify a model with a suffix (e.g., `gpt-5:high`), cecli splits it into the base model name (`gpt-5`) and suffix (`high`). -2. It looks up the suffix in the overrides file for that model. -3. The corresponding configuration parameters are applied to the model's API calls. -4. The parameters are deep-merged into the model's existing settings, with overrides taking precedence. - - -### Default Overrides - -In addition to suffix-based overrides, you can define **default overrides** that apply directly to a model by name without requiring a colon-separated suffix. Use the special `defaults` key within your `model-overrides` configuration: - -```yaml -model-overrides: - defaults: - gpt-5: - temperature: 0.7 - top_p: 0.9 - claude-4-5-sonnet: - temperature: 1 - model_settings: - cache_control: true -``` - -When you run `cecli --model gpt-5`, the default overrides specified under `defaults` are applied automatically. This is useful for setting baseline parameters for specific models without creating a named configuration. - -Default overrides work alongside suffix-based overrides. If both a default override and a suffix override match the same parameter, the suffix override takes precedence: - -```bash -# Applies default overrides for gpt-5 -cecli --model gpt-5 - -# Applies suffix-based overrides for gpt-5:high, merged on top of defaults -cecli --model gpt-5:high -``` - -```yaml -model-overrides: - defaults: - gpt-5: - temperature: 0.7 - gpt-5: - high: - temperature: 0.9 # Overrides the default of 0.7 -``` - -### Priority - -Model overrides work alongside aliases. For example, you can use: -- `cecli --model fast:high` (if `fast` is an alias for `gpt-5-mini`) -- `cecli --model sonnet:detailed` (if `sonnet` is an alias for `anthropic/claude-sonnet-4-20250514`) - -The suffix is applied after alias resolution. ## Priority diff --git a/cecli/website/docs/config/model-configuration.md b/cecli/website/docs/config/model-configuration.md new file mode 100644 index 00000000000..301b199c6af --- /dev/null +++ b/cecli/website/docs/config/model-configuration.md @@ -0,0 +1,132 @@ +--- +parent: Configuration +nav_order: 900 +description: Configure model overrides, alias-based suffixes, and structured override groups. +--- + +## Model Configuration & Overrides + +CECLI allows you to customize and override LLM configurations to fine-tune their behavior, API parameters, and metadata. You can organize these overrides into three logical configuration groups, and apply them either as **defaults** (by model name) or via **suffixes** (e.g., `gpt-5:high`). + +--- + +## Core Configuration Groups + +For advanced configurations, you can organize override parameters into three logical groups: `api`, `llm`, and `agent`. +### 1. `api` +Values under `api` are merged directly into the model's API request parameters (`headers`). This is useful for configuring provider-specific API options, temperature, or custom headers. For the full list of supported parameters, see the [LiteLLM completion input documentation](https://docs.litellm.ai/docs/completion/input). +- **Common parameters**: `temperature`, `top_p`, `max_tokens`, `parallel_tool_calls`, `extra_body` (e.g., `thinking: true` or `reasoning_effort: "high"`). + +### 2. `llm` +Values under `llm` are merged into the model's info dictionary (`self.info`). This allows you to override or augment model metadata and capabilities. For a comprehensive list of available model metadata fields, see the [LiteLLM model prices and context window reference](https://github.com/BerriAI/litellm/blob/litellm_internal_staging/model_prices_and_context_window.json). +- **Common parameters**: `supports_vision`, `supports_function_calling`, token limits, or pricing information. + +### 3. `agent` +Values under `agent` modify CECLI's internal `ModelSettings` fields. This controls how CECLI interacts with the model and manages the workspace. For all supported fields, the `ModelSettings` class in [models.py](https://github.com/cecli-dev/cecli/blob/main/cecli/models.py) contains the most comprehensive list. +- **Common parameters**: `edit_format`, `use_repo_map`, `cache_control`, `caches_by_default`. + +--- + +## Application Methods + +You can apply these configuration groups in two ways: + +### 1. Default Overrides +Default overrides apply automatically to a model by name without requiring any suffix. Use the special `defaults` key within your `model-overrides` configuration. + +When you run `cecli --model gpt-5`, any default overrides specified under `defaults` for `gpt-5` are applied automatically. + +### 2. Suffix-Based Overrides +Suffix-based overrides allow you to define named configurations for different use cases using a colon-separated suffix (e.g., `gpt-5:high` or `claude-3-5-sonnet:fast`). + +When you specify a model with a suffix, CECLI splits it into the base model name and the suffix, looks up the suffix configuration, and merges it on top of any default settings. + +--- + +## Configuration File Example + +You can define these overrides in your `config.yml` file, a `.cecli.model.overrides.yml` file, or a custom file specified via `--model-overrides-file`. + +```yaml +model-overrides: + # 1. Default overrides (applied automatically by model name) + defaults: + openai/gpt-5.5: + api: + temperature: 0.7 + top_p: 0.9 + anthropic.claude-sonnet-4-6: + api: + temperature: 1.0 + llm: + supports_vision: true + supports_function_calling: true + agent: + cache_control: true + + # 2. Suffix-based overrides (applied when using model:suffix) + openai/gpt-5.5: + high: + api: + temperature: 0.8 + top_p: 0.9 + extra_body: + reasoning_effort: high + low: + api: + temperature: 0.2 + top_p: 0.5 + creative: + api: + temperature: 0.9 + top_p: 0.95 + frequency_penalty: 0.5 + + anthropic.claude-sonnet-4-6: + fast: + api: + temperature: 0.3 + detailed: + api: + temperature: 0.7 + thinking_tokens: 4096 +``` + +--- + +## Usage & CLI Examples + +You can reference these configurations in any model argument on the command line: + +```bash +# Applies default overrides for gpt-5 +cecli --model gpt-5 + +# Applies suffix-based overrides for gpt-5:high, merged on top of defaults +cecli --model gpt-5:high --model-overrides-file .cecli.model.overrides.yml + +# Different configurations for main and weak models +cecli --model claude-3-5-sonnet:detailed --weak-model claude-3-5-sonnet:fast + +# Editor model with creative settings +cecli --model gpt-5 --editor-model gpt-5:creative + +# Direct JSON/YAML overrides via CLI +cecli --model gpt-5:high --model-overrides '{"gpt-5": {"high": {"api": {"temperature": 0.8}}}}' +``` + +--- + +## Resolution & Priority + +When resolving model configurations, CECLI applies overrides in the following order of precedence (highest priority first): + +1. **Suffix-Based Overrides**: Specific suffix configurations (e.g., `:high`) override default settings. +2. **Default Overrides**: Settings defined under the `defaults` key for the model. +3. **Base Model Settings**: The model's built-in or system-defined parameters. + +### Alias Resolution +If you use a model alias (e.g., `fast` as an alias for `gpt-5-mini`), the alias is resolved to the base model name **before** any suffixes or overrides are applied. + +For example: +- `cecli --model fast:high` resolves `fast` to `gpt-5-mini`, then applies the `high` suffix overrides defined for `gpt-5-mini`. diff --git a/cecli/website/docs/config/workspaces.md b/cecli/website/docs/config/workspaces.md index fc0fd736c93..822aead0b5a 100644 --- a/cecli/website/docs/config/workspaces.md +++ b/cecli/website/docs/config/workspaces.md @@ -5,7 +5,11 @@ description: Workspaces allow you to work across multiple related repositories s --- # Workspaces -Workspaces allow you to manage multiple git repositories within a single monorepo-like folder structure, enabling development across multiple related projects. +Workspaces allow you to manage multiple git repositories within a single monorepo-like folder structure, enabling development across multiple related projects. `cecli` supports two workspace modes: + +**clone** workspaces (remote `repo:` URLs cloned into `~/.cecli/workspaces/`) + +**local** workspaces (existing on-disk git roots referenced by absolute `path:`) ## Configuration @@ -15,6 +19,8 @@ You can configure workspaces in multiple locations. `cecli` searches for configu 2. **Local Workspaces File**: `.cecli.workspaces.yml` or `.cecli.workspaces.yaml` in the current directory. 3. **Global Workspaces File**: `~/.cecli/workspaces.yml` or `~/.cecli/workspaces.yaml`. +4. **Repo-Local Config File**: `.cecli.workspaces.yml` or `.cecli.workspaces.yaml` placed at a common ancestor of your project directories. `cecli` discovers this file by walking up from any project path, enabling a **local** workspace layout without cloning into `~/.cecli/workspaces/`. + ### Example Configuration ```yaml @@ -34,6 +40,38 @@ workspaces: ignore: "~/.cecli/backend.ignore" # Optional: Path to a custom ignore file for this project ``` +### Local Workspace Configuration + +For **local** workspaces, place a `.cecli.workspaces.yml` file at a common ancestor of your project directories. Each project references an existing git root via `path:` instead of a remote `repo:` URL. + +```yaml +# .cecli.workspaces.yml +name: my-workspace +projects: + - name: app + path: /abs/path/to/app + primary: true # At most one project can be primary + - name: lib + path: /abs/path/to/lib + readonly: true # Prevents commits to this project +``` + +**Validation rules:** + +- Each project must have a `name` and **exactly one** of `path` (local git root) or `repo` (clone URL). +- At most one project can be marked `primary: true`. +- Projects with `readonly: true` are excluded from commits. + +### Path Layout + +The workspace layout determines how file paths are structured within the workspace: + +| Layout | Prefix | Example | +|--------|--------|--------| +| **clone** (repo-based) | `{project}/main/{file}` | `app/main/src/main.py` | +| **local** (path-based) | `{project}/{file}` | `app/src/main.py` | + + ### Multiple Workspaces You can define a list of workspaces. Use the `active: true` flag to specify which one should be used by default when running `cecli` without the `--workspace-name` argument. **Note: At most one workspace can be marked as active.** @@ -61,9 +99,9 @@ cecli --workspace-name my-workspace cecli --workspaces path/to/workspaces.yml --workspace-name my-workspace ``` -If the workspace does not exist, `cecli` will create the directory structure at `~/.cecli/workspaces/my-workspace/` and clone the configured repositories. +If the workspace does not exist, `cecli` will create the directory structure at `~/.cecli/workspaces/my-workspace/` and clone the configured repositories. For **local** workspaces, the configured `path:` directories are used in-place — no cloning occurs. -### Workspace Structure +### Clone Workspace Structure ``` ~/.cecli/workspaces/ @@ -74,6 +112,17 @@ If the workspace does not exist, `cecli` will create the directory structure at └── worktrees/ # Additional worktrees ``` +### Local Workspace Structure + +Local workspaces do **not** create a `~/.cecli/workspaces/` directory. Instead, the config file directory itself serves as the workspace root, with metadata stored at: + +``` +.cecli/ +└── .workspace-meta.json +``` + +The project directories exist at their configured `path:` locations on disk. + ## Arguments `--workspaces `: Provide a JSON/YAML configuration or file path for workspace initialization. diff --git a/requirements.txt b/requirements.txt index a22e283dc96..d9800f1c66c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -161,6 +161,10 @@ jiter==0.12.0 # via # -c requirements/common-constraints.txt # openai +json-repair==0.60.1 + # via + # -c requirements/common-constraints.txt + # -r requirements/requirements.in json5==0.12.1 # via # -c requirements/common-constraints.txt diff --git a/requirements/common-constraints.txt b/requirements/common-constraints.txt index 402f53e540d..a2c112f9acc 100644 --- a/requirements/common-constraints.txt +++ b/requirements/common-constraints.txt @@ -175,6 +175,8 @@ joblib==1.5.2 # via # nltk # scikit-learn +json-repair==0.60.1 + # via -r requirements/requirements.in json5==0.12.1 # via -r requirements/requirements.in jsonschema==4.25.1 diff --git a/requirements/requirements.in b/requirements/requirements.in index acc9e3ae3fd..6aa96211a99 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -33,6 +33,7 @@ truststore xxhash>=3.6.0 py-cymbal>=0.1.24 cryptography>=42.0.0 +json-repair>=0.60.1 # File system lookup aids marisa-trie>=1.0 diff --git a/tests/helpers/monorepo/LOCAL_WORKSPACE.md b/tests/helpers/monorepo/LOCAL_WORKSPACE.md new file mode 100644 index 00000000000..dd4311195cb --- /dev/null +++ b/tests/helpers/monorepo/LOCAL_WORKSPACE.md @@ -0,0 +1,66 @@ +# PR: Local `path:` projects for cecli workspaces + +## Summary + +Extends cecli’s existing **clone** workspace mode (`repo:` URLs under `~/.cecli/workspaces/`, paths like `project/main/file.py`) with **local** layout: multiple git roots on disk referenced by absolute `path:` in a repo-local config file. + +## Motivation + +IDE clients (e.g. BrightVision) open a **primary git repo** but need agent context across **sibling repos** without cloning into `~/.cecli/workspaces/`. Submodule-only setups are a different layout; this PR adds an explicit, reviewable config surface. + +## Config + +Place at the workspace root (walked up from any listed project path): + +```yaml +# .cecli.workspaces.yml +name: my-workspace +projects: + - name: app + path: /abs/path/to/app + primary: true + - name: lib + path: /abs/path/to/lib + readonly: true +``` + +Rules (enforced in `validate_config`): + +- Each project: `name` + **exactly one** of `path` or `repo` +- At most one `primary: true` + +## Path layout + +| Layout | Prefix | Example | +|--------|--------|---------| +| **local** (this PR) | `{project}/{file}` | `app/src/main.py` | +| **clone** (existing) | `{project}/main/{file}` | `app/main/src/main.py` | + +## Behavior changes + +| Area | Change | +|------|--------| +| `GitRepo.__init__` | Multiple git roots allowed when `.cecli.workspaces.yml` is found on a common ancestor | +| `get_workspace_files` | Local layout unions `git ls-files` from each `path:` root | +| `commit` | Local layout commits per underlying repo (`_commit_local_workspace`) | +| `abs_root_path` | Resolves prefixed paths to the correct project root | + +Clone workspaces and `.cecli-workspace.json` metadata are unchanged. + +## Tests + +- `tests/helpers/monorepo/test_config.py` — validation (`path` / `repo` XOR) +- `tests/helpers/monorepo/test_local_workspace.py` — helpers + `GitRepo` integration +- Existing `test_repomap_workspace.py`, `test_workspace.py`, etc. — still pass (clone layout) + +Run: + +```bash +pytest tests/helpers/monorepo -q +``` + +## Non-goals (follow-up PRs) + +- Auto-registering git submodules into the workspace registry +- Combining submodule `RepoSet` with local YAML in one facade +- New global config file formats (reuse `.cecli.workspaces.yml` only) diff --git a/tests/helpers/monorepo/test_config.py b/tests/helpers/monorepo/test_config.py index bade92cb61b..0d85669753c 100644 --- a/tests/helpers/monorepo/test_config.py +++ b/tests/helpers/monorepo/test_config.py @@ -13,11 +13,36 @@ def test_validate_config_no_name(): validate_config({"projects": []}) -def test_validate_config_invalid_project(): - with pytest.raises(ValueError, match="Each project must have a 'name' and 'repo' URL"): +def test_validate_config_invalid_project_missing_source(): + with pytest.raises(ValueError, match="exactly one of 'path' or 'repo'"): validate_config({"name": "test", "projects": [{"name": "p1"}]}) +def test_validate_config_invalid_project_both_sources(): + with pytest.raises(ValueError, match="exactly one of 'path' or 'repo'"): + validate_config( + { + "name": "test", + "projects": [ + { + "name": "p1", + "path": "/tmp/p1", + "repo": "https://github.com/org/r.git", + } + ], + } + ) + + +def test_validate_config_path_project(): + validate_config( + { + "name": "local", + "projects": [{"name": "app", "path": "/abs/app", "primary": True}], + } + ) + + def test_validate_config_duplicate_project(): with pytest.raises(ValueError, match="Duplicate project name: p1"): validate_config( diff --git a/tests/helpers/monorepo/test_local_workspace.py b/tests/helpers/monorepo/test_local_workspace.py new file mode 100644 index 00000000000..db29df4b06d --- /dev/null +++ b/tests/helpers/monorepo/test_local_workspace.py @@ -0,0 +1,235 @@ +"""Tests for repo-local workspaces (``path:`` git roots, ``.cecli.workspaces.yml``).""" + +from __future__ import annotations + +import json +import subprocess +from pathlib import Path + +import pytest +import yaml + +from cecli.helpers.monorepo.config import load_workspace_config_file, validate_config +from cecli.helpers.monorepo.local_workspace import ( + find_workspace_config_file, + load_workspace_file, + primary_project, + project_path_prefix, + read_workspace_metadata, + resolve_workspace_file_path, + union_tracked_files, + write_workspace_metadata, +) +from cecli.io import InputOutput +from cecli.repo import GitRepo +from cecli.utils import make_repo + + +def _init_git_repo(path: Path, readme: str = "# repo\n") -> None: + make_repo(path) + readme_path = path / "README.md" + readme_path.write_text(readme, encoding="utf-8") + subprocess.run(["git", "add", "README.md"], cwd=path, check=True, capture_output=True) + subprocess.run( + ["git", "commit", "-m", "init", "--no-gpg-sign"], + cwd=path, + check=True, + capture_output=True, + ) + + +@pytest.fixture +def two_path_projects(tmp_path: Path): + """ + Workspace root with ``.cecli.workspaces.yml`` and two sibling git checkouts. + + Layout:: + + ws/ + .cecli.workspaces.yml + app/ (git) + lib/ (git) + """ + ws = tmp_path / "ws" + app = ws / "app" + lib = ws / "lib" + app.mkdir(parents=True) + lib.mkdir(parents=True) + _init_git_repo(app, "# app\n") + _init_git_repo(lib, "# lib\n") + + config = { + "name": "pair", + "projects": [ + {"name": "app", "path": str(app.resolve()), "primary": True}, + {"name": "lib", "path": str(lib.resolve())}, + ], + } + (ws / ".cecli.workspaces.yml").write_text( + yaml.dump(config, sort_keys=False), + encoding="utf-8", + ) + return ws, config, app, lib + + +class TestValidateConfigPathProjects: + def test_path_only_project_valid(self): + validate_config( + { + "name": "local", + "projects": [{"name": "app", "path": "/tmp/app", "primary": True}], + } + ) + + def test_repo_only_project_valid(self): + validate_config( + { + "name": "clone", + "projects": [{"name": "p1", "repo": "https://github.com/org/r.git"}], + } + ) + + def test_missing_path_and_repo(self): + with pytest.raises(ValueError, match="exactly one of 'path' or 'repo'"): + validate_config({"name": "test", "projects": [{"name": "p1"}]}) + + def test_both_path_and_repo(self): + with pytest.raises(ValueError, match="exactly one of 'path' or 'repo'"): + validate_config( + { + "name": "test", + "projects": [ + { + "name": "p1", + "path": "/tmp/a", + "repo": "https://github.com/org/r.git", + } + ], + } + ) + + def test_multiple_primary(self): + with pytest.raises(ValueError, match="Only one project may be marked primary"): + validate_config( + { + "name": "test", + "projects": [ + {"name": "a", "path": "/a", "primary": True}, + {"name": "b", "path": "/b", "primary": True}, + ], + } + ) + + +class TestLocalWorkspaceHelpers: + def test_find_workspace_config_file_walks_up(self, two_path_projects): + ws, _config, app, _lib = two_path_projects + expected = (ws / ".cecli.workspaces.yml").resolve() + assert find_workspace_config_file(ws).resolve() == expected + # YAML lives at workspace root; project checkout is a subdirectory. + assert find_workspace_config_file(app).resolve() == expected + assert find_workspace_config_file(app / "README.md").resolve() == expected + + def test_load_workspace_config_file(self, two_path_projects): + ws, _config, _app, _lib = two_path_projects + loaded = load_workspace_config_file(ws / ".cecli.workspaces.yml") + assert loaded["name"] == "pair" + assert len(loaded["projects"]) == 2 + + def test_union_tracked_files(self, two_path_projects): + ws, config, _app, _lib = two_path_projects + files = union_tracked_files(ws, config, layout="local") + assert "app/README.md" in files + assert "lib/README.md" in files + + def test_resolve_workspace_file_path_prefixed(self, two_path_projects): + ws, config, _app, _lib = two_path_projects + resolved = resolve_workspace_file_path(ws, "lib/README.md", config, layout="local") + assert resolved is not None + git_root, abs_path, in_repo = resolved + assert in_repo == "README.md" + assert abs_path.name == "README.md" + assert git_root.name == "lib" + + def test_project_path_prefix_local_vs_clone(self): + proj = {"name": "app"} + assert project_path_prefix(proj, layout="local") == "app" + assert project_path_prefix(proj, layout="clone") == "app/main" + + def test_primary_project_explicit_and_implicit(self): + cfg = { + "projects": [ + {"name": "a", "path": "/a"}, + {"name": "b", "path": "/b", "primary": True}, + ] + } + assert primary_project(cfg)["name"] == "b" + + single = {"projects": [{"name": "only", "path": "/only"}]} + assert primary_project(single)["name"] == "only" + + def test_workspace_metadata_roundtrip(self, two_path_projects): + ws, config, _app, _lib = two_path_projects + write_workspace_metadata(ws, config, layout="local") + meta = read_workspace_metadata(ws) + assert meta is not None + loaded, layout = meta + assert layout == "local" + assert loaded["name"] == config["name"] + meta_path = ws / ".cecli" / ".workspace-meta.json" + assert meta_path.is_file() + on_disk = json.loads(meta_path.read_text(encoding="utf-8")) + assert on_disk.get("_layout") == "local" + + +class TestGitRepoLocalWorkspace: + def test_detects_local_workspace_and_unions_files(self, two_path_projects): + ws, _config, app, lib = two_path_projects + io = InputOutput(yes=True) + repo = GitRepo(io, [str(app / "README.md"), str(lib / "README.md")], None) + + assert repo.is_workspace + assert repo.workspace_layout == "local" + assert repo.workspace_path == ws.resolve() + + files = repo.get_workspace_files() + assert "app/README.md" in files + assert "lib/README.md" in files + + def test_abs_root_path_resolves_prefixed_path(self, two_path_projects): + _ws, _config, app, _lib = two_path_projects + io = InputOutput(yes=True) + repo = GitRepo(io, [str(app)], None) + + abs_path = Path(repo.abs_root_path("app/README.md")) + assert abs_path == (app / "README.md").resolve() + + def test_without_workspace_file_multi_repo_fails(self, tmp_path: Path): + root = tmp_path / "orphan" + a = root / "a" + b = root / "b" + a.mkdir(parents=True) + b.mkdir(parents=True) + _init_git_repo(a) + _init_git_repo(b) + io = InputOutput(yes=True) + with pytest.raises(FileNotFoundError): + GitRepo(io, [str(a / "README.md"), str(b / "README.md")], None) + + def test_load_workspace_file_defaults(self, tmp_path: Path): + path = tmp_path / ".cecli.workspaces.yml" + path.write_text("projects: []\n", encoding="utf-8") + loaded = load_workspace_file(path) + assert "name" in loaded + assert loaded["projects"] == [] + + +class TestGitRepoLocalWorkspaceNoYaml: + def test_single_repo_without_yaml_is_not_local_workspace(self, tmp_path: Path): + repo_dir = tmp_path / "solo" + repo_dir.mkdir() + _init_git_repo(repo_dir) + io = InputOutput(yes=True) + repo = GitRepo(io, [str(repo_dir)], None) + assert find_workspace_config_file(repo_dir) is None + assert getattr(repo, "workspace_layout", "clone") != "local" or not repo.is_workspace diff --git a/tests/tools/test_get_lines.py b/tests/tools/test_get_lines.py index 8c7fd7705b3..1bbeb0d3b6c 100644 --- a/tests/tools/test_get_lines.py +++ b/tests/tools/test_get_lines.py @@ -142,3 +142,31 @@ def test_multiline_pattern_search(coder_with_file): assert "Retrieved context for 1 operation(s)" in result coder.io.tool_error.assert_not_called() + + +def test_empty_file_includes_edit_hint(tmp_path): + empty = tmp_path / "pubspec.yaml" + empty.write_text("") + coder = DummyCoder(tmp_path) + + from unittest.mock import patch + + with patch("cecli.helpers.conversation.ConversationService") as conv: + conv.get_files.return_value.clear_ranges = Mock() + conv.get_files.return_value.push_range = Mock() + conv.get_chunks.return_value.add_file_context_messages = Mock() + result = read_range.Tool.execute( + coder, + show=[ + { + "file_path": "pubspec.yaml", + "start_text": "@000", + "end_text": "@000", + } + ], + ) + + assert "pubspec.yaml is empty" in result + assert "EditText" in result + assert "readrange again" in result.lower() + coder.io.tool_error.assert_not_called() diff --git a/tests/tools/test_insert_block.py b/tests/tools/test_insert_block.py index 054232732c7..9e5ae2b855e 100644 --- a/tests/tools/test_insert_block.py +++ b/tests/tools/test_insert_block.py @@ -121,7 +121,7 @@ def test_mutually_exclusive_parameters_raise(coder_with_file): ) assert result.startswith("Error in EditText:") - assert "Invalid Edit - Edit Results In Same Content" in result + assert "Invalid Edit - Update content hash bounds" in result assert file_path.read_text().startswith("first line") coder.io.tool_error.assert_called() diff --git a/tests/tools/test_tool_arguments.py b/tests/tools/test_tool_arguments.py index c06ff1d5d83..c63a84135a7 100644 --- a/tests/tools/test_tool_arguments.py +++ b/tests/tools/test_tool_arguments.py @@ -98,7 +98,7 @@ def test_grep_format_output_empty_searches_does_not_crash_tool_footer(): mcp_server=SimpleNamespace(name="Local"), tool_response=tool_response, ) - assert coder.io.tool_error.called + assert not coder.io.tool_error.called def test_try_join_char_split_json_array_reconstructs_array(): diff --git a/tests/tools/validations.py b/tests/tools/validations.py new file mode 100644 index 00000000000..95f58bc183b --- /dev/null +++ b/tests/tools/validations.py @@ -0,0 +1,463 @@ +"""Tests for the ToolValidations class.""" + +from __future__ import annotations + +from cecli.tools.validations import ToolValidations + +# ========================================================================= +# _basic_validations tests +# ========================================================================= + + +class TestBasicValidations: + """Structural corrections: bare list → dict wrapping.""" + + def test_bare_list_wraps_into_single_array_property(self): + """A bare list param should be wrapped when schema has one array property.""" + schema = { + "function": { + "parameters": { + "properties": { + "delegations": { + "type": "array", + "items": {"type": "object"}, + } + } + } + } + } + params = [{"name": "a1", "prompt": "do stuff"}] + result = ToolValidations._basic_validations(params, schema) + assert result == {"delegations": params} + + def test_dict_params_pass_through(self): + """Already-wrapped dict params should pass through unchanged.""" + schema = { + "function": { + "parameters": { + "properties": { + "delegations": { + "type": "array", + "items": {"type": "object"}, + } + } + } + } + } + params = {"delegations": [{"name": "a1", "prompt": "do stuff"}]} + result = ToolValidations._basic_validations(params, schema) + assert result is params + + def test_multi_property_schema_does_not_wrap(self): + """Bare list should not wrap when schema has multiple properties.""" + schema = { + "function": { + "parameters": { + "properties": { + "add": {"type": "array"}, + "remove": {"type": "array"}, + } + } + } + } + params = ["file_a.py", "file_b.py"] + result = ToolValidations._basic_validations(params, schema) + assert result is params + + def test_non_array_property_does_not_wrap(self): + """Bare list should not wrap when the single property is not an array.""" + schema = { + "function": { + "parameters": { + "properties": { + "name": {"type": "string"}, + } + } + } + } + params = ["some", "strings"] + result = ToolValidations._basic_validations(params, schema) + assert result is params + + def test_no_schema_does_not_wrap(self): + """Bare list should passthrough when schema is None.""" + result = ToolValidations._basic_validations([1, 2, 3], None) + assert result == [1, 2, 3] + + def test_empty_properties_does_not_wrap(self): + """Bare list should passthrough when schema has no properties.""" + schema = {"function": {"parameters": {"properties": {}}}} + result = ToolValidations._basic_validations([1, 2, 3], schema) + assert result == [1, 2, 3] + + +# ========================================================================= +# coerce_list tests +# ========================================================================= + + +class TestCoerceList: + """List coercion: strings, dicts, and edge cases.""" + + def test_actual_list_passthrough(self): + """A proper list should pass through unchanged.""" + data = [{"a": 1}, {"b": 2}] + result = ToolValidations.coerce_list(data) + assert result == data + + def test_json_string_array(self): + """A JSON string containing an array should be parsed.""" + result = ToolValidations.coerce_list('[{"a": 1}, {"b": 2}]') + assert result == [{"a": 1}, {"b": 2}] + + def test_json_string_dict(self): + """A JSON string containing a dict should be wrapped in a list.""" + result = ToolValidations.coerce_list('{"task": "hello"}') + assert result == [{"task": "hello"}] + + def test_bare_dict_wraps_in_list(self): + """A bare dict should be wrapped in a list.""" + result = ToolValidations.coerce_list({"task": "hello"}) + assert result == [{"task": "hello"}] + + def test_empty_string_returns_empty_list(self): + """An empty string should return an empty list.""" + assert ToolValidations.coerce_list("") == [] + assert ToolValidations.coerce_list(" ") == [] + + def test_integer_returns_empty_list(self): + """A non-list, non-dict, non-string input should return empty list.""" + assert ToolValidations.coerce_list(42) == [] + assert ToolValidations.coerce_list(None) == [] + + def test_char_split_json_array(self): + """A char-split JSON array should be reconstructed.""" + items = ["[", "{", '"', "t", "a", "s", "k", '"', ":", " ", '"', "x", '"', "}", "]"] + result = ToolValidations.coerce_list(items) + assert result == [{"task": "x"}] + + def test_single_item_wrapping_json_string_list(self): + """A single-element list wrapping a JSON array string should unwrap.""" + result = ToolValidations.coerce_list(['[{"a": 1}]']) + assert result == [{"a": 1}] + + def test_single_item_wrapping_json_string_dict(self): + """A single-element list wrapping a JSON dict string should unwrap.""" + result = ToolValidations.coerce_list(['{"a": 1}']) + assert result == [{"a": 1}] + + +# ========================================================================= +# coerce_dict tests +# ========================================================================= + + +class TestCoerceDict: + """Dict coercion: strings, dicts, and edge cases.""" + + def test_actual_dict_passthrough(self): + """A proper dict should pass through unchanged.""" + data = {"name": "test", "prompt": "do stuff"} + result = ToolValidations.coerce_dict(data) + assert result is data + + def test_json_string_object(self): + """A JSON string object should be parsed into a dict.""" + result = ToolValidations.coerce_dict('{"name": "test", "prompt": "do stuff"}') + assert result == {"name": "test", "prompt": "do stuff"} + + def test_empty_string_returns_none(self): + """An empty string should return None.""" + assert ToolValidations.coerce_dict("") is None + assert ToolValidations.coerce_dict(" ") is None + + def test_invalid_json_string_returns_none(self): + """An invalid JSON string should return None.""" + assert ToolValidations.coerce_dict("{broken") is None + assert ToolValidations.coerce_dict("hello") is None + + def test_integer_returns_none(self): + """A non-dict, non-string input should return None.""" + assert ToolValidations.coerce_dict(42) is None + + def test_list_returns_none(self): + """A list input should return None.""" + assert ToolValidations.coerce_dict([1, 2, 3]) is None + + def test_none_returns_none(self): + """None input should return None.""" + assert ToolValidations.coerce_dict(None) is None + + +# ========================================================================= +# coerce_str tests +# ========================================================================= + + +class TestCoerceStr: + """String coercion.""" + + def test_string_passthrough(self): + """A string should pass through unchanged.""" + assert ToolValidations.coerce_str("hello") == "hello" + + def test_integer_to_string(self): + """An integer should be converted to string.""" + assert ToolValidations.coerce_str(42) == "42" + + def test_float_to_string(self): + """A float should be converted to string.""" + assert ToolValidations.coerce_str(3.14) == "3.14" + + def test_none_returns_none(self): + """None should return None.""" + assert ToolValidations.coerce_str(None) is None + + +# ========================================================================= +# coerce_int tests +# ========================================================================= + + +class TestCoerceInt: + """Integer coercion.""" + + def test_int_passthrough(self): + """An integer should pass through unchanged.""" + assert ToolValidations.coerce_int(42) == 42 + + def test_string_number(self): + """A numeric string should be converted to int.""" + assert ToolValidations.coerce_int("42") == 42 + + def test_float_truncates(self): + """A float should be truncated to int.""" + assert ToolValidations.coerce_int(3.99) == 3 + + def test_invalid_string_returns_none(self): + """A non-numeric string should return None.""" + assert ToolValidations.coerce_int("hello") is None + + def test_none_returns_none(self): + """None should return None.""" + assert ToolValidations.coerce_int(None) is None + + def test_bool_returns_none(self): + """A boolean should return None (bool is a subclass of int but we exclude it).""" + assert ToolValidations.coerce_int(True) is None + + +# ========================================================================= +# coerce_bool tests +# ========================================================================= + + +class TestCoerceBool: + """Boolean coercion.""" + + def test_bool_passthrough(self): + """A boolean should pass through unchanged.""" + assert ToolValidations.coerce_bool(True) is True + assert ToolValidations.coerce_bool(False) is False + + def test_string_true_variants(self): + """Truthy strings should be coerced to True.""" + assert ToolValidations.coerce_bool("true") is True + assert ToolValidations.coerce_bool("True") is True + assert ToolValidations.coerce_bool("1") is True + assert ToolValidations.coerce_bool("yes") is True + + def test_string_false_variants(self): + """Falsy strings should be coerced to False.""" + assert ToolValidations.coerce_bool("false") is False + assert ToolValidations.coerce_bool("False") is False + assert ToolValidations.coerce_bool("0") is False + assert ToolValidations.coerce_bool("no") is False + + def test_integer_truthy(self): + """Truthy integers should be coerced to True.""" + assert ToolValidations.coerce_bool(1) is True + assert ToolValidations.coerce_bool(5) is True + + def test_integer_falsy(self): + """Falsy integers should be coerced to False.""" + assert ToolValidations.coerce_bool(0) is False + + def test_invalid_string_returns_none(self): + """An unrecognised truthy/falsy string should return None.""" + assert ToolValidations.coerce_bool("maybe") is None + assert ToolValidations.coerce_bool("") is None + + +# ========================================================================= +# validate_params integration tests +# ========================================================================= + + +class TestValidateParams: + """Full workflow: validate_params orchestrator.""" + + # ---- empty / None validations ---- + + def test_empty_validations_returns_params(self): + """An empty VALIDATIONS dict should return params unchanged.""" + params = {"key": "value"} + result = ToolValidations.validate_params(params, {}) + assert result == {"key": "value"} + + def test_none_validations_returns_params(self): + """A None VALIDATIONS dict should return params unchanged.""" + params = {"key": "value"} + result = ToolValidations.validate_params(params, None) + assert result == {"key": "value"} + + # ---- simple keys ---- + + def test_simple_key_coerce_list(self): + """A simple key should apply validation to the top-level param value.""" + params = {"delegations": '[{"name": "a1"}]'} + result = ToolValidations.validate_params( + params, + {"delegations": ["coerce_list"]}, + ) + assert result == {"delegations": [{"name": "a1"}]} + + def test_simple_key_coerce_dict(self): + """A simple key should coerce a param value to dict.""" + params = {"item": '{"key": "val"}'} + result = ToolValidations.validate_params( + params, + {"item": ["coerce_dict"]}, + ) + assert result == {"item": {"key": "val"}} + + # ---- [] iteration ---- + + def test_list_iteration_coerce_dict(self): + """A [] key should apply validation to each list item.""" + params = { + "delegations": [ + '{"name": "a1", "prompt": "do x"}', + '{"name": "a2", "prompt": "do y"}', + ] + } + result = ToolValidations.validate_params( + params, + {"delegations[]": ["coerce_dict"]}, + ) + assert result == { + "delegations": [ + {"name": "a1", "prompt": "do x"}, + {"name": "a2", "prompt": "do y"}, + ] + } + + def test_list_iteration_skips_null_items(self): + """Items that fail validation and return None should be dropped.""" + params = { + "items": [ + '{"valid": "json"}', + "not json", + '{"also": "valid"}', + ] + } + result = ToolValidations.validate_params( + params, + {"items[]": ["coerce_dict"]}, + ) + # The invalid JSON string returns None and is dropped + assert result == { + "items": [ + {"valid": "json"}, + {"also": "valid"}, + ] + } + + def test_list_iteration_empty_list(self): + """An empty list should remain empty after iteration.""" + params = {"items": []} + result = ToolValidations.validate_params( + params, + {"items[]": ["coerce_dict"]}, + ) + assert result == {"items": []} + + # ---- chained validations ---- + + def test_chained_validations(self): + """Multiple validation methods should be applied in sequence.""" + params = {"count": "42"} + result = ToolValidations.validate_params( + params, + {"count": ["coerce_str", "coerce_int"]}, + ) + # coerce_str("42") → "42", coerce_int("42") → 42 + assert result == {"count": 42} + + # ---- integration with _basic_validations ---- + + def test_bare_list_gets_wrapped_and_validated(self): + """A bare list param should be wrapped, then validated per item.""" + schema = { + "function": { + "parameters": { + "properties": { + "delegations": { + "type": "array", + "items": {"type": "object"}, + } + } + } + } + } + bare_list = ['{"name": "a1", "prompt": "do x"}'] + result = ToolValidations.validate_params( + bare_list, + {"delegations[]": ["coerce_dict"]}, + schema, + ) + assert result == { + "delegations": [ + {"name": "a1", "prompt": "do x"}, + ] + } + + def test_bare_list_with_empty_validations(self): + """Even with empty VALIDATIONS, _basic_validations should still wrap.""" + schema = { + "function": { + "parameters": { + "properties": { + "delegations": { + "type": "array", + "items": {"type": "object"}, + } + } + } + } + } + bare_list = [{"name": "a1", "prompt": "do x"}] + result = ToolValidations.validate_params(bare_list, {}, schema) + assert result == {"delegations": bare_list} + + # ---- key not present ---- + + def test_validation_key_not_in_params(self): + """If the validation key doesn't exist in params, nothing should happen.""" + params = {"other": "value"} + result = ToolValidations.validate_params( + params, + {"missing_key": ["coerce_list"]}, + ) + assert result == {"other": "value"} + + # ---- non-list target for [] ---- + + def test_list_iteration_on_non_list_does_nothing(self): + """If the param for a [] key is not a list, it should be left alone.""" + params = {"items": "not a list"} + result = ToolValidations.validate_params( + params, + {"items[]": ["coerce_dict"]}, + ) + assert result == {"items": "not a list"}