summaryrefslogtreecommitdiffstats
path: root/scripts/codeconverter/codeconverter/patching.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/codeconverter/codeconverter/patching.py')
-rw-r--r--scripts/codeconverter/codeconverter/patching.py397
1 files changed, 397 insertions, 0 deletions
diff --git a/scripts/codeconverter/codeconverter/patching.py b/scripts/codeconverter/codeconverter/patching.py
new file mode 100644
index 0000000000..627a1a1b04
--- /dev/null
+++ b/scripts/codeconverter/codeconverter/patching.py
@@ -0,0 +1,397 @@
+# Copyright (C) 2020 Red Hat Inc.
+#
+# Authors:
+# Eduardo Habkost <ehabkost@redhat.com>
+#
+# This work is licensed under the terms of the GNU GPL, version 2. See
+# the COPYING file in the top-level directory.
+from typing import IO, Match, NamedTuple, Optional, Literal, Iterable, Type, Dict, List, Any, TypeVar, NewType, Tuple
+from pathlib import Path
+from itertools import chain
+from tempfile import NamedTemporaryFile
+import os
+import re
+import subprocess
+from io import StringIO
+
+import logging
+logger = logging.getLogger(__name__)
+DBG = logger.debug
+INFO = logger.info
+WARN = logger.warning
+ERROR = logger.error
+
+from .utils import *
+
+T = TypeVar('T')
+
+class Patch(NamedTuple):
+ # start inside file.original_content
+ start: int
+ # end position inside file.original_content
+ end: int
+ # replacement string for file.original_content[start:end]
+ replacement: str
+
+IdentifierType = Literal['type', 'symbol', 'include', 'constant']
+class RequiredIdentifier(NamedTuple):
+ type: IdentifierType
+ name: str
+
+class FileMatch:
+ """Base class for regex matches
+
+ Subclasses just need to set the `regexp` class attribute
+ """
+ regexp: Optional[str] = None
+
+ def __init__(self, f: 'FileInfo', m: Match) -> None:
+ self.file: 'FileInfo' = f
+ self.match: Match = m
+
+ @property
+ def name(self) -> str:
+ if 'name' not in self.match.groupdict():
+ return '[no name]'
+ return self.group('name')
+
+ @classmethod
+ def compiled_re(klass):
+ return re.compile(klass.regexp, re.MULTILINE)
+
+ def start(self) -> int:
+ return self.match.start()
+
+ def end(self) -> int:
+ return self.match.end()
+
+ def line_col(self) -> LineAndColumn:
+ return self.file.line_col(self.start())
+
+ def group(self, *args):
+ return self.match.group(*args)
+
+ def log(self, level, fmt, *args) -> None:
+ pos = self.line_col()
+ logger.log(level, '%s:%d:%d: '+fmt, self.file.filename, pos.line, pos.col, *args)
+
+ def debug(self, fmt, *args) -> None:
+ self.log(logging.DEBUG, fmt, *args)
+
+ def info(self, fmt, *args) -> None:
+ self.log(logging.INFO, fmt, *args)
+
+ def warn(self, fmt, *args) -> None:
+ self.log(logging.WARNING, fmt, *args)
+
+ def error(self, fmt, *args) -> None:
+ self.log(logging.ERROR, fmt, *args)
+
+ def sub(self, original: str, replacement: str) -> str:
+ """Replace content
+
+ XXX: this won't use the match position, but will just
+ replace all strings that look like the original match.
+ This should be enough for all the patterns used in this
+ script.
+ """
+ return original.replace(self.group(0), replacement)
+
+ def sanity_check(self) -> None:
+ """Sanity check match, and print warnings if necessary"""
+ pass
+
+ def replacement(self) -> Optional[str]:
+ """Return replacement text for pattern, to use new code conventions"""
+ return None
+
+ def make_patch(self, replacement: str) -> 'Patch':
+ """Make patch replacing the content of this match"""
+ return Patch(self.start(), self.end(), replacement)
+
+ def make_subpatch(self, start: int, end: int, replacement: str) -> 'Patch':
+ return Patch(self.start() + start, self.start() + end, replacement)
+
+ def make_removal_patch(self) -> 'Patch':
+ """Make patch removing contents of match completely"""
+ return self.make_patch('')
+
+ def append(self, s: str) -> 'Patch':
+ """Make patch appending string after this match"""
+ return Patch(self.end(), self.end(), s)
+
+ def prepend(self, s: str) -> 'Patch':
+ """Make patch prepending string before this match"""
+ return Patch(self.start(), self.start(), s)
+
+ def gen_patches(self) -> Iterable['Patch']:
+ """Patch source code contents to use new code patterns"""
+ replacement = self.replacement()
+ if replacement is not None:
+ yield self.make_patch(replacement)
+
+ @classmethod
+ def has_replacement_rule(klass) -> bool:
+ return (klass.gen_patches is not FileMatch.gen_patches
+ or klass.replacement is not FileMatch.replacement)
+
+ def contains(self, other: 'FileMatch') -> bool:
+ return other.start() >= self.start() and other.end() <= self.end()
+
+ def __repr__(self) -> str:
+ start = self.file.line_col(self.start())
+ end = self.file.line_col(self.end() - 1)
+ return '<%s %s at %d:%d-%d:%d: %r>' % (self.__class__.__name__,
+ self.name,
+ start.line, start.col,
+ end.line, end.col, self.group(0)[:100])
+
+ def required_identifiers(self) -> Iterable[RequiredIdentifier]:
+ """Can be implemented by subclasses to keep track of identifier references
+
+ This method will be used by the code that moves declarations around the file,
+ to make sure we find the right spot for them.
+ """
+ raise NotImplementedError()
+
+ def provided_identifiers(self) -> Iterable[RequiredIdentifier]:
+ """Can be implemented by subclasses to keep track of identifier references
+
+ This method will be used by the code that moves declarations around the file,
+ to make sure we find the right spot for them.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def find_matches(klass, content: str) -> Iterable[Match]:
+ """Generate match objects for class
+
+ Might be reimplemented by subclasses if they
+ intend to look for matches using a different method.
+ """
+ return klass.compiled_re().finditer(content)
+
+ @property
+ def allfiles(self) -> 'FileList':
+ return self.file.allfiles
+
+def all_subclasses(c: Type[FileMatch]) -> Iterable[Type[FileMatch]]:
+ for sc in c.__subclasses__():
+ yield sc
+ yield from all_subclasses(sc)
+
+def match_class_dict() -> Dict[str, Type[FileMatch]]:
+ d = dict((t.__name__, t) for t in all_subclasses(FileMatch))
+ return d
+
+def names(matches: Iterable[FileMatch]) -> Iterable[str]:
+ return [m.name for m in matches]
+
+class PatchingError(Exception):
+ pass
+
+class OverLappingPatchesError(PatchingError):
+ pass
+
+def apply_patches(s: str, patches: Iterable[Patch]) -> str:
+ """Apply a sequence of patches to string
+
+ >>> apply_patches('abcdefg', [Patch(2,2,'xxx'), Patch(0, 1, 'yy')])
+ 'yybxxxcdefg'
+ """
+ r = StringIO()
+ last = 0
+ for p in sorted(patches):
+ DBG("Applying patch at position %d (%s) - %d (%s): %r",
+ p.start, line_col(s, p.start),
+ p.end, line_col(s, p.end),
+ p.replacement)
+ if last > p.start:
+ raise OverLappingPatchesError("Overlapping patch at position %d (%s), last patch at %d (%s)" % \
+ (p.start, line_col(s, p.start), last, line_col(s, last)))
+ r.write(s[last:p.start])
+ r.write(p.replacement)
+ last = p.end
+ r.write(s[last:])
+ return r.getvalue()
+
+class RegexpScanner:
+ def __init__(self) -> None:
+ self.match_index: Dict[Type[Any], List[FileMatch]] = {}
+ self.match_name_index: Dict[Tuple[Type[Any], str, str], Optional[FileMatch]] = {}
+
+ def _find_matches(self, klass: Type[Any]) -> Iterable[FileMatch]:
+ raise NotImplementedError()
+
+ def matches_of_type(self, t: Type[T]) -> List[T]:
+ if t not in self.match_index:
+ self.match_index[t] = list(self._find_matches(t))
+ return self.match_index[t] # type: ignore
+
+ def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]:
+ indexkey = (t, name, group)
+ if indexkey in self.match_name_index:
+ return self.match_name_index[indexkey] # type: ignore
+ r: Optional[T] = None
+ for m in self.matches_of_type(t):
+ assert isinstance(m, FileMatch)
+ if m.group(group) == name:
+ r = m # type: ignore
+ self.match_name_index[indexkey] = r # type: ignore
+ return r
+
+ def reset_index(self) -> None:
+ self.match_index.clear()
+ self.match_name_index.clear()
+
+class FileInfo(RegexpScanner):
+ filename: Path
+ original_content: Optional[str] = None
+
+ def __init__(self, files: 'FileList', filename: os.PathLike, force:bool=False) -> None:
+ super().__init__()
+ self.allfiles = files
+ self.filename = Path(filename)
+ self.patches: List[Patch] = []
+ self.force = force
+
+ def __repr__(self) -> str:
+ return f'<FileInfo {repr(self.filename)}>'
+
+ def line_col(self, start: int) -> LineAndColumn:
+ """Return line and column for a match object inside original_content"""
+ return line_col(self.original_content, start)
+
+ def _find_matches(self, klass: Type[Any]) -> List[FileMatch]:
+ """Build FileMatch objects for each match of regexp"""
+ if not hasattr(klass, 'regexp') or klass.regexp is None:
+ return []
+ assert hasattr(klass, 'regexp')
+ DBG("%s: scanning for %s", self.filename, klass.__name__)
+ DBG("regexp: %s", klass.regexp)
+ matches = [klass(self, m) for m in klass.find_matches(self.original_content)]
+ DBG('%s: %d matches found for %s: %s', self.filename, len(matches),
+ klass.__name__,' '.join(names(matches)))
+ return matches
+
+ def find_match(self, t: Type[T], name: str, group: str='name') -> Optional[T]:
+ for m in self.matches_of_type(t):
+ assert isinstance(m, FileMatch)
+ if m.group(group) == name:
+ return m # type: ignore
+ return None
+
+ def reset_content(self, s:str):
+ self.original_content = s
+ self.patches.clear()
+ self.reset_index()
+ self.allfiles.reset_index()
+
+ def load(self) -> None:
+ if self.original_content is not None:
+ return
+ with open(self.filename, 'rt') as f:
+ self.reset_content(f.read())
+
+ @property
+ def all_matches(self) -> Iterable[FileMatch]:
+ lists = list(self.match_index.values())
+ return (m for l in lists
+ for m in l)
+
+ def scan_for_matches(self, class_names: Optional[List[str]]=None) -> None:
+ DBG("class names: %r", class_names)
+ class_dict = match_class_dict()
+ if class_names is None:
+ DBG("default class names")
+ class_names = list(name for name,klass in class_dict.items()
+ if klass.has_replacement_rule())
+ DBG("class_names: %r", class_names)
+ for cn in class_names:
+ matches = self.matches_of_type(class_dict[cn])
+ if len(matches) > 0:
+ DBG('%s: %d matches found for %s: %s', self.filename,
+ len(matches), cn, ' '.join(names(matches)))
+
+ def gen_patches(self) -> None:
+ for m in self.all_matches:
+ for i,p in enumerate(m.gen_patches()):
+ DBG("patch %d generated by %r:", i, m)
+ DBG("replace contents at %s-%s with %r",
+ self.line_col(p.start), self.line_col(p.end), p.replacement)
+ self.patches.append(p)
+
+ def patch_content(self, max_passes=0, class_names: Optional[List[str]]=None) -> None:
+ """Multi-pass content patching loop
+
+ We run multiple passes because there are rules that will
+ delete init functions once they become empty.
+ """
+ passes = 0
+ total_patches = 0
+ DBG("max_passes: %r", max_passes)
+ while not max_passes or max_passes <= 0 or passes < max_passes:
+ passes += 1
+ self.scan_for_matches(class_names)
+ self.gen_patches()
+ DBG("patch content: pass %d: %d patches generated", passes, len(self.patches))
+ total_patches += len(self.patches)
+ if not self.patches:
+ break
+ try:
+ self.apply_patches()
+ except PatchingError:
+ logger.exception("%s: failed to patch file", self.filename)
+ DBG("%s: %d patches applied total in %d passes", self.filename, total_patches, passes)
+
+ def apply_patches(self) -> None:
+ """Replace self.original_content after applying patches from self.patches"""
+ self.reset_content(self.get_patched_content())
+
+ def get_patched_content(self) -> str:
+ assert self.original_content is not None
+ return apply_patches(self.original_content, self.patches)
+
+ def write_to_file(self, f: IO[str]) -> None:
+ f.write(self.get_patched_content())
+
+ def write_to_filename(self, filename: os.PathLike) -> None:
+ with open(filename, 'wt') as of:
+ self.write_to_file(of)
+
+ def patch_inplace(self) -> None:
+ newfile = self.filename.with_suffix('.changed')
+ self.write_to_filename(newfile)
+ os.rename(newfile, self.filename)
+
+ def show_diff(self) -> None:
+ with NamedTemporaryFile('wt') as f:
+ self.write_to_file(f)
+ f.flush()
+ subprocess.call(['diff', '-u', self.filename, f.name])
+
+ def ref(self):
+ return TypeInfoReference
+
+class FileList(RegexpScanner):
+ def __init__(self):
+ super().__init__()
+ self.files: List[FileInfo] = []
+
+ def extend(self, *args, **kwargs):
+ self.files.extend(*args, **kwargs)
+
+ def __iter__(self):
+ return iter(self.files)
+
+ def _find_matches(self, klass: Type[Any]) -> Iterable[FileMatch]:
+ return chain(*(f._find_matches(klass) for f in self.files))
+
+ def find_file(self, name) -> Optional[FileInfo]:
+ """Get file with path ending with @name"""
+ nameparts = Path(name).parts
+ for f in self.files:
+ if f.filename.parts[:len(nameparts)] == nameparts:
+ return f
+ else:
+ return None \ No newline at end of file