diff options
Diffstat (limited to 'script/parser')
-rw-r--r-- | script/parser/ast.lua | 1751 | ||||
-rw-r--r-- | script/parser/calcline.lua | 94 | ||||
-rw-r--r-- | script/parser/compile.lua | 561 | ||||
-rw-r--r-- | script/parser/grammar.lua | 538 | ||||
-rw-r--r-- | script/parser/guide.lua | 3884 | ||||
-rw-r--r-- | script/parser/init.lua | 12 | ||||
-rw-r--r-- | script/parser/lines.lua | 45 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 991 | ||||
-rw-r--r-- | script/parser/parse.lua | 49 | ||||
-rw-r--r-- | script/parser/relabel.lua | 361 | ||||
-rw-r--r-- | script/parser/split.lua | 9 |
11 files changed, 8295 insertions, 0 deletions
diff --git a/script/parser/ast.lua b/script/parser/ast.lua new file mode 100644 index 00000000..d8614eae --- /dev/null +++ b/script/parser/ast.lua @@ -0,0 +1,1751 @@ +local tonumber = tonumber +local stringChar = string.char +local utf8Char = utf8.char +local tableUnpack = table.unpack +local mathType = math.type +local tableRemove = table.remove +local pairs = pairs +local tableSort = table.sort + +_ENV = nil + +local State +local PushError +local PushDiag +local PushComment + +-- goto 单独处理 +local RESERVED = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +local VersionOp = { + ['&'] = {'Lua 5.3', 'Lua 5.4'}, + ['~'] = {'Lua 5.3', 'Lua 5.4'}, + ['|'] = {'Lua 5.3', 'Lua 5.4'}, + ['<<'] = {'Lua 5.3', 'Lua 5.4'}, + ['>>'] = {'Lua 5.3', 'Lua 5.4'}, + ['//'] = {'Lua 5.3', 'Lua 5.4'}, +} + +local function checkOpVersion(op) + local versions = VersionOp[op.type] + if not versions then + return + end + for i = 1, #versions do + if versions[i] == State.version then + return + end + end + PushError { + type = 'UNSUPPORT_SYMBOL', + start = op.start, + finish = op.finish, + version = versions, + info = { + version = State.version, + } + } +end + +local function checkMissEnd(start) + if not State.MissEndErr then + return + end + local err = State.MissEndErr + State.MissEndErr = nil + local _, finish = State.lua:find('[%w_]+', start) + if not finish then + return + end + err.info.related = { + { + start = start, + finish = finish, + } + } + PushError { + type = 'MISS_END', + start = start, + finish = finish, + } +end + +local function getSelect(vararg, index) + return { + type = 'select', + start = vararg.start, + finish = vararg.finish, + vararg = vararg, + index = index, + } +end + +local function getValue(values, i) + if not values then + return nil, nil + end + local value = values[i] + if not value then + local last = values[#values] + if not last then + return nil, nil + end + if last.type == 'call' or last.type == 'varargs' then + return getSelect(last, i - #values + 1) + end + return nil, nil + end + if value.type == 'call' or value.type == 'varargs' then + value = getSelect(value, 1) + end + return value +end + +local function createLocal(key, effect, value, attrs) + if not key then + return nil + end + key.type = 'local' + key.effect = effect + key.value = value + key.attrs = attrs + if value then + key.range = value.finish + end + return key +end + +local function createCall(args, start, finish) + if args then + args.type = 'callargs' + args.start = start + args.finish = finish + end + return { + type = 'call', + start = start, + finish = finish, + args = args, + } +end + +local function packList(start, list, finish) + local lastFinish = start + local wantName = true + local count = 0 + for i = 1, #list do + local ast = list[i] + if ast.type == ',' then + if wantName or i == #list then + PushError { + type = 'UNEXPECT_SYMBOL', + start = ast.start, + finish = ast.finish, + info = { + symbol = ',', + } + } + end + wantName = true + else + if not wantName then + PushError { + type = 'MISS_SYMBOL', + start = lastFinish, + finish = ast.start - 1, + info = { + symbol = ',', + } + } + end + wantName = false + count = count + 1 + list[count] = list[i] + end + lastFinish = ast.finish + 1 + end + for i = count + 1, #list do + list[i] = nil + end + list.type = 'list' + list.start = start + list.finish = finish - 1 + return list +end + +local BinaryLevel = { + ['or'] = 1, + ['and'] = 2, + ['<='] = 3, + ['>='] = 3, + ['<'] = 3, + ['>'] = 3, + ['~='] = 3, + ['=='] = 3, + ['|'] = 4, + ['~'] = 5, + ['&'] = 6, + ['<<'] = 7, + ['>>'] = 7, + ['..'] = 8, + ['+'] = 9, + ['-'] = 9, + ['*'] = 10, + ['//'] = 10, + ['/'] = 10, + ['%'] = 10, + ['^'] = 11, +} + +local BinaryForward = { + [01] = true, + [02] = true, + [03] = true, + [04] = true, + [05] = true, + [06] = true, + [07] = true, + [08] = false, + [09] = true, + [10] = true, + [11] = false, +} + +local Defs = { + Nil = function (pos) + return { + type = 'nil', + start = pos, + finish = pos + 2, + } + end, + True = function (pos) + return { + type = 'boolean', + start = pos, + finish = pos + 3, + [1] = true, + } + end, + False = function (pos) + return { + type = 'boolean', + start = pos, + finish = pos + 4, + [1] = false, + } + end, + ShortComment = function (start, text, finish) + PushComment { + start = start, + finish = finish - 1, + text = text, + } + end, + LongComment = function (beforeEq, afterEq, str, missPos) + if missPos then + local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']' + local s, _, w = str:find('(%][%=]*%])[%c%s]*$') + if s then + PushError { + type = 'ERR_LCOMMENT_END', + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + info = { + symbol = endSymbol, + }, + fix = { + title = 'FIX_LCOMMENT_END', + { + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + text = endSymbol, + } + }, + } + end + PushError { + type = 'MISS_SYMBOL', + start = missPos, + finish = missPos, + info = { + symbol = endSymbol, + }, + fix = { + title = 'ADD_LCOMMENT_END', + { + start = missPos, + finish = missPos, + text = endSymbol, + } + }, + } + end + end, + CLongComment = function (start1, finish1, start2, finish2) + PushError { + type = 'ERR_C_LONG_COMMENT', + start = start1, + finish = finish2 - 1, + fix = { + title = 'FIX_C_LONG_COMMENT', + { + start = start1, + finish = finish1 - 1, + text = '--[[', + }, + { + start = start2, + finish = finish2 - 1, + text = '--]]' + }, + } + } + end, + CCommentPrefix = function (start, finish) + PushError { + type = 'ERR_COMMENT_PREFIX', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_COMMENT_PREFIX', + { + start = start, + finish = finish - 1, + text = '--', + }, + } + } + end, + String = function (start, quote, str, finish) + return { + type = 'string', + start = start, + finish = finish - 1, + [1] = str, + [2] = quote, + } + end, + LongString = function (beforeEq, afterEq, str, missPos) + if missPos then + local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']' + local s, _, w = str:find('(%][%=]*%])[%c%s]*$') + if s then + PushError { + type = 'ERR_LSTRING_END', + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + info = { + symbol = endSymbol, + }, + fix = { + title = 'FIX_LSTRING_END', + { + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + text = endSymbol, + } + }, + } + end + PushError { + type = 'MISS_SYMBOL', + start = missPos, + finish = missPos, + info = { + symbol = endSymbol, + }, + fix = { + title = 'ADD_LSTRING_END', + { + start = missPos, + finish = missPos, + text = endSymbol, + } + }, + } + end + return '[' .. ('='):rep(afterEq-beforeEq) .. '[', str + end, + Char10 = function (char) + char = tonumber(char) + if not char or char < 0 or char > 255 then + return '' + end + return stringChar(char) + end, + Char16 = function (pos, char) + if State.version == 'Lua 5.1' then + PushError { + type = 'ERR_ESC', + start = pos-1, + finish = pos, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return char + end + return stringChar(tonumber(char, 16)) + end, + CharUtf8 = function (pos, char) + if State.version ~= 'Lua 5.3' + and State.version ~= 'Lua 5.4' + and State.version ~= 'LuaJIT' + then + PushError { + type = 'ERR_ESC', + start = pos-3, + finish = pos-2, + version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return char + end + if #char == 0 then + PushError { + type = 'UTF8_SMALL', + start = pos-3, + finish = pos, + } + return '' + end + local v = tonumber(char, 16) + if not v then + for i = 1, #char do + if not tonumber(char:sub(i, i), 16) then + PushError { + type = 'MUST_X16', + start = pos + i - 1, + finish = pos + i - 1, + } + end + end + return '' + end + if State.version == 'Lua 5.4' then + if v < 0 or v > 0x7FFFFFFF then + PushError { + type = 'UTF8_MAX', + start = pos-3, + finish = pos+#char, + info = { + min = '00000000', + max = '7FFFFFFF', + } + } + end + else + if v < 0 or v > 0x10FFFF then + PushError { + type = 'UTF8_MAX', + start = pos-3, + finish = pos+#char, + version = v <= 0x7FFFFFFF and 'Lua 5.4' or nil, + info = { + min = '000000', + max = '10FFFF', + } + } + end + end + if v >= 0 and v <= 0x10FFFF then + return utf8Char(v) + end + return '' + end, + Number = function (start, number, finish) + local n = tonumber(number) + if n then + State.LastNumber = { + type = 'number', + start = start, + finish = finish - 1, + [1] = n, + } + return State.LastNumber + else + PushError { + type = 'MALFORMED_NUMBER', + start = start, + finish = finish - 1, + } + State.LastNumber = { + type = 'number', + start = start, + finish = finish - 1, + [1] = 0, + } + return State.LastNumber + end + end, + FFINumber = function (start, symbol) + local lastNumber = State.LastNumber + if mathType(lastNumber[1]) == 'float' then + PushError { + type = 'UNKNOWN_SYMBOL', + start = start, + finish = start + #symbol - 1, + info = { + symbol = symbol, + } + } + lastNumber[1] = 0 + return + end + if State.version ~= 'LuaJIT' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = start + #symbol - 1, + version = 'LuaJIT', + info = { + version = State.version, + } + } + lastNumber[1] = 0 + end + end, + ImaginaryNumber = function (start, symbol) + local lastNumber = State.LastNumber + if State.version ~= 'LuaJIT' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = start + #symbol - 1, + version = 'LuaJIT', + info = { + version = State.version, + } + } + end + lastNumber[1] = 0 + end, + Name = function (start, str, finish) + local isKeyWord + if RESERVED[str] then + isKeyWord = true + elseif str == 'goto' then + if State.version ~= 'Lua 5.1' and State.version ~= 'LuaJIT' then + isKeyWord = true + end + end + if isKeyWord then + PushError { + type = 'KEYWORD', + start = start, + finish = finish - 1, + } + end + return { + type = 'name', + start = start, + finish = finish - 1, + [1] = str, + } + end, + GetField = function (dot, field) + local obj = { + type = 'getfield', + field = field, + dot = dot, + start = dot.start, + finish = (field or dot).finish, + } + if field then + field.type = 'field' + field.parent = obj + end + return obj + end, + GetIndex = function (start, index, finish) + local obj = { + type = 'getindex', + start = start, + finish = finish - 1, + index = index, + } + if index then + index.parent = obj + end + return obj + end, + GetMethod = function (colon, method) + local obj = { + type = 'getmethod', + method = method, + colon = colon, + start = colon.start, + finish = (method or colon).finish, + } + if method then + method.type = 'method' + method.parent = obj + end + return obj + end, + Single = function (unit) + unit.type = 'getname' + return unit + end, + Simple = function (units) + local last = units[1] + for i = 2, #units do + local current = units[i] + current.node = last + current.start = last.start + last.next = current + last = units[i] + end + return last + end, + SimpleCall = function (call) + if call.type ~= 'call' and call.type ~= 'getmethod' then + PushError { + type = 'EXP_IN_ACTION', + start = call.start, + finish = call.finish, + } + end + return call + end, + BinaryOp = function (start, op) + return { + type = op, + start = start, + finish = start + #op - 1, + } + end, + UnaryOp = function (start, op) + return { + type = op, + start = start, + finish = start + #op - 1, + } + end, + Unary = function (first, ...) + if not ... then + return nil + end + local list = {first, ...} + local e = list[#list] + for i = #list - 1, 1, -1 do + local op = list[i] + checkOpVersion(op) + e = { + type = 'unary', + op = op, + start = op.start, + finish = e.finish, + [1] = e, + } + end + return e + end, + SubBinary = function (op, symb) + if symb then + return op, symb + end + PushError { + type = 'MISS_EXP', + start = op.start, + finish = op.finish, + } + end, + Binary = function (first, op, second, ...) + if not first then + return second + end + if not op then + return first + end + if not ... then + checkOpVersion(op) + return { + type = 'binary', + op = op, + start = first.start, + finish = second.finish, + [1] = first, + [2] = second, + } + end + local list = {first, op, second, ...} + local ops = {} + for i = 2, #list, 2 do + ops[#ops+1] = i + end + tableSort(ops, function (a, b) + local op1 = list[a] + local op2 = list[b] + local lv1 = BinaryLevel[op1.type] + local lv2 = BinaryLevel[op2.type] + if lv1 == lv2 then + local forward = BinaryForward[lv1] + if forward then + return op1.start > op2.start + else + return op1.start < op2.start + end + else + return lv1 < lv2 + end + end) + local final + for i = #ops, 1, -1 do + local n = ops[i] + local op = list[n] + local left = list[n-1] + local right = list[n+1] + local exp = { + type = 'binary', + op = op, + start = left.start, + finish = right and right.finish or op.finish, + [1] = left, + [2] = right, + } + local leftIndex, rightIndex + if list[left] then + leftIndex = list[left[1]] + else + leftIndex = n - 1 + end + if list[right] then + rightIndex = list[right[2]] + else + rightIndex = n + 1 + end + + list[leftIndex] = exp + list[rightIndex] = exp + list[left] = leftIndex + list[right] = rightIndex + list[exp] = n + final = exp + + checkOpVersion(op) + end + return final + end, + Paren = function (start, exp, finish) + if exp and exp.type == 'paren' then + exp.start = start + exp.finish = finish - 1 + return exp + end + return { + type = 'paren', + start = start, + finish = finish - 1, + exp = exp + } + end, + VarArgs = function (dots) + dots.type = 'varargs' + return dots + end, + PackLoopArgs = function (start, list, finish) + local list = packList(start, list, finish) + if #list == 0 then + PushError { + type = 'MISS_LOOP_MIN', + start = finish, + finish = finish, + } + elseif #list == 1 then + PushError { + type = 'MISS_LOOP_MAX', + start = finish, + finish = finish, + } + end + return list + end, + PackInNameList = function (start, list, finish) + local list = packList(start, list, finish) + if #list == 0 then + PushError { + type = 'MISS_NAME', + start = start, + finish = finish, + } + end + return list + end, + PackInExpList = function (start, list, finish) + local list = packList(start, list, finish) + if #list == 0 then + PushError { + type = 'MISS_EXP', + start = start, + finish = finish, + } + end + return list + end, + PackExpList = function (start, list, finish) + local list = packList(start, list, finish) + return list + end, + PackNameList = function (start, list, finish) + local list = packList(start, list, finish) + return list + end, + Call = function (start, args, finish) + return createCall(args, start, finish-1) + end, + COMMA = function (start) + return { + type = ',', + start = start, + finish = start, + } + end, + SEMICOLON = function (start) + return { + type = ';', + start = start, + finish = start, + } + end, + DOTS = function (start) + return { + type = '...', + start = start, + finish = start + 2, + } + end, + COLON = function (start) + return { + type = ':', + start = start, + finish = start, + } + end, + DOT = function (start) + return { + type = '.', + start = start, + finish = start, + } + end, + Function = function (functionStart, functionFinish, args, actions, endStart, endFinish) + actions.type = 'function' + actions.start = functionStart + actions.finish = endFinish - 1 + actions.args = args + actions.keyword= { + functionStart, functionFinish - 1, + endStart, endFinish - 1, + } + checkMissEnd(functionStart) + return actions + end, + NamedFunction = function (functionStart, functionFinish, name, args, actions, endStart, endFinish) + actions.type = 'function' + actions.start = functionStart + actions.finish = endFinish - 1 + actions.args = args + actions.keyword= { + functionStart, functionFinish - 1, + endStart, endFinish - 1, + } + checkMissEnd(functionStart) + if not name then + return + end + if name.type == 'getname' then + name.type = 'setname' + name.value = actions + elseif name.type == 'getfield' then + name.type = 'setfield' + name.value = actions + elseif name.type == 'getmethod' then + name.type = 'setmethod' + name.value = actions + end + name.range = actions.finish + name.vstart = functionStart + return name + end, + LocalFunction = function (start, functionStart, functionFinish, name, args, actions, endStart, endFinish) + actions.type = 'function' + actions.start = start + actions.finish = endFinish - 1 + actions.args = args + actions.keyword= { + functionStart, functionFinish - 1, + endStart, endFinish - 1, + } + checkMissEnd(start) + + if not name then + return + end + + if name.type ~= 'getname' then + PushError { + type = 'UNEXPECT_LFUNC_NAME', + start = name.start, + finish = name.finish, + } + return + end + + local loc = createLocal(name, name.start, actions) + loc.localfunction = true + loc.vstart = functionStart + + return loc + end, + Table = function (start, tbl, finish) + tbl.type = 'table' + tbl.start = start + tbl.finish = finish - 1 + local wantField = true + local lastStart = start + 1 + local fieldCount = 0 + for i = 1, #tbl do + local field = tbl[i] + if field.type == ',' or field.type == ';' then + if wantField then + PushError { + type = 'MISS_EXP', + start = lastStart, + finish = field.start - 1, + } + end + wantField = true + lastStart = field.finish + 1 + else + if not wantField then + PushError { + type = 'MISS_SEP_IN_TABLE', + start = lastStart, + finish = field.start - 1, + } + end + wantField = false + lastStart = field.finish + 1 + fieldCount = fieldCount + 1 + tbl[fieldCount] = field + end + end + for i = fieldCount + 1, #tbl do + tbl[i] = nil + end + return tbl + end, + NewField = function (start, field, value, finish) + local obj = { + type = 'tablefield', + start = start, + finish = finish-1, + field = field, + value = value, + } + if field then + field.type = 'field' + field.parent = obj + end + return obj + end, + NewIndex = function (start, index, value, finish) + local obj = { + type = 'tableindex', + start = start, + finish = finish-1, + index = index, + value = value, + } + if index then + index.parent = obj + end + return obj + end, + FuncArgs = function (start, args, finish) + args.type = 'funcargs' + args.start = start + args.finish = finish - 1 + local lastStart = start + 1 + local wantName = true + local argCount = 0 + for i = 1, #args do + local arg = args[i] + local argAst = arg + if argAst.type == ',' then + if wantName then + PushError { + type = 'MISS_NAME', + start = lastStart, + finish = argAst.start-1, + } + end + wantName = true + else + if not wantName then + PushError { + type = 'MISS_SYMBOL', + start = lastStart-1, + finish = argAst.start-1, + info = { + symbol = ',', + } + } + end + wantName = false + argCount = argCount + 1 + + if argAst.type == '...' then + args[argCount] = arg + if i < #args then + local a = args[i+1] + local b = args[#args] + PushError { + type = 'ARGS_AFTER_DOTS', + start = a.start, + finish = b.finish, + } + end + break + else + args[argCount] = createLocal(arg, arg.start) + end + end + lastStart = argAst.finish + 1 + end + for i = argCount + 1, #args do + args[i] = nil + end + if wantName and argCount > 0 then + PushError { + type = 'MISS_NAME', + start = lastStart, + finish = finish - 1, + } + end + return args + end, + Set = function (start, keys, values, finish) + for i = 1, #keys do + local key = keys[i] + if key.type == 'getname' then + key.type = 'setname' + key.value = getValue(values, i) + elseif key.type == 'getfield' then + key.type = 'setfield' + key.value = getValue(values, i) + elseif key.type == 'getindex' then + key.type = 'setindex' + key.value = getValue(values, i) + end + if key.value then + key.range = key.value.finish + end + end + if values then + for i = #keys+1, #values do + local value = values[i] + PushDiag('redundant-value', { + start = value.start, + finish = value.finish, + max = #keys, + passed = #values, + }) + end + end + return tableUnpack(keys) + end, + LocalAttr = function (attrs) + if #attrs == 0 then + return nil + end + for i = 1, #attrs do + local attr = attrs[i] + local attrAst = attr + attrAst.type = 'localattr' + if State.version ~= 'Lua 5.4' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = attrAst.start, + finish = attrAst.finish, + version = 'Lua 5.4', + info = { + version = State.version, + } + } + elseif attrAst[1] ~= 'const' and attrAst[1] ~= 'close' then + PushError { + type = 'UNKNOWN_TAG', + start = attrAst.start, + finish = attrAst.finish, + info = { + tag = attrAst[1], + } + } + elseif i > 1 then + PushError { + type = 'MULTI_TAG', + start = attrAst.start, + finish = attrAst.finish, + info = { + tag = attrAst[1], + } + } + end + end + attrs.start = attrs[1].start + attrs.finish = attrs[#attrs].finish + return attrs + end, + LocalName = function (name, attrs) + if not name then + return name + end + name.attrs = attrs + return name + end, + Local = function (start, keys, values, finish) + for i = 1, #keys do + local key = keys[i] + local attrs = key.attrs + key.attrs = nil + local value = getValue(values, i) + createLocal(key, finish, value, attrs) + end + if values then + for i = #keys+1, #values do + local value = values[i] + PushDiag('redundant-value', { + start = value.start, + finish = value.finish, + max = #keys, + passed = #values, + }) + end + end + return tableUnpack(keys) + end, + Do = function (start, actions, endA, endB) + actions.type = 'do' + actions.start = start + actions.finish = endB - 1 + actions.keyword= { + start, start + #'do' - 1, + endA , endB - 1, + } + checkMissEnd(start) + return actions + end, + Break = function (start, finish) + return { + type = 'break', + start = start, + finish = finish - 1, + } + end, + Return = function (start, exps, finish) + exps.type = 'return' + exps.start = start + exps.finish = finish - 1 + return exps + end, + Label = function (start, name, finish) + if State.version == 'Lua 5.1' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = finish - 1, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return + end + if not name then + return nil + end + name.type = 'label' + return name + end, + GoTo = function (start, name, finish) + if State.version == 'Lua 5.1' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = finish - 1, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return + end + if not name then + return nil + end + name.type = 'goto' + return name + end, + IfBlock = function (ifStart, ifFinish, exp, thenStart, thenFinish, actions, finish) + actions.type = 'ifblock' + actions.start = ifStart + actions.finish = finish - 1 + actions.filter = exp + actions.keyword= { + ifStart, ifFinish - 1, + thenStart, thenFinish - 1, + } + return actions + end, + ElseIfBlock = function (elseifStart, elseifFinish, exp, thenStart, thenFinish, actions, finish) + actions.type = 'elseifblock' + actions.start = elseifStart + actions.finish = finish - 1 + actions.filter = exp + actions.keyword= { + elseifStart, elseifFinish - 1, + thenStart, thenFinish - 1, + } + return actions + end, + ElseBlock = function (elseStart, elseFinish, actions, finish) + actions.type = 'elseblock' + actions.start = elseStart + actions.finish = finish - 1 + actions.keyword= { + elseStart, elseFinish - 1, + } + return actions + end, + If = function (start, blocks, endStart, endFinish) + blocks.type = 'if' + blocks.start = start + blocks.finish = endFinish - 1 + local hasElse + for i = 1, #blocks do + local block = blocks[i] + if i == 1 and block.type ~= 'ifblock' then + PushError { + type = 'MISS_SYMBOL', + start = block.start, + finish = block.start, + info = { + symbol = 'if', + } + } + end + if hasElse then + PushError { + type = 'BLOCK_AFTER_ELSE', + start = block.start, + finish = block.finish, + } + end + if block.type == 'elseblock' then + hasElse = true + end + end + checkMissEnd(start) + return blocks + end, + Loop = function (forA, forB, arg, steps, doA, doB, blockStart, block, endA, endB) + local loc = createLocal(arg, blockStart, steps[1]) + block.type = 'loop' + block.start = forA + block.finish = endB - 1 + block.loc = loc + block.max = steps[2] + block.step = steps[3] + block.keyword= { + forA, forB - 1, + doA , doB - 1, + endA, endB - 1, + } + checkMissEnd(forA) + return block + end, + In = function (forA, forB, keys, inA, inB, exp, doA, doB, blockStart, block, endA, endB) + local func = tableRemove(exp, 1) + block.type = 'in' + block.start = forA + block.finish = endB - 1 + block.keys = keys + block.keyword= { + forA, forB - 1, + inA , inB - 1, + doA , doB - 1, + endA, endB - 1, + } + + local values + if func then + local call = createCall(exp, func.finish + 1, exp.finish) + call.node = func + call.start = func.start + func.next = call + values = { call } + keys.range = call.finish + end + for i = 1, #keys do + local loc = keys[i] + if values then + createLocal(loc, blockStart, getValue(values, i)) + else + createLocal(loc, blockStart) + end + end + checkMissEnd(forA) + return block + end, + While = function (whileA, whileB, filter, doA, doB, block, endA, endB) + block.type = 'while' + block.start = whileA + block.finish = endB - 1 + block.filter = filter + block.keyword= { + whileA, whileB - 1, + doA , doB - 1, + endA , endB - 1, + } + checkMissEnd(whileA) + return block + end, + Repeat = function (repeatA, repeatB, block, untilA, untilB, filter, finish) + block.type = 'repeat' + block.start = repeatA + block.finish = finish + block.filter = filter + block.keyword= { + repeatA, repeatB - 1, + untilA , untilB - 1, + } + return block + end, + Lua = function (start, actions, finish) + actions.type = 'main' + actions.start = start + actions.finish = finish - 1 + return actions + end, + + -- 捕获错误 + UnknownSymbol = function (start, symbol) + PushError { + type = 'UNKNOWN_SYMBOL', + start = start, + finish = start + #symbol - 1, + info = { + symbol = symbol, + } + } + return + end, + UnknownAction = function (start, symbol) + PushError { + type = 'UNKNOWN_SYMBOL', + start = start, + finish = start + #symbol - 1, + info = { + symbol = symbol, + } + } + end, + DirtyName = function (pos) + PushError { + type = 'MISS_NAME', + start = pos, + finish = pos, + } + return nil + end, + DirtyExp = function (pos) + PushError { + type = 'MISS_EXP', + start = pos, + finish = pos, + } + return nil + end, + MissExp = function (pos) + PushError { + type = 'MISS_EXP', + start = pos, + finish = pos, + } + end, + MissExponent = function (start, finish) + PushError { + type = 'MISS_EXPONENT', + start = start, + finish = finish - 1, + } + end, + MissQuote1 = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '"' + } + } + end, + MissQuote2 = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = "'" + } + } + end, + MissEscX = function (pos) + PushError { + type = 'MISS_ESC_X', + start = pos-2, + finish = pos+1, + } + end, + MissTL = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '{', + } + } + end, + MissTR = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '}', + } + } + end, + MissBR = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ']', + } + } + end, + MissPL = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '(', + } + } + end, + MissPR = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ')', + } + } + end, + ErrEsc = function (pos) + PushError { + type = 'ERR_ESC', + start = pos-1, + finish = pos, + } + end, + MustX16 = function (pos, str) + PushError { + type = 'MUST_X16', + start = pos, + finish = pos + #str - 1, + } + end, + MissAssign = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '=', + } + } + end, + MissTableSep = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ',' + } + } + end, + MissField = function (pos) + PushError { + type = 'MISS_FIELD', + start = pos, + finish = pos, + } + end, + MissMethod = function (pos) + PushError { + type = 'MISS_METHOD', + start = pos, + finish = pos, + } + end, + MissLabel = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '::', + } + } + end, + MissEnd = function (pos) + State.MissEndErr = PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'end', + } + } + return pos, pos + end, + MissDo = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'do', + } + } + return pos, pos + end, + MissComma = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ',', + } + } + end, + MissIn = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'in', + } + } + return pos, pos + end, + MissUntil = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'until', + } + } + return pos, pos + end, + MissThen = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'then', + } + } + return pos, pos + end, + MissName = function (pos) + PushError { + type = 'MISS_NAME', + start = pos, + finish = pos, + } + end, + ExpInAction = function (start, exp, finish) + PushError { + type = 'EXP_IN_ACTION', + start = start, + finish = finish - 1, + } + -- 当exp为nil时,不能返回任何值,否则会产生带洞的actionlist + if exp then + return exp + else + return + end + end, + MissIf = function (start, block) + PushError { + type = 'MISS_SYMBOL', + start = start, + finish = start, + info = { + symbol = 'if', + } + } + return block + end, + MissGT = function (start) + PushError { + type = 'MISS_SYMBOL', + start = start, + finish = start, + info = { + symbol = '>' + } + } + end, + ErrAssign = function (start, finish) + PushError { + type = 'ERR_ASSIGN_AS_EQ', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_ASSIGN_AS_EQ', + { + start = start, + finish = finish - 1, + text = '=', + } + } + } + end, + ErrEQ = function (start, finish) + PushError { + type = 'ERR_EQ_AS_ASSIGN', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_EQ_AS_ASSIGN', + { + start = start, + finish = finish - 1, + text = '==', + } + } + } + return '==' + end, + ErrUEQ = function (start, finish) + PushError { + type = 'ERR_UEQ', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_UEQ', + { + start = start, + finish = finish - 1, + text = '~=', + } + } + } + return '==' + end, + ErrThen = function (start, finish) + PushError { + type = 'ERR_THEN_AS_DO', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_THEN_AS_DO', + { + start = start, + finish = finish - 1, + text = 'then', + } + } + } + return start, finish + end, + ErrDo = function (start, finish) + PushError { + type = 'ERR_DO_AS_THEN', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_DO_AS_THEN', + { + start = start, + finish = finish - 1, + text = 'do', + } + } + } + return start, finish + end, +} + +local function init(state) + State = state + PushError = state.pushError + PushDiag = state.pushDiag + PushComment = state.pushComment +end + +local function close() + State = nil + PushError = function (...) end + PushDiag = function (...) end + PushComment = function (...) end +end + +return { + defs = Defs, + init = init, + close = close, +} diff --git a/script/parser/calcline.lua b/script/parser/calcline.lua new file mode 100644 index 00000000..2e944167 --- /dev/null +++ b/script/parser/calcline.lua @@ -0,0 +1,94 @@ +local m = require 'lpeglabel' +local util = require 'utility' + +local row +local fl +local NL = (m.P'\r\n' + m.S'\r\n') * m.Cp() / function (pos) + row = row + 1 + fl = pos +end +local ROWCOL = (NL + m.P(1))^0 +local function rowcol(str, n) + row = 1 + fl = 1 + ROWCOL:match(str:sub(1, n)) + local col = n - fl + 1 + return row, col +end + +local function rowcol_utf8(str, n) + row = 1 + fl = 1 + ROWCOL:match(str:sub(1, n)) + return row, util.utf8Len(str, fl, n) +end + +local function position(str, _row, _col) + local cur = 1 + local row = 1 + while true do + if row == _row then + return cur + _col - 1 + elseif row > _row then + return cur - 1 + end + local pos = str:find('[\r\n]', cur) + if not pos then + return #str + end + row = row + 1 + if str:sub(pos, pos+1) == '\r\n' then + cur = pos + 2 + else + cur = pos + 1 + end + end +end + +local function position_utf8(str, _row, _col) + local cur = 1 + local row = 1 + while true do + if row == _row then + return utf8.offset(str, _col, cur) + elseif row > _row then + return cur - 1 + end + local pos = str:find('[\r\n]', cur) + if not pos then + return #str + end + row = row + 1 + if str:sub(pos, pos+1) == '\r\n' then + cur = pos + 2 + else + cur = pos + 1 + end + end +end + +local NL = m.P'\r\n' + m.S'\r\n' + +local function line(str, row) + local count = 0 + local res + local LINE = m.Cmt((1 - NL)^0, function (_, _, c) + count = count + 1 + if count == row then + res = c + return false + end + return true + end) + local MATCH = (LINE * NL)^0 * LINE + MATCH:match(str) + return res +end + +return { + rowcol = rowcol, + rowcol_utf8 = rowcol_utf8, + position = position, + position_utf8 = position_utf8, + line = line, +} diff --git a/script/parser/compile.lua b/script/parser/compile.lua new file mode 100644 index 00000000..2c7172e8 --- /dev/null +++ b/script/parser/compile.lua @@ -0,0 +1,561 @@ +local guide = require 'parser.guide' +local type = type + +local specials = { + ['_G'] = true, + ['rawset'] = true, + ['rawget'] = true, + ['setmetatable'] = true, + ['require'] = true, + ['dofile'] = true, + ['loadfile'] = true, + ['pcall'] = true, + ['xpcall'] = true, +} + +_ENV = nil + +local LocalLimit = 200 +local pushError, Compile, CompileBlock, Block, GoToTag, ENVMode, Compiled, LocalCount, Version, Root, Options + +local function addRef(node, obj) + if not node.ref then + node.ref = {} + end + node.ref[#node.ref+1] = obj + obj.node = node +end + +local function addSpecial(name, obj) + if not Root.specials then + Root.specials = {} + end + if not Root.specials[name] then + Root.specials[name] = {} + end + Root.specials[name][#Root.specials[name]+1] = obj + obj.special = name +end + +local vmMap = { + ['getname'] = function (obj) + local loc = guide.getLocal(obj, obj[1], obj.start) + if loc then + obj.type = 'getlocal' + obj.loc = loc + addRef(loc, obj) + if loc.special then + addSpecial(loc.special, obj) + end + else + obj.type = 'getglobal' + local node = guide.getLocal(obj, ENVMode, obj.start) + if node then + addRef(node, obj) + end + local name = obj[1] + if specials[name] then + addSpecial(name, obj) + elseif Options and Options.special then + local asName = Options.special[name] + if specials[asName] then + addSpecial(asName, obj) + end + end + end + return obj + end, + ['getfield'] = function (obj) + Compile(obj.node, obj) + end, + ['call'] = function (obj) + Compile(obj.node, obj) + Compile(obj.args, obj) + end, + ['callargs'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['binary'] = function (obj) + Compile(obj[1], obj) + Compile(obj[2], obj) + end, + ['unary'] = function (obj) + Compile(obj[1], obj) + end, + ['varargs'] = function (obj) + local func = guide.getParentFunction(obj) + if func then + local index, vararg = guide.getFunctionVarArgs(func) + if not index then + pushError { + type = 'UNEXPECT_DOTS', + start = obj.start, + finish = obj.finish, + } + end + if vararg then + if not vararg.ref then + vararg.ref = {} + end + vararg.ref[#vararg.ref+1] = obj + end + end + end, + ['paren'] = function (obj) + Compile(obj.exp, obj) + end, + ['getindex'] = function (obj) + Compile(obj.node, obj) + Compile(obj.index, obj) + end, + ['setindex'] = function (obj) + Compile(obj.node, obj) + Compile(obj.index, obj) + Compile(obj.value, obj) + end, + ['getmethod'] = function (obj) + Compile(obj.node, obj) + Compile(obj.method, obj) + end, + ['setmethod'] = function (obj) + Compile(obj.node, obj) + Compile(obj.method, obj) + local value = obj.value + value.localself = { + type = 'local', + start = 0, + finish = 0, + method = obj, + effect = obj.finish, + tag = 'self', + [1] = 'self', + } + Compile(value, obj) + end, + ['function'] = function (obj) + local lastBlock = Block + local LastLocalCount = LocalCount + Block = obj + LocalCount = 0 + if obj.localself then + Compile(obj.localself, obj) + obj.localself = nil + end + Compile(obj.args, obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + Block = lastBlock + LocalCount = LastLocalCount + end, + ['funcargs'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['table'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['tablefield'] = function (obj) + Compile(obj.value, obj) + end, + ['tableindex'] = function (obj) + Compile(obj.index, obj) + Compile(obj.value, obj) + end, + ['index'] = function (obj) + Compile(obj.index, obj) + end, + ['select'] = function (obj) + local vararg = obj.vararg + if vararg.parent then + if not vararg.extParent then + vararg.extParent = {} + end + vararg.extParent[#vararg.extParent+1] = obj + else + Compile(vararg, obj) + end + end, + ['setname'] = function (obj) + Compile(obj.value, obj) + local loc = guide.getLocal(obj, obj[1], obj.start) + if loc then + obj.type = 'setlocal' + obj.loc = loc + addRef(loc, obj) + if loc.attrs then + local const + for i = 1, #loc.attrs do + local attr = loc.attrs[i][1] + if attr == 'const' + or attr == 'close' then + const = true + break + end + end + if const then + pushError { + type = 'SET_CONST', + start = obj.start, + finish = obj.finish, + } + end + end + else + obj.type = 'setglobal' + local node = guide.getLocal(obj, ENVMode, obj.start) + if node then + addRef(node, obj) + end + local name = obj[1] + if specials[name] then + addSpecial(name, obj) + elseif Options and Options.special then + local asName = Options.special[name] + if specials[asName] then + addSpecial(asName, obj) + end + end + end + end, + ['local'] = function (obj) + local attrs = obj.attrs + if attrs then + for i = 1, #attrs do + Compile(attrs[i], obj) + end + end + if Block then + if not Block.locals then + Block.locals = {} + end + Block.locals[#Block.locals+1] = obj + LocalCount = LocalCount + 1 + if LocalCount > LocalLimit then + pushError { + type = 'LOCAL_LIMIT', + start = obj.start, + finish = obj.finish, + } + end + end + if obj.localfunction then + obj.localfunction = nil + end + Compile(obj.value, obj) + if obj.value and obj.value.special then + addSpecial(obj.value.special, obj) + end + end, + ['setfield'] = function (obj) + Compile(obj.node, obj) + Compile(obj.value, obj) + end, + ['do'] = function (obj) + local lastBlock = Block + Block = obj + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['return'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + if Block and Block[#Block] ~= obj then + pushError { + type = 'ACTION_AFTER_RETURN', + start = obj.start, + finish = obj.finish, + } + end + local func = guide.getParentFunction(obj) + if func then + if not func.returns then + func.returns = {} + end + func.returns[#func.returns+1] = obj + end + end, + ['label'] = function (obj) + local block = guide.getBlock(obj) + if block then + if not block.labels then + block.labels = {} + end + local name = obj[1] + local label = guide.getLabel(block, name) + if label then + if Version == 'Lua 5.4' + or block == guide.getBlock(label) then + pushError { + type = 'REDEFINED_LABEL', + start = obj.start, + finish = obj.finish, + relative = { + { + label.start, + label.finish, + } + } + } + end + end + block.labels[name] = obj + end + end, + ['goto'] = function (obj) + GoToTag[#GoToTag+1] = obj + end, + ['if'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['ifblock'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.filter, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['elseifblock'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.filter, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['elseblock'] = function (obj) + local lastBlock = Block + Block = obj + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['loop'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.loc, obj) + Compile(obj.max, obj) + Compile(obj.step, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['in'] = function (obj) + local lastBlock = Block + Block = obj + local keys = obj.keys + for i = 1, #keys do + Compile(keys[i], obj) + end + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['while'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.filter, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['repeat'] = function (obj) + local lastBlock = Block + Block = obj + CompileBlock(obj, obj) + Compile(obj.filter, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['break'] = function (obj) + local block = guide.getBreakBlock(obj) + if block then + if not block.breaks then + block.breaks = {} + end + block.breaks[#block.breaks+1] = obj + else + pushError { + type = 'BREAK_OUTSIDE', + start = obj.start, + finish = obj.finish, + } + end + end, + ['main'] = function (obj) + Block = obj + Compile({ + type = 'local', + start = 0, + finish = 0, + effect = 0, + tag = '_ENV', + special= '_G', + [1] = ENVMode, + }, obj) + --- _ENV 是上值,不计入局部变量计数 + LocalCount = 0 + CompileBlock(obj, obj) + Block = nil + end, +} + +function CompileBlock(obj, parent) + for i = 1, #obj do + local act = obj[i] + local f = vmMap[act.type] + if f then + act.parent = parent + f(act) + end + end +end + +function Compile(obj, parent) + if not obj then + return nil + end + if Compiled[obj] then + return + end + Compiled[obj] = true + obj.parent = parent + local f = vmMap[obj.type] + if not f then + return + end + f(obj) +end + +local function compileGoTo(obj) + local name = obj[1] + local label = guide.getLabel(obj, name) + if not label then + pushError { + type = 'NO_VISIBLE_LABEL', + start = obj.start, + finish = obj.finish, + info = { + label = name, + } + } + return + end + if not label.ref then + label.ref = {} + end + label.ref[#label.ref+1] = obj + obj.node = label + + -- 如果有局部变量在 goto 与 label 之间声明, + -- 并在 label 之后使用,则算作语法错误 + + -- 如果 label 在 goto 之前声明,那么不会有中间声明的局部变量 + if obj.start > label.start then + return + end + + local block = guide.getBlock(obj) + local locals = block and block.locals + if not locals then + return + end + + for i = 1, #locals do + local loc = locals[i] + -- 检查局部变量声明位置为 goto 与 label 之间 + if loc.start < obj.start or loc.finish > label.finish then + goto CONTINUE + end + -- 检查局部变量的使用位置在 label 之后 + local refs = loc.ref + if not refs then + goto CONTINUE + end + for j = 1, #refs do + local ref = refs[j] + if ref.finish > label.finish then + pushError { + type = 'JUMP_LOCAL_SCOPE', + start = obj.start, + finish = obj.finish, + info = { + loc = loc[1], + }, + relative = { + { + start = label.start, + finish = label.finish, + }, + { + start = loc.start, + finish = loc.finish, + } + }, + } + return + end + end + ::CONTINUE:: + end +end + +local function PostCompile() + for i = 1, #GoToTag do + compileGoTo(GoToTag[i]) + end +end + +return function (self, lua, mode, version, options) + local state, err = self:parse(lua, mode, version) + if not state then + return nil, err + end + pushError = state.pushError + if version == 'Lua 5.1' or version == 'LuaJIT' then + ENVMode = '@fenv' + else + ENVMode = '_ENV' + end + Compiled = {} + GoToTag = {} + LocalCount = 0 + Version = version + Root = state.ast + Root.state = state + Options = options + state.ENVMode = ENVMode + if type(state.ast) == 'table' then + Compile(state.ast) + end + PostCompile() + Compiled = nil + GoToTag = nil + return state +end diff --git a/script/parser/grammar.lua b/script/parser/grammar.lua new file mode 100644 index 00000000..06dae246 --- /dev/null +++ b/script/parser/grammar.lua @@ -0,0 +1,538 @@ +local re = require 'parser.relabel' +local m = require 'lpeglabel' +local ast = require 'parser.ast' + +local scriptBuf = '' +local compiled = {} +local defs = ast.defs + +-- goto 可以作为名字,合法性之后处理 +local RESERVED = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +defs.nl = (m.P'\r\n' + m.S'\r\n') +defs.s = m.S' \t' +defs.S = - defs.s +defs.ea = '\a' +defs.eb = '\b' +defs.ef = '\f' +defs.en = '\n' +defs.er = '\r' +defs.et = '\t' +defs.ev = '\v' +defs['nil'] = m.Cp() / function () return nil end +defs['false'] = m.Cp() / function () return false end +defs.NotReserved = function (_, _, str) + if RESERVED[str] then + return false + end + return true +end +defs.Reserved = function (_, _, str) + if RESERVED[str] then + return true + end + return false +end +defs.None = function () end +defs.np = m.Cp() / function (n) return n+1 end + +m.setmaxstack(1000) + +local eof = re.compile '!. / %{SYNTAX_ERROR}' + +local function grammar(tag) + return function (script) + scriptBuf = script .. '\r\n' .. scriptBuf + compiled[tag] = re.compile(scriptBuf, defs) * eof + end +end + +local function errorpos(pos, err) + return { + type = 'UNKNOWN', + start = pos or 0, + finish = pos or 0, + err = err, + } +end + +grammar 'Comment' [[ +Comment <- LongComment + / '--' ShortComment +LongComment <- ('--[' {} {:eq: '='* :} {} '[' + {(!CommentClose .)*} + (CommentClose / {})) + -> LongComment + / ( + {} '/*' {} + (!'*/' .)* + {} '*/' {} + ) + -> CLongComment +CommentClose <- ']' =eq ']' +ShortComment <- ({} {(!%nl .)*} {}) + -> ShortComment +]] + +grammar 'Sp' [[ +Sp <- (Comment / %nl / %s)* +Sps <- (Comment / %nl / %s)+ +]] + +grammar 'Common' [[ +Word <- [a-zA-Z0-9_] +Cut <- !Word +X16 <- [a-fA-F0-9] +Rest <- (!%nl .)* + +AND <- Sp {'and'} Cut +BREAK <- Sp 'break' Cut +FALSE <- Sp 'false' Cut +GOTO <- Sp 'goto' Cut +LOCAL <- Sp 'local' Cut +NIL <- Sp 'nil' Cut +NOT <- Sp 'not' Cut +OR <- Sp {'or'} Cut +RETURN <- Sp 'return' Cut +TRUE <- Sp 'true' Cut + +DO <- Sp {} 'do' {} Cut + / Sp({} 'then' {} Cut) -> ErrDo +IF <- Sp {} 'if' {} Cut +ELSE <- Sp {} 'else' {} Cut +ELSEIF <- Sp {} 'elseif' {} Cut +END <- Sp {} 'end' {} Cut +FOR <- Sp {} 'for' {} Cut +FUNCTION <- Sp {} 'function' {} Cut +IN <- Sp {} 'in' {} Cut +REPEAT <- Sp {} 'repeat' {} Cut +THEN <- Sp {} 'then' {} Cut + / Sp({} 'do' {} Cut) -> ErrThen +UNTIL <- Sp {} 'until' {} Cut +WHILE <- Sp {} 'while' {} Cut + + +Esc <- '\' -> '' + EChar +EChar <- 'a' -> ea + / 'b' -> eb + / 'f' -> ef + / 'n' -> en + / 'r' -> er + / 't' -> et + / 'v' -> ev + / '\' + / '"' + / "'" + / %nl + / ('z' (%nl / %s)*) -> '' + / ({} 'x' {X16 X16}) -> Char16 + / ([0-9] [0-9]? [0-9]?) -> Char10 + / ('u{' {} {Word*} '}') -> CharUtf8 + -- 错误处理 + / 'x' {} -> MissEscX + / 'u' !'{' {} -> MissTL + / 'u{' Word* !'}' {} -> MissTR + / {} -> ErrEsc + +BOR <- Sp {'|'} +BXOR <- Sp {'~'} !'=' +BAND <- Sp {'&'} +Bshift <- Sp {BshiftList} +BshiftList <- '<<' + / '>>' +Concat <- Sp {'..'} +Adds <- Sp {AddsList} +AddsList <- '+' + / '-' +Muls <- Sp {MulsList} +MulsList <- '*' + / '//' + / '/' + / '%' +Unary <- Sp {} {UnaryList} +UnaryList <- NOT + / '#' + / '-' + / '~' !'=' +POWER <- Sp {'^'} + +BinaryOp <-( Sp {} {'or'} Cut + / Sp {} {'and'} Cut + / Sp {} {'<=' / '>=' / '<'!'<' / '>'!'>' / '~=' / '=='} + / Sp {} ({} '=' {}) -> ErrEQ + / Sp {} ({} '!=' {}) -> ErrUEQ + / Sp {} {'|'} + / Sp {} {'~'} + / Sp {} {'&'} + / Sp {} {'<<' / '>>'} + / Sp {} {'..'} !'.' + / Sp {} {'+' / '-'} + / Sp {} {'*' / '//' / '/' / '%'} + / Sp {} {'^'} + )-> BinaryOp +UnaryOp <-( Sp {} {'not' Cut / '#' / '~' !'=' / '-' !'-'} + )-> UnaryOp + +PL <- Sp '(' +PR <- Sp ')' +BL <- Sp '[' !'[' !'=' +BR <- Sp ']' +TL <- Sp '{' +TR <- Sp '}' +COMMA <- Sp ({} ',') + -> COMMA +SEMICOLON <- Sp ({} ';') + -> SEMICOLON +DOTS <- Sp ({} '...') + -> DOTS +DOT <- Sp ({} '.' !'.') + -> DOT +COLON <- Sp ({} ':' !':') + -> COLON +LABEL <- Sp '::' +ASSIGN <- Sp '=' !'=' +AssignOrEQ <- Sp ({} '==' {}) + -> ErrAssign + / Sp '=' + +DirtyBR <- BR / {} -> MissBR +DirtyTR <- TR / {} -> MissTR +DirtyPR <- PR / {} -> MissPR +DirtyLabel <- LABEL / {} -> MissLabel +NeedEnd <- END / {} -> MissEnd +NeedDo <- DO / {} -> MissDo +NeedAssign <- ASSIGN / {} -> MissAssign +NeedComma <- COMMA / {} -> MissComma +NeedIn <- IN / {} -> MissIn +NeedUntil <- UNTIL / {} -> MissUntil +NeedThen <- THEN / {} -> MissThen +]] + +grammar 'Nil' [[ +Nil <- Sp ({} -> Nil) NIL +]] + +grammar 'Boolean' [[ +Boolean <- Sp ({} -> True) TRUE + / Sp ({} -> False) FALSE +]] + +grammar 'String' [[ +String <- Sp ({} StringDef {}) + -> String +StringDef <- {'"'} + {~(Esc / !%nl !'"' .)*~} -> 1 + ('"' / {} -> MissQuote1) + / {"'"} + {~(Esc / !%nl !"'" .)*~} -> 1 + ("'" / {} -> MissQuote2) + / ('[' {} {:eq: '='* :} {} '[' %nl? + {(!StringClose .)*} -> 1 + (StringClose / {})) + -> LongString +StringClose <- ']' =eq ']' +]] + +grammar 'Number' [[ +Number <- Sp ({} {NumberDef} {}) -> Number + NumberSuffix? + ErrNumber? +NumberDef <- Number16 / Number10 +NumberSuffix<- ({} {[uU]? [lL] [lL]}) -> FFINumber + / ({} {[iI]}) -> ImaginaryNumber +ErrNumber <- ({} {([0-9a-zA-Z] / '.')+}) -> UnknownSymbol + +Number10 <- Float10 Float10Exp? + / Integer10 Float10? Float10Exp? +Integer10 <- [0-9]+ ('.' [0-9]*)? +Float10 <- '.' [0-9]+ +Float10Exp <- [eE] [+-]? [0-9]+ + / ({} [eE] [+-]? {}) -> MissExponent + +Number16 <- '0' [xX] Float16 Float16Exp? + / '0' [xX] Integer16 Float16? Float16Exp? +Integer16 <- X16+ ('.' X16*)? + / ({} {Word*}) -> MustX16 +Float16 <- '.' X16+ + / '.' ({} {Word*}) -> MustX16 +Float16Exp <- [pP] [+-]? [0-9]+ + / ({} [pP] [+-]? {}) -> MissExponent +]] + +grammar 'Name' [[ +Name <- Sp ({} NameBody {}) + -> Name +NameBody <- {[a-zA-Z_] [a-zA-Z0-9_]*} +FreeName <- Sp ({} {NameBody=>NotReserved} {}) + -> Name +KeyWord <- Sp NameBody=>Reserved +MustName <- Name / DirtyName +DirtyName <- {} -> DirtyName +]] + +grammar 'Exp' [[ +Exp <- (UnUnit BinUnit*) + -> Binary +BinUnit <- (BinaryOp UnUnit?) + -> SubBinary +UnUnit <- ExpUnit + / (UnaryOp+ (ExpUnit / MissExp)) + -> Unary +ExpUnit <- Nil + / Boolean + / String + / Number + / Dots + / Table + / Function + / Simple + +Simple <- {| Prefix (Sp Suffix)* |} + -> Simple +Prefix <- Sp ({} PL DirtyExp DirtyPR {}) + -> Paren + / Single +Single <- FreeName + -> Single +Suffix <- SuffixWithoutCall + / ({} PL SuffixCall DirtyPR {}) + -> Call +SuffixCall <- Sp ({} {| (COMMA / Exp)+ |} {}) + -> PackExpList + / %nil +SuffixWithoutCall + <- (DOT (Name / MissField)) + -> GetField + / ({} BL DirtyExp DirtyBR {}) + -> GetIndex + / (COLON (Name / MissMethod) NeedCall) + -> GetMethod + / ({} {| Table |} {}) + -> Call + / ({} {| String |} {}) + -> Call +NeedCall <- (!(Sp CallStart) {} -> MissPL)? +MissField <- {} -> MissField +MissMethod <- {} -> MissMethod +CallStart <- PL + / TL + / '"' + / "'" + / '[' '='* '[' + +DirtyExp <- Exp + / {} -> DirtyExp +MaybeExp <- Exp / MissExp +MissExp <- {} -> MissExp +ExpList <- Sp {| MaybeExp (Sp ',' MaybeExp)* |} + +Dots <- DOTS + -> VarArgs + +Table <- Sp ({} TL {| TableField* |} DirtyTR {}) + -> Table +TableField <- COMMA + / SEMICOLON + / NewIndex + / NewField + / Exp +Index <- BL DirtyExp DirtyBR +NewIndex <- Sp ({} Index NeedAssign DirtyExp {}) + -> NewIndex +NewField <- Sp ({} MustName ASSIGN DirtyExp {}) + -> NewField + +Function <- FunctionBody + -> Function +FuncArgs <- Sp ({} PL {| FuncArg+ |} DirtyPR {}) + -> FuncArgs + / PL DirtyPR %nil +FuncArgsMiss<- {} -> MissPL DirtyPR %nil +FuncArg <- DOTS + / Name + / COMMA +FunctionBody<- FUNCTION FuncArgs + {| (!END Action)* |} + NeedEnd + / FUNCTION FuncArgsMiss + {| %nil |} + NeedEnd + +-- 纯占位,修改了 `relabel.lua` 使重复定义不抛错 +Action <- !END . +]] + +grammar 'Action' [[ +Action <- Sp (CrtAction / UnkAction) +CrtAction <- Semicolon + / Do + / Break + / Return + / Label + / GoTo + / If + / For + / While + / Repeat + / NamedFunction + / LocalFunction + / Local + / Set + / Call + / ExpInAction +UnkAction <- ({} {Word+}) + -> UnknownAction + / ({} '//' {} (LongComment / ShortComment)) + -> CCommentPrefix + / ({} {. (!Sps !CrtAction .)*}) + -> UnknownAction +ExpInAction <- Sp ({} Exp {}) + -> ExpInAction + +Semicolon <- Sp ';' +SimpleList <- {| Simple (Sp ',' Simple)* |} + +Do <- Sp ({} + 'do' Cut + {| (!END Action)* |} + NeedEnd) + -> Do + +Break <- Sp ({} BREAK {}) + -> Break + +Return <- Sp ({} RETURN ReturnExpList {}) + -> Return +ReturnExpList + <- Sp {| Exp (Sp ',' MaybeExp)* |} + / Sp {| !Exp !',' |} + / ExpList + +Label <- Sp ({} LABEL MustName DirtyLabel {}) + -> Label + +GoTo <- Sp ({} GOTO MustName {}) + -> GoTo + +If <- Sp ({} {| IfHead IfBody* |} NeedEnd) + -> If + +IfHead <- Sp (IfPart {}) -> IfBlock + / Sp (ElseIfPart {}) -> ElseIfBlock + / Sp (ElsePart {}) -> ElseBlock +IfBody <- Sp (ElseIfPart {}) -> ElseIfBlock + / Sp (ElsePart {}) -> ElseBlock +IfPart <- IF DirtyExp NeedThen + {| (!ELSEIF !ELSE !END Action)* |} +ElseIfPart <- ELSEIF DirtyExp NeedThen + {| (!ELSEIF !ELSE !END Action)* |} +ElsePart <- ELSE + {| (!ELSEIF !ELSE !END Action)* |} + +For <- Loop / In + +Loop <- LoopBody + -> Loop +LoopBody <- FOR LoopArgs NeedDo + {} {| (!END Action)* |} + NeedEnd +LoopArgs <- MustName AssignOrEQ + ({} {| (COMMA / !DO !END Exp)* |} {}) + -> PackLoopArgs + +In <- InBody + -> In +InBody <- FOR InNameList NeedIn InExpList NeedDo + {} {| (!END Action)* |} + NeedEnd +InNameList <- ({} {| (COMMA / !IN !DO !END Name)* |} {}) + -> PackInNameList +InExpList <- ({} {| (COMMA / !DO !DO !END Exp)* |} {}) + -> PackInExpList + +While <- WhileBody + -> While +WhileBody <- WHILE DirtyExp NeedDo + {| (!END Action)* |} + NeedEnd + +Repeat <- (RepeatBody {}) + -> Repeat +RepeatBody <- REPEAT + {| (!UNTIL Action)* |} + NeedUntil DirtyExp + +LocalAttr <- {| (Sp '<' Sp MustName Sp LocalAttrEnd)+ |} + -> LocalAttr +LocalAttrEnd<- '>' / {} -> MissGT +Local <- Sp ({} LOCAL LocalNameList ((AssignOrEQ ExpList) / %nil) {}) + -> Local +Set <- Sp ({} SimpleList AssignOrEQ ExpList {}) + -> Set +LocalNameList + <- {| LocalName (Sp ',' LocalName)* |} +LocalName <- (MustName LocalAttr?) + -> LocalName + +Call <- Simple + -> SimpleCall + +LocalFunction + <- Sp ({} LOCAL FunctionNamedBody) + -> LocalFunction + +NamedFunction + <- FunctionNamedBody + -> NamedFunction +FunctionNamedBody + <- FUNCTION FuncName FuncArgs + {| (!END Action)* |} + NeedEnd + / FUNCTION FuncName FuncArgsMiss + {| %nil |} + NeedEnd +FuncName <- {| Single (Sp SuffixWithoutCall)* |} + -> Simple + / {} -> MissName %nil +]] + +grammar 'Lua' [[ +Lua <- Head? + ({} {| Action* |} {}) -> Lua + Sp +Head <- '#' (!%nl .)* +]] + +return function (self, lua, mode) + local gram = compiled[mode] or compiled['Lua'] + local r, _, pos = gram:match(lua) + if not r then + local err = errorpos(pos) + return nil, err + end + + return r +end diff --git a/script/parser/guide.lua b/script/parser/guide.lua new file mode 100644 index 00000000..6ef239f1 --- /dev/null +++ b/script/parser/guide.lua @@ -0,0 +1,3884 @@ +local util = require 'utility' +local error = error +local type = type +local next = next +local tostring = tostring +local print = print +local ipairs = ipairs +local tableInsert = table.insert +local tableUnpack = table.unpack +local tableRemove = table.remove +local tableMove = table.move +local tableSort = table.sort +local tableConcat = table.concat +local mathType = math.type +local pairs = pairs +local setmetatable = setmetatable +local assert = assert +local select = select +local osClock = os.clock +local DEVELOP = _G.DEVELOP +local log = log +local _G = _G + +local function logWarn(...) + log.warn(...) +end + +_ENV = nil + +local m = {} + +local blockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, + ['do'] = true, + ['function'] = true, + ['ifblock'] = true, + ['elseblock'] = true, + ['elseifblock'] = true, + ['main'] = true, +} + +local breakBlockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, +} + +m.childMap = { + ['main'] = {'#', 'docs'}, + ['repeat'] = {'#', 'filter'}, + ['while'] = {'filter', '#'}, + ['in'] = {'keys', '#'}, + ['loop'] = {'loc', 'max', 'step', '#'}, + ['if'] = {'#'}, + ['ifblock'] = {'filter', '#'}, + ['elseifblock'] = {'filter', '#'}, + ['elseblock'] = {'#'}, + ['setfield'] = {'node', 'field', 'value'}, + ['setglobal'] = {'value'}, + ['local'] = {'attrs', 'value'}, + ['setlocal'] = {'value'}, + ['return'] = {'#'}, + ['do'] = {'#'}, + ['select'] = {'vararg'}, + ['table'] = {'#'}, + ['tableindex'] = {'index', 'value'}, + ['tablefield'] = {'field', 'value'}, + ['function'] = {'args', '#'}, + ['funcargs'] = {'#'}, + ['setmethod'] = {'node', 'method', 'value'}, + ['getmethod'] = {'node', 'method'}, + ['setindex'] = {'node', 'index', 'value'}, + ['getindex'] = {'node', 'index'}, + ['paren'] = {'exp'}, + ['call'] = {'node', 'args'}, + ['callargs'] = {'#'}, + ['getfield'] = {'node', 'field'}, + ['list'] = {'#'}, + ['binary'] = {1, 2}, + ['unary'] = {1}, + + ['doc'] = {'#'}, + ['doc.class'] = {'class', 'extends'}, + ['doc.type'] = {'#types', '#enums', 'name'}, + ['doc.alias'] = {'alias', 'extends'}, + ['doc.param'] = {'param', 'extends'}, + ['doc.return'] = {'#returns'}, + ['doc.field'] = {'field', 'extends'}, + ['doc.generic'] = {'#generics'}, + ['doc.generic.object'] = {'generic', 'extends'}, + ['doc.vararg'] = {'vararg'}, + ['doc.type.table'] = {'key', 'value'}, + ['doc.type.function'] = {'#args', '#returns'}, + ['doc.overload'] = {'overload'}, +} + +m.actionMap = { + ['main'] = {'#'}, + ['repeat'] = {'#'}, + ['while'] = {'#'}, + ['in'] = {'#'}, + ['loop'] = {'#'}, + ['if'] = {'#'}, + ['ifblock'] = {'#'}, + ['elseifblock'] = {'#'}, + ['elseblock'] = {'#'}, + ['do'] = {'#'}, + ['function'] = {'#'}, + ['funcargs'] = {'#'}, +} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['nil'] = 999, +} + +local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) + +--- 是否是字面量 +function m.isLiteral(obj) + local tp = obj.type + return tp == 'nil' + or tp == 'boolean' + or tp == 'string' + or tp == 'number' + or tp == 'table' +end + +--- 获取字面量 +function m.getLiteral(obj) + local tp = obj.type + if tp == 'boolean' then + return obj[1] + elseif tp == 'string' then + return obj[1] + elseif tp == 'number' then + return obj[1] + end + return nil +end + +--- 寻找父函数 +function m.getParentFunction(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + break + end + local tp = obj.type + if tp == 'function' or tp == 'main' then + return obj + end + end + return nil +end + +--- 寻找所在区块 +function m.getBlock(obj) + for _ = 1, 1000 do + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + obj = obj.parent + end + error('guide.getBlock overstack') +end + +--- 寻找所在父区块 +function m.getParentBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + end + error('guide.getParentBlock overstack') +end + +--- 寻找所在可break的父区块 +function m.getBreakBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if breakBlockTypes[tp] then + return obj + end + if tp == 'function' then + return nil + end + end + error('guide.getBreakBlock overstack') +end + +--- 寻找doc的主体 +function m.getDocState(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + obj = parent + end + error('guide.getDocState overstack') +end + +--- 寻找所在父类型 +function m.getParentType(obj, want) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + if want == obj.type then + return obj + end + end + error('guide.getParentType overstack') +end + +--- 寻找根区块 +function m.getRoot(obj) + for _ = 1, 1000 do + if obj.type == 'main' then + return obj + end + local parent = obj.parent + if not parent then + return nil + end + obj = parent + end + error('guide.getRoot overstack') +end + +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = m.getRoot(obj) + if root then + return root.uri + end + return '' +end + +function m.getENV(source, start) + if not start then + start = 1 + end + return m.getLocal(source, '_ENV', start) + or m.getLocal(source, '@fenv', start) +end + +--- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。 +--- 如果函数是主函数,则返回`0, nil`。 +---@return table +---@return integer +function m.getFunctionVarArgs(func) + if func.type == 'main' then + return 0, nil + end + if func.type ~= 'function' then + return nil, nil + end + local args = func.args + if not args then + return nil, nil + end + for i = 1, #args do + local arg = args[i] + if arg.type == '...' then + return i, arg + end + end + return nil, nil +end + +--- 获取指定区块中可见的局部变量 +---@param block table +---@param name string {comment = '变量名'} +---@param pos integer {comment = '可见位置'} +function m.getLocal(block, name, pos) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local locals = block.locals + local res + if not locals then + goto CONTINUE + end + for i = 1, #locals do + local loc = locals[i] + if loc.effect > pos then + break + end + if loc[1] == name then + if not res or res.effect < loc.effect then + res = loc + end + end + end + if res then + return res, res + end + ::CONTINUE:: + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +--- 获取指定区块中所有的可见局部变量名称 +function m.getVisibleLocals(block, pos) + local result = {} + m.eachSourceContain(m.getRoot(block), pos, function (source) + local locals = source.locals + if locals then + for i = 1, #locals do + local loc = locals[i] + local name = loc[1] + if loc.effect <= pos then + result[name] = loc + end + end + end + end) + return result +end + +--- 获取指定区块中可见的标签 +---@param block table +---@param name string {comment = '标签名'} +function m.getLabel(block, name) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local labels = block.labels + if labels then + local label = labels[name] + if label then + return label + end + end + if block.type == 'function' then + return nil + end + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +function m.getStartFinish(source) + local start = source.start + local finish = source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.start + finish = last.finish + end + return start, finish +end + +function m.getRange(source) + local start = source.vstart or source.start + local finish = source.range or source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.vstart or first.start + finish = last.range or last.finish + end + return start, finish +end + +--- 判断source是否包含offset +function m.isContain(source, offset) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +--- 判断offset在source的影响范围内 +--- +--- 主要针对赋值等语句时,key包含value +function m.isInRange(source, offset) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +function m.isBetween(source, tStart, tFinish) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +function m.isBetweenRange(source, tStart, tFinish) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +--- 添加child +function m.addChilds(list, obj, map) + local keys = map[obj.type] + if keys then + for i = 1, #keys do + local key = keys[i] + if key == '#' then + for i = 1, #obj do + list[#list+1] = obj[i] + end + elseif obj[key] then + list[#list+1] = obj[key] + elseif type(key) == 'string' + and key:sub(1, 1) == '#' then + key = key:sub(2) + for i = 1, #obj[key] do + list[#list+1] = obj[key][i] + end + end + end + end +end + +--- 遍历所有包含offset的source +function m.eachSourceContain(ast, offset, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isInRange(obj, offset) then + if m.isContain(obj, offset) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有在某个范围内的source +function m.eachSourceBetween(ast, start, finish, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isBetweenRange(obj, start, finish) then + if m.isBetween(obj, start, finish) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有指定类型的source +function m.eachSourceType(ast, type, callback) + local cache = ast.typeCache + if not cache then + cache = {} + ast.typeCache = cache + m.eachSource(ast, function (source) + local tp = source.type + if not tp then + return + end + local myCache = cache[tp] + if not myCache then + myCache = {} + cache[tp] = myCache + end + myCache[#myCache+1] = source + end) + end + local myCache = cache[type] + if not myCache then + return + end + for i = 1, #myCache do + callback(myCache[i]) + end +end + +--- 遍历所有的source +function m.eachSource(ast, callback) + local list = { ast } + local mark = {} + local index = 1 + while true do + local obj = list[index] + if not obj then + return + end + list[index] = false + index = index + 1 + if not mark[obj] then + mark[obj] = true + callback(obj) + m.addChilds(list, obj, m.childMap) + end + end +end + +--- 获取指定的 special +function m.eachSpecialOf(ast, name, callback) + local root = m.getRoot(ast) + if not root.specials then + return + end + local specials = root.specials[name] + if not specials then + return + end + for i = 1, #specials do + callback(specials[i]) + end +end + +--- 获取偏移对应的坐标 +---@param lines table +---@return integer {name = 'row'} +---@return integer {name = 'col'} +function m.positionOf(lines, offset) + if offset < 1 then + return 0, 0 + end + local lastLine = lines[#lines] + if offset > lastLine.finish then + return #lines, lastLine.finish - lastLine.start + 1 + end + local min = 1 + local max = #lines + for _ = 1, 100 do + if max <= min then + local line = lines[min] + return min, offset - line.start + 1 + end + local row = (max - min) // 2 + min + local line = lines[row] + if offset < line.start then + max = row - 1 + elseif offset > line.finish then + min = row + 1 + else + return row, offset - line.start + 1 + end + end + error('Stack overflow!') +end + +--- 获取坐标对应的偏移 +---@param lines table +---@param row integer +---@param col integer +---@return integer {name = 'offset'} +function m.offsetOf(lines, row, col) + if row < 1 then + return 0 + end + if row > #lines then + local lastLine = lines[#lines] + return lastLine.finish + end + local line = lines[row] + local len = line.finish - line.start + 1 + if col < 0 then + return line.start + elseif col > len then + return line.finish + else + return line.start + col - 1 + end +end + +function m.lineContent(lines, text, row, ignoreNL) + local line = lines[row] + if not line then + return '' + end + if ignoreNL then + return text:sub(line.start, line.range) + else + return text:sub(line.start, line.finish) + end +end + +function m.lineRange(lines, row, ignoreNL) + local line = lines[row] + if not line then + return 0, 0 + end + if ignoreNL then + return line.start, line.range + else + return line.start, line.finish + end +end + +function m.getNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'string' then + return obj[1] + end + return nil +end + +function m.getName(obj) + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + return obj.field and obj.field[1] + elseif tp == 'getmethod' + or tp == 'setmethod' then + return obj.method and obj.method[1] + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' then + return obj[1] + elseif tp == 'doc.class' then + return obj.class[1] + elseif tp == 'doc.alias' then + return obj.alias[1] + elseif tp == 'doc.field' then + return obj.field[1] + end + return m.getNameOfLiteral(obj) +end + +function m.getKeyNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return 's|' .. obj[1] + elseif tp == 'string' then + local s = obj[1] + if s then + return 's|' .. s + else + return 's' + end + elseif tp == 'number' then + local n = obj[1] + if n then + return ('n|%s'):format(util.viewLiteral(obj[1])) + else + return 'n' + end + elseif tp == 'boolean' then + local b = obj[1] + if b then + return 'b|' .. tostring(b) + else + return 'b' + end + end + return nil +end + +function m.getKeyName(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return 's|' .. obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return 'l|' .. obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + if obj.field then + return 's|' .. obj.field[1] + end + elseif tp == 'getmethod' + or tp == 'setmethod' then + if obj.method then + return 's|' .. obj.method[1] + end + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' then + return 's|' .. obj[1] + elseif tp == 'doc.class' then + return 's|' .. obj.class[1] + elseif tp == 'doc.alias' then + return 's|' .. obj.alias[1] + elseif tp == 'doc.field' then + return 's|' .. obj.field[1] + end + return m.getKeyNameOfLiteral(obj) +end + +function m.getSimpleName(obj) + if obj.type == 'call' then + local key = obj.args and obj.args[2] + return m.getKeyName(key) + elseif obj.type == 'table' then + return ('t|%p'):format(obj) + elseif obj.type == 'select' then + return ('v|%p'):format(obj) + elseif obj.type == 'string' then + return ('z|%p'):format(obj) + elseif obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' then + return ('c|%s'):format(obj[1]) + elseif obj.type == 'doc.class' then + return ('c|%s'):format(obj.class[1]) + end + return m.getKeyName(obj) +end + +--- 测试 a 到 b 的路径(不经过函数,不考虑 goto), +--- 每个路径是一个 block 。 +--- +--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>` +--- +--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>` +--- +--- 否则返回 `false` +--- +--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 +---@param a table +---@param b table +---@return string|boolean mode +---@return table pathA? +---@return table pathB? +function m.getPath(a, b, sameFunction) + --- 首先测试双方在同一个函数内 + if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then + return false + end + local mode + local objA + local objB + if a.finish < b.start then + mode = 'before' + objA = a + objB = b + elseif a.start > b.finish then + mode = 'after' + objA = b + objB = a + else + return 'equal', {}, {} + end + local pathA = {} + local pathB = {} + for _ = 1, 1000 do + objA = m.getParentBlock(objA) + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end + end + for _ = 1, 1000 do + objB = m.getParentBlock(objB) + pathB[#pathB+1] = objB + if (not sameFunction and objA.type == 'function') or objB.type == 'main' then + break + end + end + -- pathA: {1, 2, 3, 4, 5} + -- pathB: {5, 6, 2, 3} + local top = #pathB + local start + for i = #pathA, 1, -1 do + local currentBlock = pathA[i] + if currentBlock == pathB[top] then + start = i + break + end + end + if not start then + return nil + end + -- pathA: { 1, 2, 3} + -- pathB: {5, 6, 2, 3} + local extra = 0 + local align = top - start + for i = start, 1, -1 do + local currentA = pathA[i] + local currentB = pathB[i+align] + if currentA ~= currentB then + extra = i + break + end + end + -- pathA: {1} + local resultA = {} + for i = extra, 1, -1 do + resultA[#resultA+1] = pathA[i] + end + -- pathB: {5, 6} + local resultB = {} + for i = extra + align, 1, -1 do + resultB[#resultB+1] = pathB[i] + end + return mode, resultA, resultB +end + +-- 根据语法,单步搜索定义 +local function stepRefOfLocal(loc, mode) + local results = {} + if loc.start ~= 0 then + results[#results+1] = loc + end + local refs = loc.ref + if not refs then + return results + end + for i = 1, #refs do + local ref = refs[i] + if ref.start == 0 then + goto CONTINUE + end + if mode == 'def' then + if ref.type == 'local' + or ref.type == 'setlocal' then + results[#results+1] = ref + end + else + if ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + end + end + ::CONTINUE:: + end + return results +end + +local function stepRefOfLabel(label, mode) + local results = { label } + if not label or mode == 'def' then + return results + end + local refs = label.ref + for i = 1, #refs do + local ref = refs[i] + results[#results+1] = ref + end + return results +end + +local function stepRefOfDocType(status, obj, mode) + local results = {} + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.alias.name' + or obj.type == 'doc.extends.name' then + local name = obj[1] + if not name or not status.interface.docType then + return results + end + local docs = status.interface.docType(name) + for i = 1, #docs do + local doc = docs[i] + if mode == 'def' then + if doc.type == 'doc.class.name' + or doc.type == 'doc.alias.name' then + results[#results+1] = doc + end + else + results[#results+1] = doc + end + end + else + results[#results+1] = obj + end + return results +end + +function m.getStepRef(status, obj, mode) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepRefOfLocal(obj.node, mode) + end + if obj.type == 'local' then + return stepRefOfLocal(obj, mode) + end + if obj.type == 'label' then + return stepRefOfLabel(obj, mode) + end + if obj.type == 'goto' then + return stepRefOfLabel(obj.node, mode) + end + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.extends.name' + or obj.type == 'doc.alias.name' then + return stepRefOfDocType(status, obj, mode) + end + return nil +end + +-- 根据语法,单步搜索field +local function stepFieldOfLocal(loc) + local results = {} + local refs = loc.ref + for i = 1, #refs do + local ref = refs[i] + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'getlocal' then + local nxt = ref.next + if nxt then + if nxt.type == 'setfield' + or nxt.type == 'getfield' + or nxt.type == 'setmethod' + or nxt.type == 'getmethod' + or nxt.type == 'setindex' + or nxt.type == 'getindex' then + results[#results+1] = nxt + end + end + end + end + return results +end +local function stepFieldOfTable(tbl) + local result = {} + for i = 1, #tbl do + result[i] = tbl[i] + end + return result +end +function m.getStepField(obj) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepFieldOfLocal(obj.node) + end + if obj.type == 'local' then + return stepFieldOfLocal(obj) + end + if obj.type == 'table' then + return stepFieldOfTable(obj) + end +end + +local function convertSimpleList(list) + local simple = {} + for i = #list, 1, -1 do + local c = list[i] + if c.type == 'getglobal' + or c.type == 'setglobal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + local loc = c.node + if loc.special == '_G' then + simple.mode = 'global' + if not simple.node then + simple.node = c + end + else + simple.mode = 'local' + simple[#simple+1] = m.getSimpleName(loc) + if not simple.node then + simple.node = loc + end + end + elseif c.type == 'getlocal' + or c.type == 'setlocal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + simple.mode = 'local' + if not simple.node then + simple.node = c.node + end + elseif c.type == 'local' then + simple.mode = 'local' + if not simple.node then + simple.node = c + end + else + if not simple.node then + simple.node = c + end + end + simple[#simple+1] = m.getSimpleName(c) + ::CONTINUE:: + end + if simple.mode == 'global' and #simple == 0 then + simple[1] = 's|_G' + simple.node = list[#list] + end + return simple +end + +-- 搜索 `a.b.c` 的等价表达式 +local function buildSimpleList(obj, max) + local list = {} + local cur = obj + local limit = max and (max + 1) or 11 + for i = 1, max or limit do + if i == limit then + return nil + end + while cur.type == 'paren' do + cur = cur.exp + if not cur then + return nil + end + end + if cur.type == 'setfield' + or cur.type == 'getfield' + or cur.type == 'setmethod' + or cur.type == 'getmethod' + or cur.type == 'setindex' + or cur.type == 'getindex' then + list[i] = cur + cur = cur.node + elseif cur.type == 'tablefield' + or cur.type == 'tableindex' then + list[i] = cur + cur = cur.parent.parent + if cur.type == 'return' then + list[i+1] = list[i].parent + break + end + elseif cur.type == 'getlocal' + or cur.type == 'setlocal' + or cur.type == 'local' then + list[i] = cur + break + elseif cur.type == 'setglobal' + or cur.type == 'getglobal' then + list[i] = cur + break + elseif cur.type == 'select' + or cur.type == 'table' then + list[i] = cur + break + elseif cur.type == 'string' then + list[i] = cur + break + elseif cur.type == 'doc.class.name' + or cur.type == 'doc.type.name' + or cur.type == 'doc.class' then + list[i] = cur + break + elseif cur.type == 'function' + or cur.type == 'main' then + break + else + return nil + end + end + return convertSimpleList(list) +end + +function m.getSimple(obj, max) + local simpleList + if obj.type == 'getfield' + or obj.type == 'setfield' + or obj.type == 'getmethod' + or obj.type == 'setmethod' + or obj.type == 'getindex' + or obj.type == 'setindex' + or obj.type == 'local' + or obj.type == 'getlocal' + or obj.type == 'setlocal' + or obj.type == 'setglobal' + or obj.type == 'getglobal' + or obj.type == 'tablefield' + or obj.type == 'tableindex' + or obj.type == 'select' + or obj.type == 'table' + or obj.type == 'string' + or obj.type == 'doc.class.name' + or obj.type == 'doc.class' + or obj.type == 'doc.type.name' then + simpleList = buildSimpleList(obj, max) + elseif obj.type == 'field' + or obj.type == 'method' then + simpleList = buildSimpleList(obj.parent, max) + end + return simpleList +end + +function m.status(parentStatus, interface) + local status = { + cache = parentStatus and parentStatus.cache or { + count = 0, + }, + depth = parentStatus and (parentStatus.depth + 1) or 1, + interface = parentStatus and parentStatus.interface or {}, + locks = parentStatus and parentStatus.locks or {}, + deep = parentStatus and parentStatus.deep, + results = {}, + } + status.lock = status.locks[status.depth] or {} + status.locks[status.depth] = status.lock + if interface then + for k, v in pairs(interface) do + status.interface[k] = v + end + end + local searchDepth = status.interface.getSearchDepth and status.interface.getSearchDepth() or 0 + if status.depth >= searchDepth then + status.deep = false + end + return status +end + +function m.copyStatusResults(a, b) + local ra = a.results + local rb = b.results + for i = 1, #rb do + ra[#ra+1] = rb[i] + end +end + +function m.isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.type == 'getfield' + or source.type == 'setfield' then + local node = source.node + if node and node.special == '_G' then + return true + end + end + return false +end + +function m.isDoc(source) + return source.type:sub(1, 4) == 'doc.' +end + +--- 根据函数的调用参数,获取:调用,参数索引 +function m.getCallAndArgIndex(callarg) + local callargs = callarg.parent + if not callargs or callargs.type ~= 'callargs' then + return nil + end + local index + for i = 1, #callargs do + if callargs[i] == callarg then + index = i + break + end + end + local call = callargs.parent + return call, index +end + +--- 根据函数调用的返回值,获取:调用的函数,参数列表,自己是第几个返回值 +function m.getCallValue(source) + local value = m.getObjectValue(source) or source + if not value then + return + end + local call, index + if value.type == 'call' then + call = value + index = 1 + elseif value.type == 'select' then + call = value.vararg + index = value.index + if call.type ~= 'call' then + return + end + else + return + end + return call.node, call.args, index +end + +function m.getNextRef(ref) + local nextRef = ref.next + if nextRef then + if nextRef.type == 'setfield' + or nextRef.type == 'getfield' + or nextRef.type == 'setmethod' + or nextRef.type == 'getmethod' + or nextRef.type == 'setindex' + or nextRef.type == 'getindex' then + return nextRef + end + end + -- 穿透 rawget 与 rawset + local call, index = m.getCallAndArgIndex(ref) + if call then + if call.node.special == 'rawset' and index == 1 then + return call + end + if call.node.special == 'rawget' and index == 1 then + return call + end + end + -- doc.type.array + if ref.type == 'doc.type' then + local arrays = {} + for _, typeUnit in ipairs(ref.types) do + if typeUnit.type == 'doc.type.array' then + arrays[#arrays+1] = typeUnit.node + end + end + -- 返回一个 dummy + -- TODO 用弱表维护唯一性? + return { + type = 'doc.type', + start = ref.start, + finish = ref.finish, + types = arrays, + parent = ref.parent, + array = true, + enums = {}, + resumes = {}, + } + end + + return nil +end + +function m.checkSameSimpleInValueOfTable(status, value, start, queue) + if value.type ~= 'table' then + return + end + for i = 1, #value do + local field = value[i] + queue[#queue+1] = { + obj = field, + start = start + 1, + } + end +end + +function m.searchFields(status, obj, key) + local simple = m.getSimple(obj) + if not simple then + return + end + simple[#simple+1] = key and ('s|' .. key) or '*' + m.searchSameFields(status, simple, 'field') + m.cleanResults(status.results) +end + +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args[3] + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +function m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + local newStatus = m.status(status) + m.searchFields(newStatus, mt, '__index') + local refsStatus = m.status(status) + for i = 1, #newStatus.results do + local indexValue = m.getObjectValue(newStatus.results[i]) + if indexValue then + m.searchRefs(refsStatus, indexValue, 'ref') + end + end + for i = 1, #refsStatus.results do + local obj = refsStatus.results[i] + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end +function m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + if not func or func.special ~= 'setmetatable' then + return + end + local call = func.parent + local args = call.args + local obj = args[1] + local mt = args[2] + if obj then + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + if mt then + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.checkSameSimpleInValueOfCallMetaTable(status, call, start, queue) + if call.type == 'call' then + m.checkSameSimpleInValueOfSetMetaTable(status, call.node, start, queue) + end +end + +function m.checkSameSimpleInSpecialBranch(status, obj, start, queue) + if not status.interface.index then + return + end + local results = status.interface.index(obj) + if not results then + return + end + for _, res in ipairs(results) do + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end +end + +function m.checkSameSimpleByDocType(status, doc) + if status.cache.searchingBindedDoc then + return + end + if doc.type ~= 'doc.type' then + return + end + local results = {} + for _, piece in ipairs(doc.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + results[#results+1] = res + end + end + return results +end + +function m.checkSameSimpleByBindDocs(status, obj, start, queue, mode) + if not obj.bindDocs then + return + end + if status.cache.searchingBindedDoc then + return + end + local skipInfer = false + local results = {} + for _, doc in ipairs(obj.bindDocs) do + if doc.type == 'doc.class' then + results[#results+1] = doc + elseif doc.type == 'doc.type' then + results[#results+1] = doc + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if obj.type == 'local' + and m.getName(obj) == doc.param[1] then + if obj.parent.type == 'funcargs' + or obj.parent.type == 'in' + or obj.parent.type == 'loop' then + results[#results+1] = doc.extends + end + end + elseif doc.type == 'doc.field' then + results[#results+1] = doc + end + end + for _, res in ipairs(results) do + if res.type == 'doc.class' + or res.type == 'doc.type' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + skipInfer = true + end + if res.type == 'doc.type.function' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + elseif res.type == 'doc.field' then + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end + end + return skipInfer +end + +function m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if status.cache.searchingBindedDoc then + return + end + if not obj.bindSources then + return + end + status.cache.searchingBindedDoc = true + local mark = {} + local newStatus = m.status(status) + for _, ref in ipairs(obj.bindSources) do + if not mark[ref] then + mark[ref] = true + m.searchRefs(newStatus, ref, mode) + end + end + status.cache.searchingBindedDoc = nil + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleByDoc(status, obj, start, queue, mode) + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' then + obj = m.getDocState(obj) + end + if obj.type == 'doc.class' then + local classStart + for _, doc in ipairs(obj.bindGroup) do + if doc == obj then + classStart = true + elseif doc.type == 'doc.class' then + classStart = false + end + if classStart and doc.type == 'doc.field' then + queue[#queue+1] = { + obj = doc, + start = start + 1, + } + end + end + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if mode == 'ref' then + local pieceResult = stepRefOfDocType(status, obj.class, 'ref') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + return true + elseif obj.type == 'doc.type' then + for _, piece in ipairs(obj.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + if mode == 'ref' then + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + end + return true + elseif obj.type == 'doc.field' then + if mode ~= 'field' then + return m.checkSameSimpleByDoc(status, obj.extends, start, queue, mode) + end + end +end + +function m.checkSameSimpleInArg1OfSetMetaTable(status, obj, start, queue) + local args = obj.parent + if not args or args.type ~= 'callargs' then + return + end + if args[1] ~= obj then + return + end + local mt = args[2] + if mt then + if m.checkValueMark(status, obj, mt) then + return + end + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.searchSameMethodCrossSelf(ref, mark) + local selfNode + if ref.tag == 'self' then + selfNode = ref + else + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + local node = ref.node + if node.tag == 'self' then + selfNode = node + end + end + end + if selfNode then + if mark[selfNode] then + return nil + end + mark[selfNode] = true + return selfNode.method.node + end +end + +function m.searchSameMethod(ref, mark) + if mark['method'] then + return nil + end + local nxt = ref.next + if not nxt then + return nil + end + if nxt.type == 'setmethod' then + mark['method'] = true + return ref + end + return nil +end + +function m.searchSameFieldsCrossMethod(status, ref, start, queue) + local mark = status.cache.crossMethodMark + if not mark then + mark = {} + status.cache.crossMethodMark = mark + end + local method = m.searchSameMethod(ref, mark) + or m.searchSameMethodCrossSelf(ref, mark) + if not method then + return + end + local methodStatus = m.status(status) + m.searchRefs(methodStatus, method, 'ref') + for _, md in ipairs(methodStatus.results) do + queue[#queue+1] = { + obj = md, + start = start, + force = true, + } + local nxt = md.next + if not nxt then + goto CONTINUE + end + if nxt.type == 'setmethod' then + local func = nxt.value + if not func then + goto CONTINUE + end + local selfNode = func.locals and func.locals[1] + if not selfNode or not selfNode.ref then + goto CONTINUE + end + if mark[selfNode] then + goto CONTINUE + end + mark[selfNode] = true + for _, selfRef in ipairs(selfNode.ref) do + queue[#queue+1] = { + obj = selfRef, + start = start, + force = true, + } + end + end + ::CONTINUE:: + end +end + +local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local types = m.checkSameSimpleByDocType(status, rtn) + if not types then + return + end + for _, res in ipairs(types) do + results[#results+1] = res + end + return true +end + +local function checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local types = m.checkSameSimpleByDocType(status, rtn) + if types then + for _, res in ipairs(types) do + results[#results+1] = res + end + return true + end + end + end + end + end + end +end + +function m.checkSameSimpleInCallInSameFile(status, func, args, index) + local newStatus = m.status(status) + m.searchRefs(newStatus, func, 'def') + local results = {} + for _, def in ipairs(newStatus.results) do + local hasDocReturn = checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, def, index) + or checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, def, index) + if not hasDocReturn then + local value = m.getObjectValue(def) or def + if value.type == 'function' then + local returns = value.returns + if returns then + for _, ret in ipairs(returns) do + local exp = ret[index] + if exp then + results[#results+1] = exp + end + end + end + end + end + end + return results +end + +function m.checkSameSimpleInCall(status, ref, start, queue, mode) + local func, args, index = m.getCallValue(ref) + if not func then + return + end + if m.checkCallMark(status, func.parent, true) then + return + end + status.cache.crossCallCount = status.cache.crossCallCount or 0 + if status.cache.crossCallCount >= 5 then + return + end + status.cache.crossCallCount = status.cache.crossCallCount + 1 + -- 检查赋值是 semetatable() 的情况 + m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + -- 检查赋值是 func() 的情况 + local objs = m.checkSameSimpleInCallInSameFile(status, func, args, index) + if status.interface.call then + local cobjs = status.interface.call(func, args, index) + if cobjs then + for _, obj in ipairs(cobjs) do + if not m.checkReturnMark(status, obj) then + objs[#objs+1] = obj + end + end + end + end + m.cleanResults(objs) + local newStatus = m.status(status) + for _, obj in ipairs(objs) do + m.searchRefs(newStatus, obj, mode) + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + status.cache.crossCallCount = status.cache.crossCallCount - 1 + for _, obj in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end + +local function searchRawset(ref, results) + if m.getKeyName(ref) ~= 's|rawset' then + return + end + local call = ref.parent + if call.type ~= 'call' or call.node ~= ref then + return + end + if not call.args then + return + end + local arg1 = call.args[1] + if arg1.special ~= '_G' then + -- 不会吧不会吧,不会真的有人写成 `rawset(_G._G._G, 'xxx', value)` 吧 + return + end + results[#results+1] = call +end + +local function searchG(ref, results) + while ref and m.getKeyName(ref) == 's|_G' do + results[#results+1] = ref + ref = ref.next + end + if ref then + results[#results+1] = ref + searchRawset(ref, results) + end +end + +local function searchEnvRef(ref, results) + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + searchG(ref, results) + elseif ref.type == 'getlocal' then + results[#results+1] = ref.next + searchG(ref.next, results) + end +end + +function m.findGlobals(ast) + local root = m.getRoot(ast) + local results = {} + local env = m.getENV(root) + if env.ref then + for _, ref in ipairs(env.ref) do + searchEnvRef(ref, results) + end + end + return results +end + +function m.findGlobalsOfName(ast, name) + local root = m.getRoot(ast) + local results = {} + local globals = m.findGlobals(root) + for _, global in ipairs(globals) do + if m.getKeyName(global) == name then + results[#results+1] = global + end + end + return results +end + +function m.checkSameSimpleInGlobal(status, name, source, start, queue) + if not name then + return + end + local objs + if status.interface.global then + objs = status.interface.global(name) + else + objs = m.findGlobalsOfName(source, name) + end + if objs then + for _, obj in ipairs(objs) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + end +end + +function m.checkValueMark(status, a, b) + if not status.cache.valueMark then + status.cache.valueMark = {} + end + if status.cache.valueMark[a] + or status.cache.valueMark[b] then + return true + end + status.cache.valueMark[a] = true + status.cache.valueMark[b] = true + return false +end + +function m.checkCallMark(status, a, mark) + if not status.cache.callMark then + status.cache.callMark = {} + end + if mark then + status.cache.callMark[a] = mark + else + return status.cache.callMark[a] + end + return false +end + +function m.checkReturnMark(status, a, mark) + if not status.cache.returnMark then + status.cache.returnMark = {} + end + if mark then + status.cache.returnMark[a] = mark + else + return status.cache.returnMark[a] + end + return false +end + +function m.searchSameFieldsInValue(status, ref, start, queue, mode) + local value = m.getObjectValue(ref) + if not value then + return + end + if m.checkValueMark(status, ref, value) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, value, mode) + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + queue[#queue+1] = { + obj = value, + start = start, + force = true, + } + -- 检查形如 a = f() 的分支情况 + m.checkSameSimpleInCall(status, value, start, queue, mode) +end + +function m.checkSameSimpleAsTableField(status, ref, start, queue) + if not status.deep then + --return + end + local parent = ref.parent + if not parent or parent.type ~= 'tablefield' then + return + end + if m.checkValueMark(status, parent, ref) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, parent.field, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSearchLevel(status) + status.cache.back = status.cache.back or 0 + if status.cache.back >= (status.interface.searchLevel or 0) then + -- TODO 限制向前搜索的次数 + --return true + end + status.cache.back = status.cache.back + 1 + return false +end + +function m.checkSameSimpleAsReturn(status, ref, start, queue) + if not status.deep then + return + end + if not ref.parent or ref.parent.type ~= 'return' then + return + end + if ref.parent.parent.type ~= 'main' then + return + end + if m.checkSearchLevel(status) then + return + end + local newStatus = m.status(status) + m.searchRefsAsFunctionReturn(newStatus, ref, 'ref') + for _, res in ipairs(newStatus.results) do + if not m.checkCallMark(status, res) then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end +end + +function m.checkSameSimpleAsSetValue(status, ref, start, queue) + if ref.type == 'select' then + return + end + local parent = ref.parent + if not parent then + return + end + if m.getObjectValue(parent) ~= ref then + return + end + if m.checkValueMark(status, ref, parent) then + return + end + if m.checkSearchLevel(status) then + return + end + local obj + if parent.type == 'local' + or parent.type == 'setglobal' + or parent.type == 'setlocal' then + obj = parent + elseif parent.type == 'setfield' then + obj = parent.field + elseif parent.type == 'setmethod' then + obj = parent.method + end + if not obj then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, obj, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleInString(status, ref, start, queue, mode) + -- 特殊处理 ('xxx').xxx 的形式 + if ref.type ~= 'string' then + return + end + if not status.interface.docType then + return + end + if status.cache.searchingBindedDoc then + return + end + local newStatus = m.status(status) + local docs = status.interface.docType('string*') + local mark = {} + for i = 1, #docs do + local doc = docs[i] + m.searchFields(newStatus, doc) + end + for _, res in ipairs(newStatus.results) do + if mark[res] then + goto CONTINUE + end + mark[res] = true + queue[#queue+1] = { + obj = res, + start = start + 1, + } + ::CONTINUE:: + end + return true +end + +function m.pushResult(status, mode, ref, simple) + local results = status.results + if mode == 'def' then + if ref.type == 'setglobal' + or ref.type == 'setlocal' + or ref.type == 'local' then + results[#results+1] = ref + elseif ref.type == 'setfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' then + results[#results+1] = ref + end + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + if m.getParentFunction(ref) ~= m.getParentFunction(simple.node) then + results[#results+1] = ref + end + end + elseif mode == 'ref' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' + or ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + results[#results+1] = ref + end + elseif mode == 'field' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + end +end + +function m.checkSameSimpleName(ref, sm) + if sm == '*' then + return true + end + if m.getSimpleName(ref) == sm then + return true + end + if ref.type == 'doc.type' + and ref.array == true then + return true + end + return false +end + +function m.checkSameSimple(status, simple, data, mode, queue) + local ref = data.obj + local start = data.start + local force = data.force + if start > #simple then + return + end + for i = start, #simple do + local sm = simple[i] + if not force and not m.checkSameSimpleName(ref, sm) then + return + end + force = false + local cmode = mode + if i < #simple then + cmode = 'ref' + end + -- 检查 doc + local skipInfer = m.checkSameSimpleByBindDocs(status, ref, i, queue, cmode) + or m.checkSameSimpleByDoc(status, ref, i, queue, cmode) + if not skipInfer then + -- 穿透 self:func 与 mt:func + m.searchSameFieldsCrossMethod(status, ref, i, queue) + -- 穿透赋值 + m.searchSameFieldsInValue(status, ref, i, queue, cmode) + -- 检查自己是字面量表的情况 + m.checkSameSimpleInValueOfTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 第一个参数的情况 + m.checkSameSimpleInArg1OfSetMetaTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 调用的情况 + m.checkSameSimpleInValueOfCallMetaTable(status, ref, i, queue) + -- 检查自己是特殊变量的分支的情况 + m.checkSameSimpleInSpecialBranch(status, ref, i, queue) + -- 检查自己是字面量字符串的分支情况 + m.checkSameSimpleInString(status, ref, i, queue, cmode) + if cmode == 'ref' then + -- 检查形如 { a = f } 的情况 + m.checkSameSimpleAsTableField(status, ref, i, queue) + -- 检查形如 return m 的情况 + m.checkSameSimpleAsReturn(status, ref, i, queue) + -- 检查形如 a = f 的情况 + m.checkSameSimpleAsSetValue(status, ref, i, queue) + end + end + if i == #simple then + break + end + ref = m.getNextRef(ref) + if not ref then + return + end + end + m.pushResult(status, mode, ref, simple) + local value = m.getObjectValue(ref) + if value then + m.pushResult(status, mode, value, simple) + end +end + +function m.searchSameFields(status, simple, mode) + local queue = {} + if simple.mode == 'global' then + -- 全局变量开头 + m.checkSameSimpleInGlobal(status, simple[1], simple.node, 1, queue) + elseif simple.mode == 'local' then + -- 局部变量开头 + queue[1] = { + obj = simple.node, + start = 1, + } + local refs = simple.node.ref + if refs then + for i = 1, #refs do + queue[#queue+1] = { + obj = refs[i], + start = 1, + } + end + end + else + queue[1] = { + obj = simple.node, + start = 1, + } + end + local max = 0 + for i = 1, 1e6 do + local data = queue[i] + if not data then + return + end + if not status.lock[data.obj] then + status.lock[data.obj] = true + max = max + 1 + status.cache.count = status.cache.count + 1 + m.checkSameSimple(status, simple, data, mode, queue) + if max >= 10000 then + logWarn('Queue too large!') + break + end + end + end +end + +function m.getCallerInSameFile(status, func) + -- 搜索所有所在函数的调用者 + local funcRefs = m.status(status) + m.searchRefOfValue(funcRefs, func) + + local calls = {} + if #funcRefs.results == 0 then + return calls + end + for _, res in ipairs(funcRefs.results) do + local call = res.parent + if call.type == 'call' then + calls[#calls+1] = call + end + end + return calls +end + +function m.getCallerCrossFiles(status, main) + if status.interface.link then + return status.interface.link(main.uri) + end + return {} +end + +function m.searchRefsAsFunctionReturn(status, obj, mode) + if mode == 'def' then + return + end + if m.checkReturnMark(status, obj, true) then + return + end + status.results[#status.results+1] = obj + -- 搜索所在函数 + local currentFunc = m.getParentFunction(obj) + local rtn = obj.parent + if rtn.type ~= 'return' then + return + end + -- 看看他是第几个返回值 + local index + for i = 1, #rtn do + if obj == rtn[i] then + index = i + break + end + end + if not index then + return + end + local calls + if currentFunc.type == 'main' then + calls = m.getCallerCrossFiles(status, currentFunc) + else + calls = m.getCallerInSameFile(status, currentFunc) + end + -- 搜索调用者的返回值 + if #calls == 0 then + return + end + local selects = {} + for i = 1, #calls do + local parent = calls[i].parent + if parent.type == 'select' and parent.index == index then + selects[#selects+1] = parent.parent + end + local extParent = calls[i].extParent + if extParent then + for j = 1, #extParent do + local ext = extParent[j] + if ext.type == 'select' and ext.index == index then + selects[#selects+1] = ext.parent + end + end + end + end + -- 搜索调用者的引用 + for i = 1, #selects do + m.searchRefs(status, selects[i], 'ref') + end +end + +function m.searchRefsAsFunctionSet(status, obj, mode) + local parent = obj.parent + if not parent then + return + end + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' + or parent.type == 'setfield' + or parent.type == 'setmethod' + or parent.type == 'tablefield' then + m.searchRefs(status, parent, mode) + elseif parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == obj then + m.searchRefs(status, parent, mode) + end + end +end + +function m.searchRefsAsFunction(status, obj, mode) + if obj.type ~= 'function' + and obj.type ~= 'table' then + return + end + m.searchRefsAsFunctionSet(status, obj, mode) + -- 检查自己作为返回函数时的引用 + m.searchRefsAsFunctionReturn(status, obj, mode) +end + +function m.cleanResults(results) + local mark = {} + for i = #results, 1, -1 do + local res = results[i] + if res.tag == 'self' + or mark[res] then + results[i] = results[#results] + results[#results] = nil + else + mark[res] = true + end + end +end + +--function m.getRefCache(status, obj, mode) +-- local cache = status.interface.cache and status.interface.cache() +-- if not cache then +-- return +-- end +-- if m.isGlobal(obj) then +-- obj = m.getKeyName(obj) +-- end +-- if not cache[mode] then +-- cache[mode] = {} +-- end +-- local sourceCache = cache[mode][obj] +-- if sourceCache then +-- return sourceCache +-- end +-- sourceCache = {} +-- cache[mode][obj] = sourceCache +-- return nil, function (results) +-- for i = 1, #results do +-- sourceCache[i] = results[i] +-- end +-- end +--end + +function m.getRefCache(status, obj, mode) + local cache, globalCache + if status.depth == 1 + and status.deep then + globalCache = status.interface.cache and status.interface.cache() or {} + end + cache = status.cache.refCache or {} + status.cache.refCache = cache + if m.isGlobal(obj) then + obj = m.getKeyName(obj) + end + if not cache[mode] then + cache[mode] = {} + end + if globalCache and not globalCache[mode] then + globalCache[mode] = {} + end + local sourceCache = globalCache and globalCache[mode][obj] or cache[mode][obj] + if sourceCache then + return sourceCache + end + sourceCache = {} + cache[mode][obj] = sourceCache + if globalCache then + globalCache[mode][obj] = sourceCache + end + return nil, function (results) + for i = 1, #results do + sourceCache[i] = results[i] + end + end +end + +function m.searchRefs(status, obj, mode) + local cache, makeCache = m.getRefCache(status, obj, mode) + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + -- 检查单步引用 + local res = m.getStepRef(status, obj, mode) + if res then + for i = 1, #res do + status.results[#status.results+1] = res[i] + end + end + -- 检查simple + if status.depth <= 100 then + local simple = m.getSimple(obj) + if simple then + m.searchSameFields(status, simple, mode) + end + else + if m.debugMode then + error('status.depth overflow') + elseif DEVELOP then + --log.warn(debug.traceback('status.depth overflow')) + logWarn('status.depth overflow') + end + end + + m.cleanResults(status.results) + + if makeCache then + makeCache(status.results) + end +end + +function m.searchRefOfValue(status, obj) + local var = obj.parent + if var.type == 'local' + or var.type == 'set' then + return m.searchRefs(status, var, 'ref') + end +end + +function m.allocInfer(o) + if type(o.type) == 'table' then + local infers = {} + for i = 1, #o.type do + infers[i] = { + type = o.type[i], + value = o.value, + source = o.source, + } + end + return infers + else + return { + [1] = o, + } + end +end + +function m.mergeTypes(types) + local results = {} + local mark = {} + local hasAny + -- 这里把 any 去掉 + for i = 1, #types do + local tp = types[i] + if tp == 'any' then + hasAny = true + end + if not mark[tp] and tp ~= 'any' then + mark[tp] = true + results[#results+1] = tp + end + end + if #results == 0 then + return 'any' + end + -- 只有显性的 nil 与 any 时,取 any + if #results == 1 then + if results[1] == 'nil' and hasAny then + return 'any' + else + return results[1] + end + end + -- 同时包含 number 与 integer 时,去掉 integer + if mark['number'] and mark['integer'] then + for i = 1, #results do + if results[i] == 'integer' then + tableRemove(results, i) + break + end + end + end + tableSort(results, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + return sa < sb + end) + return tableConcat(results, '|') +end + +function m.viewInferType(infers) + if not infers then + return 'any' + end + local mark = {} + local types = {} + local hasDoc + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' then + if infer.type ~= 'any' then + hasDoc = true + break + end + end + end + if hasDoc then + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' + or src.type == 'doc.type.enum' + or src.type == 'doc.resume' then + local tp = infer.type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + else + for i = 1, #infers do + local tp = infers[i].type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + return m.mergeTypes(types) +end + +function m.checkTrue(status, source) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + -- 当前认为的结果 + local current + for _, infer in ipairs(newStatus.results) do + -- 新的结果 + local new + if infer.type == 'nil' then + new = false + elseif infer.type == 'boolean' then + if infer.value == true then + new = true + elseif infer.value == false then + new = false + end + end + if new ~= nil then + if current == nil then + current = new + else + -- 如果2个结果完全相反,则返回 nil 表示不确定 + if new ~= current then + return nil + end + end + end + end + return current +end + +--- 获取特定类型的字面量值 +function m.getInferLiteral(status, source, type) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + for _, infer in ipairs(newStatus.results) do + if infer.value ~= nil then + if type == nil or infer.type == type then + return infer.value + end + end + end + return nil +end + +--- 是否包含某种类型 +function m.hasType(status, source, type) + m.searchInfer(status, source) + for _, infer in ipairs(status.results) do + if infer.type == type then + return true + end + end + return false +end + +function m.isSameValue(status, a, b) + local statusA = m.status(status) + m.searchInfer(statusA, a) + local statusB = m.status(status) + m.searchInfer(statusB, b) + local infers = {} + for _, infer in ipairs(statusA.results) do + local literal = infer.value + if literal then + infers[literal] = false + end + end + for _, infer in ipairs(statusB.results) do + local literal = infer.value + if literal then + if infers[literal] == nil then + return false + end + infers[literal] = true + end + end + for k, v in pairs(infers) do + if v == false then + return false + end + end + return true +end + +function m.inferCheckLiteralTableWithDocVararg(status, source) + if #source ~= 1 then + return + end + local vararg = source[1] + if vararg.type ~= 'varargs' then + return + end + local results = m.getVarargDocType(status, source) + status.results[#status.results+1] = { + type = m.viewInferType(results) .. '[]', + source = source, + } + return true +end + +function m.inferCheckLiteral(status, source) + if source.type == 'string' then + status.results = m.allocInfer { + type = 'string', + value = source[1], + source = source, + } + return true + elseif source.type == 'nil' then + status.results = m.allocInfer { + type = 'nil', + value = NIL, + source = source, + } + return true + elseif source.type == 'boolean' then + status.results = m.allocInfer { + type = 'boolean', + value = source[1], + source = source, + } + return true + elseif source.type == 'number' then + if mathType(source[1]) == 'integer' then + status.results = m.allocInfer { + type = 'integer', + value = source[1], + source = source, + } + return true + else + status.results = m.allocInfer { + type = 'number', + value = source[1], + source = source, + } + return true + end + elseif source.type == 'integer' then + status.results = m.allocInfer { + type = 'integer', + source = source, + } + return true + elseif source.type == 'table' then + if m.inferCheckLiteralTableWithDocVararg(status, source) then + return true + end + status.results = m.allocInfer { + type = 'table', + source = source, + } + return true + elseif source.type == 'function' then + status.results = m.allocInfer { + type = 'function', + source = source, + } + return true + elseif source.type == '...' then + status.results = m.allocInfer { + type = '...', + source = source, + } + return true + end +end + +local function getDocAliasExtends(status, name) + if not status.interface.docType then + return nil + end + for _, doc in ipairs(status.interface.docType(name)) do + if doc.type == 'doc.alias.name' then + return m.viewInferType(m.getDocTypeNames(status, doc.parent.extends)) + end + end + return nil +end + +local function getDocTypeUnitName(status, unit, genericCallback) + local typeName + if unit.type == 'doc.type.name' then + typeName = getDocAliasExtends(status, unit[1]) or unit[1] + elseif unit.type == 'doc.type.function' then + typeName = 'function' + elseif unit.type == 'doc.type.array' then + typeName = getDocTypeUnitName(status, unit.node, genericCallback) .. '[]' + elseif unit.type == 'doc.type.generic' then + typeName = ('%s<%s, %s>'):format( + getDocTypeUnitName(status, unit.node, genericCallback), + m.viewInferType(m.getDocTypeNames(status, unit.key, genericCallback)), + m.viewInferType(m.getDocTypeNames(status, unit.value, genericCallback)) + ) + end + if unit.typeGeneric then + if genericCallback then + typeName = genericCallback(typeName, unit) + or ('<%s>'):format(typeName) + else + typeName = ('<%s>'):format(typeName) + end + end + return typeName +end + +function m.getDocTypeNames(status, doc, genericCallback) + local results = {} + if not doc then + return results + end + for _, unit in ipairs(doc.types) do + local typeName = getDocTypeUnitName(status, unit, genericCallback) + results[#results+1] = { + type = typeName, + source = unit, + } + end + for _, enum in ipairs(doc.enums) do + results[#results+1] = { + type = enum[1], + source = enum, + } + end + for _, resume in ipairs(doc.resumes) do + if not resume.additional then + results[#results+1] = { + type = resume[1], + source = resume, + } + end + end + return results +end + +function m.inferCheckDoc(status, source) + if source.type == 'doc.class.name' then + status.results[#status.results+1] = { + type = source[1], + source = source, + } + return true + end + if source.type == 'doc.class' then + status.results[#status.results+1] = { + type = source.class[1], + source = source, + } + return true + end + if source.type == 'doc.type' then + local results = m.getDocTypeNames(status, source) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + if source.type == 'doc.field' then + local results = m.getDocTypeNames(status, source.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end +end + +function m.getVarargDocType(status, source) + local func = m.getParentFunction(source) + if not func then + return + end + if not func.args then + return + end + for _, arg in ipairs(func.args) do + if arg.type == '...' then + if arg.bindDocs then + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.vararg' then + return m.getDocTypeNames(status, doc.vararg) + end + end + end + end + end +end + +function m.inferCheckUpDocOfVararg(status, source) + if not source.vararg then + return + end + local results = m.getVarargDocType(status, source) + if not results then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +function m.inferCheckUpDoc(status, source) + if m.inferCheckUpDocOfVararg(status, source) then + return true + end + local parent = source.parent + if parent then + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' then + source = parent + end + if parent.type == 'setfield' + or parent.type == 'tablefield' then + if parent.field == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setmethod' then + if parent.method == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == source + or parent.value == source then + source = parent + end + end + end + local binds = source.bindDocs + if not binds then + return + end + status.results = {} + for _, doc in ipairs(binds) do + if doc.type == 'doc.class' then + status.results[#status.results+1] = { + type = doc.class[1], + source = doc, + } + -- ---@class Class + -- local x = { field = 1 } + -- 这种情况下,将字面量表接受为Class的定义 + if source.value and source.value.type == 'table' then + status.results[#status.results+1] = { + type = source.value.type, + source = source.value, + } + end + return true + elseif doc.type == 'doc.type' then + local results = m.getDocTypeNames(status, doc) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if source.type == 'local' + and m.getName(source) == doc.param[1] then + if source.parent.type == 'funcargs' + or source.parent.type == 'in' + or source.parent.type == 'loop' then + local results = m.getDocTypeNames(status, doc.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + end + end + end +end + +function m.inferCheckFieldDoc(status, source) + -- 检查 string[] 的情况 + if source.type == 'getindex' then + local node = source.node + if not node then + return + end + local newStatus = m.status(status) + m.searchInfer(newStatus, node) + local ok + for _, infer in ipairs(newStatus.results) do + local src = infer.source + if src.type == 'doc.type.array' then + ok = true + status.results[#status.results+1] = { + type = infer.type:gsub('%[%]$', ''), + source = src.node, + } + end + end + return ok + end +end + +function m.inferCheckUnary(status, source) + if source.type ~= 'unary' then + return + end + local op = source.op + if op.type == 'not' then + local checkTrue = m.checkTrue(status, source[1]) + local value = nil + if checkTrue == true then + value = false + elseif checkTrue == false then + value = true + end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '#' then + status.results = m.allocInfer { + type = 'integer', + source = source, + } + return true + elseif op.type == '~' then + local l = m.getInferLiteral(status, source[1], 'integer') + status.results = m.allocInfer { + type = 'integer', + value = l and ~l or nil, + source = source, + } + return true + elseif op.type == '-' then + local v = m.getInferLiteral(status, source[1], 'integer') + if v then + status.results = m.allocInfer { + type = 'integer', + value = - v, + source = source, + } + return true + end + v = m.getInferLiteral(status, source[1], 'number') + status.results = m.allocInfer { + type = 'number', + value = v and -v or nil, + source = source, + } + return true + end +end + +local function mathCheck(status, a, b) + local v1 = m.getInferLiteral(status, a, 'integer') + or m.getInferLiteral(status, a, 'number') + local v2 = m.getInferLiteral(status, b, 'integer') + or m.getInferLiteral(status, a, 'number') + local int = m.hasType(status, a, 'integer') + and m.hasType(status, b, 'integer') + and not m.hasType(status, a, 'number') + and not m.hasType(status, b, 'number') + return int and 'integer' or 'number', v1, v2 +end + +function m.inferCheckBinary(status, source) + if source.type ~= 'binary' then + return + end + local op = source.op + if op.type == 'and' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[2]) + return true + elseif isTrue == false then + m.searchInfer(status, source[1]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == 'or' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[1]) + return true + elseif isTrue == false then + m.searchInfer(status, source[2]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == '==' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = true + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '~=' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = not value, + source = source, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = false + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '<=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 <= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '>=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 >= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 < v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '|' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 | v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '~' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 ~ v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '&' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 & v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '<<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 << v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '>>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 >> v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '..' then + local v1 = m.getInferLiteral(status, source[1], 'string') + local v2 = m.getInferLiteral(status, source[2], 'string') + local v + if v1 and v2 then + v = v1 .. v2 + end + status.results = m.allocInfer { + type = 'string', + value = v, + source = source, + } + return true + elseif op.type == '^' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 ^ v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + } + return true + elseif op.type == '/' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + } + return true + -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数 + elseif op.type == '+' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 + v2) or nil, + source = source, + } + return true + elseif op.type == '-' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 - v2) or nil, + source = source, + } + return true + elseif op.type == '*' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 * v2) or nil, + source = source, + } + return true + elseif op.type == '%' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 % v2) or nil, + source = source, + } + return true + elseif op.type == '//' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 // v2) or nil, + source = source, + } + return true + end +end + +function m.inferByDef(status, obj) + if not status.cache.inferedDef then + status.cache.inferedDef = {} + end + if status.cache.inferedDef[obj] then + return + end + status.cache.inferedDef[obj] = true + local mark = {} + local newStatus = m.status(status, status.interface) + m.searchRefs(newStatus, obj, 'def') + for _, src in ipairs(newStatus.results) do + local inferStatus = m.status(newStatus) + m.searchInfer(inferStatus, src) + for _, infer in ipairs(inferStatus.results) do + if not mark[infer.source] then + mark[infer.source] = true + status.results[#status.results+1] = infer + end + end + end +end + +local function inferBySetOfLocal(status, source) + if status.cache[source] then + return + end + status.cache[source] = true + local newStatus = m.status(status) + if source.value then + m.searchInfer(newStatus, source.value) + end + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + break + end + m.searchInfer(newStatus, ref) + end + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end +end + +function m.inferBySet(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'local' then + inferBySetOfLocal(status, source) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + inferBySetOfLocal(status, source.node) + end +end + +function m.inferByCall(status, source) + if #status.results ~= 0 then + return + end + if not source.parent then + return + end + if source.parent.type ~= 'call' then + return + end + if source.parent.node == source then + status.results[#status.results+1] = { + type = 'function', + source = source, + } + return + end +end + +function m.inferByGetTable(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local next = source.next + if not next then + return + end + if next.type == 'getfield' + or next.type == 'getindex' + or next.type == 'getmethod' + or next.type == 'setfield' + or next.type == 'setindex' + or next.type == 'setmethod' then + status.results[#status.results+1] = { + type = 'table', + source = source, + } + end +end + +function m.inferByUnary(status, source) + if #status.results ~= 0 then + return + end + local parent = source.parent + if not parent or parent.type ~= 'unary' then + return + end + local op = parent.op + if op.type == '#' then + status.results[#status.results+1] = { + type = 'string', + source = source + } + status.results[#status.results+1] = { + type = 'table', + source = source + } + elseif op.type == '~' then + status.results[#status.results+1] = { + type = 'integer', + source = source + } + elseif op.type == '-' then + status.results[#status.results+1] = { + type = 'number', + source = source + } + end +end + +function m.inferByBinary(status, source) + if #status.results ~= 0 then + return + end + local parent = source.parent + if not parent or parent.type ~= 'binary' then + return + end + local op = parent.op + if op.type == '<=' + or op.type == '>=' + or op.type == '<' + or op.type == '>' + or op.type == '^' + or op.type == '/' + or op.type == '+' + or op.type == '-' + or op.type == '*' + or op.type == '%' then + status.results[#status.results+1] = { + type = 'number', + source = source, + } + elseif op.type == '|' + or op.type == '~' + or op.type == '&' + or op.type == '<<' + or op.type == '>>' + -- 整数的可能性比较高 + or op.type == '//' then + status.results[#status.results+1] = { + type = 'integer', + source = source, + } + elseif op.type == '..' then + status.results[#status.results+1] = { + type = 'string', + source = source, + } + end +end + +local function mergeFunctionReturnsByDoc(status, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local results = m.getDocTypeNames(status, rtn, function (typeName, typeUnit) + if not source.args or not call.args then + return + end + local name = typeUnit[1] + local generics = typeUnit.typeGeneric[name] + if not generics then + return + end + local first = generics[1] + if not first or first == typeUnit then + return + end + local docParam = m.getParentType(first, 'doc.param') + local paramName = docParam.param[1] + for i, arg in ipairs(source.args) do + if arg[1] == paramName then + local callArg = call.args[i] + if not callArg then + return + end + return m.viewInferType(m.searchInfer(status, callArg)) + end + end + end) + if #results == 0 then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +local function mergeDocTypeFunctionReturns(status, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local results = m.getDocTypeNames(status, rtn) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + end + end + end + end + end +end + +local function mergeFunctionReturns(status, source, index, call) + local returns = source.returns + if not returns then + return + end + for i = 1, #returns do + local rtn = returns[i] + if rtn[index] then + if rtn[index].type == 'call' then + if not m.checkReturnMark(status, rtn[index]) then + m.checkReturnMark(status, rtn[index], true) + m.inferByCallReturnAndIndex(status, rtn[index], index) + end + else + local newStatus = m.status(status) + m.searchInfer(newStatus, rtn[index]) + if #newStatus.results == 0 then + status.results[#status.results+1] = { + type = 'any', + source = rtn[index], + } + else + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end + end + end + end +end + +function m.inferByCallReturnAndIndex(status, call, index) + local node = call.node + local newStatus = m.status(nil, status.interface) + m.searchRefs(newStatus, node, 'def') + local hasDocReturn + for _, src in ipairs(newStatus.results) do + if mergeDocTypeFunctionReturns(status, src, index) then + hasDocReturn = true + elseif mergeFunctionReturnsByDoc(status, src.value, index, call) then + hasDocReturn = true + end + end + if not hasDocReturn then + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + if not m.checkReturnMark(status, src.value, true) then + mergeFunctionReturns(status, src.value, index, call) + end + end + end + end +end + +function m.inferByCallReturn(status, source) + if source.type == 'call' then + m.inferByCallReturnAndIndex(status, source, 1) + return + end + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + if not source.vararg or source.vararg.type ~= 'call' then + return + end + m.inferByCallReturnAndIndex(status, source.vararg, source.index) +end + +function m.inferByPCallReturn(status, source) + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + local call = source.vararg + if not call or call.type ~= 'call' then + return + end + local node = call.node + local specialName = node.special + local func, index + if specialName == 'pcall' then + func = call.args[1] + index = source.index - 1 + elseif specialName == 'xpcall' then + func = call.args[1] + index = source.index - 2 + else + return + end + local newStatus = m.status(nil, status.interface) + m.searchRefs(newStatus, func, 'def') + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + mergeFunctionReturns(status, src.value, index) + end + end +end + +function m.cleanInfers(infers) + local mark = {} + for i = #infers, 1, -1 do + local infer = infers[i] + local key = ('%s|%p'):format(infer.type, infer.source) + if mark[key] then + infers[i] = infers[#infers] + infers[#infers] = nil + else + mark[key] = true + end + end +end + +function m.searchInfer(status, obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return + end + end + while true do + local value = m.getObjectValue(obj) + if not value or value == obj then + break + end + obj = value + end + + local cache, makeCache = m.getRefCache(status, obj, 'infer') + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + if DEVELOP then + status.cache.clock = status.cache.clock or osClock() + end + + if not status.cache.lockInfer then + status.cache.lockInfer = {} + end + if status.cache.lockInfer[obj] then + return + end + status.cache.lockInfer[obj] = true + + local checked = m.inferCheckDoc(status, obj) + or m.inferCheckUpDoc(status, obj) + or m.inferCheckFieldDoc(status, obj) + or m.inferCheckLiteral(status, obj) + or m.inferCheckUnary(status, obj) + or m.inferCheckBinary(status, obj) + if checked then + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) + end + return + end + + if status.deep then + m.inferByDef(status, obj) + end + m.inferBySet(status, obj) + m.inferByCall(status, obj) + m.inferByGetTable(status, obj) + m.inferByUnary(status, obj) + m.inferByBinary(status, obj) + m.inferByCallReturn(status, obj) + m.inferByPCallReturn(status, obj) + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) + end +end + +--- 请求对象的引用,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 不穿透 `setmetatable` ,考虑由 +--- 业务层进行反向 def 搜索。 +function m.requestReference(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'ref') + + m.searchRefsAsFunction(status, obj, 'ref') + + if m.debugMode then + print('count:', status.cache.count) + end + + return status.results, status.cache.count +end + +--- 请求对象的定义,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 穿透 `setmetatable` 。 +function m.requestDefinition(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + -- 根据 field 搜索定义 + m.searchRefs(status, obj, 'def') + + return status.results, status.cache.count +end + +--- 请求对象的域 +function m.requestFields(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + + m.searchFields(status, obj) + + return status.results, status.cache.count +end + +--- 请求对象的类型推测 +function m.requestInfer(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + m.searchInfer(status, obj) + + return status.results, status.cache.count +end + +return m diff --git a/script/parser/init.lua b/script/parser/init.lua new file mode 100644 index 00000000..ba40d145 --- /dev/null +++ b/script/parser/init.lua @@ -0,0 +1,12 @@ +local api = { + grammar = require 'parser.grammar', + parse = require 'parser.parse', + compile = require 'parser.compile', + split = require 'parser.split', + calcline = require 'parser.calcline', + lines = require 'parser.lines', + guide = require 'parser.guide', + luadoc = require 'parser.luadoc', +} + +return api diff --git a/script/parser/lines.lua b/script/parser/lines.lua new file mode 100644 index 00000000..ee6b4f41 --- /dev/null +++ b/script/parser/lines.lua @@ -0,0 +1,45 @@ +local m = require 'lpeglabel' + +_ENV = nil + +local function Line(start, line, range, finish) + line.start = start + line.finish = finish - 1 + line.range = range - 1 + return line +end + +local function Space(...) + local line = {...} + local sp = 0 + local tab = 0 + for i = 1, #line do + if line[i] == ' ' then + sp = sp + 1 + elseif line[i] == '\t' then + tab = tab + 1 + end + line[i] = nil + end + line.sp = sp + line.tab = tab + return line +end + +local parser = m.P{ +'Lines', +Lines = m.Ct(m.V'Line'^0 * m.V'LastLine'), +Line = m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.V'Nl' * m.Cp() / Line, +LastLine= m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.Cp() / Line, +Nl = m.P'\r\n' + m.S'\r\n', +Indent = m.C(m.S' \t')^0 / Space, +} + +return function (self, text) + local lines, err = parser:match(text) + if not lines then + return nil, err + end + + return lines +end diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua new file mode 100644 index 00000000..b31c4baf --- /dev/null +++ b/script/parser/luadoc.lua @@ -0,0 +1,991 @@ +local m = require 'lpeglabel' +local re = require 'parser.relabel' +local lines = require 'parser.lines' +local guide = require 'parser.guide' + +local TokenTypes, TokenStarts, TokenFinishs, TokenContents +local Ci, Offset, pushError, Ct, NextComment +local parseType +local Parser = re.compile([[ +Main <- (Token / Sp)* +Sp <- %s+ +X16 <- [a-fA-F0-9] +Word <- [a-zA-Z0-9_] +Token <- Name / String / Symbol +Name <- ({} {[a-zA-Z0-9_] [a-zA-Z0-9_.*]*} {}) + -> Name +String <- ({} StringDef {}) + -> String +StringDef <- '"' + {~(Esc / !'"' .)*~} -> 1 + ('"'?) + / "'" + {~(Esc / !"'" .)*~} -> 1 + ("'"?) + / ('[' {:eq: '='* :} '[' + {(!StringClose .)*} -> 1 + (StringClose?)) +StringClose <- ']' =eq ']' +Esc <- '\' -> '' + EChar +EChar <- 'a' -> ea + / 'b' -> eb + / 'f' -> ef + / 'n' -> en + / 'r' -> er + / 't' -> et + / 'v' -> ev + / '\' + / '"' + / "'" + / %nl + / ('z' (%nl / %s)*) -> '' + / ('x' {X16 X16}) -> Char16 + / ([0-9] [0-9]? [0-9]?) -> Char10 + / ('u{' {Word*} '}') -> CharUtf8 +Symbol <- ({} { + ':' + / '|' + / ',' + / '[]' + / '<' + / '>' + / '(' + / ')' + / '?' + / '...' + / '+' + } {}) + -> Symbol +]], { + s = m.S' \t', + ea = '\a', + eb = '\b', + ef = '\f', + en = '\n', + er = '\r', + et = '\t', + ev = '\v', + Char10 = function (char) + char = tonumber(char) + if not char or char < 0 or char > 255 then + return '' + end + return string.char(char) + end, + Char16 = function (char) + return string.char(tonumber(char, 16)) + end, + CharUtf8 = function (char) + if #char == 0 then + return '' + end + local v = tonumber(char, 16) + if not v then + return '' + end + if v >= 0 and v <= 0x10FFFF then + return utf8.char(v) + end + return '' + end, + Name = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'name' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, + String = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'string' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, + Symbol = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'symbol' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, +}) + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function parseTokens(text, offset) + Ct = offset + Ci = 0 + Offset = offset + TokenTypes = {} + TokenStarts = {} + TokenFinishs = {} + TokenContents = {} + Parser:match(text) + Ci = 0 +end + +local function peekToken() + return TokenTypes[Ci+1], TokenContents[Ci+1] +end + +local function nextToken() + Ci = Ci + 1 + if not TokenTypes[Ci] then + Ci = Ci - 1 + return nil + end + return TokenTypes[Ci], TokenContents[Ci] +end + +local function checkToken(tp, content, offset) + offset = offset or 0 + return TokenTypes[Ci + offset] == tp + and TokenContents[Ci + offset] == content +end + +local function getStart() + if Ci == 0 then + return Offset + end + return TokenStarts[Ci] + Offset +end + +local function getFinish() + if Ci == 0 then + return Offset + end + return TokenFinishs[Ci] + Offset +end + +local function try(callback) + local savePoint = Ci + -- rollback + local suc = callback() + if not suc then + Ci = savePoint + end + return suc +end + +local function parseName(tp, parent) + local nameTp, nameText = peekToken() + if nameTp ~= 'name' then + return nil + end + nextToken() + local class = { + type = tp, + start = getStart(), + finish = getFinish(), + parent = parent, + [1] = nameText, + } + return class +end + +local function parseClass(parent) + local result = { + type = 'doc.class', + parent = parent, + } + result.class = parseName('doc.class.name', result) + if not result.class then + pushError { + type = 'LUADOC_MISS_CLASS_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.start = getStart() + result.finish = getFinish() + if not peekToken() then + return result + end + nextToken() + if not checkToken('symbol', ':') then + pushError { + type = 'LUADOC_MISS_EXTENDS_SYMBOL', + start = result.finish + 1, + finish = getStart() - 1, + } + return result + end + result.extends = parseName('doc.extends.name', result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_CLASS_EXTENDS_NAME', + start = getFinish(), + finish = getFinish(), + } + return result + end + result.finish = getFinish() + return result +end + +local function nextSymbolOrError(symbol) + if checkToken('symbol', symbol, 1) then + nextToken() + return true + end + pushError { + type = 'LUADOC_MISS_SYMBOL', + start = getFinish(), + finish = getFinish(), + info = { + symbol = symbol, + } + } + return false +end + +local function parseTypeUnitArray(node) + if not checkToken('symbol', '[]', 1) then + return nil + end + nextToken() + local result = { + type = 'doc.type.array', + start = node.start, + finish = getFinish(), + node = node, + } + return result +end + +local function parseTypeUnitGeneric(node) + if not checkToken('symbol', '<', 1) then + return nil + end + if not nextSymbolOrError('<') then + return nil + end + local key = parseType(node) + if not key or not nextSymbolOrError(',') then + return nil + end + local value = parseType(node) + if not value then + return nil + end + nextSymbolOrError('>') + local result = { + type = 'doc.type.generic', + start = node.start, + finish = getFinish(), + node = node, + key = key, + value = value, + } + return result +end + +local function parseTypeUnitFunction() + local typeUnit = { + type = 'doc.type.function', + start = getStart(), + args = {}, + returns = {}, + } + if not nextSymbolOrError('(') then + return nil + end + while true do + if checkToken('symbol', ')', 1) then + nextToken() + break + end + local arg = { + type = 'doc.type.arg', + parent = typeUnit, + } + arg.name = parseName('doc.type.name', arg) + if not arg.name then + pushError { + type = 'LUADOC_MISS_ARG_NAME', + start = getFinish(), + finish = getFinish(), + } + break + end + if not arg.start then + arg.start = arg.name.start + end + if checkToken('symbol', '?', 1) then + nextToken() + arg.optional = true + end + arg.finish = getFinish() + if not nextSymbolOrError(':') then + break + end + arg.extends = parseType(arg) + if not arg.extends then + break + end + arg.finish = getFinish() + typeUnit.args[#typeUnit.args+1] = arg + if checkToken('symbol', ',', 1) then + nextToken() + else + nextSymbolOrError(')') + break + end + end + if checkToken('symbol', ':', 1) then + nextToken() + while true do + local rtn = parseType(typeUnit) + if not rtn then + break + end + if checkToken('symbol', '?', 1) then + nextToken() + rtn.optional = true + end + typeUnit.returns[#typeUnit.returns+1] = rtn + if checkToken('symbol', ',', 1) then + nextToken() + else + break + end + end + end + typeUnit.finish = getFinish() + return typeUnit +end + +local function parseTypeUnit(parent, content) + local result + if content == 'fun' then + result = parseTypeUnitFunction() + end + if not result then + result = { + type = 'doc.type.name', + start = getStart(), + finish = getFinish(), + [1] = content, + } + end + if not result then + return nil + end + result.parent = parent + while true do + local newResult = parseTypeUnitArray(result) + or parseTypeUnitGeneric(result) + if not newResult then + break + end + result = newResult + end + return result +end + +local function parseResume() + local result = { + type = 'doc.resume' + } + + if checkToken('symbol', '>', 1) then + nextToken() + result.default = true + end + + if checkToken('symbol', '+', 1) then + nextToken() + result.additional = true + end + + local tp = peekToken() + if tp ~= 'string' then + pushError { + type = 'LUADOC_MISS_STRING', + start = getFinish(), + finish = getFinish(), + } + return nil + end + local _, str = nextToken() + result[1] = str + result.start = getStart() + result.finish = getFinish() + return result +end + +function parseType(parent) + local result = { + type = 'doc.type', + parent = parent, + types = {}, + enums = {}, + resumes = {}, + } + result.start = getStart() + while true do + local tp, content = peekToken() + if not tp then + break + end + if tp == 'name' then + nextToken() + local typeUnit = parseTypeUnit(result, content) + if not typeUnit then + break + end + result.types[#result.types+1] = typeUnit + if not result.start then + result.start = typeUnit.start + end + elseif tp == 'string' then + nextToken() + local typeEnum = { + type = 'doc.type.enum', + start = getStart(), + finish = getFinish(), + parent = result, + [1] = content, + } + result.enums[#result.enums+1] = typeEnum + if not result.start then + result.start = typeEnum.start + end + elseif tp == 'symbol' and content == '...' then + nextToken() + local vararg = { + type = 'doc.type.name', + start = getStart(), + finish = getFinish(), + parent = result, + [1] = content, + } + result.types[#result.types+1] = vararg + if not result.start then + result.start = vararg.start + end + end + if not checkToken('symbol', '|', 1) then + break + end + nextToken() + end + result.finish = getFinish() + + while true do + local nextComm = NextComment('peek') + if nextComm and nextComm.text:sub(1, 2) == '-|' then + NextComment() + local finishPos = nextComm.text:find('#', 3) or #nextComm.text + parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1) + local resume = parseResume() + if resume then + resume.comment = nextComm.text:match('#%s*(.+)', 3) + result.resumes[#result.resumes+1] = resume + result.finish = resume.finish + end + else + break + end + end + + if #result.types == 0 and #result.enums == 0 and #result.resumes == 0 then + pushError { + type = 'LUADOC_MISS_TYPE_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + return result +end + +local function parseAlias() + local result = { + type = 'doc.alias', + } + result.alias = parseName('doc.alias.name', result) + if not result.alias then + pushError { + type = 'LUADOC_MISS_ALIAS_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.start = getStart() + result.extends = parseType(result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_ALIAS_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.finish = getFinish() + return result +end + +local function parseParam() + local result = { + type = 'doc.param', + } + result.param = parseName('doc.param.name', result) + if not result.param then + pushError { + type = 'LUADOC_MISS_PARAM_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + if checkToken('symbol', '?', 1) then + nextToken() + result.optional = true + end + result.start = result.param.start + result.finish = getFinish() + result.extends = parseType(result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_PARAM_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return result + end + result.finish = getFinish() + return result +end + +local function parseReturn() + local result = { + type = 'doc.return', + returns = {}, + } + while true do + local docType = parseType(result) + if not docType then + break + end + if not result.start then + result.start = docType.start + end + if checkToken('symbol', '?', 1) then + nextToken() + docType.optional = true + end + docType.name = parseName('doc.return.name', docType) + result.returns[#result.returns+1] = docType + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + if #result.returns == 0 then + return nil + end + result.finish = getFinish() + return result +end + +local function parseField() + local result = { + type = 'doc.field', + } + try(function () + local tp, value = nextToken() + if tp == 'name' then + if value == 'public' + or value == 'protected' + or value == 'private' then + result.visible = value + result.start = getStart() + return true + end + end + return false + end) + result.field = parseName('doc.field.name', result) + if not result.field then + pushError { + type = 'LUADOC_MISS_FIELD_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + if not result.start then + result.start = result.field.start + end + result.extends = parseType(result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_FIELD_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.finish = getFinish() + return result +end + +local function parseGeneric() + local result = { + type = 'doc.generic', + generics = {}, + } + while true do + local object = { + type = 'doc.generic.object', + parent = result, + } + object.generic = parseName('doc.generic.name', object) + if not object.generic then + pushError { + type = 'LUADOC_MISS_GENERIC_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + object.start = object.generic.start + if not result.start then + result.start = object.start + end + if checkToken('symbol', ':', 1) then + nextToken() + object.extends = parseType(object) + end + object.finish = getFinish() + result.generics[#result.generics+1] = object + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + result.finish = getFinish() + return result +end + +local function parseVararg() + local result = { + type = 'doc.vararg', + } + result.vararg = parseType(result) + if not result.vararg then + pushError { + type = 'LUADOC_MISS_VARARG_TYPE', + start = getFinish(), + finish = getFinish(), + } + return + end + result.start = result.vararg.start + result.finish = result.vararg.finish + return result +end + +local function parseOverload() + local tp, name = peekToken() + if tp ~= 'name' or name ~= 'fun' then + pushError { + type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD', + start = getFinish(), + finish = getFinish(), + } + return nil + end + nextToken() + local result = { + type = 'doc.overload', + } + result.overload = parseTypeUnitFunction() + if not result.overload then + return nil + end + result.overload.parent = result + result.start = result.overload.start + result.finish = result.overload.finish + return result +end + +local function parseDeprecated() + return { + type = 'doc.deprecated', + start = getFinish(), + finish = getFinish(), + } +end + +local function parseMeta() + return { + type = 'doc.meta', + start = getFinish(), + finish = getFinish(), + } +end + +local function parseVersion() + local result = { + type = 'doc.version', + versions = {}, + } + while true do + local tp, text = nextToken() + if not tp then + pushError { + type = 'LUADOC_MISS_VERSION', + start = getStart(), + finish = getFinish(), + } + break + end + if not result.start then + result.start = getStart() + end + local version = { + type = 'doc.version.unit', + start = getStart(), + } + if tp == 'symbol' then + if text == '>' then + version.ge = true + elseif text == '<' then + version.le = true + end + tp, text = nextToken() + end + if tp ~= 'name' then + pushError { + type = 'LUADOC_MISS_VERSION', + start = getStart(), + finish = getFinish(), + } + break + end + version.version = tonumber(text) or text + version.finish = getFinish() + result.versions[#result.versions+1] = version + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + if #result.versions == 0 then + return nil + end + result.finish = getFinish() + return result +end + +local function convertTokens() + local tp, text = nextToken() + if not tp then + return + end + if tp ~= 'name' then + pushError { + type = 'LUADOC_MISS_CATE_NAME', + start = getStart(), + finish = getFinish(), + } + return nil + end + if text == 'class' then + return parseClass() + elseif text == 'type' then + return parseType() + elseif text == 'alias' then + return parseAlias() + elseif text == 'param' then + return parseParam() + elseif text == 'return' then + return parseReturn() + elseif text == 'field' then + return parseField() + elseif text == 'generic' then + return parseGeneric() + elseif text == 'vararg' then + return parseVararg() + elseif text == 'overload' then + return parseOverload() + elseif text == 'deprecated' then + return parseDeprecated() + elseif text == 'meta' then + return parseMeta() + elseif text == 'version' then + return parseVersion() + end +end + +local function buildLuaDoc(comment) + local text = comment.text + if text:sub(1, 1) ~= '-' then + return + end + if text:sub(2, 2) ~= '@' then + return { + type = 'doc.comment', + start = comment.start, + finish = comment.finish, + comment = comment, + } + end + local finishPos = text:find('@', 3) + local doc, lastComment + if finishPos then + doc = text:sub(3, finishPos - 1) + lastComment = text:sub(finishPos) + else + doc = text:sub(3) + end + + parseTokens(doc, comment.start + 1) + local result = convertTokens() + if result then + result.comment = lastComment + end + + return result +end + +local function isNextLine(lns, binded, doc) + if not binded then + return false + end + local lastDoc = binded[#binded] + local lastRow = guide.positionOf(lns, lastDoc.finish) + local newRow = guide.positionOf(lns, doc.start) + return newRow - lastRow == 1 +end + +local function bindGeneric(binded) + local generics = {} + for _, doc in ipairs(binded) do + if doc.type == 'doc.generic' then + for _, obj in ipairs(doc.generics) do + local name = obj.generic[1] + generics[name] = {} + end + elseif doc.type == 'doc.param' + or doc.type == 'doc.return' then + guide.eachSourceType(doc, 'doc.type.name', function (src) + local name = src[1] + if generics[name] then + generics[name][#generics[name]+1] = src + src.typeGeneric = generics + end + end) + end + end +end + +local function bindDoc(state, lns, binded) + if not binded then + return + end + local lastDoc = binded[#binded] + if not lastDoc then + return + end + local bindSources = {} + for _, doc in ipairs(binded) do + doc.bindGroup = binded + doc.bindSources = bindSources + end + bindGeneric(binded) + local row = guide.positionOf(lns, lastDoc.finish) + local start, finish = guide.lineRange(lns, row + 1) + if start >= finish then + -- 空行 + return + end + guide.eachSourceBetween(state.ast, start, finish, function (src) + if src.start and src.start < start then + return + end + if src.type == 'local' + or src.type == 'setlocal' + or src.type == 'setglobal' + or src.type == 'setfield' + or src.type == 'setmethod' + or src.type == 'setindex' + or src.type == 'tablefield' + or src.type == 'tableindex' + or src.type == 'function' + or src.type == '...' then + src.bindDocs = binded + bindSources[#bindSources+1] = src + end + end) +end + +local function bindDocs(state) + local lns = lines(nil, state.lua) + local binded + for _, doc in ipairs(state.ast.docs) do + if not isNextLine(lns, binded, doc) then + bindDoc(state, lns, binded) + binded = {} + state.ast.docs.groups[#state.ast.docs.groups+1] = binded + end + binded[#binded+1] = doc + end + bindDoc(state, lns, binded) +end + +return function (_, state) + local ast = state.ast + local comments = state.comms + table.sort(comments, function (a, b) + return a.start < b.start + end) + ast.docs = { + type = 'doc', + parent = ast, + groups = {}, + } + + pushError = state.pushError + + local ci = 1 + NextComment = function (peek) + local comment = comments[ci] + if not peek then + ci = ci + 1 + end + return comment + end + + while true do + local comment = NextComment() + if not comment then + break + end + local doc = buildLuaDoc(comment) + if doc then + ast.docs[#ast.docs+1] = doc + doc.parent = ast.docs + if ast.start > doc.start then + ast.start = doc.start + end + if ast.finish < doc.finish then + ast.finish = doc.finish + end + end + end + + if #ast.docs == 0 then + return + end + + bindDocs(state) +end diff --git a/script/parser/parse.lua b/script/parser/parse.lua new file mode 100644 index 00000000..f813cc59 --- /dev/null +++ b/script/parser/parse.lua @@ -0,0 +1,49 @@ +local ast = require 'parser.ast' + +return function (self, lua, mode, version) + local errs = {} + local diags = {} + local comms = {} + local state = { + version = version, + lua = lua, + root = {}, + errs = errs, + diags = diags, + comms = comms, + pushError = function (err) + if err.finish < err.start then + err.finish = err.start + end + local last = errs[#errs] + if last then + if last.start <= err.start and last.finish >= err.finish then + return + end + end + err.level = err.level or 'error' + errs[#errs+1] = err + return err + end, + pushDiag = function (code, info) + if not diags[code] then + diags[code] = {} + end + diags[code][#diags[code]+1] = info + end, + pushComment = function (comment) + comms[#comms+1] = comment + end + } + ast.init(state) + local suc, res, err = xpcall(self.grammar, debug.traceback, self, lua, mode) + ast.close() + if not suc then + return nil, res + end + if not res then + state.pushError(err) + end + state.ast = res + return state +end diff --git a/script/parser/relabel.lua b/script/parser/relabel.lua new file mode 100644 index 00000000..ac902403 --- /dev/null +++ b/script/parser/relabel.lua @@ -0,0 +1,361 @@ +-- $Id: re.lua,v 1.44 2013/03/26 20:11:40 roberto Exp $ + +-- imported functions and modules +local tonumber, type, print, error = tonumber, type, print, error +local pcall = pcall +local setmetatable = setmetatable +local tinsert, concat = table.insert, table.concat +local rep = string.rep +local m = require"lpeglabel" + +-- 'm' will be used to parse expressions, and 'mm' will be used to +-- create expressions; that is, 're' runs on 'm', creating patterns +-- on 'mm' +local mm = m + +-- pattern's metatable +local mt = getmetatable(mm.P(0)) + + + +-- No more global accesses after this point +_ENV = nil + + +local any = m.P(1) +local dummy = mm.P(false) + + +local errinfo = { + NoPatt = "no pattern found", + ExtraChars = "unexpected characters after the pattern", + + ExpPatt1 = "expected a pattern after '/'", + + ExpPatt2 = "expected a pattern after '&'", + ExpPatt3 = "expected a pattern after '!'", + + ExpPatt4 = "expected a pattern after '('", + ExpPatt5 = "expected a pattern after ':'", + ExpPatt6 = "expected a pattern after '{~'", + ExpPatt7 = "expected a pattern after '{|'", + + ExpPatt8 = "expected a pattern after '<-'", + + ExpPattOrClose = "expected a pattern or closing '}' after '{'", + + ExpNumName = "expected a number, '+', '-' or a name (no space) after '^'", + ExpCap = "expected a string, number, '{}' or name after '->'", + + ExpName1 = "expected the name of a rule after '=>'", + ExpName2 = "expected the name of a rule after '=' (no space)", + ExpName3 = "expected the name of a rule after '<' (no space)", + + ExpLab1 = "expected a label after '{'", + + ExpNameOrLab = "expected a name or label after '%' (no space)", + + ExpItem = "expected at least one item after '[' or '^'", + + MisClose1 = "missing closing ')'", + MisClose2 = "missing closing ':}'", + MisClose3 = "missing closing '~}'", + MisClose4 = "missing closing '|}'", + MisClose5 = "missing closing '}'", -- for the captures + + MisClose6 = "missing closing '>'", + MisClose7 = "missing closing '}'", -- for the labels + + MisClose8 = "missing closing ']'", + + MisTerm1 = "missing terminating single quote", + MisTerm2 = "missing terminating double quote", +} + +local function expect (pattern, label) + return pattern + m.T(label) +end + + +-- Pre-defined names +local Predef = { nl = m.P"\n" } + + +local mem +local fmem +local gmem + + +local function updatelocale () + mm.locale(Predef) + Predef.a = Predef.alpha + Predef.c = Predef.cntrl + Predef.d = Predef.digit + Predef.g = Predef.graph + Predef.l = Predef.lower + Predef.p = Predef.punct + Predef.s = Predef.space + Predef.u = Predef.upper + Predef.w = Predef.alnum + Predef.x = Predef.xdigit + Predef.A = any - Predef.a + Predef.C = any - Predef.c + Predef.D = any - Predef.d + Predef.G = any - Predef.g + Predef.L = any - Predef.l + Predef.P = any - Predef.p + Predef.S = any - Predef.s + Predef.U = any - Predef.u + Predef.W = any - Predef.w + Predef.X = any - Predef.x + mem = {} -- restart memoization + fmem = {} + gmem = {} + local mt = {__mode = "v"} + setmetatable(mem, mt) + setmetatable(fmem, mt) + setmetatable(gmem, mt) +end + + +updatelocale() + + + +local I = m.P(function (s,i) print(i, s:sub(1, i-1)); return i end) + + +local function getdef (id, defs) + local c = defs and defs[id] + if not c then + error("undefined name: " .. id) + end + return c +end + + +local function mult (p, n) + local np = mm.P(true) + while n >= 1 do + if n%2 >= 1 then np = np * p end + p = p * p + n = n/2 + end + return np +end + +local function equalcap (s, i, c) + if type(c) ~= "string" then return nil end + local e = #c + i + if s:sub(i, e - 1) == c then return e else return nil end +end + + +local S = (Predef.space + "--" * (any - Predef.nl)^0)^0 + +local name = m.C(m.R("AZ", "az", "__") * m.R("AZ", "az", "__", "09")^0) + +local arrow = S * "<-" + +-- a defined name only have meaning in a given environment +local Def = name * m.Carg(1) + +local num = m.C(m.R"09"^1) * S / tonumber + +local String = "'" * m.C((any - "'" - m.P"\n")^0) * expect("'", "MisTerm1") + + '"' * m.C((any - '"' - m.P"\n")^0) * expect('"', "MisTerm2") + + +local defined = "%" * Def / function (c,Defs) + local cat = Defs and Defs[c] or Predef[c] + if not cat then + error("name '" .. c .. "' undefined") + end + return cat +end + +local Range = m.Cs(any * (m.P"-"/"") * (any - "]")) / mm.R + +local item = defined + Range + m.C(any - m.P"\n") + +local Class = + "[" + * (m.C(m.P"^"^-1)) -- optional complement symbol + * m.Cf(expect(item, "ExpItem") * (item - "]")^0, mt.__add) + / function (c, p) return c == "^" and any - p or p end + * expect("]", "MisClose8") + +local function adddef (t, k, exp) + if t[k] then + -- TODO 改了一下这里的代码,重复定义不会抛错 + --error("'"..k.."' already defined as a rule") + else + t[k] = exp + end + return t +end + +local function firstdef (n, r) return adddef({n}, n, r) end + + +local function NT (n, b) + if not b then + error("rule '"..n.."' used outside a grammar") + else return mm.V(n) + end +end + + +local exp = m.P{ "Exp", + Exp = S * ( m.V"Grammar" + + m.Cf(m.V"Seq" * (S * "/" * expect(S * m.V"Seq", "ExpPatt1"))^0, mt.__add) ); + Seq = m.Cf(m.Cc(m.P"") * m.V"Prefix" * (S * m.V"Prefix")^0, mt.__mul); + Prefix = "&" * expect(S * m.V"Prefix", "ExpPatt2") / mt.__len + + "!" * expect(S * m.V"Prefix", "ExpPatt3") / mt.__unm + + m.V"Suffix"; + Suffix = m.Cf(m.V"Primary" * + ( S * ( m.P"+" * m.Cc(1, mt.__pow) + + m.P"*" * m.Cc(0, mt.__pow) + + m.P"?" * m.Cc(-1, mt.__pow) + + "^" * expect( m.Cg(num * m.Cc(mult)) + + m.Cg(m.C(m.S"+-" * m.R"09"^1) * m.Cc(mt.__pow) + + name * m.Cc"lab" + ), + "ExpNumName") + + "->" * expect(S * ( m.Cg((String + num) * m.Cc(mt.__div)) + + m.P"{}" * m.Cc(nil, m.Ct) + + m.Cg(Def / getdef * m.Cc(mt.__div)) + ), + "ExpCap") + + "=>" * expect(S * m.Cg(Def / getdef * m.Cc(m.Cmt)), + "ExpName1") + ) + )^0, function (a,b,f) if f == "lab" then return a + mm.T(b) else return f(a,b) end end ); + Primary = "(" * expect(m.V"Exp", "ExpPatt4") * expect(S * ")", "MisClose1") + + String / mm.P + + Class + + defined + + "%" * expect(m.P"{", "ExpNameOrLab") + * expect(S * m.V"Label", "ExpLab1") + * expect(S * "}", "MisClose7") / mm.T + + "{:" * (name * ":" + m.Cc(nil)) * expect(m.V"Exp", "ExpPatt5") + * expect(S * ":}", "MisClose2") + / function (n, p) return mm.Cg(p, n) end + + "=" * expect(name, "ExpName2") + / function (n) return mm.Cmt(mm.Cb(n), equalcap) end + + m.P"{}" / mm.Cp + + "{~" * expect(m.V"Exp", "ExpPatt6") + * expect(S * "~}", "MisClose3") / mm.Cs + + "{|" * expect(m.V"Exp", "ExpPatt7") + * expect(S * "|}", "MisClose4") / mm.Ct + + "{" * expect(m.V"Exp", "ExpPattOrClose") + * expect(S * "}", "MisClose5") / mm.C + + m.P"." * m.Cc(any) + + (name * -arrow + "<" * expect(name, "ExpName3") + * expect(">", "MisClose6")) * m.Cb("G") / NT; + Label = num + name; + Definition = name * arrow * expect(m.V"Exp", "ExpPatt8"); + Grammar = m.Cg(m.Cc(true), "G") + * m.Cf(m.V"Definition" / firstdef * (S * m.Cg(m.V"Definition"))^0, + adddef) / mm.P; +} + +local pattern = S * m.Cg(m.Cc(false), "G") * expect(exp, "NoPatt") / mm.P + * S * expect(-any, "ExtraChars") + +local function lineno (s, i) + if i == 1 then return 1, 1 end + local adjustment = 0 + -- report the current line if at end of line, not the next + if s:sub(i,i) == '\n' then + i = i-1 + adjustment = 1 + end + local rest, num = s:sub(1,i):gsub("[^\n]*\n", "") + local r = #rest + return 1 + num, (r ~= 0 and r or 1) + adjustment +end + +local function calcline (s, i) + if i == 1 then return 1, 1 end + local rest, line = s:sub(1,i):gsub("[^\n]*\n", "") + local col = #rest + return 1 + line, col ~= 0 and col or 1 +end + + +local function splitlines(str) + local t = {} + local function helper(line) tinsert(t, line) return "" end + helper((str:gsub("(.-)\r?\n", helper))) + return t +end + +local function compile (p, defs) + if mm.type(p) == "pattern" then return p end -- already compiled + p = p .. " " -- for better reporting of column numbers in errors when at EOF + local ok, cp, label, poserr = pcall(function() return pattern:match(p, 1, defs) end) + if not ok and cp then + if type(cp) == "string" then + cp = cp:gsub("^[^:]+:[^:]+: ", "") + end + error(cp, 3) + end + if not cp then + local lines = splitlines(p) + local line, col = lineno(p, poserr) + local err = {} + tinsert(err, "L" .. line .. ":C" .. col .. ": " .. errinfo[label]) + tinsert(err, lines[line]) + tinsert(err, rep(" ", col-1) .. "^") + error("syntax error(s) in pattern\n" .. concat(err, "\n"), 3) + end + return cp +end + +local function match (s, p, i) + local cp = mem[p] + if not cp then + cp = compile(p) + mem[p] = cp + end + return cp:match(s, i or 1) +end + +local function find (s, p, i) + local cp = fmem[p] + if not cp then + cp = compile(p) / 0 + cp = mm.P{ mm.Cp() * cp * mm.Cp() + 1 * mm.V(1) } + fmem[p] = cp + end + local i, e = cp:match(s, i or 1) + if i then return i, e - 1 + else return i + end +end + +local function gsub (s, p, rep) + local g = gmem[p] or {} -- ensure gmem[p] is not collected while here + gmem[p] = g + local cp = g[rep] + if not cp then + cp = compile(p) + cp = mm.Cs((cp / rep + 1)^0) + g[rep] = cp + end + return cp:match(s) +end + + +-- exported names +local re = { + compile = compile, + match = match, + find = find, + gsub = gsub, + updatelocale = updatelocale, + calcline = calcline +} + +return re diff --git a/script/parser/split.lua b/script/parser/split.lua new file mode 100644 index 00000000..6ce4a4e7 --- /dev/null +++ b/script/parser/split.lua @@ -0,0 +1,9 @@ +local m = require 'lpeglabel' + +local NL = m.P'\r\n' + m.S'\r\n' +local LINE = m.C(1 - NL) + +return function (str) + local MATCH = m.Ct((LINE * NL)^0 * LINE) + return MATCH:match(str) +end |