Ansel 0.0
A darktable fork - bloat + design vision
Loading...
Searching...
No Matches
find-null-checks.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3import sys
4from pathlib import Path
5from clang.cindex import Index, CursorKind, TypeKind, Config
6
7import regex as re
8
9# Config.set_library_file("/usr/lib/llvm-17/lib/libclang.so")
10
11EXTS = {".c", ".h", ".cpp", ".hpp"}
12NULL_NAME = "NULL"
13
14def rewrite_condition(cond: str) -> str:
15 """
16 Conservative rule:
17 - if token is pointer-like: convert
18 - if already negated: flip logic
19 """
20
21 token_re = re.compile(r'(!?)([A-Za-z_]\w*(?:->\w*)?(?:\[[^\]]+\])?)')
22 return token_re.sub(_rewrite_token_repl, cond)
23
24# -------------------------
25# extract condition text safely
26# -------------------------
27
28def get_text(src, extent):
29 return src[extent.start.offset:extent.end.offset].decode("utf-8", "replace")
30
31# -------------------------
32# process file
33# -------------------------
34
35def find_decl_in_cursor(cursor, name, visited=None):
36 if visited is None:
37 visited = set()
38 # build a stable key for the cursor to avoid infinite recursion: (kind, start, end, spelling)
39 key = (cursor.kind, getattr(cursor.extent, 'start', None) and cursor.extent.start.offset,
40 getattr(cursor.extent, 'end', None) and cursor.extent.end.offset, cursor.spelling)
41
42 if key in visited:
43 return None
44 visited.add(key)
45
46 for ch in cursor.get_children():
47 if ch.spelling == name and ch.kind in (CursorKind.PARM_DECL, CursorKind.VAR_DECL):
48 return ch
49 res = find_decl_in_cursor(ch, name, visited)
50 if res:
51 return res
52 return None
53
54
56 ty = decl.type
57 if ty.kind == TypeKind.POINTER:
58 return True
59 # check canonical type in case of typedefs
60 get_canon = getattr(ty, 'get_canonical', None)
61 if get_canon:
62 canon = ty.get_canonical()
63 if canon.kind == TypeKind.POINTER:
64 return True
65
66 return False
67
68
69def collect_decls(root):
70 """Collect PARAM_DECL and VAR_DECL under root into a name->cursor dict."""
71 res = {}
72 stack = [root]
73 while stack:
74 node = stack.pop()
75 if node.kind in (CursorKind.PARM_DECL, CursorKind.VAR_DECL):
76 if node.spelling:
77 res[node.spelling] = node
78
79 for ch in node.get_children():
80 stack.append(ch)
81
82 return res
83
84
85def collect_decl_refs(n, call_refs):
86 if n.kind == CursorKind.DECL_REF_EXPR:
87 ref = getattr(n, 'referenced', None)
88 name = ref.spelling if ref else n.spelling
89 if name:
90 call_refs.add(name)
91 for ch in n.get_children():
92 collect_decl_refs(ch, call_refs)
93
94
95def collect_calls(n, call_refs):
96 if n.kind == CursorKind.CALL_EXPR:
97 collect_decl_refs(n, call_refs)
98 else:
99 for ch in n.get_children():
100 collect_calls(ch, call_refs)
101
102
103def collect_call_extents(n, call_extents):
104 if n.kind == CursorKind.CALL_EXPR:
105 call_extents.append((n.extent.start.offset, n.extent.end.offset))
106
107 for ch in n.get_children():
108 collect_call_extents(ch, call_extents)
109
110
111def collect_call_args(n, call_args_names):
112 if n.kind == CursorKind.CALL_EXPR:
113 collect_decl_refs(n, call_args_names)
114
115 for ch in n.get_children():
116 collect_call_args(ch, call_args_names)
117
118
119# cache for global declarations per translation unit
120_global_decls_cache = {}
121
122def node_text(src_bytes, node):
123 start = max(0, node.extent.start.offset)
124 end = max(start, node.extent.end.offset)
125 return src_bytes[start:end].decode("utf-8", "replace")
126
127
128def _pos_in_string(s: str, pos: int) -> bool:
129 """Return True if position `pos` is inside a single- or double-quoted literal in s."""
130 in_double = False
131 in_single = False
132 i = 0
133 while i < pos and i < len(s):
134 ch = s[i]
135 if ch == '\\':
136 i += 2
137 continue
138 if not in_single and ch == '"':
139 in_double = not in_double
140 elif not in_double and ch == "'":
141 in_single = not in_single
142 i += 1
143 return in_double or in_single
144
145
146def find_literal_ranges(text: str):
147 ranges = []
148 i = 0
149 L = len(text)
150 while i < L:
151 c = text[i]
152 if c == '"' or c == "'":
153 q = c
154 start = i
155 i += 1
156 while i < L:
157 if text[i] == '\\':
158 i += 2
159 continue
160 if text[i] == q:
161 i += 1
162 break
163 i += 1
164 end = i
165 ranges.append((start, end))
166 else:
167 i += 1
168 return ranges
169
170
171def is_pos_in_literal(pos: int, lit_ranges) -> bool:
172 for a, b in lit_ranges:
173 if pos >= a and pos < b:
174 return True
175 return False
176
177
178def find_call_arg_ranges(text: str, lit_ranges):
179 ranges = []
180 for m in re.finditer(r'\b[A-Za-z_]\w*\s*\‍(', text):
181 open_pos = m.end() - 1
182 i = open_pos + 1
183 depth = 0
184 L = len(text)
185 while i < L:
186 if is_pos_in_literal(i, lit_ranges):
187 # find literal range that contains i
188 for a, b in lit_ranges:
189 if i >= a and i < b:
190 i = b
191 break
192 continue
193 if text[i] == '(':
194 depth += 1
195 elif text[i] == ')':
196 if depth == 0:
197 ranges.append((open_pos + 1, i))
198 break
199 depth -= 1
200 i += 1
201 return ranges
202
203
204def is_pos_in_call_args(pos: int, call_ranges) -> bool:
205 for a, b in call_ranges:
206 if pos >= a and pos < b:
207 return True
208 return False
209
210class ReplEq:
211 def __init__(self, var, lit_ranges, call_ranges):
212 self.var = var
213 self.lit_ranges = lit_ranges
214 self.call_ranges = call_ranges
215
216 def __call__(self, m):
217 if is_pos_in_literal(m.start(), self.lit_ranges) or is_pos_in_call_args(m.start(), self.call_ranges):
218 return m.group(0)
219 return f'IS_NULL_PTR({self.var})'
220
221
223 def __init__(self, var, lit_ranges, call_ranges):
224 self.var = var
225 self.lit_ranges = lit_ranges
226 self.call_ranges = call_ranges
227
228 def __call__(self, m):
229 if is_pos_in_literal(m.start(), self.lit_ranges) or is_pos_in_call_args(m.start(), self.call_ranges):
230 return m.group(0)
231 return f'!IS_NULL_PTR({self.var})'
232
233
235 def __init__(self, lit_ranges, call_ranges):
236 self.lit_ranges = lit_ranges
237 self.call_ranges = call_ranges
238
239 def __call__(self, m):
240 if is_pos_in_literal(m.start(), self.lit_ranges) or is_pos_in_call_args(m.start(), self.call_ranges):
241 return m.group(0)
242 return f'IS_NULL_PTR({m.group(1)})'
243
244
246 def __init__(self, var, lit_ranges, call_ranges):
247 self.var = var
248 self.lit_ranges = lit_ranges
249 self.call_ranges = call_ranges
250
251 def __call__(self, m):
252 if is_pos_in_literal(m.start(), self.lit_ranges) or is_pos_in_call_args(m.start(), self.call_ranges):
253 return m.group(0)
254 return f'IS_NULL_PTR({self.var})'
255
256
257def transform_condition(cond_text: str, pointer_vars: list) -> str:
258 s = cond_text
259 # literal and call-arg ranges are computed by top-level helpers
260 lit_ranges = find_literal_ranges(s)
261 call_ranges = find_call_arg_ranges(s, lit_ranges)
262
263 # first, normalize spacing for reliable regex matching
264 # handle comparisons var == NULL and var != NULL (and NULL == var)
265 for var in pointer_vars:
266 # var == NULL -> IS_NULL_PTR(var)
267 s = re.sub(r'\b' + re.escape(var) + r"\s*==\s*NULL\b", ReplEq(var, lit_ranges, call_ranges), s)
268 s = re.sub(r'\bNULL\s*==\s*' + re.escape(var) + r"\b", ReplEq(var, lit_ranges, call_ranges), s)
269 # var != NULL -> !IS_NULL_PTR(var)
270 s = re.sub(r'\b' + re.escape(var) + r"\s*!=\s*NULL\b", ReplNeq(var, lit_ranges, call_ranges), s)
271 s = re.sub(r'\bNULL\s*!=\s*' + re.escape(var) + r"\b", ReplNeq(var, lit_ranges, call_ranges), s)
272
273 # handle explicit negation !var -> IS_NULL_PTR(var) (avoid matching '!=')
274 # explicit negation patterns: only apply to pointer variables
275 for var in pointer_vars:
276 # avoid matching negation of function calls: do not match if identifier is followed by '('
277 pat = r'(?<![=])!\s*' + re.escape(var) + r'\b(?!\s*(?:->|\.|\[|\‍())'
278 s = re.sub(pat, ReplNotVar(var, lit_ranges, call_ranges), s)
279
280 # replace bare occurrences of var (not part of comparisons or member access) with !IS_NULL_PTR(var)
281 # We do this by scanning matches and checking surrounding characters
282 for var in pointer_vars:
283 out = []
284 last = 0
285 for m in re.finditer(r'\b' + re.escape(var) + r'\b', s):
286 start, end = m.start(), m.end()
287 # skip if inside string/char literal or inside call argument list
288 if is_pos_in_literal(start, lit_ranges) or is_pos_in_call_args(start, call_ranges):
289 continue
290 # skip if match lies inside a string or char literal
291 if _pos_in_string(s, start):
292 continue
293 # check surrounding context
294 before = s[max(0, start - 10):start]
295 after = s[end:end + 10]
296 # skip if member access like '->' or '.' following or preceding (allowing spaces)
297 if re.search(r'(->|\.)\s*$', before) or re.match(r'\s*(->|\.|\[)', after):
298 continue
299 # skip if unary '!' immediately precedes the identifier (allowing spaces): '! ptr' or '!ptr'
300 if re.search(r'!\s*$', before):
301 continue
302 # skip if already transformed (e.g., IS_NULL_PTR(var) or !IS_NULL_PTR(var))
303 prefix = s[max(0, start - 16):start]
304 if 'IS_NULL_PTR' in prefix or '!IS_NULL_PTR' in prefix:
305 continue
306 # skip if part of comparisons (==, !=, <=, >=, <, >) with non-NULL rhs
307 if re.match(r"\s*(?:==|!=|<=|>=|<|>)", s[end:]):
308 mright = re.match(r"\s*(?:==|!=|<=|>=|<|>)\s*NULL\b", s[end:])
309 if not mright:
310 continue
311 if re.search(r"(?:==|!=|<=|>=|<|>)\s*$", s[max(0, start - 10):start]):
312 mleft = re.search(r"\bNULL\s*(?:==|!=|<=|>=|<|>)\s*$", s[max(0, start - 30):start])
313 if not mleft:
314 continue
315 # skip dereferenced identifiers (e.g. '*var') — do not transform the pointee
316 if re.search(r'\*\s*$', before):
317 continue
318 out.append((start, end))
319 if not out:
320 continue
321 # build new s with replacements from end to start to not disturb indices
322 parts = []
323 last = 0
324 for (start, end) in out:
325 parts.append(s[last:start])
326 parts.append(f'!IS_NULL_PTR({var})')
327 last = end
328 parts.append(s[last:])
329 s = ''.join(parts)
330
331 return s
332
333
334def collect_ast_replacements(cond_node, src_bytes):
335 """Walk condition AST and produce precise replacements as (start,end,bytes,old,new).
336 Replacements handled:
337 - `var == NULL` and `NULL == var` -> `IS_NULL_PTR(var)`
338 - `var != NULL` and `NULL != var` -> `!IS_NULL_PTR(var)`
339 - `!var` where var is a bare DeclRefExpr (not member access) -> `IS_NULL_PTR(var)`
340 - bare `var` (DeclRefExpr) used as boolean -> `!IS_NULL_PTR(var)`
341 The function avoids touching identifiers that are part of CALL_EXPR, MEMBER_REF_EXPR,
342 ARRAY_SUBSCRIPT_EXPR, or binary comparisons not involving NULL.
343 """
344 tu = cond_node.translation_unit
345 edits = []
346
347 _collect_ast_walk(cond_node, None, tu, src_bytes, edits)
348 # deduplicate overlapping edits by preferring larger spans (sort by start)
349 # We'll return raw edits; caller will sort and apply
350 return edits
351
352
354 """Collect identifier names that appear as arguments of any call-like token sequence
355 inside cond_node by tokenizing the source via libclang. This handles macros
356 and other call-like constructs that do not appear as CALL_EXPR in the AST.
357 """
358 tu = cond_node.translation_unit
359 args = set()
360 tokens = list(tu.get_tokens(extent=cond_node.extent))
361
362 i = 0
363 # tokens: inspect sequences IDENTIFIER '(' ... ')'
364 while i < len(tokens) - 1:
365 t = tokens[i]
366 nxt = tokens[i + 1]
367 # token.kind may not be TokenKind enum accessible here; compare name when available
368 kind_name = getattr(t.kind, 'name', None)
369 if kind_name == 'IDENTIFIER' and getattr(nxt, 'spelling', '') == '(':
370 # parse until matching ')'
371 depth = 0
372 j = i + 1
373 # find start of inner tokens
374 j += 1
375 content_idents = []
376 while j < len(tokens):
377 s = tokens[j].spelling
378 if s == '(':
379 depth += 1
380 elif s == ')':
381 if depth == 0:
382 break
383 depth -= 1
384 else:
385 # collect identifier tokens inside parentheses
386 kname = getattr(tokens[j].kind, 'name', None)
387 if kname == 'IDENTIFIER':
388 content_idents.append(s)
389 j += 1
390 for ident in content_idents:
391 args.add(ident)
392 i = j
393 else:
394 i += 1
395
396 return args
397
398
400 neg = m.group(1)
401 tok = m.group(2)
402
403 # skip NULL comparisons
404 if tok == NULL_NAME:
405 return tok
406
407 if neg:
408 return f"IS_NULL_PTR({tok})"
409 else:
410 return f"!IS_NULL_PTR({tok})"
411
412
413def _collect_ast_walk(n, parent, tu, src_bytes, edits):
414 kind = n.kind
415
416 # Binary operator: look for == or != with NULL
417 if kind == CursorKind.BINARY_OPERATOR:
418 # get tokens to detect operator
419 tokens = list(tu.get_tokens(extent=n.extent))
420 ops = [t.spelling for t in tokens if t.spelling in ('==', '!=')]
421
422 if ops:
423 op = ops[0]
424 # check children for declref and NULL
425 children = list(n.get_children())
426 decl_child = None
427 null_found = False
428 for ch in children:
429 if ch.kind == CursorKind.DECL_REF_EXPR:
430 ref = getattr(ch, 'referenced', None)
431 if ref and is_pointer_decl(ref):
432 decl_child = ch
433 # cheap NULL detection in child's text
434 if 'NULL' in node_text(src_bytes, ch):
435 null_found = True
436 if decl_child and null_found:
437 var_text = node_text(src_bytes, decl_child).strip()
438 if op == '==':
439 new = f'IS_NULL_PTR({var_text})'
440 else:
441 new = f'!IS_NULL_PTR({var_text})'
442 start = n.extent.start.offset
443 end = n.extent.end.offset
444 edits.append((start, end, new.encode('utf-8'), node_text(src_bytes, n), new))
445 return # don't descend into replaced node
446
447 # Unary operator: handle leading '!' applied to bare DeclRefExpr
448 if kind == CursorKind.UNARY_OPERATOR:
449 tokens = list(tu.get_tokens(extent=n.extent))
450 if tokens and tokens[0].spelling == '!':
451 children = list(n.get_children())
452 if len(children) == 1 and children[0].kind == CursorKind.DECL_REF_EXPR:
453 ref = getattr(children[0], 'referenced', None)
454 if ref and is_pointer_decl(ref):
455 var_text = node_text(src_bytes, children[0]).strip()
456 new = f'IS_NULL_PTR({var_text})'
457 start = n.extent.start.offset
458 end = n.extent.end.offset
459 edits.append((start, end, new.encode('utf-8'), node_text(src_bytes, n), new))
460 return
461
462 # DeclRefExpr: bare identifier used as condition (not member, not call arg)
463 if kind == CursorKind.DECL_REF_EXPR:
464 ref = getattr(n, 'referenced', None)
465 if ref and is_pointer_decl(ref):
466 # skip if parent is call, member access, array subscript or binary op (handled above)
467 if parent is not None and parent.kind in (CursorKind.CALL_EXPR, CursorKind.MEMBER_REF_EXPR, CursorKind.ARRAY_SUBSCRIPT_EXPR):
468 pass
469 elif parent is not None and parent.kind == CursorKind.BINARY_OPERATOR:
470 # if binary operator with NULL, was handled earlier
471 pass
472 else:
473 var_text = node_text(src_bytes, n).strip()
474 new = f'!IS_NULL_PTR({var_text})'
475 start = n.extent.start.offset
476 end = n.extent.end.offset
477 edits.append((start, end, new.encode('utf-8'), node_text(src_bytes, n), new))
478
479 for ch in n.get_children():
480 _collect_ast_walk(ch, n, tu, src_bytes, edits)
481
482
483def process_if(node, src_bytes, path, edits):
484 text = get_text(src_bytes, node.extent)
485 # try to extract condition from first child (AST-provided extent)
486 cond_text = None
487 cond_node = None
488
489 # pick a child that looks like the condition expression (prefer expression-like kinds)
490 expr_kinds = (
491 CursorKind.BINARY_OPERATOR,
492 CursorKind.UNARY_OPERATOR,
493 CursorKind.DECL_REF_EXPR,
494 CursorKind.CALL_EXPR,
495 CursorKind.PAREN_EXPR,
496 CursorKind.COMPOUND_STMT,
497 )
498 for c in list(node.get_children()):
499 # ensure child is fully inside the if extent
500 if not (c.extent.start.offset >= node.extent.start.offset and c.extent.end.offset <= node.extent.end.offset):
501 continue
502 # prefer expression-like nodes
503 if c.kind in expr_kinds or c.kind.is_expression():
504 cond_node = c
505 cond_text = node_text(src_bytes, c)
506 break
507 # as a safety: confirm the extracted cond_node text sits between parentheses immediately
508 # following the 'if' token in the source; otherwise discard to avoid overbroad extents
509 if cond_node is not None:
510 # look backwards from cond_node.start to find the nearest '(' before it
511 start = cond_node.extent.start.offset
512 # search a small window before start
513 window_start = max(node.extent.start.offset, start - 256)
514 prefix = src_bytes[window_start:start].decode('utf-8', 'replace')
515 idx = prefix.rfind('(')
516 if idx == -1:
517 # no '(' found nearby; abandon to avoid touching non-condition code
518 cond_node = None
519 cond_text = None
520 else:
521 # ensure there's a closing ')' after cond_node.end within the if extent
522 end = cond_node.extent.end.offset
523 window_end = min(node.extent.end.offset, end + 256)
524 suffix = src_bytes[end:window_end].decode('utf-8', 'replace')
525 if ')' not in suffix:
526 cond_node = None
527 cond_text = None
528
529 # collect pointer variables from the condition AST (if available)
530 pointer_vars = []
531 if cond_node is not None:
532 # fallback approach: extract identifier tokens from condition text
533 cond_text = node_text(src_bytes, cond_node)
534 # extract identifier tokens outside of string/char literals
535 idents = []
536 for m in re.finditer(r"[A-Za-z_]\w*", cond_text):
537 if _pos_in_string(cond_text, m.start()):
538 continue
539 idents.append(m.group(0))
540 skip = {NULL_NAME, 'IS_NULL_PTR', 'sizeof'}
541 found = set()
542
543 # find enclosing function for scope resolution
544 func = cond_node.semantic_parent
545 while func is not None and func.kind != CursorKind.FUNCTION_DECL:
546 func = func.semantic_parent
547
548 tu = node.translation_unit
549
550 # collect declarations in function (fast) and globals (cached)
551 func_decls = collect_decls(func) if func is not None else {}
552 tu_key = id(tu.cursor)
553 if tu_key not in _global_decls_cache:
554 _global_decls_cache[tu_key] = collect_decls(tu.cursor)
555 global_decls = _global_decls_cache[tu_key]
556
557 for ident in idents:
558 if ident in skip:
559 continue
560 decl = func_decls.get(ident)
561 if decl is None:
562 decl = global_decls.get(ident)
563 if decl is not None and is_pointer_decl(decl):
564 found.add(ident)
565
566 # collect identifiers that are passed as arguments to any CALL_EXPR inside the condition (AST)
567 call_arg_names = set()
568 collect_call_args(cond_node, call_arg_names)
569
570 # also collect identifiers appearing inside call-like tokens (macros / unexpanded calls)
571 token_args = collect_call_args_tokens(cond_node)
572 call_arg_names.update(token_args)
573
574 # filter out idents that appear as call arguments (AST or token-level)
575 pointer_vars = sorted(x for x in found if x not in call_arg_names)
576
577 # conservative textual transform for now (AST branch is available but may miss cases)
578 if pointer_vars:
579 new_cond = transform_condition(cond_text, pointer_vars)
580 if new_cond != cond_text:
581 start = cond_node.extent.start.offset
582 end = cond_node.extent.end.offset
583 edits.append((start, end, new_cond.encode('utf-8'), cond_text, new_cond))
584
585 if(pointer_vars):
586 print("[IF]", path, "\n", text)
587 print("[IFCOND]", path, "\n", cond_text)
588 print("[IFVARS]", path, "\n", pointer_vars)
589
590
591def process_file(path):
592 index = Index.create()
593 tu = index.parse(str(path), args=["-x", "c", "-std=c11"])
594
595 # read bytes because libclang extents are byte offsets
596 src_bytes = path.read_bytes()
597 src_text = src_bytes.decode("utf-8", "replace")
598 edits = []
599 visit(tu.cursor, src_bytes, path, edits)
600
601 if edits:
602 # apply edits in descending order of start offset
603 edits.sort(key=lambda e: e[0], reverse=True)
604 new_src = src_bytes
605 for (start, end, new_bytes, old_text, new_text) in edits:
606 before = new_src[:start]
607 after = new_src[end:]
608 new_src = before + new_bytes + after
609 print("[EDIT]", path, "start=", start, "end=", end, "->", len(new_bytes), "bytes", "\n-", old_text, "\n+", new_text)
610 path.write_bytes(new_src)
611 print("[MODIFIED]", path)
612
613
614def visit(node, src_bytes, path, edits):
615 # If this node is an if-statement, process it
616 if node.kind == CursorKind.IF_STMT:
617 process_if(node, src_bytes, path, edits)
618
619 # Recurse into all children
620 for child in node.get_children():
621 visit(child, src_bytes, path, edits)
622
623
624# -------------------------
625# folder runner
626# -------------------------
627
628def process_folder(folder):
629 for p in Path(folder).rglob("*"):
630 if p.suffix in EXTS:
631 process_file(p)
632
633# -------------------------
634# main
635# -------------------------
636
637if __name__ == "__main__":
638 if len(sys.argv) != 2:
639 print("usage: script.py <folder>")
640 sys.exit(1)
641
642 process_folder(sys.argv[1])
__init__(self, var, lit_ranges, call_ranges)
__init__(self, var, lit_ranges, call_ranges)
__init__(self, var, lit_ranges, call_ranges)
__init__(self, lit_ranges, call_ranges)
if(L< 0.5f) C
static const float const float const float min
const float max
str transform_condition(str cond_text, list pointer_vars)
bool is_pos_in_call_args(int pos, call_ranges)
collect_call_extents(n, call_extents)
collect_ast_replacements(cond_node, src_bytes)
collect_call_args(n, call_args_names)
find_literal_ranges(str text)
node_text(src_bytes, node)
find_decl_in_cursor(cursor, name, visited=None)
get_text(src, extent)
collect_call_args_tokens(cond_node)
str rewrite_condition(str cond)
bool is_pos_in_literal(int pos, lit_ranges)
visit(node, src_bytes, path, edits)
find_call_arg_ranges(str text, lit_ranges)
_collect_ast_walk(n, parent, tu, src_bytes, edits)
process_if(node, src_bytes, path, edits)
bool _pos_in_string(str s, int pos)
collect_decl_refs(n, call_refs)
collect_calls(n, call_refs)