add cache for the parser

This commit is contained in:
Tan, Kian-ting 2025-09-09 21:05:32 +08:00
parent 0ff09027a2
commit f4f7df05b8

View file

@ -3,7 +3,8 @@ using Match
struct ParserResult
matched
remained
pos
txt
end
OptParserResult = Union{Nothing,ParserResult}
@ -12,18 +13,20 @@ struct Psr
fun
end
parserCache = Dict()
function strng(c)
return Psr((x)-> length(x) >= 1 ?
(x[1][1] == c ?
ParserResult(x[1], x[2:end])
return Psr((txt, pos)-> length(txt) >= pos ?
(txt[pos][1] == c ?
ParserResult(txt[pos], pos+1, txt)
: nothing)
: nothing)
end
function typ(t)
return Psr((x)-> length(x) >= 1 ?
(x[1][2] == t ?
ParserResult(x[1], x[2:end])
return Psr((txt, pos)-> length(txt) >= pos ?
(txt[pos][2] == t ?
ParserResult(txt[pos], pos+1, txt)
: nothing)
: nothing)
end
@ -34,27 +37,56 @@ function then(a, b)
if a == nothing
return a
else
return b.fun(a.remained)
key = (objectid(b), a.pos)
if !(key in keys(parserCache))
res = b.fun(a.txt, a.pos)
global parserCache[key] = res
return res
else
return parserCache[key]
end
end
end
(|)(a::Psr, b::Psr) = choice(a, b)
function choice(a, b)
return Psr((x)-> (a.fun(x) == nothing ? b.fun(x) : a.fun(x)))
function inner(txt, pos)
aKey = (objectid(a), pos)
if !(aKey in keys(parserCache))
aRes = a.fun(txt, pos)
global parserCache[aKey] = aRes
else
aRes = parserCache[aKey]
end
if a.fun(txt, pos) == nothing
bKey = (objectid(b), pos)
if !(bKey in keys(parserCache))
bRes = b.fun(txt, pos)
global parserCache[bKey] = bRes
else
bRes = parserCache[bKey]
end
return bRes
else
return aRes
end
end
return Psr(inner)
end
function many0(parser)
function many0Aux(s)
function many0Aux(txt, pos)
result = []
tmp = parser.fun(s)
tmp = parser.fun(txt, pos)
p = pos
while tmp != nothing
s = tmp.remained
p = tmp.pos
result = push!(result, tmp.matched)
tmp = parser.fun(s)
tmp = parser.fun(txt, p)
end
return ParserResult(result, s)
return ParserResult(result, p, txt)
end
return Psr(many0Aux)
@ -62,23 +94,23 @@ function many0(parser)
end
function seq(parserLst)
function seqAux(s)
function seqAux(txt, pos)
result = []
isNothing = false
tmp = nothing
for p in parserLst
tmp = p.fun(s)
tmp = p.fun(txt, pos)
if tmp == nothing
return nothing
else
s = tmp.remained
pos = tmp.pos
result = push!(result, tmp.matched)
end
end
return ParserResult(result, s)
return ParserResult(result, pos, txt)
end
return Psr(seqAux)
@ -214,12 +246,12 @@ atom = typ("int") | typ("id") | typ("bool")
function fnArgItemAux(input)
function fnArgItemAux(input, pos)
rawFunc = seq([typ("comma"), typ("id")])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
matched = rawRes.matched[2]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
@ -227,12 +259,12 @@ function fnArgItemAux(input)
end
fnArgItem = Psr(fnArgItemAux)
function fnArgsAux(input)
function fnArgsAux(input, pos)
rawFunc = seq([typ("id"), many0(fnArgItem)])
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
if res != nothing
matched = vcat([res.matched[1]], res.matched[2])
res = ParserResult(matched, res.remained)
res = ParserResult(matched, res.pos, res.txt)
return res
else
return nothing
@ -240,47 +272,47 @@ function fnArgsAux(input)
end
fnArgs = Psr(fnArgsAux)
function funcAux(input)
function funcAux(input, pos)
rawFunc = seq([typ("lParen"), fnArgs, typ("rParen"), typ("lambda"), body])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
matched = [("%lambda", "id"), rawRes.matched[2], rawRes.matched[5]]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function longUnitAux(input)
function longUnitAux(input, pos)
rawFunc = seq([typ("lParen"), exp, typ("rParen")])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
#fix for tree fix problem
matched = [("%wrapper", "id"), rawRes.matched[2]]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function unitAux(input)
function unitAux(input, pos)
fun = Psr(funcAux)
longUnit = Psr(longUnitAux)
rawFunc = fun | longUnit | atom
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
unit = Psr(unitAux)
function argItemAux(input)
function argItemAux(input, pos)
rawFunc = seq([typ("comma"), exp])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
matched = rawRes.matched[2]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
@ -288,42 +320,42 @@ function argItemAux(input)
end
argItem = Psr(argItemAux)
function argsAux(input)
function argsAux(input, pos)
rawFunc = seq([exp, many0(argItem)])
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
if res != nothing
matched = vcat([res.matched[1]], res.matched[2])
res = ParserResult(matched, res.remained)
res = ParserResult(matched, res.pos, res.txt)
end
return res
end
args = Psr(argsAux)
function longFactorAux(input)
function longFactorAux(input, pos)
rawFunc = seq([unit, typ("lParen"), args, typ("rParen")])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
matched = [("%call", "id"), rawRes.matched[1], rawRes.matched[3]]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function factorAux(input)
function factorAux(input, pos)
longFactor = Psr(longFactorAux)
rawFunc = longFactor | unit
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
factor = Psr(factorAux)
function longUnaryAux(input)
function longUnaryAux(input, pos)
rawFunc = seq([(typ("minus") | typ("not")), factor])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
rator = rawRes.matched[1]
rand = rawRes.matched[2]
@ -332,25 +364,25 @@ function longUnaryAux(input)
else
matched = [rator, rand]
end
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function unaryAux(input)
function unaryAux(input, pos)
longUnary = Psr(longUnaryAux)
rawFunc = longUnary | factor
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
unary = Psr(unaryAux)
function longTermAux(input)
function longTermAux(input, pos)
rawFunc = seq([unary, (typ("mul") | typ("div")), term])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
#correct the tree a /(b / c) -> (a / b) / c
leftRator = rawRes.matched[2]
@ -369,26 +401,26 @@ function longTermAux(input)
end
#matched = [("%prime", "id"), leftRator, [a, bc]]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function termAux(input)
function termAux(input, pos)
longTerm = Psr(longTermAux)
rawFunc = longTerm | unary
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
term = Psr(termAux)
function longEqCheckeeAux(input)
function longEqCheckeeAux(input, pos)
rawFunc = seq([term, (typ("plus") | typ("minus")), eqCheckee])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
#correct the tree a -(b - c) -> (a - b) - c
leftRator = rawRes.matched[2]
@ -407,89 +439,89 @@ function longEqCheckeeAux(input)
end
#matched = [("%prime", "id"), leftRator, [a, bc]]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function eqCheckeeAux(input)
function eqCheckeeAux(input, pos)
longEqCheckee = Psr(longEqCheckeeAux)
rawFunc = longEqCheckee | term
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
eqCheckee = Psr(eqCheckeeAux)
function longlogicalAux(input)
function longlogicalAux(input, pos)
rawFunc = seq([eqCheckee, (typ("eq") | typ("ne")|typ("lt") | typ("gt")|typ("le") | typ("ge")), eqCheckee])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
rator = rawRes.matched[2]
l = rawRes.matched[1]
r = rawRes.matched[3]
matched = [rator, l, r]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function logicalAux(input)
function logicalAux(input, pos)
longLogical = Psr(longlogicalAux)
rawFunc = longLogical | eqCheckee
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
logical = Psr(logicalAux)
function longSubExpAux(input)
function longSubExpAux(input, pos)
rawFunc = seq([logical, (typ("and") | typ("or")), logical])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
rator = rawRes.matched[2]
l = rawRes.matched[1]
r = rawRes.matched[3]
matched = [rator, l, r]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function subExpAux(input)
function subExpAux(input, pos)
longSubExp = Psr(longSubExpAux)
rawFunc = longSubExp | logical
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
subExp = Psr(subExpAux)
function longExpAux(input)
function longExpAux(input, pos)
rawFunc = seq([typ("if"), exp,
typ("then"), body,
typ("else"), body])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
cond = rawRes.matched[2]
branch1 = rawRes.matched[4]
branch2 = rawRes.matched[6]
matched = [("%if", "id"), cond, branch1, branch2]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
end
end
function expAux(input)
function expAux(input, pos)
longExp = Psr(longExpAux)
rawFunc = longExp | subExp
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
exp = Psr(expAux)
@ -501,12 +533,12 @@ tyHead = tyOfArgs | tyOfFn | id
tyOfFn = "(" tyHead -> ty ")"
ty = id | tyOfFn """
function tyArgItemAux(input)
function tyArgItemAux(input, pos)
rawFunc = seq([typ("comma"), ty])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
matched = rawRes.matched[2]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
@ -527,21 +559,21 @@ function tyOfArgsAux(input)
end
end
function tyHeadAux(input)
function tyHeadAux(input, pos)
tyOfArgs = Psr(tyOfArgsAux)
tyOfFn = Psr(tyOfFnAux)
rawFunc = tyOfArgs | tyOfFn | typ("id")
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
function tyOfFnAux(input)
function tyOfFnAux(input, pos)
tyHead = Psr(tyHeadAux)
rawFunc = seq([typ("lParen"), tyHead, typ("funType"), ty, typ("rParen")])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
matched = [("%funType", "id"), rawRes.matched[2], rawRes.matched[4]]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
@ -550,18 +582,18 @@ end
function tyAux(input)
function tyAux(input, pos)
tyOfFn= Psr(tyOfFnAux)
rawFunc = tyOfFn | typ("id")
res = rawFunc.fun(input)
res = rawFunc.fun(input, pos)
return res
end
ty = Psr(tyAux)
function letExpAux(input)
function letExpAux(input, pos)
#id id "=" exp ";" body
rawFunc = seq([ty, typ("id"), typ("assign"), exp, typ("semicolon"), body])
rawRes = rawFunc.fun(input)
rawRes = rawFunc.fun(input, pos)
if rawRes != nothing
typ_matched = rawRes.matched[1]
var_matched = rawRes.matched[2]
@ -569,7 +601,7 @@ function letExpAux(input)
body_matched = rawRes.matched[6]
matched = [("%let", "id"), [typ_matched, var_matched], val_matched, body_matched]
res = ParserResult(matched, rawRes.remained)
res = ParserResult(matched, rawRes.pos, rawRes.txt)
return res
else
return nothing
@ -578,7 +610,20 @@ end
letExp = Psr(letExpAux)
body = letExp | exp
function bodyAux(input, pos)
rawFunc = letExp | exp
key = (objectid(rawFunc), pos)
if !(key in keys(parserCache))
res = rawFunc.fun(input, pos)
global parserCache[key] = res
else
res = parserCache[key]
end
return res
end
body = Psr(bodyAux)
function fixTree(item)
@ -603,7 +648,7 @@ function totalParse(prog)
zippedTokenList = collect(zip(matchedList, groupNameList))
withoutSpaceCmt = filter((x)-> (x[2] != "sp") & (x[2] != "cmt"), zippedTokenList)
initWrapped = ParserResult([], withoutSpaceCmt)
initWrapped = ParserResult([], 1, withoutSpaceCmt)
res = initWrapped >> body
res2 = fixTree(res.matched)
@ -613,4 +658,3 @@ end
end