correct graph coloring and variable liveness analysis

This commit is contained in:
Tan, Kian-ting 2025-09-02 22:36:05 +08:00
parent eea281bc38
commit d52b14fd38
7 changed files with 221 additions and 55 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
a.s
a.out
.gitignore

View file

@ -15,4 +15,4 @@ to make it executable, please use `gcc`: `gcc ./a.c -o output.out`
the example `.tc` file is in `./test` the example `.tc` file is in `./test`
## Known issues ## Known issues
- parser for a + b + c .. and a * b * c - all connected variable are pathized (I don't understand the tricky method to elimitated the path number)

View file

@ -14,3 +14,4 @@ assign_homes
5. **OK**assign homes (register allocation) 5. **OK**assign homes (register allocation)
6. **OK**patch instructions 6. **OK**patch instructions
7. **OK**prelude & conclusion 7. **OK**prelude & conclusion
8. **OK** (vartex coloring)

View file

@ -1,6 +1,8 @@
module TComp module TComp
include("./parser.jl") include("./parser.jl")
include("./vertexcoloring.jl")
using .Parser using .Parser
using .VertexColoring
using Match using Match
# For pass 2 # For pass 2
@ -15,7 +17,7 @@ prog = read(f, String)
#(prog) #(prog)
parsed = Parser.totalParse(prog) parsed = Parser.totalParse(prog)
#print(parsed) println("PARSED\n", Parser.prettyStringLisp(parsed))
tmp_var_no = 0 tmp_var_no = 0
@ -156,6 +158,8 @@ end
### PASS 3 assign x86 instruction ### PASS 3 assign x86 instruction
function assignInstruction(inp) function assignInstruction(inp)
println("INP", inp)
resList = [] resList = []
for i in inp for i in inp
@match i begin @match i begin
@ -166,7 +170,7 @@ function assignInstruction(inp)
push!(resList, ["movq", val, "%rax"]) push!(resList, ["movq", val, "%rax"])
end end
[("%let", "id"), [_ty, (id, "id")], [("%let", "id"), [_ty, (id, "id")],
[("%prime", "id"), (op, _), [(lhs, lhs_t), (rhs, rhs_t)]]] => [("%prime", "id"), (op, _), [(lhs, lhs_t), (rhs, rhs_t)]]] =>
begin begin
instr = "" instr = ""
ops = ["+", "-", "*", "/"] ops = ["+", "-", "*", "/"]
@ -205,13 +209,8 @@ function assignInstruction(inp)
line = ["movq", val, id] line = ["movq", val, id]
push!(resList, line) push!(resList, line)
end end
(c, "int") => push!(resList, ["movq", "\$" * c, "%rax"])
(c, "int") => begin (val, "id") => push!(resList, ["movq", val, "%rax"])
c_modified = "\$" * c
push!(resList, [c_modified])
end
(v, "id") => push!(resList, [v])
_ => println("Error") _ => println("Error")
end end
end end
@ -228,11 +227,78 @@ res2 = explicitControlRemoveComplex(res)
#println("PASS2", Parser.prettyStringLisp(res2)) #println("PASS2", Parser.prettyStringLisp(res2))
res3 = assignInstruction(res2) res3 = assignInstruction(res2)
#println("PASS3", res3) println("RES3:\n", res3)
# generate vertex graph for register allocation using vertex-coloring model
function generateVertexGraph(prog)
graph = []
progReversed = reverse(prog)
varRegex = r"(^[^\$].*)"
out = Set([])
in_ = Set([])
use = Set([])
def = Set([])
for i in progReversed
use = Set([])
def = Set([])
# only support binary operation(3-term)
out = in_ #prev. out-set becomes current in-set
@match i begin
["movq", orig, dest] => # dest = orig
begin
if match(varRegex, orig) != nothing
push!(use, orig)
end
if match(varRegex, dest) != nothing
push!(def, dest)
end
end
# dest = orig (+-*/) dest
[op, orig, dest] where (op in ["addq", "subq", "imulq", "divq"]) =>
begin
if match(varRegex, orig) != nothing
push!(use, orig)
end
if match(varRegex, dest) != nothing
push!(use, dest)
push!(def, dest)
end
end
_ => 0
end
in_ = union(use, setdiff(out, def))
#println("IN", in_)
#println("DEF", def)
#println("OUT", out)
in_list = collect(in_)
for j in range(1, length(in_list)-1)
for k in range(j+1,length(in_list))
path1 = [in_list[j], in_list[k]]
path2 = [in_list[k], in_list[j]]
if !(path1 in graph) & !(path2 in graph)
push!(graph, path1)
end
end
end
end
return graph
end
vertexGraph = generateVertexGraph(res3)
println(vertexGraph)
colorGraphDictWithMaxId = VertexColoring.vertexColoring(vertexGraph)
colorGraphDict = colorGraphDictWithMaxId[1]
colorMaxId = colorGraphDictWithMaxId[2]
registerList = ["%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15"]
# PASS4 assign home # PASS4 assign home
function assignHomes(inp) function assignHomes(inp, colorGraph, registerList)
varRegex = r"(^[^\$%].*)" varRegex = r"(^[^\$%].*)"
res = [] res = []
vars = [] vars = []
@ -250,8 +316,6 @@ function assignHomes(inp)
end end
end end
end end
#println("ALL_VAR", vars)
varsLength = length(vars) varsLength = length(vars)
for i in inp for i in inp
@ -259,18 +323,33 @@ function assignHomes(inp)
orig = i[2] orig = i[2]
dest = i[3] dest = i[3]
origIdx = findfirst(x -> x == orig,vars)
if origIdx != nothing if orig in keys(colorGraph)
realAddressIdx = varsLength - origIdx + 1 origColorId = colorGraph[orig]
realAddress = "-$(realAddressIdx * 8)(%rbp)" if origColorId <= length(registerList)
orig = realAddress orig = registerList[origColorId]
else
origIdx = findfirst(x -> x == orig,vars)
realAddressIdx = varsLength - origIdx + 1
realAddress = "-$(realAddressIdx * 8)(%rbp)"
orig = realAddress
end
elseif match(r"(^[^\$%].+)", orig) != nothing # isolated (unpathized) variable
orig = "%rax"
end end
destIdx = findfirst(x -> x == dest,vars) if dest in keys(colorGraph)
if destIdx != nothing destColorId = colorGraph[dest]
realAddressIdx = varsLength - destIdx + 1 if destColorId <= length(registerList)
realAddress = "-$(realAddressIdx * 8)(%rbp)" dest = registerList[destColorId]
dest = realAddress else
destIdx = findfirst(x -> x == dest,vars)
realAddressIdx = varsLength - destIdx + 1
realAddress = "-$(realAddressIdx * 8)(%rbp)"
dest = realAddress
end
elseif match(r"(^[^\$%].+)", dest) != nothing # isolated (unpathized) variable
dest = "%rax"
end end
push!(res, [instr, orig, dest]) push!(res, [instr, orig, dest])
@ -294,13 +373,20 @@ function patchInstruction(inp)
cmd2 = [inst, "%rax", dest] cmd2 = [inst, "%rax", dest]
push!(res, cmd2) push!(res, cmd2)
elseif (inst == "imulq") & (match(r"^%.+", dest) == nothing) elseif (inst == "imulq")
cmd1 = ["movq", dest, "%rax"] if (match(r"^%.+", dest) == nothing)
cmd2 = ["imulq", orig, "%rax"] cmd1 = ["movq", dest, "%rax"]
cmd3 = ["movq", "%rax", dest] cmd2 = ["imulq", orig, "%rax"]
push!(res, cmd1) cmd3 = ["movq", "%rax", dest]
push!(res, cmd2) push!(res, cmd1)
push!(res, cmd3) push!(res, cmd2)
push!(res, cmd3)
else #if dest is a %register
cmd1 = ["imulq", orig, dest] #result stored in %rax
cmd2 = ["movq", "%rax", dest]
push!(res, cmd1)
push!(res, cmd2)
end
else else
push!(res, i) push!(res, i)
@ -309,7 +395,7 @@ function patchInstruction(inp)
return res return res
end end
res4 = assignHomes(res3) res4 = assignHomes(res3, colorGraphDict, registerList)
res4_prog = res4[1] res4_prog = res4[1]
varNumber = res4[2] varNumber = res4[2]
res5 = patchInstruction(res4_prog) res5 = patchInstruction(res4_prog)
@ -317,8 +403,8 @@ res5 = patchInstruction(res4_prog)
## PASS6 add prelude and conclude ## PASS6 add prelude and conclude
function preludeConclude(prog, varNumber) function preludeConclude(prog, colorMaxId, registerList)
rspSubqMax = varNumber * 8 rspSubqMax = (colorMaxId < length(registerList)) ? 0 : ((colorMaxId - length(registerList)) * 8)
body = "start:\n" body = "start:\n"
@ -337,13 +423,13 @@ function preludeConclude(prog, varNumber)
pushq %rbp pushq %rbp
movq %rsp, %rbp\n""" * "\tsubq \$$rspSubqMax, %rsp\n\tjmp start\n\n" movq %rsp, %rbp\n""" * "\tsubq \$$rspSubqMax, %rsp\n\tjmp start\n\n"
conclude = """\nconclusion:\n""" * "\taddq \$$rspSubqMax, %rsp\n\tpopq %rbp\n\tretq" conclude = """\nconclusion:\n""" * "\taddq \$$rspSubqMax, %rsp\n\tpopq %rbp\n\tretq\n"
assemblyProg = prelude * body * conclude assemblyProg = prelude * body * conclude
return assemblyProg return assemblyProg
end end
res6 = preludeConclude(res5, varNumber) res6 = preludeConclude(res5, colorMaxId, registerList)
# println("PASS6",res6) # emit assembly code # println("PASS6",res6) # emit assembly code
f2 = open("./a.s", "w") f2 = open("./a.s", "w")
write(f2, res6) #write the assembly code write(f2, res6) #write the assembly code

View file

@ -1,4 +1,5 @@
module Parser module Parser
using Match
struct ParserResult struct ParserResult
matched matched
@ -85,7 +86,9 @@ end
patternList = [("int", "\\d+"), patternList = [
("cmt", "[#][^\\r\\n]+"),
("int", "\\d+"),
("id", "[_a-zA-Z][_0-9a-zA-Z]*"), ("id", "[_a-zA-Z][_0-9a-zA-Z]*"),
("lParen", "[\\(]"), ("lParen", "[\\(]"),
("rParen", "[\\)]"), ("rParen", "[\\)]"),
@ -184,8 +187,8 @@ func = "(" fn_args ")" "=>" body
unit = func | "(" exp ")" | atom unit = func | "(" exp ")" | atom
args = unit ("," unit)* args = unit ("," unit)*
factor = unit "(" args ")" factor = unit "(" args ")"
term = (factor (*|/) factor) | factor term = (factor (*|/) term) | factor
exp = (term (+|-) term) | term exp = (term [(+|-) exp]) | term
letexp = ty id "=" exp ";" body letexp = ty id "=" exp ";" body
body = exp | letexp body = exp | letexp
@ -233,10 +236,12 @@ function funcAux(input)
end end
function longUnitAux(input) function longUnitAux(input)
rawFunc = seq([typ("lParen"), exp, typ("rParen")]) rawFunc = seq([typ("lParen"), exp, typ("rParen")])
rawRes = rawFunc.fun(input) rawRes = rawFunc.fun(input)
if rawRes != nothing if rawRes != nothing
matched = rawRes.matched[2] #fix for tree fix problem
matched = [("%wrapper", "id"), rawRes.matched[2]]
res = ParserResult(matched, rawRes.remained) res = ParserResult(matched, rawRes.remained)
return res return res
else else
@ -300,10 +305,26 @@ end
factor = Psr(factorAux) factor = Psr(factorAux)
function longTermAux(input) function longTermAux(input)
rawFunc = seq([factor, (typ("mul") | typ("div")), factor]) rawFunc = seq([factor, (typ("mul") | typ("div")), term])
rawRes = rawFunc.fun(input) rawRes = rawFunc.fun(input)
if rawRes != nothing if rawRes != nothing
matched = [("%prime", "id"), rawRes.matched[2], [rawRes.matched[1], rawRes.matched[3]]] #correct the tree a /(b / c) -> (a / b) / c
leftRator = rawRes.matched[2]
a = rawRes.matched[1]
bc = rawRes.matched[3]
matched = @match bc begin
[("%prime", "id"), (rightRator, tRightRator), [b, c]] where ((rightRator == "*") || (rightRator == "/")) =>
begin
[("%prime", "id"), (rightRator, tRightRator), [[("%prime", "id"), leftRator, [a, b]], c]]
end
_ => begin
[("%prime", "id"), leftRator, [a, bc]]
end
end
#matched = [("%prime", "id"), leftRator, [a, bc]]
res = ParserResult(matched, rawRes.remained) res = ParserResult(matched, rawRes.remained)
return res return res
else else
@ -322,10 +343,26 @@ term = Psr(termAux)
function longExpAux(input) function longExpAux(input)
rawFunc = seq([term, (typ("plus") | typ("minus")), term]) rawFunc = seq([term, (typ("plus") | typ("minus")), exp])
rawRes = rawFunc.fun(input) rawRes = rawFunc.fun(input)
if rawRes != nothing if rawRes != nothing
matched = [("%prime", "id"), rawRes.matched[2], [rawRes.matched[1], rawRes.matched[3]]] #correct the tree a -(b - c) -> (a - b) - c
leftRator = rawRes.matched[2]
a = rawRes.matched[1]
bc = rawRes.matched[3]
matched = @match bc begin
[("%prime", "id"), (rightRator, tRightRator), [b, c]] where ((rightRator == "+") || (rightRator == "-")) =>
begin
[("%prime", "id"), (rightRator, tRightRator), [[("%prime", "id"), leftRator, [a, b]], c]]
end
_ => begin
[("%prime", "id"), leftRator, [a, bc]]
end
end
#matched = [("%prime", "id"), leftRator, [a, bc]]
res = ParserResult(matched, rawRes.remained) res = ParserResult(matched, rawRes.remained)
return res return res
else else
@ -426,6 +463,17 @@ letExp = Psr(letExpAux)
body = letExp | exp body = letExp | exp
function fixTree(item)
return @match item begin
(val, t) => item
[("%wrapper", "id"), inner] => fixTree(inner)
[vars...] => map(fixTree, item)
_ => println("parse Error in fixTree")
end
end
function totalParse(prog) function totalParse(prog)
isEntirelyMatched = match(matchEntirely, prog) isEntirelyMatched = match(matchEntirely, prog)
if isEntirelyMatched == false if isEntirelyMatched == false
@ -437,10 +485,13 @@ function totalParse(prog)
groupNameList = map(processKeys, collect(mI)) groupNameList = map(processKeys, collect(mI))
zippedTokenList = collect(zip(matchedList, groupNameList)) zippedTokenList = collect(zip(matchedList, groupNameList))
withoutSpaces = filter((x)-> x[2] != "sp", zippedTokenList) withoutSpaceCmt = filter((x)-> (x[2] != "sp") & (x[2] != "cmt"), zippedTokenList)
initWrapped = ParserResult([], withoutSpaces) initWrapped = ParserResult([], withoutSpaceCmt)
res = initWrapped >> body res = initWrapped >> body
return res.matched
res2 = fixTree(res.matched)
return res2
end end

View file

@ -1,4 +1,6 @@
graph = [['a', 'b'], ['b', 'c'], ['e', 'd'], ['e', 'a'], ['a', 'c'], ['b','e'], ['e','c']] module VertexColoring
#graph = [['a', 'b'], ['b', 'c'], ['e', 'd'], ['e', 'a'], ['a', 'c'], ['b','e'], ['e','c']]
function vertexColoring(graph) function vertexColoring(graph)
@ -30,19 +32,18 @@ function vertexColoring(graph)
color = Dict() color = Dict()
println(verticesList)
tmpId = -1
for i in verticesList for i in verticesList
i_adjacents = adjacentNodes[i] i_adjacents = adjacentNodes[i]
println(i_adjacents)
i_adjacents_color_set = Set(map(x -> getColor(x, color), collect(i_adjacents))) i_adjacents_color_set = Set(map(x -> getColor(x, color), collect(i_adjacents)))
i_adjacents_color_list = sort(collect(i_adjacents_color_set)) i_adjacents_color_list = sort(collect(i_adjacents_color_set))
if i_adjacents_color_list == [notDefined] if i_adjacents_color_list == [notDefined]
color[i] = 0 color[i] = 2
else else
tmpId = 0 tmpId = 2
for i in i_adjacents_color_list for i in i_adjacents_color_list
if tmpId == i if tmpId == i
tmpId += 1 tmpId += 1
@ -52,8 +53,26 @@ function vertexColoring(graph)
end end
end end
return color maxColorId = tmpId
#force gDict["%rax"] = 1
color["%rax"] = 1
return (color, maxColorId)
end end
println(vertexColoring(graph)) #println(vertexColoring(graph))
# Disabled = force the color id for %rax set to 1
"""function correctGraph(gDict)
raxOrigId = gDict["%rax"]
for node in keys(gDict)
if gDict[node] == raxOrigId
gDict[node] = 1
elseif gDict[node] == 1
gDict[node] = raxOrigId
end
end
return gDict
end"""
end

6
test/prog2.tc Normal file
View file

@ -0,0 +1,6 @@
int a = 10;
int a = 13;
int b = (12 + (0 - a));
int c = (14 + b);
int d = 20;
a - (b - c)