git-am: do not lose already edited final-commit when resuming.
[git/jnareb-git/bp-gitweb.git] / gitMergeCommon.py
blob1b5bddd467c3d2265e8204bb90c80a2c53742b4a
1 import sys, re, os, traceback
2 from sets import Set
4 def die(*args):
5 printList(args, sys.stderr)
6 sys.exit(2)
8 def printList(list, file=sys.stdout):
9 for x in list:
10 file.write(str(x))
11 file.write(' ')
12 file.write('\n')
14 import subprocess
16 # Debugging machinery
17 # -------------------
19 DEBUG = 0
20 functionsToDebug = Set()
22 def addDebug(func):
23 if type(func) == str:
24 functionsToDebug.add(func)
25 else:
26 functionsToDebug.add(func.func_name)
28 def debug(*args):
29 if DEBUG:
30 funcName = traceback.extract_stack()[-2][2]
31 if funcName in functionsToDebug:
32 printList(args)
34 # Program execution
35 # -----------------
37 class ProgramError(Exception):
38 def __init__(self, progStr, error):
39 self.progStr = progStr
40 self.error = error
42 def __str__(self):
43 return self.progStr + ': ' + self.error
45 addDebug('runProgram')
46 def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
47 debug('runProgram prog:', str(prog), 'input:', str(input))
48 if type(prog) is str:
49 progStr = prog
50 else:
51 progStr = ' '.join(prog)
53 try:
54 if pipeOutput:
55 stderr = subprocess.STDOUT
56 stdout = subprocess.PIPE
57 else:
58 stderr = None
59 stdout = None
60 pop = subprocess.Popen(prog,
61 shell = type(prog) is str,
62 stderr=stderr,
63 stdout=stdout,
64 stdin=subprocess.PIPE,
65 env=env)
66 except OSError, e:
67 debug('strerror:', e.strerror)
68 raise ProgramError(progStr, e.strerror)
70 if input != None:
71 pop.stdin.write(input)
72 pop.stdin.close()
74 if pipeOutput:
75 out = pop.stdout.read()
76 else:
77 out = ''
79 code = pop.wait()
80 if returnCode:
81 ret = [out, code]
82 else:
83 ret = out
84 if code != 0 and not returnCode:
85 debug('error output:', out)
86 debug('prog:', prog)
87 raise ProgramError(progStr, out)
88 # debug('output:', out.replace('\0', '\n'))
89 return ret
91 # Code for computing common ancestors
92 # -----------------------------------
94 currentId = 0
95 def getUniqueId():
96 global currentId
97 currentId += 1
98 return currentId
100 # The 'virtual' commit objects have SHAs which are integers
101 shaRE = re.compile('^[0-9a-f]{40}$')
102 def isSha(obj):
103 return (type(obj) is str and bool(shaRE.match(obj))) or \
104 (type(obj) is int and obj >= 1)
106 class Commit:
107 def __init__(self, sha, parents, tree=None):
108 self.parents = parents
109 self.firstLineMsg = None
110 self.children = []
112 if tree:
113 tree = tree.rstrip()
114 assert(isSha(tree))
115 self._tree = tree
117 if not sha:
118 self.sha = getUniqueId()
119 self.virtual = True
120 self.firstLineMsg = 'virtual commit'
121 assert(isSha(tree))
122 else:
123 self.virtual = False
124 self.sha = sha.rstrip()
125 assert(isSha(self.sha))
127 def tree(self):
128 self.getInfo()
129 assert(self._tree != None)
130 return self._tree
132 def shortInfo(self):
133 self.getInfo()
134 return str(self.sha) + ' ' + self.firstLineMsg
136 def __str__(self):
137 return self.shortInfo()
139 def getInfo(self):
140 if self.virtual or self.firstLineMsg != None:
141 return
142 else:
143 info = runProgram(['git-cat-file', 'commit', self.sha])
144 info = info.split('\n')
145 msg = False
146 for l in info:
147 if msg:
148 self.firstLineMsg = l
149 break
150 else:
151 if l.startswith('tree'):
152 self._tree = l[5:].rstrip()
153 elif l == '':
154 msg = True
156 class Graph:
157 def __init__(self):
158 self.commits = []
159 self.shaMap = {}
161 def addNode(self, node):
162 assert(isinstance(node, Commit))
163 self.shaMap[node.sha] = node
164 self.commits.append(node)
165 for p in node.parents:
166 p.children.append(node)
167 return node
169 def reachableNodes(self, n1, n2):
170 res = {}
171 def traverse(n):
172 res[n] = True
173 for p in n.parents:
174 traverse(p)
176 traverse(n1)
177 traverse(n2)
178 return res
180 def fixParents(self, node):
181 for x in range(0, len(node.parents)):
182 node.parents[x] = self.shaMap[node.parents[x]]
184 # addDebug('buildGraph')
185 def buildGraph(heads):
186 debug('buildGraph heads:', heads)
187 for h in heads:
188 assert(isSha(h))
190 g = Graph()
192 out = runProgram(['git-rev-list', '--parents'] + heads)
193 for l in out.split('\n'):
194 if l == '':
195 continue
196 shas = l.split(' ')
198 # This is a hack, we temporarily use the 'parents' attribute
199 # to contain a list of SHA1:s. They are later replaced by proper
200 # Commit objects.
201 c = Commit(shas[0], shas[1:])
203 g.commits.append(c)
204 g.shaMap[c.sha] = c
206 for c in g.commits:
207 g.fixParents(c)
209 for c in g.commits:
210 for p in c.parents:
211 p.children.append(c)
212 return g
214 # Write the empty tree to the object database and return its SHA1
215 def writeEmptyTree():
216 tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index'
217 def delTmpIndex():
218 try:
219 os.unlink(tmpIndex)
220 except OSError:
221 pass
222 delTmpIndex()
223 newEnv = os.environ.copy()
224 newEnv['GIT_INDEX_FILE'] = tmpIndex
225 res = runProgram(['git-write-tree'], env=newEnv).rstrip()
226 delTmpIndex()
227 return res
229 def addCommonRoot(graph):
230 roots = []
231 for c in graph.commits:
232 if len(c.parents) == 0:
233 roots.append(c)
235 superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
236 graph.addNode(superRoot)
237 for r in roots:
238 r.parents = [superRoot]
239 superRoot.children = roots
240 return superRoot
242 def getCommonAncestors(graph, commit1, commit2):
243 '''Find the common ancestors for commit1 and commit2'''
244 assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
246 def traverse(start, set):
247 stack = [start]
248 while len(stack) > 0:
249 el = stack.pop()
250 set.add(el)
251 for p in el.parents:
252 if p not in set:
253 stack.append(p)
254 h1Set = Set()
255 h2Set = Set()
256 traverse(commit1, h1Set)
257 traverse(commit2, h2Set)
258 shared = h1Set.intersection(h2Set)
260 if len(shared) == 0:
261 shared = [addCommonRoot(graph)]
263 res = Set()
265 for s in shared:
266 if len([c for c in s.children if c in shared]) == 0:
267 res.add(s)
268 return list(res)