summaryrefslogtreecommitdiff
path: root/script/parser
diff options
context:
space:
mode:
Diffstat (limited to 'script/parser')
-rw-r--r--script/parser/ast.lua1751
-rw-r--r--script/parser/calcline.lua94
-rw-r--r--script/parser/compile.lua561
-rw-r--r--script/parser/grammar.lua538
-rw-r--r--script/parser/guide.lua3884
-rw-r--r--script/parser/init.lua12
-rw-r--r--script/parser/lines.lua45
-rw-r--r--script/parser/luadoc.lua991
-rw-r--r--script/parser/parse.lua49
-rw-r--r--script/parser/relabel.lua361
-rw-r--r--script/parser/split.lua9
11 files changed, 8295 insertions, 0 deletions
diff --git a/script/parser/ast.lua b/script/parser/ast.lua
new file mode 100644
index 00000000..d8614eae
--- /dev/null
+++ b/script/parser/ast.lua
@@ -0,0 +1,1751 @@
+local tonumber = tonumber
+local stringChar = string.char
+local utf8Char = utf8.char
+local tableUnpack = table.unpack
+local mathType = math.type
+local tableRemove = table.remove
+local pairs = pairs
+local tableSort = table.sort
+
+_ENV = nil
+
+local State
+local PushError
+local PushDiag
+local PushComment
+
+-- goto 单独处理
+local RESERVED = {
+ ['and'] = true,
+ ['break'] = true,
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['end'] = true,
+ ['false'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['if'] = true,
+ ['in'] = true,
+ ['local'] = true,
+ ['nil'] = true,
+ ['not'] = true,
+ ['or'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['true'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+local VersionOp = {
+ ['&'] = {'Lua 5.3', 'Lua 5.4'},
+ ['~'] = {'Lua 5.3', 'Lua 5.4'},
+ ['|'] = {'Lua 5.3', 'Lua 5.4'},
+ ['<<'] = {'Lua 5.3', 'Lua 5.4'},
+ ['>>'] = {'Lua 5.3', 'Lua 5.4'},
+ ['//'] = {'Lua 5.3', 'Lua 5.4'},
+}
+
+local function checkOpVersion(op)
+ local versions = VersionOp[op.type]
+ if not versions then
+ return
+ end
+ for i = 1, #versions do
+ if versions[i] == State.version then
+ return
+ end
+ end
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = op.start,
+ finish = op.finish,
+ version = versions,
+ info = {
+ version = State.version,
+ }
+ }
+end
+
+local function checkMissEnd(start)
+ if not State.MissEndErr then
+ return
+ end
+ local err = State.MissEndErr
+ State.MissEndErr = nil
+ local _, finish = State.lua:find('[%w_]+', start)
+ if not finish then
+ return
+ end
+ err.info.related = {
+ {
+ start = start,
+ finish = finish,
+ }
+ }
+ PushError {
+ type = 'MISS_END',
+ start = start,
+ finish = finish,
+ }
+end
+
+local function getSelect(vararg, index)
+ return {
+ type = 'select',
+ start = vararg.start,
+ finish = vararg.finish,
+ vararg = vararg,
+ index = index,
+ }
+end
+
+local function getValue(values, i)
+ if not values then
+ return nil, nil
+ end
+ local value = values[i]
+ if not value then
+ local last = values[#values]
+ if not last then
+ return nil, nil
+ end
+ if last.type == 'call' or last.type == 'varargs' then
+ return getSelect(last, i - #values + 1)
+ end
+ return nil, nil
+ end
+ if value.type == 'call' or value.type == 'varargs' then
+ value = getSelect(value, 1)
+ end
+ return value
+end
+
+local function createLocal(key, effect, value, attrs)
+ if not key then
+ return nil
+ end
+ key.type = 'local'
+ key.effect = effect
+ key.value = value
+ key.attrs = attrs
+ if value then
+ key.range = value.finish
+ end
+ return key
+end
+
+local function createCall(args, start, finish)
+ if args then
+ args.type = 'callargs'
+ args.start = start
+ args.finish = finish
+ end
+ return {
+ type = 'call',
+ start = start,
+ finish = finish,
+ args = args,
+ }
+end
+
+local function packList(start, list, finish)
+ local lastFinish = start
+ local wantName = true
+ local count = 0
+ for i = 1, #list do
+ local ast = list[i]
+ if ast.type == ',' then
+ if wantName or i == #list then
+ PushError {
+ type = 'UNEXPECT_SYMBOL',
+ start = ast.start,
+ finish = ast.finish,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantName = true
+ else
+ if not wantName then
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = lastFinish,
+ finish = ast.start - 1,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantName = false
+ count = count + 1
+ list[count] = list[i]
+ end
+ lastFinish = ast.finish + 1
+ end
+ for i = count + 1, #list do
+ list[i] = nil
+ end
+ list.type = 'list'
+ list.start = start
+ list.finish = finish - 1
+ return list
+end
+
+local BinaryLevel = {
+ ['or'] = 1,
+ ['and'] = 2,
+ ['<='] = 3,
+ ['>='] = 3,
+ ['<'] = 3,
+ ['>'] = 3,
+ ['~='] = 3,
+ ['=='] = 3,
+ ['|'] = 4,
+ ['~'] = 5,
+ ['&'] = 6,
+ ['<<'] = 7,
+ ['>>'] = 7,
+ ['..'] = 8,
+ ['+'] = 9,
+ ['-'] = 9,
+ ['*'] = 10,
+ ['//'] = 10,
+ ['/'] = 10,
+ ['%'] = 10,
+ ['^'] = 11,
+}
+
+local BinaryForward = {
+ [01] = true,
+ [02] = true,
+ [03] = true,
+ [04] = true,
+ [05] = true,
+ [06] = true,
+ [07] = true,
+ [08] = false,
+ [09] = true,
+ [10] = true,
+ [11] = false,
+}
+
+local Defs = {
+ Nil = function (pos)
+ return {
+ type = 'nil',
+ start = pos,
+ finish = pos + 2,
+ }
+ end,
+ True = function (pos)
+ return {
+ type = 'boolean',
+ start = pos,
+ finish = pos + 3,
+ [1] = true,
+ }
+ end,
+ False = function (pos)
+ return {
+ type = 'boolean',
+ start = pos,
+ finish = pos + 4,
+ [1] = false,
+ }
+ end,
+ ShortComment = function (start, text, finish)
+ PushComment {
+ start = start,
+ finish = finish - 1,
+ text = text,
+ }
+ end,
+ LongComment = function (beforeEq, afterEq, str, missPos)
+ if missPos then
+ local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']'
+ local s, _, w = str:find('(%][%=]*%])[%c%s]*$')
+ if s then
+ PushError {
+ type = 'ERR_LCOMMENT_END',
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'FIX_LCOMMENT_END',
+ {
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = missPos,
+ finish = missPos,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'ADD_LCOMMENT_END',
+ {
+ start = missPos,
+ finish = missPos,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ end,
+ CLongComment = function (start1, finish1, start2, finish2)
+ PushError {
+ type = 'ERR_C_LONG_COMMENT',
+ start = start1,
+ finish = finish2 - 1,
+ fix = {
+ title = 'FIX_C_LONG_COMMENT',
+ {
+ start = start1,
+ finish = finish1 - 1,
+ text = '--[[',
+ },
+ {
+ start = start2,
+ finish = finish2 - 1,
+ text = '--]]'
+ },
+ }
+ }
+ end,
+ CCommentPrefix = function (start, finish)
+ PushError {
+ type = 'ERR_COMMENT_PREFIX',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_COMMENT_PREFIX',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '--',
+ },
+ }
+ }
+ end,
+ String = function (start, quote, str, finish)
+ return {
+ type = 'string',
+ start = start,
+ finish = finish - 1,
+ [1] = str,
+ [2] = quote,
+ }
+ end,
+ LongString = function (beforeEq, afterEq, str, missPos)
+ if missPos then
+ local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']'
+ local s, _, w = str:find('(%][%=]*%])[%c%s]*$')
+ if s then
+ PushError {
+ type = 'ERR_LSTRING_END',
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'FIX_LSTRING_END',
+ {
+ start = missPos - #str + s - 1,
+ finish = missPos - #str + s + #w - 2,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = missPos,
+ finish = missPos,
+ info = {
+ symbol = endSymbol,
+ },
+ fix = {
+ title = 'ADD_LSTRING_END',
+ {
+ start = missPos,
+ finish = missPos,
+ text = endSymbol,
+ }
+ },
+ }
+ end
+ return '[' .. ('='):rep(afterEq-beforeEq) .. '[', str
+ end,
+ Char10 = function (char)
+ char = tonumber(char)
+ if not char or char < 0 or char > 255 then
+ return ''
+ end
+ return stringChar(char)
+ end,
+ Char16 = function (pos, char)
+ if State.version == 'Lua 5.1' then
+ PushError {
+ type = 'ERR_ESC',
+ start = pos-1,
+ finish = pos,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return char
+ end
+ return stringChar(tonumber(char, 16))
+ end,
+ CharUtf8 = function (pos, char)
+ if State.version ~= 'Lua 5.3'
+ and State.version ~= 'Lua 5.4'
+ and State.version ~= 'LuaJIT'
+ then
+ PushError {
+ type = 'ERR_ESC',
+ start = pos-3,
+ finish = pos-2,
+ version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return char
+ end
+ if #char == 0 then
+ PushError {
+ type = 'UTF8_SMALL',
+ start = pos-3,
+ finish = pos,
+ }
+ return ''
+ end
+ local v = tonumber(char, 16)
+ if not v then
+ for i = 1, #char do
+ if not tonumber(char:sub(i, i), 16) then
+ PushError {
+ type = 'MUST_X16',
+ start = pos + i - 1,
+ finish = pos + i - 1,
+ }
+ end
+ end
+ return ''
+ end
+ if State.version == 'Lua 5.4' then
+ if v < 0 or v > 0x7FFFFFFF then
+ PushError {
+ type = 'UTF8_MAX',
+ start = pos-3,
+ finish = pos+#char,
+ info = {
+ min = '00000000',
+ max = '7FFFFFFF',
+ }
+ }
+ end
+ else
+ if v < 0 or v > 0x10FFFF then
+ PushError {
+ type = 'UTF8_MAX',
+ start = pos-3,
+ finish = pos+#char,
+ version = v <= 0x7FFFFFFF and 'Lua 5.4' or nil,
+ info = {
+ min = '000000',
+ max = '10FFFF',
+ }
+ }
+ end
+ end
+ if v >= 0 and v <= 0x10FFFF then
+ return utf8Char(v)
+ end
+ return ''
+ end,
+ Number = function (start, number, finish)
+ local n = tonumber(number)
+ if n then
+ State.LastNumber = {
+ type = 'number',
+ start = start,
+ finish = finish - 1,
+ [1] = n,
+ }
+ return State.LastNumber
+ else
+ PushError {
+ type = 'MALFORMED_NUMBER',
+ start = start,
+ finish = finish - 1,
+ }
+ State.LastNumber = {
+ type = 'number',
+ start = start,
+ finish = finish - 1,
+ [1] = 0,
+ }
+ return State.LastNumber
+ end
+ end,
+ FFINumber = function (start, symbol)
+ local lastNumber = State.LastNumber
+ if mathType(lastNumber[1]) == 'float' then
+ PushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ info = {
+ symbol = symbol,
+ }
+ }
+ lastNumber[1] = 0
+ return
+ end
+ if State.version ~= 'LuaJIT' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ lastNumber[1] = 0
+ end
+ end,
+ ImaginaryNumber = function (start, symbol)
+ local lastNumber = State.LastNumber
+ if State.version ~= 'LuaJIT' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ version = 'LuaJIT',
+ info = {
+ version = State.version,
+ }
+ }
+ end
+ lastNumber[1] = 0
+ end,
+ Name = function (start, str, finish)
+ local isKeyWord
+ if RESERVED[str] then
+ isKeyWord = true
+ elseif str == 'goto' then
+ if State.version ~= 'Lua 5.1' and State.version ~= 'LuaJIT' then
+ isKeyWord = true
+ end
+ end
+ if isKeyWord then
+ PushError {
+ type = 'KEYWORD',
+ start = start,
+ finish = finish - 1,
+ }
+ end
+ return {
+ type = 'name',
+ start = start,
+ finish = finish - 1,
+ [1] = str,
+ }
+ end,
+ GetField = function (dot, field)
+ local obj = {
+ type = 'getfield',
+ field = field,
+ dot = dot,
+ start = dot.start,
+ finish = (field or dot).finish,
+ }
+ if field then
+ field.type = 'field'
+ field.parent = obj
+ end
+ return obj
+ end,
+ GetIndex = function (start, index, finish)
+ local obj = {
+ type = 'getindex',
+ start = start,
+ finish = finish - 1,
+ index = index,
+ }
+ if index then
+ index.parent = obj
+ end
+ return obj
+ end,
+ GetMethod = function (colon, method)
+ local obj = {
+ type = 'getmethod',
+ method = method,
+ colon = colon,
+ start = colon.start,
+ finish = (method or colon).finish,
+ }
+ if method then
+ method.type = 'method'
+ method.parent = obj
+ end
+ return obj
+ end,
+ Single = function (unit)
+ unit.type = 'getname'
+ return unit
+ end,
+ Simple = function (units)
+ local last = units[1]
+ for i = 2, #units do
+ local current = units[i]
+ current.node = last
+ current.start = last.start
+ last.next = current
+ last = units[i]
+ end
+ return last
+ end,
+ SimpleCall = function (call)
+ if call.type ~= 'call' and call.type ~= 'getmethod' then
+ PushError {
+ type = 'EXP_IN_ACTION',
+ start = call.start,
+ finish = call.finish,
+ }
+ end
+ return call
+ end,
+ BinaryOp = function (start, op)
+ return {
+ type = op,
+ start = start,
+ finish = start + #op - 1,
+ }
+ end,
+ UnaryOp = function (start, op)
+ return {
+ type = op,
+ start = start,
+ finish = start + #op - 1,
+ }
+ end,
+ Unary = function (first, ...)
+ if not ... then
+ return nil
+ end
+ local list = {first, ...}
+ local e = list[#list]
+ for i = #list - 1, 1, -1 do
+ local op = list[i]
+ checkOpVersion(op)
+ e = {
+ type = 'unary',
+ op = op,
+ start = op.start,
+ finish = e.finish,
+ [1] = e,
+ }
+ end
+ return e
+ end,
+ SubBinary = function (op, symb)
+ if symb then
+ return op, symb
+ end
+ PushError {
+ type = 'MISS_EXP',
+ start = op.start,
+ finish = op.finish,
+ }
+ end,
+ Binary = function (first, op, second, ...)
+ if not first then
+ return second
+ end
+ if not op then
+ return first
+ end
+ if not ... then
+ checkOpVersion(op)
+ return {
+ type = 'binary',
+ op = op,
+ start = first.start,
+ finish = second.finish,
+ [1] = first,
+ [2] = second,
+ }
+ end
+ local list = {first, op, second, ...}
+ local ops = {}
+ for i = 2, #list, 2 do
+ ops[#ops+1] = i
+ end
+ tableSort(ops, function (a, b)
+ local op1 = list[a]
+ local op2 = list[b]
+ local lv1 = BinaryLevel[op1.type]
+ local lv2 = BinaryLevel[op2.type]
+ if lv1 == lv2 then
+ local forward = BinaryForward[lv1]
+ if forward then
+ return op1.start > op2.start
+ else
+ return op1.start < op2.start
+ end
+ else
+ return lv1 < lv2
+ end
+ end)
+ local final
+ for i = #ops, 1, -1 do
+ local n = ops[i]
+ local op = list[n]
+ local left = list[n-1]
+ local right = list[n+1]
+ local exp = {
+ type = 'binary',
+ op = op,
+ start = left.start,
+ finish = right and right.finish or op.finish,
+ [1] = left,
+ [2] = right,
+ }
+ local leftIndex, rightIndex
+ if list[left] then
+ leftIndex = list[left[1]]
+ else
+ leftIndex = n - 1
+ end
+ if list[right] then
+ rightIndex = list[right[2]]
+ else
+ rightIndex = n + 1
+ end
+
+ list[leftIndex] = exp
+ list[rightIndex] = exp
+ list[left] = leftIndex
+ list[right] = rightIndex
+ list[exp] = n
+ final = exp
+
+ checkOpVersion(op)
+ end
+ return final
+ end,
+ Paren = function (start, exp, finish)
+ if exp and exp.type == 'paren' then
+ exp.start = start
+ exp.finish = finish - 1
+ return exp
+ end
+ return {
+ type = 'paren',
+ start = start,
+ finish = finish - 1,
+ exp = exp
+ }
+ end,
+ VarArgs = function (dots)
+ dots.type = 'varargs'
+ return dots
+ end,
+ PackLoopArgs = function (start, list, finish)
+ local list = packList(start, list, finish)
+ if #list == 0 then
+ PushError {
+ type = 'MISS_LOOP_MIN',
+ start = finish,
+ finish = finish,
+ }
+ elseif #list == 1 then
+ PushError {
+ type = 'MISS_LOOP_MAX',
+ start = finish,
+ finish = finish,
+ }
+ end
+ return list
+ end,
+ PackInNameList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ if #list == 0 then
+ PushError {
+ type = 'MISS_NAME',
+ start = start,
+ finish = finish,
+ }
+ end
+ return list
+ end,
+ PackInExpList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ if #list == 0 then
+ PushError {
+ type = 'MISS_EXP',
+ start = start,
+ finish = finish,
+ }
+ end
+ return list
+ end,
+ PackExpList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ return list
+ end,
+ PackNameList = function (start, list, finish)
+ local list = packList(start, list, finish)
+ return list
+ end,
+ Call = function (start, args, finish)
+ return createCall(args, start, finish-1)
+ end,
+ COMMA = function (start)
+ return {
+ type = ',',
+ start = start,
+ finish = start,
+ }
+ end,
+ SEMICOLON = function (start)
+ return {
+ type = ';',
+ start = start,
+ finish = start,
+ }
+ end,
+ DOTS = function (start)
+ return {
+ type = '...',
+ start = start,
+ finish = start + 2,
+ }
+ end,
+ COLON = function (start)
+ return {
+ type = ':',
+ start = start,
+ finish = start,
+ }
+ end,
+ DOT = function (start)
+ return {
+ type = '.',
+ start = start,
+ finish = start,
+ }
+ end,
+ Function = function (functionStart, functionFinish, args, actions, endStart, endFinish)
+ actions.type = 'function'
+ actions.start = functionStart
+ actions.finish = endFinish - 1
+ actions.args = args
+ actions.keyword= {
+ functionStart, functionFinish - 1,
+ endStart, endFinish - 1,
+ }
+ checkMissEnd(functionStart)
+ return actions
+ end,
+ NamedFunction = function (functionStart, functionFinish, name, args, actions, endStart, endFinish)
+ actions.type = 'function'
+ actions.start = functionStart
+ actions.finish = endFinish - 1
+ actions.args = args
+ actions.keyword= {
+ functionStart, functionFinish - 1,
+ endStart, endFinish - 1,
+ }
+ checkMissEnd(functionStart)
+ if not name then
+ return
+ end
+ if name.type == 'getname' then
+ name.type = 'setname'
+ name.value = actions
+ elseif name.type == 'getfield' then
+ name.type = 'setfield'
+ name.value = actions
+ elseif name.type == 'getmethod' then
+ name.type = 'setmethod'
+ name.value = actions
+ end
+ name.range = actions.finish
+ name.vstart = functionStart
+ return name
+ end,
+ LocalFunction = function (start, functionStart, functionFinish, name, args, actions, endStart, endFinish)
+ actions.type = 'function'
+ actions.start = start
+ actions.finish = endFinish - 1
+ actions.args = args
+ actions.keyword= {
+ functionStart, functionFinish - 1,
+ endStart, endFinish - 1,
+ }
+ checkMissEnd(start)
+
+ if not name then
+ return
+ end
+
+ if name.type ~= 'getname' then
+ PushError {
+ type = 'UNEXPECT_LFUNC_NAME',
+ start = name.start,
+ finish = name.finish,
+ }
+ return
+ end
+
+ local loc = createLocal(name, name.start, actions)
+ loc.localfunction = true
+ loc.vstart = functionStart
+
+ return loc
+ end,
+ Table = function (start, tbl, finish)
+ tbl.type = 'table'
+ tbl.start = start
+ tbl.finish = finish - 1
+ local wantField = true
+ local lastStart = start + 1
+ local fieldCount = 0
+ for i = 1, #tbl do
+ local field = tbl[i]
+ if field.type == ',' or field.type == ';' then
+ if wantField then
+ PushError {
+ type = 'MISS_EXP',
+ start = lastStart,
+ finish = field.start - 1,
+ }
+ end
+ wantField = true
+ lastStart = field.finish + 1
+ else
+ if not wantField then
+ PushError {
+ type = 'MISS_SEP_IN_TABLE',
+ start = lastStart,
+ finish = field.start - 1,
+ }
+ end
+ wantField = false
+ lastStart = field.finish + 1
+ fieldCount = fieldCount + 1
+ tbl[fieldCount] = field
+ end
+ end
+ for i = fieldCount + 1, #tbl do
+ tbl[i] = nil
+ end
+ return tbl
+ end,
+ NewField = function (start, field, value, finish)
+ local obj = {
+ type = 'tablefield',
+ start = start,
+ finish = finish-1,
+ field = field,
+ value = value,
+ }
+ if field then
+ field.type = 'field'
+ field.parent = obj
+ end
+ return obj
+ end,
+ NewIndex = function (start, index, value, finish)
+ local obj = {
+ type = 'tableindex',
+ start = start,
+ finish = finish-1,
+ index = index,
+ value = value,
+ }
+ if index then
+ index.parent = obj
+ end
+ return obj
+ end,
+ FuncArgs = function (start, args, finish)
+ args.type = 'funcargs'
+ args.start = start
+ args.finish = finish - 1
+ local lastStart = start + 1
+ local wantName = true
+ local argCount = 0
+ for i = 1, #args do
+ local arg = args[i]
+ local argAst = arg
+ if argAst.type == ',' then
+ if wantName then
+ PushError {
+ type = 'MISS_NAME',
+ start = lastStart,
+ finish = argAst.start-1,
+ }
+ end
+ wantName = true
+ else
+ if not wantName then
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = lastStart-1,
+ finish = argAst.start-1,
+ info = {
+ symbol = ',',
+ }
+ }
+ end
+ wantName = false
+ argCount = argCount + 1
+
+ if argAst.type == '...' then
+ args[argCount] = arg
+ if i < #args then
+ local a = args[i+1]
+ local b = args[#args]
+ PushError {
+ type = 'ARGS_AFTER_DOTS',
+ start = a.start,
+ finish = b.finish,
+ }
+ end
+ break
+ else
+ args[argCount] = createLocal(arg, arg.start)
+ end
+ end
+ lastStart = argAst.finish + 1
+ end
+ for i = argCount + 1, #args do
+ args[i] = nil
+ end
+ if wantName and argCount > 0 then
+ PushError {
+ type = 'MISS_NAME',
+ start = lastStart,
+ finish = finish - 1,
+ }
+ end
+ return args
+ end,
+ Set = function (start, keys, values, finish)
+ for i = 1, #keys do
+ local key = keys[i]
+ if key.type == 'getname' then
+ key.type = 'setname'
+ key.value = getValue(values, i)
+ elseif key.type == 'getfield' then
+ key.type = 'setfield'
+ key.value = getValue(values, i)
+ elseif key.type == 'getindex' then
+ key.type = 'setindex'
+ key.value = getValue(values, i)
+ end
+ if key.value then
+ key.range = key.value.finish
+ end
+ end
+ if values then
+ for i = #keys+1, #values do
+ local value = values[i]
+ PushDiag('redundant-value', {
+ start = value.start,
+ finish = value.finish,
+ max = #keys,
+ passed = #values,
+ })
+ end
+ end
+ return tableUnpack(keys)
+ end,
+ LocalAttr = function (attrs)
+ if #attrs == 0 then
+ return nil
+ end
+ for i = 1, #attrs do
+ local attr = attrs[i]
+ local attrAst = attr
+ attrAst.type = 'localattr'
+ if State.version ~= 'Lua 5.4' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = attrAst.start,
+ finish = attrAst.finish,
+ version = 'Lua 5.4',
+ info = {
+ version = State.version,
+ }
+ }
+ elseif attrAst[1] ~= 'const' and attrAst[1] ~= 'close' then
+ PushError {
+ type = 'UNKNOWN_TAG',
+ start = attrAst.start,
+ finish = attrAst.finish,
+ info = {
+ tag = attrAst[1],
+ }
+ }
+ elseif i > 1 then
+ PushError {
+ type = 'MULTI_TAG',
+ start = attrAst.start,
+ finish = attrAst.finish,
+ info = {
+ tag = attrAst[1],
+ }
+ }
+ end
+ end
+ attrs.start = attrs[1].start
+ attrs.finish = attrs[#attrs].finish
+ return attrs
+ end,
+ LocalName = function (name, attrs)
+ if not name then
+ return name
+ end
+ name.attrs = attrs
+ return name
+ end,
+ Local = function (start, keys, values, finish)
+ for i = 1, #keys do
+ local key = keys[i]
+ local attrs = key.attrs
+ key.attrs = nil
+ local value = getValue(values, i)
+ createLocal(key, finish, value, attrs)
+ end
+ if values then
+ for i = #keys+1, #values do
+ local value = values[i]
+ PushDiag('redundant-value', {
+ start = value.start,
+ finish = value.finish,
+ max = #keys,
+ passed = #values,
+ })
+ end
+ end
+ return tableUnpack(keys)
+ end,
+ Do = function (start, actions, endA, endB)
+ actions.type = 'do'
+ actions.start = start
+ actions.finish = endB - 1
+ actions.keyword= {
+ start, start + #'do' - 1,
+ endA , endB - 1,
+ }
+ checkMissEnd(start)
+ return actions
+ end,
+ Break = function (start, finish)
+ return {
+ type = 'break',
+ start = start,
+ finish = finish - 1,
+ }
+ end,
+ Return = function (start, exps, finish)
+ exps.type = 'return'
+ exps.start = start
+ exps.finish = finish - 1
+ return exps
+ end,
+ Label = function (start, name, finish)
+ if State.version == 'Lua 5.1' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = finish - 1,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return
+ end
+ if not name then
+ return nil
+ end
+ name.type = 'label'
+ return name
+ end,
+ GoTo = function (start, name, finish)
+ if State.version == 'Lua 5.1' then
+ PushError {
+ type = 'UNSUPPORT_SYMBOL',
+ start = start,
+ finish = finish - 1,
+ version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'},
+ info = {
+ version = State.version,
+ }
+ }
+ return
+ end
+ if not name then
+ return nil
+ end
+ name.type = 'goto'
+ return name
+ end,
+ IfBlock = function (ifStart, ifFinish, exp, thenStart, thenFinish, actions, finish)
+ actions.type = 'ifblock'
+ actions.start = ifStart
+ actions.finish = finish - 1
+ actions.filter = exp
+ actions.keyword= {
+ ifStart, ifFinish - 1,
+ thenStart, thenFinish - 1,
+ }
+ return actions
+ end,
+ ElseIfBlock = function (elseifStart, elseifFinish, exp, thenStart, thenFinish, actions, finish)
+ actions.type = 'elseifblock'
+ actions.start = elseifStart
+ actions.finish = finish - 1
+ actions.filter = exp
+ actions.keyword= {
+ elseifStart, elseifFinish - 1,
+ thenStart, thenFinish - 1,
+ }
+ return actions
+ end,
+ ElseBlock = function (elseStart, elseFinish, actions, finish)
+ actions.type = 'elseblock'
+ actions.start = elseStart
+ actions.finish = finish - 1
+ actions.keyword= {
+ elseStart, elseFinish - 1,
+ }
+ return actions
+ end,
+ If = function (start, blocks, endStart, endFinish)
+ blocks.type = 'if'
+ blocks.start = start
+ blocks.finish = endFinish - 1
+ local hasElse
+ for i = 1, #blocks do
+ local block = blocks[i]
+ if i == 1 and block.type ~= 'ifblock' then
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = block.start,
+ finish = block.start,
+ info = {
+ symbol = 'if',
+ }
+ }
+ end
+ if hasElse then
+ PushError {
+ type = 'BLOCK_AFTER_ELSE',
+ start = block.start,
+ finish = block.finish,
+ }
+ end
+ if block.type == 'elseblock' then
+ hasElse = true
+ end
+ end
+ checkMissEnd(start)
+ return blocks
+ end,
+ Loop = function (forA, forB, arg, steps, doA, doB, blockStart, block, endA, endB)
+ local loc = createLocal(arg, blockStart, steps[1])
+ block.type = 'loop'
+ block.start = forA
+ block.finish = endB - 1
+ block.loc = loc
+ block.max = steps[2]
+ block.step = steps[3]
+ block.keyword= {
+ forA, forB - 1,
+ doA , doB - 1,
+ endA, endB - 1,
+ }
+ checkMissEnd(forA)
+ return block
+ end,
+ In = function (forA, forB, keys, inA, inB, exp, doA, doB, blockStart, block, endA, endB)
+ local func = tableRemove(exp, 1)
+ block.type = 'in'
+ block.start = forA
+ block.finish = endB - 1
+ block.keys = keys
+ block.keyword= {
+ forA, forB - 1,
+ inA , inB - 1,
+ doA , doB - 1,
+ endA, endB - 1,
+ }
+
+ local values
+ if func then
+ local call = createCall(exp, func.finish + 1, exp.finish)
+ call.node = func
+ call.start = func.start
+ func.next = call
+ values = { call }
+ keys.range = call.finish
+ end
+ for i = 1, #keys do
+ local loc = keys[i]
+ if values then
+ createLocal(loc, blockStart, getValue(values, i))
+ else
+ createLocal(loc, blockStart)
+ end
+ end
+ checkMissEnd(forA)
+ return block
+ end,
+ While = function (whileA, whileB, filter, doA, doB, block, endA, endB)
+ block.type = 'while'
+ block.start = whileA
+ block.finish = endB - 1
+ block.filter = filter
+ block.keyword= {
+ whileA, whileB - 1,
+ doA , doB - 1,
+ endA , endB - 1,
+ }
+ checkMissEnd(whileA)
+ return block
+ end,
+ Repeat = function (repeatA, repeatB, block, untilA, untilB, filter, finish)
+ block.type = 'repeat'
+ block.start = repeatA
+ block.finish = finish
+ block.filter = filter
+ block.keyword= {
+ repeatA, repeatB - 1,
+ untilA , untilB - 1,
+ }
+ return block
+ end,
+ Lua = function (start, actions, finish)
+ actions.type = 'main'
+ actions.start = start
+ actions.finish = finish - 1
+ return actions
+ end,
+
+ -- 捕获错误
+ UnknownSymbol = function (start, symbol)
+ PushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ info = {
+ symbol = symbol,
+ }
+ }
+ return
+ end,
+ UnknownAction = function (start, symbol)
+ PushError {
+ type = 'UNKNOWN_SYMBOL',
+ start = start,
+ finish = start + #symbol - 1,
+ info = {
+ symbol = symbol,
+ }
+ }
+ end,
+ DirtyName = function (pos)
+ PushError {
+ type = 'MISS_NAME',
+ start = pos,
+ finish = pos,
+ }
+ return nil
+ end,
+ DirtyExp = function (pos)
+ PushError {
+ type = 'MISS_EXP',
+ start = pos,
+ finish = pos,
+ }
+ return nil
+ end,
+ MissExp = function (pos)
+ PushError {
+ type = 'MISS_EXP',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ MissExponent = function (start, finish)
+ PushError {
+ type = 'MISS_EXPONENT',
+ start = start,
+ finish = finish - 1,
+ }
+ end,
+ MissQuote1 = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '"'
+ }
+ }
+ end,
+ MissQuote2 = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = "'"
+ }
+ }
+ end,
+ MissEscX = function (pos)
+ PushError {
+ type = 'MISS_ESC_X',
+ start = pos-2,
+ finish = pos+1,
+ }
+ end,
+ MissTL = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '{',
+ }
+ }
+ end,
+ MissTR = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '}',
+ }
+ }
+ end,
+ MissBR = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ']',
+ }
+ }
+ end,
+ MissPL = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '(',
+ }
+ }
+ end,
+ MissPR = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ')',
+ }
+ }
+ end,
+ ErrEsc = function (pos)
+ PushError {
+ type = 'ERR_ESC',
+ start = pos-1,
+ finish = pos,
+ }
+ end,
+ MustX16 = function (pos, str)
+ PushError {
+ type = 'MUST_X16',
+ start = pos,
+ finish = pos + #str - 1,
+ }
+ end,
+ MissAssign = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '=',
+ }
+ }
+ end,
+ MissTableSep = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ','
+ }
+ }
+ end,
+ MissField = function (pos)
+ PushError {
+ type = 'MISS_FIELD',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ MissMethod = function (pos)
+ PushError {
+ type = 'MISS_METHOD',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ MissLabel = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = '::',
+ }
+ }
+ end,
+ MissEnd = function (pos)
+ State.MissEndErr = PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'end',
+ }
+ }
+ return pos, pos
+ end,
+ MissDo = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'do',
+ }
+ }
+ return pos, pos
+ end,
+ MissComma = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = ',',
+ }
+ }
+ end,
+ MissIn = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'in',
+ }
+ }
+ return pos, pos
+ end,
+ MissUntil = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'until',
+ }
+ }
+ return pos, pos
+ end,
+ MissThen = function (pos)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = pos,
+ finish = pos,
+ info = {
+ symbol = 'then',
+ }
+ }
+ return pos, pos
+ end,
+ MissName = function (pos)
+ PushError {
+ type = 'MISS_NAME',
+ start = pos,
+ finish = pos,
+ }
+ end,
+ ExpInAction = function (start, exp, finish)
+ PushError {
+ type = 'EXP_IN_ACTION',
+ start = start,
+ finish = finish - 1,
+ }
+ -- 当exp为nil时,不能返回任何值,否则会产生带洞的actionlist
+ if exp then
+ return exp
+ else
+ return
+ end
+ end,
+ MissIf = function (start, block)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = start,
+ finish = start,
+ info = {
+ symbol = 'if',
+ }
+ }
+ return block
+ end,
+ MissGT = function (start)
+ PushError {
+ type = 'MISS_SYMBOL',
+ start = start,
+ finish = start,
+ info = {
+ symbol = '>'
+ }
+ }
+ end,
+ ErrAssign = function (start, finish)
+ PushError {
+ type = 'ERR_ASSIGN_AS_EQ',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_ASSIGN_AS_EQ',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '=',
+ }
+ }
+ }
+ end,
+ ErrEQ = function (start, finish)
+ PushError {
+ type = 'ERR_EQ_AS_ASSIGN',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_EQ_AS_ASSIGN',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '==',
+ }
+ }
+ }
+ return '=='
+ end,
+ ErrUEQ = function (start, finish)
+ PushError {
+ type = 'ERR_UEQ',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_UEQ',
+ {
+ start = start,
+ finish = finish - 1,
+ text = '~=',
+ }
+ }
+ }
+ return '=='
+ end,
+ ErrThen = function (start, finish)
+ PushError {
+ type = 'ERR_THEN_AS_DO',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_THEN_AS_DO',
+ {
+ start = start,
+ finish = finish - 1,
+ text = 'then',
+ }
+ }
+ }
+ return start, finish
+ end,
+ ErrDo = function (start, finish)
+ PushError {
+ type = 'ERR_DO_AS_THEN',
+ start = start,
+ finish = finish - 1,
+ fix = {
+ title = 'FIX_DO_AS_THEN',
+ {
+ start = start,
+ finish = finish - 1,
+ text = 'do',
+ }
+ }
+ }
+ return start, finish
+ end,
+}
+
+local function init(state)
+ State = state
+ PushError = state.pushError
+ PushDiag = state.pushDiag
+ PushComment = state.pushComment
+end
+
+local function close()
+ State = nil
+ PushError = function (...) end
+ PushDiag = function (...) end
+ PushComment = function (...) end
+end
+
+return {
+ defs = Defs,
+ init = init,
+ close = close,
+}
diff --git a/script/parser/calcline.lua b/script/parser/calcline.lua
new file mode 100644
index 00000000..2e944167
--- /dev/null
+++ b/script/parser/calcline.lua
@@ -0,0 +1,94 @@
+local m = require 'lpeglabel'
+local util = require 'utility'
+
+local row
+local fl
+local NL = (m.P'\r\n' + m.S'\r\n') * m.Cp() / function (pos)
+ row = row + 1
+ fl = pos
+end
+local ROWCOL = (NL + m.P(1))^0
+local function rowcol(str, n)
+ row = 1
+ fl = 1
+ ROWCOL:match(str:sub(1, n))
+ local col = n - fl + 1
+ return row, col
+end
+
+local function rowcol_utf8(str, n)
+ row = 1
+ fl = 1
+ ROWCOL:match(str:sub(1, n))
+ return row, util.utf8Len(str, fl, n)
+end
+
+local function position(str, _row, _col)
+ local cur = 1
+ local row = 1
+ while true do
+ if row == _row then
+ return cur + _col - 1
+ elseif row > _row then
+ return cur - 1
+ end
+ local pos = str:find('[\r\n]', cur)
+ if not pos then
+ return #str
+ end
+ row = row + 1
+ if str:sub(pos, pos+1) == '\r\n' then
+ cur = pos + 2
+ else
+ cur = pos + 1
+ end
+ end
+end
+
+local function position_utf8(str, _row, _col)
+ local cur = 1
+ local row = 1
+ while true do
+ if row == _row then
+ return utf8.offset(str, _col, cur)
+ elseif row > _row then
+ return cur - 1
+ end
+ local pos = str:find('[\r\n]', cur)
+ if not pos then
+ return #str
+ end
+ row = row + 1
+ if str:sub(pos, pos+1) == '\r\n' then
+ cur = pos + 2
+ else
+ cur = pos + 1
+ end
+ end
+end
+
+local NL = m.P'\r\n' + m.S'\r\n'
+
+local function line(str, row)
+ local count = 0
+ local res
+ local LINE = m.Cmt((1 - NL)^0, function (_, _, c)
+ count = count + 1
+ if count == row then
+ res = c
+ return false
+ end
+ return true
+ end)
+ local MATCH = (LINE * NL)^0 * LINE
+ MATCH:match(str)
+ return res
+end
+
+return {
+ rowcol = rowcol,
+ rowcol_utf8 = rowcol_utf8,
+ position = position,
+ position_utf8 = position_utf8,
+ line = line,
+}
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
new file mode 100644
index 00000000..2c7172e8
--- /dev/null
+++ b/script/parser/compile.lua
@@ -0,0 +1,561 @@
+local guide = require 'parser.guide'
+local type = type
+
+local specials = {
+ ['_G'] = true,
+ ['rawset'] = true,
+ ['rawget'] = true,
+ ['setmetatable'] = true,
+ ['require'] = true,
+ ['dofile'] = true,
+ ['loadfile'] = true,
+ ['pcall'] = true,
+ ['xpcall'] = true,
+}
+
+_ENV = nil
+
+local LocalLimit = 200
+local pushError, Compile, CompileBlock, Block, GoToTag, ENVMode, Compiled, LocalCount, Version, Root, Options
+
+local function addRef(node, obj)
+ if not node.ref then
+ node.ref = {}
+ end
+ node.ref[#node.ref+1] = obj
+ obj.node = node
+end
+
+local function addSpecial(name, obj)
+ if not Root.specials then
+ Root.specials = {}
+ end
+ if not Root.specials[name] then
+ Root.specials[name] = {}
+ end
+ Root.specials[name][#Root.specials[name]+1] = obj
+ obj.special = name
+end
+
+local vmMap = {
+ ['getname'] = function (obj)
+ local loc = guide.getLocal(obj, obj[1], obj.start)
+ if loc then
+ obj.type = 'getlocal'
+ obj.loc = loc
+ addRef(loc, obj)
+ if loc.special then
+ addSpecial(loc.special, obj)
+ end
+ else
+ obj.type = 'getglobal'
+ local node = guide.getLocal(obj, ENVMode, obj.start)
+ if node then
+ addRef(node, obj)
+ end
+ local name = obj[1]
+ if specials[name] then
+ addSpecial(name, obj)
+ elseif Options and Options.special then
+ local asName = Options.special[name]
+ if specials[asName] then
+ addSpecial(asName, obj)
+ end
+ end
+ end
+ return obj
+ end,
+ ['getfield'] = function (obj)
+ Compile(obj.node, obj)
+ end,
+ ['call'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.args, obj)
+ end,
+ ['callargs'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['binary'] = function (obj)
+ Compile(obj[1], obj)
+ Compile(obj[2], obj)
+ end,
+ ['unary'] = function (obj)
+ Compile(obj[1], obj)
+ end,
+ ['varargs'] = function (obj)
+ local func = guide.getParentFunction(obj)
+ if func then
+ local index, vararg = guide.getFunctionVarArgs(func)
+ if not index then
+ pushError {
+ type = 'UNEXPECT_DOTS',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ if vararg then
+ if not vararg.ref then
+ vararg.ref = {}
+ end
+ vararg.ref[#vararg.ref+1] = obj
+ end
+ end
+ end,
+ ['paren'] = function (obj)
+ Compile(obj.exp, obj)
+ end,
+ ['getindex'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.index, obj)
+ end,
+ ['setindex'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.index, obj)
+ Compile(obj.value, obj)
+ end,
+ ['getmethod'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.method, obj)
+ end,
+ ['setmethod'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.method, obj)
+ local value = obj.value
+ value.localself = {
+ type = 'local',
+ start = 0,
+ finish = 0,
+ method = obj,
+ effect = obj.finish,
+ tag = 'self',
+ [1] = 'self',
+ }
+ Compile(value, obj)
+ end,
+ ['function'] = function (obj)
+ local lastBlock = Block
+ local LastLocalCount = LocalCount
+ Block = obj
+ LocalCount = 0
+ if obj.localself then
+ Compile(obj.localself, obj)
+ obj.localself = nil
+ end
+ Compile(obj.args, obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ Block = lastBlock
+ LocalCount = LastLocalCount
+ end,
+ ['funcargs'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['table'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['tablefield'] = function (obj)
+ Compile(obj.value, obj)
+ end,
+ ['tableindex'] = function (obj)
+ Compile(obj.index, obj)
+ Compile(obj.value, obj)
+ end,
+ ['index'] = function (obj)
+ Compile(obj.index, obj)
+ end,
+ ['select'] = function (obj)
+ local vararg = obj.vararg
+ if vararg.parent then
+ if not vararg.extParent then
+ vararg.extParent = {}
+ end
+ vararg.extParent[#vararg.extParent+1] = obj
+ else
+ Compile(vararg, obj)
+ end
+ end,
+ ['setname'] = function (obj)
+ Compile(obj.value, obj)
+ local loc = guide.getLocal(obj, obj[1], obj.start)
+ if loc then
+ obj.type = 'setlocal'
+ obj.loc = loc
+ addRef(loc, obj)
+ if loc.attrs then
+ local const
+ for i = 1, #loc.attrs do
+ local attr = loc.attrs[i][1]
+ if attr == 'const'
+ or attr == 'close' then
+ const = true
+ break
+ end
+ end
+ if const then
+ pushError {
+ type = 'SET_CONST',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ end
+ else
+ obj.type = 'setglobal'
+ local node = guide.getLocal(obj, ENVMode, obj.start)
+ if node then
+ addRef(node, obj)
+ end
+ local name = obj[1]
+ if specials[name] then
+ addSpecial(name, obj)
+ elseif Options and Options.special then
+ local asName = Options.special[name]
+ if specials[asName] then
+ addSpecial(asName, obj)
+ end
+ end
+ end
+ end,
+ ['local'] = function (obj)
+ local attrs = obj.attrs
+ if attrs then
+ for i = 1, #attrs do
+ Compile(attrs[i], obj)
+ end
+ end
+ if Block then
+ if not Block.locals then
+ Block.locals = {}
+ end
+ Block.locals[#Block.locals+1] = obj
+ LocalCount = LocalCount + 1
+ if LocalCount > LocalLimit then
+ pushError {
+ type = 'LOCAL_LIMIT',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ end
+ if obj.localfunction then
+ obj.localfunction = nil
+ end
+ Compile(obj.value, obj)
+ if obj.value and obj.value.special then
+ addSpecial(obj.value.special, obj)
+ end
+ end,
+ ['setfield'] = function (obj)
+ Compile(obj.node, obj)
+ Compile(obj.value, obj)
+ end,
+ ['do'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['return'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ if Block and Block[#Block] ~= obj then
+ pushError {
+ type = 'ACTION_AFTER_RETURN',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ local func = guide.getParentFunction(obj)
+ if func then
+ if not func.returns then
+ func.returns = {}
+ end
+ func.returns[#func.returns+1] = obj
+ end
+ end,
+ ['label'] = function (obj)
+ local block = guide.getBlock(obj)
+ if block then
+ if not block.labels then
+ block.labels = {}
+ end
+ local name = obj[1]
+ local label = guide.getLabel(block, name)
+ if label then
+ if Version == 'Lua 5.4'
+ or block == guide.getBlock(label) then
+ pushError {
+ type = 'REDEFINED_LABEL',
+ start = obj.start,
+ finish = obj.finish,
+ relative = {
+ {
+ label.start,
+ label.finish,
+ }
+ }
+ }
+ end
+ end
+ block.labels[name] = obj
+ end
+ end,
+ ['goto'] = function (obj)
+ GoToTag[#GoToTag+1] = obj
+ end,
+ ['if'] = function (obj)
+ for i = 1, #obj do
+ Compile(obj[i], obj)
+ end
+ end,
+ ['ifblock'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.filter, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['elseifblock'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.filter, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['elseblock'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['loop'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.loc, obj)
+ Compile(obj.max, obj)
+ Compile(obj.step, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['in'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ local keys = obj.keys
+ for i = 1, #keys do
+ Compile(keys[i], obj)
+ end
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['while'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ Compile(obj.filter, obj)
+ CompileBlock(obj, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['repeat'] = function (obj)
+ local lastBlock = Block
+ Block = obj
+ CompileBlock(obj, obj)
+ Compile(obj.filter, obj)
+ if Block.locals then
+ LocalCount = LocalCount - #Block.locals
+ end
+ Block = lastBlock
+ end,
+ ['break'] = function (obj)
+ local block = guide.getBreakBlock(obj)
+ if block then
+ if not block.breaks then
+ block.breaks = {}
+ end
+ block.breaks[#block.breaks+1] = obj
+ else
+ pushError {
+ type = 'BREAK_OUTSIDE',
+ start = obj.start,
+ finish = obj.finish,
+ }
+ end
+ end,
+ ['main'] = function (obj)
+ Block = obj
+ Compile({
+ type = 'local',
+ start = 0,
+ finish = 0,
+ effect = 0,
+ tag = '_ENV',
+ special= '_G',
+ [1] = ENVMode,
+ }, obj)
+ --- _ENV 是上值,不计入局部变量计数
+ LocalCount = 0
+ CompileBlock(obj, obj)
+ Block = nil
+ end,
+}
+
+function CompileBlock(obj, parent)
+ for i = 1, #obj do
+ local act = obj[i]
+ local f = vmMap[act.type]
+ if f then
+ act.parent = parent
+ f(act)
+ end
+ end
+end
+
+function Compile(obj, parent)
+ if not obj then
+ return nil
+ end
+ if Compiled[obj] then
+ return
+ end
+ Compiled[obj] = true
+ obj.parent = parent
+ local f = vmMap[obj.type]
+ if not f then
+ return
+ end
+ f(obj)
+end
+
+local function compileGoTo(obj)
+ local name = obj[1]
+ local label = guide.getLabel(obj, name)
+ if not label then
+ pushError {
+ type = 'NO_VISIBLE_LABEL',
+ start = obj.start,
+ finish = obj.finish,
+ info = {
+ label = name,
+ }
+ }
+ return
+ end
+ if not label.ref then
+ label.ref = {}
+ end
+ label.ref[#label.ref+1] = obj
+ obj.node = label
+
+ -- 如果有局部变量在 goto 与 label 之间声明,
+ -- 并在 label 之后使用,则算作语法错误
+
+ -- 如果 label 在 goto 之前声明,那么不会有中间声明的局部变量
+ if obj.start > label.start then
+ return
+ end
+
+ local block = guide.getBlock(obj)
+ local locals = block and block.locals
+ if not locals then
+ return
+ end
+
+ for i = 1, #locals do
+ local loc = locals[i]
+ -- 检查局部变量声明位置为 goto 与 label 之间
+ if loc.start < obj.start or loc.finish > label.finish then
+ goto CONTINUE
+ end
+ -- 检查局部变量的使用位置在 label 之后
+ local refs = loc.ref
+ if not refs then
+ goto CONTINUE
+ end
+ for j = 1, #refs do
+ local ref = refs[j]
+ if ref.finish > label.finish then
+ pushError {
+ type = 'JUMP_LOCAL_SCOPE',
+ start = obj.start,
+ finish = obj.finish,
+ info = {
+ loc = loc[1],
+ },
+ relative = {
+ {
+ start = label.start,
+ finish = label.finish,
+ },
+ {
+ start = loc.start,
+ finish = loc.finish,
+ }
+ },
+ }
+ return
+ end
+ end
+ ::CONTINUE::
+ end
+end
+
+local function PostCompile()
+ for i = 1, #GoToTag do
+ compileGoTo(GoToTag[i])
+ end
+end
+
+return function (self, lua, mode, version, options)
+ local state, err = self:parse(lua, mode, version)
+ if not state then
+ return nil, err
+ end
+ pushError = state.pushError
+ if version == 'Lua 5.1' or version == 'LuaJIT' then
+ ENVMode = '@fenv'
+ else
+ ENVMode = '_ENV'
+ end
+ Compiled = {}
+ GoToTag = {}
+ LocalCount = 0
+ Version = version
+ Root = state.ast
+ Root.state = state
+ Options = options
+ state.ENVMode = ENVMode
+ if type(state.ast) == 'table' then
+ Compile(state.ast)
+ end
+ PostCompile()
+ Compiled = nil
+ GoToTag = nil
+ return state
+end
diff --git a/script/parser/grammar.lua b/script/parser/grammar.lua
new file mode 100644
index 00000000..06dae246
--- /dev/null
+++ b/script/parser/grammar.lua
@@ -0,0 +1,538 @@
+local re = require 'parser.relabel'
+local m = require 'lpeglabel'
+local ast = require 'parser.ast'
+
+local scriptBuf = ''
+local compiled = {}
+local defs = ast.defs
+
+-- goto 可以作为名字,合法性之后处理
+local RESERVED = {
+ ['and'] = true,
+ ['break'] = true,
+ ['do'] = true,
+ ['else'] = true,
+ ['elseif'] = true,
+ ['end'] = true,
+ ['false'] = true,
+ ['for'] = true,
+ ['function'] = true,
+ ['if'] = true,
+ ['in'] = true,
+ ['local'] = true,
+ ['nil'] = true,
+ ['not'] = true,
+ ['or'] = true,
+ ['repeat'] = true,
+ ['return'] = true,
+ ['then'] = true,
+ ['true'] = true,
+ ['until'] = true,
+ ['while'] = true,
+}
+
+defs.nl = (m.P'\r\n' + m.S'\r\n')
+defs.s = m.S' \t'
+defs.S = - defs.s
+defs.ea = '\a'
+defs.eb = '\b'
+defs.ef = '\f'
+defs.en = '\n'
+defs.er = '\r'
+defs.et = '\t'
+defs.ev = '\v'
+defs['nil'] = m.Cp() / function () return nil end
+defs['false'] = m.Cp() / function () return false end
+defs.NotReserved = function (_, _, str)
+ if RESERVED[str] then
+ return false
+ end
+ return true
+end
+defs.Reserved = function (_, _, str)
+ if RESERVED[str] then
+ return true
+ end
+ return false
+end
+defs.None = function () end
+defs.np = m.Cp() / function (n) return n+1 end
+
+m.setmaxstack(1000)
+
+local eof = re.compile '!. / %{SYNTAX_ERROR}'
+
+local function grammar(tag)
+ return function (script)
+ scriptBuf = script .. '\r\n' .. scriptBuf
+ compiled[tag] = re.compile(scriptBuf, defs) * eof
+ end
+end
+
+local function errorpos(pos, err)
+ return {
+ type = 'UNKNOWN',
+ start = pos or 0,
+ finish = pos or 0,
+ err = err,
+ }
+end
+
+grammar 'Comment' [[
+Comment <- LongComment
+ / '--' ShortComment
+LongComment <- ('--[' {} {:eq: '='* :} {} '['
+ {(!CommentClose .)*}
+ (CommentClose / {}))
+ -> LongComment
+ / (
+ {} '/*' {}
+ (!'*/' .)*
+ {} '*/' {}
+ )
+ -> CLongComment
+CommentClose <- ']' =eq ']'
+ShortComment <- ({} {(!%nl .)*} {})
+ -> ShortComment
+]]
+
+grammar 'Sp' [[
+Sp <- (Comment / %nl / %s)*
+Sps <- (Comment / %nl / %s)+
+]]
+
+grammar 'Common' [[
+Word <- [a-zA-Z0-9_]
+Cut <- !Word
+X16 <- [a-fA-F0-9]
+Rest <- (!%nl .)*
+
+AND <- Sp {'and'} Cut
+BREAK <- Sp 'break' Cut
+FALSE <- Sp 'false' Cut
+GOTO <- Sp 'goto' Cut
+LOCAL <- Sp 'local' Cut
+NIL <- Sp 'nil' Cut
+NOT <- Sp 'not' Cut
+OR <- Sp {'or'} Cut
+RETURN <- Sp 'return' Cut
+TRUE <- Sp 'true' Cut
+
+DO <- Sp {} 'do' {} Cut
+ / Sp({} 'then' {} Cut) -> ErrDo
+IF <- Sp {} 'if' {} Cut
+ELSE <- Sp {} 'else' {} Cut
+ELSEIF <- Sp {} 'elseif' {} Cut
+END <- Sp {} 'end' {} Cut
+FOR <- Sp {} 'for' {} Cut
+FUNCTION <- Sp {} 'function' {} Cut
+IN <- Sp {} 'in' {} Cut
+REPEAT <- Sp {} 'repeat' {} Cut
+THEN <- Sp {} 'then' {} Cut
+ / Sp({} 'do' {} Cut) -> ErrThen
+UNTIL <- Sp {} 'until' {} Cut
+WHILE <- Sp {} 'while' {} Cut
+
+
+Esc <- '\' -> ''
+ EChar
+EChar <- 'a' -> ea
+ / 'b' -> eb
+ / 'f' -> ef
+ / 'n' -> en
+ / 'r' -> er
+ / 't' -> et
+ / 'v' -> ev
+ / '\'
+ / '"'
+ / "'"
+ / %nl
+ / ('z' (%nl / %s)*) -> ''
+ / ({} 'x' {X16 X16}) -> Char16
+ / ([0-9] [0-9]? [0-9]?) -> Char10
+ / ('u{' {} {Word*} '}') -> CharUtf8
+ -- 错误处理
+ / 'x' {} -> MissEscX
+ / 'u' !'{' {} -> MissTL
+ / 'u{' Word* !'}' {} -> MissTR
+ / {} -> ErrEsc
+
+BOR <- Sp {'|'}
+BXOR <- Sp {'~'} !'='
+BAND <- Sp {'&'}
+Bshift <- Sp {BshiftList}
+BshiftList <- '<<'
+ / '>>'
+Concat <- Sp {'..'}
+Adds <- Sp {AddsList}
+AddsList <- '+'
+ / '-'
+Muls <- Sp {MulsList}
+MulsList <- '*'
+ / '//'
+ / '/'
+ / '%'
+Unary <- Sp {} {UnaryList}
+UnaryList <- NOT
+ / '#'
+ / '-'
+ / '~' !'='
+POWER <- Sp {'^'}
+
+BinaryOp <-( Sp {} {'or'} Cut
+ / Sp {} {'and'} Cut
+ / Sp {} {'<=' / '>=' / '<'!'<' / '>'!'>' / '~=' / '=='}
+ / Sp {} ({} '=' {}) -> ErrEQ
+ / Sp {} ({} '!=' {}) -> ErrUEQ
+ / Sp {} {'|'}
+ / Sp {} {'~'}
+ / Sp {} {'&'}
+ / Sp {} {'<<' / '>>'}
+ / Sp {} {'..'} !'.'
+ / Sp {} {'+' / '-'}
+ / Sp {} {'*' / '//' / '/' / '%'}
+ / Sp {} {'^'}
+ )-> BinaryOp
+UnaryOp <-( Sp {} {'not' Cut / '#' / '~' !'=' / '-' !'-'}
+ )-> UnaryOp
+
+PL <- Sp '('
+PR <- Sp ')'
+BL <- Sp '[' !'[' !'='
+BR <- Sp ']'
+TL <- Sp '{'
+TR <- Sp '}'
+COMMA <- Sp ({} ',')
+ -> COMMA
+SEMICOLON <- Sp ({} ';')
+ -> SEMICOLON
+DOTS <- Sp ({} '...')
+ -> DOTS
+DOT <- Sp ({} '.' !'.')
+ -> DOT
+COLON <- Sp ({} ':' !':')
+ -> COLON
+LABEL <- Sp '::'
+ASSIGN <- Sp '=' !'='
+AssignOrEQ <- Sp ({} '==' {})
+ -> ErrAssign
+ / Sp '='
+
+DirtyBR <- BR / {} -> MissBR
+DirtyTR <- TR / {} -> MissTR
+DirtyPR <- PR / {} -> MissPR
+DirtyLabel <- LABEL / {} -> MissLabel
+NeedEnd <- END / {} -> MissEnd
+NeedDo <- DO / {} -> MissDo
+NeedAssign <- ASSIGN / {} -> MissAssign
+NeedComma <- COMMA / {} -> MissComma
+NeedIn <- IN / {} -> MissIn
+NeedUntil <- UNTIL / {} -> MissUntil
+NeedThen <- THEN / {} -> MissThen
+]]
+
+grammar 'Nil' [[
+Nil <- Sp ({} -> Nil) NIL
+]]
+
+grammar 'Boolean' [[
+Boolean <- Sp ({} -> True) TRUE
+ / Sp ({} -> False) FALSE
+]]
+
+grammar 'String' [[
+String <- Sp ({} StringDef {})
+ -> String
+StringDef <- {'"'}
+ {~(Esc / !%nl !'"' .)*~} -> 1
+ ('"' / {} -> MissQuote1)
+ / {"'"}
+ {~(Esc / !%nl !"'" .)*~} -> 1
+ ("'" / {} -> MissQuote2)
+ / ('[' {} {:eq: '='* :} {} '[' %nl?
+ {(!StringClose .)*} -> 1
+ (StringClose / {}))
+ -> LongString
+StringClose <- ']' =eq ']'
+]]
+
+grammar 'Number' [[
+Number <- Sp ({} {NumberDef} {}) -> Number
+ NumberSuffix?
+ ErrNumber?
+NumberDef <- Number16 / Number10
+NumberSuffix<- ({} {[uU]? [lL] [lL]}) -> FFINumber
+ / ({} {[iI]}) -> ImaginaryNumber
+ErrNumber <- ({} {([0-9a-zA-Z] / '.')+}) -> UnknownSymbol
+
+Number10 <- Float10 Float10Exp?
+ / Integer10 Float10? Float10Exp?
+Integer10 <- [0-9]+ ('.' [0-9]*)?
+Float10 <- '.' [0-9]+
+Float10Exp <- [eE] [+-]? [0-9]+
+ / ({} [eE] [+-]? {}) -> MissExponent
+
+Number16 <- '0' [xX] Float16 Float16Exp?
+ / '0' [xX] Integer16 Float16? Float16Exp?
+Integer16 <- X16+ ('.' X16*)?
+ / ({} {Word*}) -> MustX16
+Float16 <- '.' X16+
+ / '.' ({} {Word*}) -> MustX16
+Float16Exp <- [pP] [+-]? [0-9]+
+ / ({} [pP] [+-]? {}) -> MissExponent
+]]
+
+grammar 'Name' [[
+Name <- Sp ({} NameBody {})
+ -> Name
+NameBody <- {[a-zA-Z_] [a-zA-Z0-9_]*}
+FreeName <- Sp ({} {NameBody=>NotReserved} {})
+ -> Name
+KeyWord <- Sp NameBody=>Reserved
+MustName <- Name / DirtyName
+DirtyName <- {} -> DirtyName
+]]
+
+grammar 'Exp' [[
+Exp <- (UnUnit BinUnit*)
+ -> Binary
+BinUnit <- (BinaryOp UnUnit?)
+ -> SubBinary
+UnUnit <- ExpUnit
+ / (UnaryOp+ (ExpUnit / MissExp))
+ -> Unary
+ExpUnit <- Nil
+ / Boolean
+ / String
+ / Number
+ / Dots
+ / Table
+ / Function
+ / Simple
+
+Simple <- {| Prefix (Sp Suffix)* |}
+ -> Simple
+Prefix <- Sp ({} PL DirtyExp DirtyPR {})
+ -> Paren
+ / Single
+Single <- FreeName
+ -> Single
+Suffix <- SuffixWithoutCall
+ / ({} PL SuffixCall DirtyPR {})
+ -> Call
+SuffixCall <- Sp ({} {| (COMMA / Exp)+ |} {})
+ -> PackExpList
+ / %nil
+SuffixWithoutCall
+ <- (DOT (Name / MissField))
+ -> GetField
+ / ({} BL DirtyExp DirtyBR {})
+ -> GetIndex
+ / (COLON (Name / MissMethod) NeedCall)
+ -> GetMethod
+ / ({} {| Table |} {})
+ -> Call
+ / ({} {| String |} {})
+ -> Call
+NeedCall <- (!(Sp CallStart) {} -> MissPL)?
+MissField <- {} -> MissField
+MissMethod <- {} -> MissMethod
+CallStart <- PL
+ / TL
+ / '"'
+ / "'"
+ / '[' '='* '['
+
+DirtyExp <- Exp
+ / {} -> DirtyExp
+MaybeExp <- Exp / MissExp
+MissExp <- {} -> MissExp
+ExpList <- Sp {| MaybeExp (Sp ',' MaybeExp)* |}
+
+Dots <- DOTS
+ -> VarArgs
+
+Table <- Sp ({} TL {| TableField* |} DirtyTR {})
+ -> Table
+TableField <- COMMA
+ / SEMICOLON
+ / NewIndex
+ / NewField
+ / Exp
+Index <- BL DirtyExp DirtyBR
+NewIndex <- Sp ({} Index NeedAssign DirtyExp {})
+ -> NewIndex
+NewField <- Sp ({} MustName ASSIGN DirtyExp {})
+ -> NewField
+
+Function <- FunctionBody
+ -> Function
+FuncArgs <- Sp ({} PL {| FuncArg+ |} DirtyPR {})
+ -> FuncArgs
+ / PL DirtyPR %nil
+FuncArgsMiss<- {} -> MissPL DirtyPR %nil
+FuncArg <- DOTS
+ / Name
+ / COMMA
+FunctionBody<- FUNCTION FuncArgs
+ {| (!END Action)* |}
+ NeedEnd
+ / FUNCTION FuncArgsMiss
+ {| %nil |}
+ NeedEnd
+
+-- 纯占位,修改了 `relabel.lua` 使重复定义不抛错
+Action <- !END .
+]]
+
+grammar 'Action' [[
+Action <- Sp (CrtAction / UnkAction)
+CrtAction <- Semicolon
+ / Do
+ / Break
+ / Return
+ / Label
+ / GoTo
+ / If
+ / For
+ / While
+ / Repeat
+ / NamedFunction
+ / LocalFunction
+ / Local
+ / Set
+ / Call
+ / ExpInAction
+UnkAction <- ({} {Word+})
+ -> UnknownAction
+ / ({} '//' {} (LongComment / ShortComment))
+ -> CCommentPrefix
+ / ({} {. (!Sps !CrtAction .)*})
+ -> UnknownAction
+ExpInAction <- Sp ({} Exp {})
+ -> ExpInAction
+
+Semicolon <- Sp ';'
+SimpleList <- {| Simple (Sp ',' Simple)* |}
+
+Do <- Sp ({}
+ 'do' Cut
+ {| (!END Action)* |}
+ NeedEnd)
+ -> Do
+
+Break <- Sp ({} BREAK {})
+ -> Break
+
+Return <- Sp ({} RETURN ReturnExpList {})
+ -> Return
+ReturnExpList
+ <- Sp {| Exp (Sp ',' MaybeExp)* |}
+ / Sp {| !Exp !',' |}
+ / ExpList
+
+Label <- Sp ({} LABEL MustName DirtyLabel {})
+ -> Label
+
+GoTo <- Sp ({} GOTO MustName {})
+ -> GoTo
+
+If <- Sp ({} {| IfHead IfBody* |} NeedEnd)
+ -> If
+
+IfHead <- Sp (IfPart {}) -> IfBlock
+ / Sp (ElseIfPart {}) -> ElseIfBlock
+ / Sp (ElsePart {}) -> ElseBlock
+IfBody <- Sp (ElseIfPart {}) -> ElseIfBlock
+ / Sp (ElsePart {}) -> ElseBlock
+IfPart <- IF DirtyExp NeedThen
+ {| (!ELSEIF !ELSE !END Action)* |}
+ElseIfPart <- ELSEIF DirtyExp NeedThen
+ {| (!ELSEIF !ELSE !END Action)* |}
+ElsePart <- ELSE
+ {| (!ELSEIF !ELSE !END Action)* |}
+
+For <- Loop / In
+
+Loop <- LoopBody
+ -> Loop
+LoopBody <- FOR LoopArgs NeedDo
+ {} {| (!END Action)* |}
+ NeedEnd
+LoopArgs <- MustName AssignOrEQ
+ ({} {| (COMMA / !DO !END Exp)* |} {})
+ -> PackLoopArgs
+
+In <- InBody
+ -> In
+InBody <- FOR InNameList NeedIn InExpList NeedDo
+ {} {| (!END Action)* |}
+ NeedEnd
+InNameList <- ({} {| (COMMA / !IN !DO !END Name)* |} {})
+ -> PackInNameList
+InExpList <- ({} {| (COMMA / !DO !DO !END Exp)* |} {})
+ -> PackInExpList
+
+While <- WhileBody
+ -> While
+WhileBody <- WHILE DirtyExp NeedDo
+ {| (!END Action)* |}
+ NeedEnd
+
+Repeat <- (RepeatBody {})
+ -> Repeat
+RepeatBody <- REPEAT
+ {| (!UNTIL Action)* |}
+ NeedUntil DirtyExp
+
+LocalAttr <- {| (Sp '<' Sp MustName Sp LocalAttrEnd)+ |}
+ -> LocalAttr
+LocalAttrEnd<- '>' / {} -> MissGT
+Local <- Sp ({} LOCAL LocalNameList ((AssignOrEQ ExpList) / %nil) {})
+ -> Local
+Set <- Sp ({} SimpleList AssignOrEQ ExpList {})
+ -> Set
+LocalNameList
+ <- {| LocalName (Sp ',' LocalName)* |}
+LocalName <- (MustName LocalAttr?)
+ -> LocalName
+
+Call <- Simple
+ -> SimpleCall
+
+LocalFunction
+ <- Sp ({} LOCAL FunctionNamedBody)
+ -> LocalFunction
+
+NamedFunction
+ <- FunctionNamedBody
+ -> NamedFunction
+FunctionNamedBody
+ <- FUNCTION FuncName FuncArgs
+ {| (!END Action)* |}
+ NeedEnd
+ / FUNCTION FuncName FuncArgsMiss
+ {| %nil |}
+ NeedEnd
+FuncName <- {| Single (Sp SuffixWithoutCall)* |}
+ -> Simple
+ / {} -> MissName %nil
+]]
+
+grammar 'Lua' [[
+Lua <- Head?
+ ({} {| Action* |} {}) -> Lua
+ Sp
+Head <- '#' (!%nl .)*
+]]
+
+return function (self, lua, mode)
+ local gram = compiled[mode] or compiled['Lua']
+ local r, _, pos = gram:match(lua)
+ if not r then
+ local err = errorpos(pos)
+ return nil, err
+ end
+
+ return r
+end
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
new file mode 100644
index 00000000..6ef239f1
--- /dev/null
+++ b/script/parser/guide.lua
@@ -0,0 +1,3884 @@
+local util = require 'utility'
+local error = error
+local type = type
+local next = next
+local tostring = tostring
+local print = print
+local ipairs = ipairs
+local tableInsert = table.insert
+local tableUnpack = table.unpack
+local tableRemove = table.remove
+local tableMove = table.move
+local tableSort = table.sort
+local tableConcat = table.concat
+local mathType = math.type
+local pairs = pairs
+local setmetatable = setmetatable
+local assert = assert
+local select = select
+local osClock = os.clock
+local DEVELOP = _G.DEVELOP
+local log = log
+local _G = _G
+
+local function logWarn(...)
+ log.warn(...)
+end
+
+_ENV = nil
+
+local m = {}
+
+local blockTypes = {
+ ['while'] = true,
+ ['in'] = true,
+ ['loop'] = true,
+ ['repeat'] = true,
+ ['do'] = true,
+ ['function'] = true,
+ ['ifblock'] = true,
+ ['elseblock'] = true,
+ ['elseifblock'] = true,
+ ['main'] = true,
+}
+
+local breakBlockTypes = {
+ ['while'] = true,
+ ['in'] = true,
+ ['loop'] = true,
+ ['repeat'] = true,
+}
+
+m.childMap = {
+ ['main'] = {'#', 'docs'},
+ ['repeat'] = {'#', 'filter'},
+ ['while'] = {'filter', '#'},
+ ['in'] = {'keys', '#'},
+ ['loop'] = {'loc', 'max', 'step', '#'},
+ ['if'] = {'#'},
+ ['ifblock'] = {'filter', '#'},
+ ['elseifblock'] = {'filter', '#'},
+ ['elseblock'] = {'#'},
+ ['setfield'] = {'node', 'field', 'value'},
+ ['setglobal'] = {'value'},
+ ['local'] = {'attrs', 'value'},
+ ['setlocal'] = {'value'},
+ ['return'] = {'#'},
+ ['do'] = {'#'},
+ ['select'] = {'vararg'},
+ ['table'] = {'#'},
+ ['tableindex'] = {'index', 'value'},
+ ['tablefield'] = {'field', 'value'},
+ ['function'] = {'args', '#'},
+ ['funcargs'] = {'#'},
+ ['setmethod'] = {'node', 'method', 'value'},
+ ['getmethod'] = {'node', 'method'},
+ ['setindex'] = {'node', 'index', 'value'},
+ ['getindex'] = {'node', 'index'},
+ ['paren'] = {'exp'},
+ ['call'] = {'node', 'args'},
+ ['callargs'] = {'#'},
+ ['getfield'] = {'node', 'field'},
+ ['list'] = {'#'},
+ ['binary'] = {1, 2},
+ ['unary'] = {1},
+
+ ['doc'] = {'#'},
+ ['doc.class'] = {'class', 'extends'},
+ ['doc.type'] = {'#types', '#enums', 'name'},
+ ['doc.alias'] = {'alias', 'extends'},
+ ['doc.param'] = {'param', 'extends'},
+ ['doc.return'] = {'#returns'},
+ ['doc.field'] = {'field', 'extends'},
+ ['doc.generic'] = {'#generics'},
+ ['doc.generic.object'] = {'generic', 'extends'},
+ ['doc.vararg'] = {'vararg'},
+ ['doc.type.table'] = {'key', 'value'},
+ ['doc.type.function'] = {'#args', '#returns'},
+ ['doc.overload'] = {'overload'},
+}
+
+m.actionMap = {
+ ['main'] = {'#'},
+ ['repeat'] = {'#'},
+ ['while'] = {'#'},
+ ['in'] = {'#'},
+ ['loop'] = {'#'},
+ ['if'] = {'#'},
+ ['ifblock'] = {'#'},
+ ['elseifblock'] = {'#'},
+ ['elseblock'] = {'#'},
+ ['do'] = {'#'},
+ ['function'] = {'#'},
+ ['funcargs'] = {'#'},
+}
+
+local TypeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['nil'] = 999,
+}
+
+local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+
+--- 是否是字面量
+function m.isLiteral(obj)
+ local tp = obj.type
+ return tp == 'nil'
+ or tp == 'boolean'
+ or tp == 'string'
+ or tp == 'number'
+ or tp == 'table'
+end
+
+--- 获取字面量
+function m.getLiteral(obj)
+ local tp = obj.type
+ if tp == 'boolean' then
+ return obj[1]
+ elseif tp == 'string' then
+ return obj[1]
+ elseif tp == 'number' then
+ return obj[1]
+ end
+ return nil
+end
+
+--- 寻找父函数
+function m.getParentFunction(obj)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ break
+ end
+ local tp = obj.type
+ if tp == 'function' or tp == 'main' then
+ return obj
+ end
+ end
+ return nil
+end
+
+--- 寻找所在区块
+function m.getBlock(obj)
+ for _ = 1, 1000 do
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if blockTypes[tp] then
+ return obj
+ end
+ obj = obj.parent
+ end
+ error('guide.getBlock overstack')
+end
+
+--- 寻找所在父区块
+function m.getParentBlock(obj)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if blockTypes[tp] then
+ return obj
+ end
+ end
+ error('guide.getParentBlock overstack')
+end
+
+--- 寻找所在可break的父区块
+function m.getBreakBlock(obj)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if breakBlockTypes[tp] then
+ return obj
+ end
+ if tp == 'function' then
+ return nil
+ end
+ end
+ error('guide.getBreakBlock overstack')
+end
+
+--- 寻找doc的主体
+function m.getDocState(obj)
+ for _ = 1, 1000 do
+ local parent = obj.parent
+ if not parent then
+ return obj
+ end
+ if parent.type == 'doc' then
+ return obj
+ end
+ obj = parent
+ end
+ error('guide.getDocState overstack')
+end
+
+--- 寻找所在父类型
+function m.getParentType(obj, want)
+ for _ = 1, 1000 do
+ obj = obj.parent
+ if not obj then
+ return nil
+ end
+ if want == obj.type then
+ return obj
+ end
+ end
+ error('guide.getParentType overstack')
+end
+
+--- 寻找根区块
+function m.getRoot(obj)
+ for _ = 1, 1000 do
+ if obj.type == 'main' then
+ return obj
+ end
+ local parent = obj.parent
+ if not parent then
+ return nil
+ end
+ obj = parent
+ end
+ error('guide.getRoot overstack')
+end
+
+function m.getUri(obj)
+ if obj.uri then
+ return obj.uri
+ end
+ local root = m.getRoot(obj)
+ if root then
+ return root.uri
+ end
+ return ''
+end
+
+function m.getENV(source, start)
+ if not start then
+ start = 1
+ end
+ return m.getLocal(source, '_ENV', start)
+ or m.getLocal(source, '@fenv', start)
+end
+
+--- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。
+--- 如果函数是主函数,则返回`0, nil`。
+---@return table
+---@return integer
+function m.getFunctionVarArgs(func)
+ if func.type == 'main' then
+ return 0, nil
+ end
+ if func.type ~= 'function' then
+ return nil, nil
+ end
+ local args = func.args
+ if not args then
+ return nil, nil
+ end
+ for i = 1, #args do
+ local arg = args[i]
+ if arg.type == '...' then
+ return i, arg
+ end
+ end
+ return nil, nil
+end
+
+--- 获取指定区块中可见的局部变量
+---@param block table
+---@param name string {comment = '变量名'}
+---@param pos integer {comment = '可见位置'}
+function m.getLocal(block, name, pos)
+ block = m.getBlock(block)
+ for _ = 1, 1000 do
+ if not block then
+ return nil
+ end
+ local locals = block.locals
+ local res
+ if not locals then
+ goto CONTINUE
+ end
+ for i = 1, #locals do
+ local loc = locals[i]
+ if loc.effect > pos then
+ break
+ end
+ if loc[1] == name then
+ if not res or res.effect < loc.effect then
+ res = loc
+ end
+ end
+ end
+ if res then
+ return res, res
+ end
+ ::CONTINUE::
+ block = m.getParentBlock(block)
+ end
+ error('guide.getLocal overstack')
+end
+
+--- 获取指定区块中所有的可见局部变量名称
+function m.getVisibleLocals(block, pos)
+ local result = {}
+ m.eachSourceContain(m.getRoot(block), pos, function (source)
+ local locals = source.locals
+ if locals then
+ for i = 1, #locals do
+ local loc = locals[i]
+ local name = loc[1]
+ if loc.effect <= pos then
+ result[name] = loc
+ end
+ end
+ end
+ end)
+ return result
+end
+
+--- 获取指定区块中可见的标签
+---@param block table
+---@param name string {comment = '标签名'}
+function m.getLabel(block, name)
+ block = m.getBlock(block)
+ for _ = 1, 1000 do
+ if not block then
+ return nil
+ end
+ local labels = block.labels
+ if labels then
+ local label = labels[name]
+ if label then
+ return label
+ end
+ end
+ if block.type == 'function' then
+ return nil
+ end
+ block = m.getParentBlock(block)
+ end
+ error('guide.getLocal overstack')
+end
+
+function m.getStartFinish(source)
+ local start = source.start
+ local finish = source.finish
+ if not start then
+ local first = source[1]
+ if not first then
+ return nil, nil
+ end
+ local last = source[#source]
+ start = first.start
+ finish = last.finish
+ end
+ return start, finish
+end
+
+function m.getRange(source)
+ local start = source.vstart or source.start
+ local finish = source.range or source.finish
+ if not start then
+ local first = source[1]
+ if not first then
+ return nil, nil
+ end
+ local last = source[#source]
+ start = first.vstart or first.start
+ finish = last.range or last.finish
+ end
+ return start, finish
+end
+
+--- 判断source是否包含offset
+function m.isContain(source, offset)
+ local start, finish = m.getStartFinish(source)
+ if not start then
+ return false
+ end
+ return start <= offset and finish >= offset - 1
+end
+
+--- 判断offset在source的影响范围内
+---
+--- 主要针对赋值等语句时,key包含value
+function m.isInRange(source, offset)
+ local start, finish = m.getRange(source)
+ if not start then
+ return false
+ end
+ return start <= offset and finish >= offset - 1
+end
+
+function m.isBetween(source, tStart, tFinish)
+ local start, finish = m.getStartFinish(source)
+ if not start then
+ return false
+ end
+ return start <= tFinish and finish >= tStart - 1
+end
+
+function m.isBetweenRange(source, tStart, tFinish)
+ local start, finish = m.getRange(source)
+ if not start then
+ return false
+ end
+ return start <= tFinish and finish >= tStart - 1
+end
+
+--- 添加child
+function m.addChilds(list, obj, map)
+ local keys = map[obj.type]
+ if keys then
+ for i = 1, #keys do
+ local key = keys[i]
+ if key == '#' then
+ for i = 1, #obj do
+ list[#list+1] = obj[i]
+ end
+ elseif obj[key] then
+ list[#list+1] = obj[key]
+ elseif type(key) == 'string'
+ and key:sub(1, 1) == '#' then
+ key = key:sub(2)
+ for i = 1, #obj[key] do
+ list[#list+1] = obj[key][i]
+ end
+ end
+ end
+ end
+end
+
+--- 遍历所有包含offset的source
+function m.eachSourceContain(ast, offset, callback)
+ local list = { ast }
+ local mark = {}
+ while true do
+ local len = #list
+ if len == 0 then
+ return
+ end
+ local obj = list[len]
+ list[len] = nil
+ if not mark[obj] then
+ mark[obj] = true
+ if m.isInRange(obj, offset) then
+ if m.isContain(obj, offset) then
+ local res = callback(obj)
+ if res ~= nil then
+ return res
+ end
+ end
+ m.addChilds(list, obj, m.childMap)
+ end
+ end
+ end
+end
+
+--- 遍历所有在某个范围内的source
+function m.eachSourceBetween(ast, start, finish, callback)
+ local list = { ast }
+ local mark = {}
+ while true do
+ local len = #list
+ if len == 0 then
+ return
+ end
+ local obj = list[len]
+ list[len] = nil
+ if not mark[obj] then
+ mark[obj] = true
+ if m.isBetweenRange(obj, start, finish) then
+ if m.isBetween(obj, start, finish) then
+ local res = callback(obj)
+ if res ~= nil then
+ return res
+ end
+ end
+ m.addChilds(list, obj, m.childMap)
+ end
+ end
+ end
+end
+
+--- 遍历所有指定类型的source
+function m.eachSourceType(ast, type, callback)
+ local cache = ast.typeCache
+ if not cache then
+ cache = {}
+ ast.typeCache = cache
+ m.eachSource(ast, function (source)
+ local tp = source.type
+ if not tp then
+ return
+ end
+ local myCache = cache[tp]
+ if not myCache then
+ myCache = {}
+ cache[tp] = myCache
+ end
+ myCache[#myCache+1] = source
+ end)
+ end
+ local myCache = cache[type]
+ if not myCache then
+ return
+ end
+ for i = 1, #myCache do
+ callback(myCache[i])
+ end
+end
+
+--- 遍历所有的source
+function m.eachSource(ast, callback)
+ local list = { ast }
+ local mark = {}
+ local index = 1
+ while true do
+ local obj = list[index]
+ if not obj then
+ return
+ end
+ list[index] = false
+ index = index + 1
+ if not mark[obj] then
+ mark[obj] = true
+ callback(obj)
+ m.addChilds(list, obj, m.childMap)
+ end
+ end
+end
+
+--- 获取指定的 special
+function m.eachSpecialOf(ast, name, callback)
+ local root = m.getRoot(ast)
+ if not root.specials then
+ return
+ end
+ local specials = root.specials[name]
+ if not specials then
+ return
+ end
+ for i = 1, #specials do
+ callback(specials[i])
+ end
+end
+
+--- 获取偏移对应的坐标
+---@param lines table
+---@return integer {name = 'row'}
+---@return integer {name = 'col'}
+function m.positionOf(lines, offset)
+ if offset < 1 then
+ return 0, 0
+ end
+ local lastLine = lines[#lines]
+ if offset > lastLine.finish then
+ return #lines, lastLine.finish - lastLine.start + 1
+ end
+ local min = 1
+ local max = #lines
+ for _ = 1, 100 do
+ if max <= min then
+ local line = lines[min]
+ return min, offset - line.start + 1
+ end
+ local row = (max - min) // 2 + min
+ local line = lines[row]
+ if offset < line.start then
+ max = row - 1
+ elseif offset > line.finish then
+ min = row + 1
+ else
+ return row, offset - line.start + 1
+ end
+ end
+ error('Stack overflow!')
+end
+
+--- 获取坐标对应的偏移
+---@param lines table
+---@param row integer
+---@param col integer
+---@return integer {name = 'offset'}
+function m.offsetOf(lines, row, col)
+ if row < 1 then
+ return 0
+ end
+ if row > #lines then
+ local lastLine = lines[#lines]
+ return lastLine.finish
+ end
+ local line = lines[row]
+ local len = line.finish - line.start + 1
+ if col < 0 then
+ return line.start
+ elseif col > len then
+ return line.finish
+ else
+ return line.start + col - 1
+ end
+end
+
+function m.lineContent(lines, text, row, ignoreNL)
+ local line = lines[row]
+ if not line then
+ return ''
+ end
+ if ignoreNL then
+ return text:sub(line.start, line.range)
+ else
+ return text:sub(line.start, line.finish)
+ end
+end
+
+function m.lineRange(lines, row, ignoreNL)
+ local line = lines[row]
+ if not line then
+ return 0, 0
+ end
+ if ignoreNL then
+ return line.start, line.range
+ else
+ return line.start, line.finish
+ end
+end
+
+function m.getNameOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'string' then
+ return obj[1]
+ end
+ return nil
+end
+
+function m.getName(obj)
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return obj[1]
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return obj[1]
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ return obj.field and obj.field[1]
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ return obj.method and obj.method[1]
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getNameOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method' then
+ return obj[1]
+ elseif tp == 'doc.class' then
+ return obj.class[1]
+ elseif tp == 'doc.alias' then
+ return obj.alias[1]
+ elseif tp == 'doc.field' then
+ return obj.field[1]
+ end
+ return m.getNameOfLiteral(obj)
+end
+
+function m.getKeyNameOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'field'
+ or tp == 'method' then
+ return 's|' .. obj[1]
+ elseif tp == 'string' then
+ local s = obj[1]
+ if s then
+ return 's|' .. s
+ else
+ return 's'
+ end
+ elseif tp == 'number' then
+ local n = obj[1]
+ if n then
+ return ('n|%s'):format(util.viewLiteral(obj[1]))
+ else
+ return 'n'
+ end
+ elseif tp == 'boolean' then
+ local b = obj[1]
+ if b then
+ return 'b|' .. tostring(b)
+ else
+ return 'b'
+ end
+ end
+ return nil
+end
+
+function m.getKeyName(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return 's|' .. obj[1]
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return 'l|' .. obj[1]
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ if obj.field then
+ return 's|' .. obj.field[1]
+ end
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ if obj.method then
+ return 's|' .. obj.method[1]
+ end
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getKeyNameOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method' then
+ return 's|' .. obj[1]
+ elseif tp == 'doc.class' then
+ return 's|' .. obj.class[1]
+ elseif tp == 'doc.alias' then
+ return 's|' .. obj.alias[1]
+ elseif tp == 'doc.field' then
+ return 's|' .. obj.field[1]
+ end
+ return m.getKeyNameOfLiteral(obj)
+end
+
+function m.getSimpleName(obj)
+ if obj.type == 'call' then
+ local key = obj.args and obj.args[2]
+ return m.getKeyName(key)
+ elseif obj.type == 'table' then
+ return ('t|%p'):format(obj)
+ elseif obj.type == 'select' then
+ return ('v|%p'):format(obj)
+ elseif obj.type == 'string' then
+ return ('z|%p'):format(obj)
+ elseif obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name' then
+ return ('c|%s'):format(obj[1])
+ elseif obj.type == 'doc.class' then
+ return ('c|%s'):format(obj.class[1])
+ end
+ return m.getKeyName(obj)
+end
+
+--- 测试 a 到 b 的路径(不经过函数,不考虑 goto),
+--- 每个路径是一个 block 。
+---
+--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>`
+---
+--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>`
+---
+--- 否则返回 `false`
+---
+--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。
+---@param a table
+---@param b table
+---@return string|boolean mode
+---@return table pathA?
+---@return table pathB?
+function m.getPath(a, b, sameFunction)
+ --- 首先测试双方在同一个函数内
+ if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
+ return false
+ end
+ local mode
+ local objA
+ local objB
+ if a.finish < b.start then
+ mode = 'before'
+ objA = a
+ objB = b
+ elseif a.start > b.finish then
+ mode = 'after'
+ objA = b
+ objB = a
+ else
+ return 'equal', {}, {}
+ end
+ local pathA = {}
+ local pathB = {}
+ for _ = 1, 1000 do
+ objA = m.getParentBlock(objA)
+ pathA[#pathA+1] = objA
+ if (not sameFunction and objA.type == 'function') or objA.type == 'main' then
+ break
+ end
+ end
+ for _ = 1, 1000 do
+ objB = m.getParentBlock(objB)
+ pathB[#pathB+1] = objB
+ if (not sameFunction and objA.type == 'function') or objB.type == 'main' then
+ break
+ end
+ end
+ -- pathA: {1, 2, 3, 4, 5}
+ -- pathB: {5, 6, 2, 3}
+ local top = #pathB
+ local start
+ for i = #pathA, 1, -1 do
+ local currentBlock = pathA[i]
+ if currentBlock == pathB[top] then
+ start = i
+ break
+ end
+ end
+ if not start then
+ return nil
+ end
+ -- pathA: { 1, 2, 3}
+ -- pathB: {5, 6, 2, 3}
+ local extra = 0
+ local align = top - start
+ for i = start, 1, -1 do
+ local currentA = pathA[i]
+ local currentB = pathB[i+align]
+ if currentA ~= currentB then
+ extra = i
+ break
+ end
+ end
+ -- pathA: {1}
+ local resultA = {}
+ for i = extra, 1, -1 do
+ resultA[#resultA+1] = pathA[i]
+ end
+ -- pathB: {5, 6}
+ local resultB = {}
+ for i = extra + align, 1, -1 do
+ resultB[#resultB+1] = pathB[i]
+ end
+ return mode, resultA, resultB
+end
+
+-- 根据语法,单步搜索定义
+local function stepRefOfLocal(loc, mode)
+ local results = {}
+ if loc.start ~= 0 then
+ results[#results+1] = loc
+ end
+ local refs = loc.ref
+ if not refs then
+ return results
+ end
+ for i = 1, #refs do
+ local ref = refs[i]
+ if ref.start == 0 then
+ goto CONTINUE
+ end
+ if mode == 'def' then
+ if ref.type == 'local'
+ or ref.type == 'setlocal' then
+ results[#results+1] = ref
+ end
+ else
+ if ref.type == 'local'
+ or ref.type == 'setlocal'
+ or ref.type == 'getlocal' then
+ results[#results+1] = ref
+ end
+ end
+ ::CONTINUE::
+ end
+ return results
+end
+
+local function stepRefOfLabel(label, mode)
+ local results = { label }
+ if not label or mode == 'def' then
+ return results
+ end
+ local refs = label.ref
+ for i = 1, #refs do
+ local ref = refs[i]
+ results[#results+1] = ref
+ end
+ return results
+end
+
+local function stepRefOfDocType(status, obj, mode)
+ local results = {}
+ if obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name'
+ or obj.type == 'doc.alias.name'
+ or obj.type == 'doc.extends.name' then
+ local name = obj[1]
+ if not name or not status.interface.docType then
+ return results
+ end
+ local docs = status.interface.docType(name)
+ for i = 1, #docs do
+ local doc = docs[i]
+ if mode == 'def' then
+ if doc.type == 'doc.class.name'
+ or doc.type == 'doc.alias.name' then
+ results[#results+1] = doc
+ end
+ else
+ results[#results+1] = doc
+ end
+ end
+ else
+ results[#results+1] = obj
+ end
+ return results
+end
+
+function m.getStepRef(status, obj, mode)
+ if obj.type == 'getlocal'
+ or obj.type == 'setlocal' then
+ return stepRefOfLocal(obj.node, mode)
+ end
+ if obj.type == 'local' then
+ return stepRefOfLocal(obj, mode)
+ end
+ if obj.type == 'label' then
+ return stepRefOfLabel(obj, mode)
+ end
+ if obj.type == 'goto' then
+ return stepRefOfLabel(obj.node, mode)
+ end
+ if obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name'
+ or obj.type == 'doc.extends.name'
+ or obj.type == 'doc.alias.name' then
+ return stepRefOfDocType(status, obj, mode)
+ end
+ return nil
+end
+
+-- 根据语法,单步搜索field
+local function stepFieldOfLocal(loc)
+ local results = {}
+ local refs = loc.ref
+ for i = 1, #refs do
+ local ref = refs[i]
+ if ref.type == 'setglobal'
+ or ref.type == 'getglobal' then
+ results[#results+1] = ref
+ elseif ref.type == 'getlocal' then
+ local nxt = ref.next
+ if nxt then
+ if nxt.type == 'setfield'
+ or nxt.type == 'getfield'
+ or nxt.type == 'setmethod'
+ or nxt.type == 'getmethod'
+ or nxt.type == 'setindex'
+ or nxt.type == 'getindex' then
+ results[#results+1] = nxt
+ end
+ end
+ end
+ end
+ return results
+end
+local function stepFieldOfTable(tbl)
+ local result = {}
+ for i = 1, #tbl do
+ result[i] = tbl[i]
+ end
+ return result
+end
+function m.getStepField(obj)
+ if obj.type == 'getlocal'
+ or obj.type == 'setlocal' then
+ return stepFieldOfLocal(obj.node)
+ end
+ if obj.type == 'local' then
+ return stepFieldOfLocal(obj)
+ end
+ if obj.type == 'table' then
+ return stepFieldOfTable(obj)
+ end
+end
+
+local function convertSimpleList(list)
+ local simple = {}
+ for i = #list, 1, -1 do
+ local c = list[i]
+ if c.type == 'getglobal'
+ or c.type == 'setglobal' then
+ if c.special == '_G' then
+ simple.mode = 'global'
+ goto CONTINUE
+ end
+ local loc = c.node
+ if loc.special == '_G' then
+ simple.mode = 'global'
+ if not simple.node then
+ simple.node = c
+ end
+ else
+ simple.mode = 'local'
+ simple[#simple+1] = m.getSimpleName(loc)
+ if not simple.node then
+ simple.node = loc
+ end
+ end
+ elseif c.type == 'getlocal'
+ or c.type == 'setlocal' then
+ if c.special == '_G' then
+ simple.mode = 'global'
+ goto CONTINUE
+ end
+ simple.mode = 'local'
+ if not simple.node then
+ simple.node = c.node
+ end
+ elseif c.type == 'local' then
+ simple.mode = 'local'
+ if not simple.node then
+ simple.node = c
+ end
+ else
+ if not simple.node then
+ simple.node = c
+ end
+ end
+ simple[#simple+1] = m.getSimpleName(c)
+ ::CONTINUE::
+ end
+ if simple.mode == 'global' and #simple == 0 then
+ simple[1] = 's|_G'
+ simple.node = list[#list]
+ end
+ return simple
+end
+
+-- 搜索 `a.b.c` 的等价表达式
+local function buildSimpleList(obj, max)
+ local list = {}
+ local cur = obj
+ local limit = max and (max + 1) or 11
+ for i = 1, max or limit do
+ if i == limit then
+ return nil
+ end
+ while cur.type == 'paren' do
+ cur = cur.exp
+ if not cur then
+ return nil
+ end
+ end
+ if cur.type == 'setfield'
+ or cur.type == 'getfield'
+ or cur.type == 'setmethod'
+ or cur.type == 'getmethod'
+ or cur.type == 'setindex'
+ or cur.type == 'getindex' then
+ list[i] = cur
+ cur = cur.node
+ elseif cur.type == 'tablefield'
+ or cur.type == 'tableindex' then
+ list[i] = cur
+ cur = cur.parent.parent
+ if cur.type == 'return' then
+ list[i+1] = list[i].parent
+ break
+ end
+ elseif cur.type == 'getlocal'
+ or cur.type == 'setlocal'
+ or cur.type == 'local' then
+ list[i] = cur
+ break
+ elseif cur.type == 'setglobal'
+ or cur.type == 'getglobal' then
+ list[i] = cur
+ break
+ elseif cur.type == 'select'
+ or cur.type == 'table' then
+ list[i] = cur
+ break
+ elseif cur.type == 'string' then
+ list[i] = cur
+ break
+ elseif cur.type == 'doc.class.name'
+ or cur.type == 'doc.type.name'
+ or cur.type == 'doc.class' then
+ list[i] = cur
+ break
+ elseif cur.type == 'function'
+ or cur.type == 'main' then
+ break
+ else
+ return nil
+ end
+ end
+ return convertSimpleList(list)
+end
+
+function m.getSimple(obj, max)
+ local simpleList
+ if obj.type == 'getfield'
+ or obj.type == 'setfield'
+ or obj.type == 'getmethod'
+ or obj.type == 'setmethod'
+ or obj.type == 'getindex'
+ or obj.type == 'setindex'
+ or obj.type == 'local'
+ or obj.type == 'getlocal'
+ or obj.type == 'setlocal'
+ or obj.type == 'setglobal'
+ or obj.type == 'getglobal'
+ or obj.type == 'tablefield'
+ or obj.type == 'tableindex'
+ or obj.type == 'select'
+ or obj.type == 'table'
+ or obj.type == 'string'
+ or obj.type == 'doc.class.name'
+ or obj.type == 'doc.class'
+ or obj.type == 'doc.type.name' then
+ simpleList = buildSimpleList(obj, max)
+ elseif obj.type == 'field'
+ or obj.type == 'method' then
+ simpleList = buildSimpleList(obj.parent, max)
+ end
+ return simpleList
+end
+
+function m.status(parentStatus, interface)
+ local status = {
+ cache = parentStatus and parentStatus.cache or {
+ count = 0,
+ },
+ depth = parentStatus and (parentStatus.depth + 1) or 1,
+ interface = parentStatus and parentStatus.interface or {},
+ locks = parentStatus and parentStatus.locks or {},
+ deep = parentStatus and parentStatus.deep,
+ results = {},
+ }
+ status.lock = status.locks[status.depth] or {}
+ status.locks[status.depth] = status.lock
+ if interface then
+ for k, v in pairs(interface) do
+ status.interface[k] = v
+ end
+ end
+ local searchDepth = status.interface.getSearchDepth and status.interface.getSearchDepth() or 0
+ if status.depth >= searchDepth then
+ status.deep = false
+ end
+ return status
+end
+
+function m.copyStatusResults(a, b)
+ local ra = a.results
+ local rb = b.results
+ for i = 1, #rb do
+ ra[#ra+1] = rb[i]
+ end
+end
+
+function m.isGlobal(source)
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ if source.node and source.node.tag == '_ENV' then
+ return true
+ end
+ end
+ if source.type == 'field' then
+ source = source.parent
+ end
+ if source.type == 'getfield'
+ or source.type == 'setfield' then
+ local node = source.node
+ if node and node.special == '_G' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isDoc(source)
+ return source.type:sub(1, 4) == 'doc.'
+end
+
+--- 根据函数的调用参数,获取:调用,参数索引
+function m.getCallAndArgIndex(callarg)
+ local callargs = callarg.parent
+ if not callargs or callargs.type ~= 'callargs' then
+ return nil
+ end
+ local index
+ for i = 1, #callargs do
+ if callargs[i] == callarg then
+ index = i
+ break
+ end
+ end
+ local call = callargs.parent
+ return call, index
+end
+
+--- 根据函数调用的返回值,获取:调用的函数,参数列表,自己是第几个返回值
+function m.getCallValue(source)
+ local value = m.getObjectValue(source) or source
+ if not value then
+ return
+ end
+ local call, index
+ if value.type == 'call' then
+ call = value
+ index = 1
+ elseif value.type == 'select' then
+ call = value.vararg
+ index = value.index
+ if call.type ~= 'call' then
+ return
+ end
+ else
+ return
+ end
+ return call.node, call.args, index
+end
+
+function m.getNextRef(ref)
+ local nextRef = ref.next
+ if nextRef then
+ if nextRef.type == 'setfield'
+ or nextRef.type == 'getfield'
+ or nextRef.type == 'setmethod'
+ or nextRef.type == 'getmethod'
+ or nextRef.type == 'setindex'
+ or nextRef.type == 'getindex' then
+ return nextRef
+ end
+ end
+ -- 穿透 rawget 与 rawset
+ local call, index = m.getCallAndArgIndex(ref)
+ if call then
+ if call.node.special == 'rawset' and index == 1 then
+ return call
+ end
+ if call.node.special == 'rawget' and index == 1 then
+ return call
+ end
+ end
+ -- doc.type.array
+ if ref.type == 'doc.type' then
+ local arrays = {}
+ for _, typeUnit in ipairs(ref.types) do
+ if typeUnit.type == 'doc.type.array' then
+ arrays[#arrays+1] = typeUnit.node
+ end
+ end
+ -- 返回一个 dummy
+ -- TODO 用弱表维护唯一性?
+ return {
+ type = 'doc.type',
+ start = ref.start,
+ finish = ref.finish,
+ types = arrays,
+ parent = ref.parent,
+ array = true,
+ enums = {},
+ resumes = {},
+ }
+ end
+
+ return nil
+end
+
+function m.checkSameSimpleInValueOfTable(status, value, start, queue)
+ if value.type ~= 'table' then
+ return
+ end
+ for i = 1, #value do
+ local field = value[i]
+ queue[#queue+1] = {
+ obj = field,
+ start = start + 1,
+ }
+ end
+end
+
+function m.searchFields(status, obj, key)
+ local simple = m.getSimple(obj)
+ if not simple then
+ return
+ end
+ simple[#simple+1] = key and ('s|' .. key) or '*'
+ m.searchSameFields(status, simple, 'field')
+ m.cleanResults(status.results)
+end
+
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args[3]
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
+function m.checkSameSimpleInValueInMetaTable(status, mt, start, queue)
+ local newStatus = m.status(status)
+ m.searchFields(newStatus, mt, '__index')
+ local refsStatus = m.status(status)
+ for i = 1, #newStatus.results do
+ local indexValue = m.getObjectValue(newStatus.results[i])
+ if indexValue then
+ m.searchRefs(refsStatus, indexValue, 'ref')
+ end
+ end
+ for i = 1, #refsStatus.results do
+ local obj = refsStatus.results[i]
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+end
+function m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue)
+ if not func or func.special ~= 'setmetatable' then
+ return
+ end
+ local call = func.parent
+ local args = call.args
+ local obj = args[1]
+ local mt = args[2]
+ if obj then
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+ if mt then
+ m.checkSameSimpleInValueInMetaTable(status, mt, start, queue)
+ end
+end
+
+function m.checkSameSimpleInValueOfCallMetaTable(status, call, start, queue)
+ if call.type == 'call' then
+ m.checkSameSimpleInValueOfSetMetaTable(status, call.node, start, queue)
+ end
+end
+
+function m.checkSameSimpleInSpecialBranch(status, obj, start, queue)
+ if not status.interface.index then
+ return
+ end
+ local results = status.interface.index(obj)
+ if not results then
+ return
+ end
+ for _, res in ipairs(results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start + 1,
+ }
+ end
+end
+
+function m.checkSameSimpleByDocType(status, doc)
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ if doc.type ~= 'doc.type' then
+ return
+ end
+ local results = {}
+ for _, piece in ipairs(doc.types) do
+ local pieceResult = stepRefOfDocType(status, piece, 'def')
+ for _, res in ipairs(pieceResult) do
+ results[#results+1] = res
+ end
+ end
+ return results
+end
+
+function m.checkSameSimpleByBindDocs(status, obj, start, queue, mode)
+ if not obj.bindDocs then
+ return
+ end
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ local skipInfer = false
+ local results = {}
+ for _, doc in ipairs(obj.bindDocs) do
+ if doc.type == 'doc.class' then
+ results[#results+1] = doc
+ elseif doc.type == 'doc.type' then
+ results[#results+1] = doc
+ elseif doc.type == 'doc.param' then
+ -- function (x) 的情况
+ if obj.type == 'local'
+ and m.getName(obj) == doc.param[1] then
+ if obj.parent.type == 'funcargs'
+ or obj.parent.type == 'in'
+ or obj.parent.type == 'loop' then
+ results[#results+1] = doc.extends
+ end
+ end
+ elseif doc.type == 'doc.field' then
+ results[#results+1] = doc
+ end
+ end
+ for _, res in ipairs(results) do
+ if res.type == 'doc.class'
+ or res.type == 'doc.type' then
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ skipInfer = true
+ end
+ if res.type == 'doc.type.function' then
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ elseif res.type == 'doc.field' then
+ queue[#queue+1] = {
+ obj = res,
+ start = start + 1,
+ }
+ end
+ end
+ return skipInfer
+end
+
+function m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode)
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ if not obj.bindSources then
+ return
+ end
+ status.cache.searchingBindedDoc = true
+ local mark = {}
+ local newStatus = m.status(status)
+ for _, ref in ipairs(obj.bindSources) do
+ if not mark[ref] then
+ mark[ref] = true
+ m.searchRefs(newStatus, ref, mode)
+ end
+ end
+ status.cache.searchingBindedDoc = nil
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+function m.checkSameSimpleByDoc(status, obj, start, queue, mode)
+ if obj.type == 'doc.class.name'
+ or obj.type == 'doc.type.name' then
+ obj = m.getDocState(obj)
+ end
+ if obj.type == 'doc.class' then
+ local classStart
+ for _, doc in ipairs(obj.bindGroup) do
+ if doc == obj then
+ classStart = true
+ elseif doc.type == 'doc.class' then
+ classStart = false
+ end
+ if classStart and doc.type == 'doc.field' then
+ queue[#queue+1] = {
+ obj = doc,
+ start = start + 1,
+ }
+ end
+ end
+ m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode)
+ if mode == 'ref' then
+ local pieceResult = stepRefOfDocType(status, obj.class, 'ref')
+ for _, res in ipairs(pieceResult) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ end
+ return true
+ elseif obj.type == 'doc.type' then
+ for _, piece in ipairs(obj.types) do
+ local pieceResult = stepRefOfDocType(status, piece, 'def')
+ for _, res in ipairs(pieceResult) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ end
+ if mode == 'ref' then
+ m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode)
+ end
+ return true
+ elseif obj.type == 'doc.field' then
+ if mode ~= 'field' then
+ return m.checkSameSimpleByDoc(status, obj.extends, start, queue, mode)
+ end
+ end
+end
+
+function m.checkSameSimpleInArg1OfSetMetaTable(status, obj, start, queue)
+ local args = obj.parent
+ if not args or args.type ~= 'callargs' then
+ return
+ end
+ if args[1] ~= obj then
+ return
+ end
+ local mt = args[2]
+ if mt then
+ if m.checkValueMark(status, obj, mt) then
+ return
+ end
+ m.checkSameSimpleInValueInMetaTable(status, mt, start, queue)
+ end
+end
+
+function m.searchSameMethodCrossSelf(ref, mark)
+ local selfNode
+ if ref.tag == 'self' then
+ selfNode = ref
+ else
+ if ref.type == 'getlocal'
+ or ref.type == 'setlocal' then
+ local node = ref.node
+ if node.tag == 'self' then
+ selfNode = node
+ end
+ end
+ end
+ if selfNode then
+ if mark[selfNode] then
+ return nil
+ end
+ mark[selfNode] = true
+ return selfNode.method.node
+ end
+end
+
+function m.searchSameMethod(ref, mark)
+ if mark['method'] then
+ return nil
+ end
+ local nxt = ref.next
+ if not nxt then
+ return nil
+ end
+ if nxt.type == 'setmethod' then
+ mark['method'] = true
+ return ref
+ end
+ return nil
+end
+
+function m.searchSameFieldsCrossMethod(status, ref, start, queue)
+ local mark = status.cache.crossMethodMark
+ if not mark then
+ mark = {}
+ status.cache.crossMethodMark = mark
+ end
+ local method = m.searchSameMethod(ref, mark)
+ or m.searchSameMethodCrossSelf(ref, mark)
+ if not method then
+ return
+ end
+ local methodStatus = m.status(status)
+ m.searchRefs(methodStatus, method, 'ref')
+ for _, md in ipairs(methodStatus.results) do
+ queue[#queue+1] = {
+ obj = md,
+ start = start,
+ force = true,
+ }
+ local nxt = md.next
+ if not nxt then
+ goto CONTINUE
+ end
+ if nxt.type == 'setmethod' then
+ local func = nxt.value
+ if not func then
+ goto CONTINUE
+ end
+ local selfNode = func.locals and func.locals[1]
+ if not selfNode or not selfNode.ref then
+ goto CONTINUE
+ end
+ if mark[selfNode] then
+ goto CONTINUE
+ end
+ mark[selfNode] = true
+ for _, selfRef in ipairs(selfNode.ref) do
+ queue[#queue+1] = {
+ obj = selfRef,
+ start = start,
+ force = true,
+ }
+ end
+ end
+ ::CONTINUE::
+ end
+end
+
+local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, source, index, call)
+ if not source or source.type ~= 'function' then
+ return
+ end
+ if not source.bindDocs then
+ return
+ end
+ local returns = {}
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ returns[#returns+1] = rtn
+ end
+ end
+ end
+ local rtn = returns[index]
+ if not rtn then
+ return
+ end
+ local types = m.checkSameSimpleByDocType(status, rtn)
+ if not types then
+ return
+ end
+ for _, res in ipairs(types) do
+ results[#results+1] = res
+ end
+ return true
+end
+
+local function checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, source, index)
+ if not source.bindDocs then
+ return
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.type' then
+ for _, typeUnit in ipairs(doc.types) do
+ if typeUnit.type == 'doc.type.function' then
+ local rtn = typeUnit.returns[index]
+ if rtn then
+ local types = m.checkSameSimpleByDocType(status, rtn)
+ if types then
+ for _, res in ipairs(types) do
+ results[#results+1] = res
+ end
+ return true
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+function m.checkSameSimpleInCallInSameFile(status, func, args, index)
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, func, 'def')
+ local results = {}
+ for _, def in ipairs(newStatus.results) do
+ local hasDocReturn = checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, def, index)
+ or checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, def, index)
+ if not hasDocReturn then
+ local value = m.getObjectValue(def) or def
+ if value.type == 'function' then
+ local returns = value.returns
+ if returns then
+ for _, ret in ipairs(returns) do
+ local exp = ret[index]
+ if exp then
+ results[#results+1] = exp
+ end
+ end
+ end
+ end
+ end
+ end
+ return results
+end
+
+function m.checkSameSimpleInCall(status, ref, start, queue, mode)
+ local func, args, index = m.getCallValue(ref)
+ if not func then
+ return
+ end
+ if m.checkCallMark(status, func.parent, true) then
+ return
+ end
+ status.cache.crossCallCount = status.cache.crossCallCount or 0
+ if status.cache.crossCallCount >= 5 then
+ return
+ end
+ status.cache.crossCallCount = status.cache.crossCallCount + 1
+ -- 检查赋值是 semetatable() 的情况
+ m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue)
+ -- 检查赋值是 func() 的情况
+ local objs = m.checkSameSimpleInCallInSameFile(status, func, args, index)
+ if status.interface.call then
+ local cobjs = status.interface.call(func, args, index)
+ if cobjs then
+ for _, obj in ipairs(cobjs) do
+ if not m.checkReturnMark(status, obj) then
+ objs[#objs+1] = obj
+ end
+ end
+ end
+ end
+ m.cleanResults(objs)
+ local newStatus = m.status(status)
+ for _, obj in ipairs(objs) do
+ m.searchRefs(newStatus, obj, mode)
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+ status.cache.crossCallCount = status.cache.crossCallCount - 1
+ for _, obj in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+local function searchRawset(ref, results)
+ if m.getKeyName(ref) ~= 's|rawset' then
+ return
+ end
+ local call = ref.parent
+ if call.type ~= 'call' or call.node ~= ref then
+ return
+ end
+ if not call.args then
+ return
+ end
+ local arg1 = call.args[1]
+ if arg1.special ~= '_G' then
+ -- 不会吧不会吧,不会真的有人写成 `rawset(_G._G._G, 'xxx', value)` 吧
+ return
+ end
+ results[#results+1] = call
+end
+
+local function searchG(ref, results)
+ while ref and m.getKeyName(ref) == 's|_G' do
+ results[#results+1] = ref
+ ref = ref.next
+ end
+ if ref then
+ results[#results+1] = ref
+ searchRawset(ref, results)
+ end
+end
+
+local function searchEnvRef(ref, results)
+ if ref.type == 'setglobal'
+ or ref.type == 'getglobal' then
+ results[#results+1] = ref
+ searchG(ref, results)
+ elseif ref.type == 'getlocal' then
+ results[#results+1] = ref.next
+ searchG(ref.next, results)
+ end
+end
+
+function m.findGlobals(ast)
+ local root = m.getRoot(ast)
+ local results = {}
+ local env = m.getENV(root)
+ if env.ref then
+ for _, ref in ipairs(env.ref) do
+ searchEnvRef(ref, results)
+ end
+ end
+ return results
+end
+
+function m.findGlobalsOfName(ast, name)
+ local root = m.getRoot(ast)
+ local results = {}
+ local globals = m.findGlobals(root)
+ for _, global in ipairs(globals) do
+ if m.getKeyName(global) == name then
+ results[#results+1] = global
+ end
+ end
+ return results
+end
+
+function m.checkSameSimpleInGlobal(status, name, source, start, queue)
+ if not name then
+ return
+ end
+ local objs
+ if status.interface.global then
+ objs = status.interface.global(name)
+ else
+ objs = m.findGlobalsOfName(source, name)
+ end
+ if objs then
+ for _, obj in ipairs(objs) do
+ queue[#queue+1] = {
+ obj = obj,
+ start = start,
+ force = true,
+ }
+ end
+ end
+end
+
+function m.checkValueMark(status, a, b)
+ if not status.cache.valueMark then
+ status.cache.valueMark = {}
+ end
+ if status.cache.valueMark[a]
+ or status.cache.valueMark[b] then
+ return true
+ end
+ status.cache.valueMark[a] = true
+ status.cache.valueMark[b] = true
+ return false
+end
+
+function m.checkCallMark(status, a, mark)
+ if not status.cache.callMark then
+ status.cache.callMark = {}
+ end
+ if mark then
+ status.cache.callMark[a] = mark
+ else
+ return status.cache.callMark[a]
+ end
+ return false
+end
+
+function m.checkReturnMark(status, a, mark)
+ if not status.cache.returnMark then
+ status.cache.returnMark = {}
+ end
+ if mark then
+ status.cache.returnMark[a] = mark
+ else
+ return status.cache.returnMark[a]
+ end
+ return false
+end
+
+function m.searchSameFieldsInValue(status, ref, start, queue, mode)
+ local value = m.getObjectValue(ref)
+ if not value then
+ return
+ end
+ if m.checkValueMark(status, ref, value) then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, value, mode)
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ queue[#queue+1] = {
+ obj = value,
+ start = start,
+ force = true,
+ }
+ -- 检查形如 a = f() 的分支情况
+ m.checkSameSimpleInCall(status, value, start, queue, mode)
+end
+
+function m.checkSameSimpleAsTableField(status, ref, start, queue)
+ if not status.deep then
+ --return
+ end
+ local parent = ref.parent
+ if not parent or parent.type ~= 'tablefield' then
+ return
+ end
+ if m.checkValueMark(status, parent, ref) then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, parent.field, 'ref')
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+function m.checkSearchLevel(status)
+ status.cache.back = status.cache.back or 0
+ if status.cache.back >= (status.interface.searchLevel or 0) then
+ -- TODO 限制向前搜索的次数
+ --return true
+ end
+ status.cache.back = status.cache.back + 1
+ return false
+end
+
+function m.checkSameSimpleAsReturn(status, ref, start, queue)
+ if not status.deep then
+ return
+ end
+ if not ref.parent or ref.parent.type ~= 'return' then
+ return
+ end
+ if ref.parent.parent.type ~= 'main' then
+ return
+ end
+ if m.checkSearchLevel(status) then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefsAsFunctionReturn(newStatus, ref, 'ref')
+ for _, res in ipairs(newStatus.results) do
+ if not m.checkCallMark(status, res) then
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+ end
+end
+
+function m.checkSameSimpleAsSetValue(status, ref, start, queue)
+ if ref.type == 'select' then
+ return
+ end
+ local parent = ref.parent
+ if not parent then
+ return
+ end
+ if m.getObjectValue(parent) ~= ref then
+ return
+ end
+ if m.checkValueMark(status, ref, parent) then
+ return
+ end
+ if m.checkSearchLevel(status) then
+ return
+ end
+ local obj
+ if parent.type == 'local'
+ or parent.type == 'setglobal'
+ or parent.type == 'setlocal' then
+ obj = parent
+ elseif parent.type == 'setfield' then
+ obj = parent.field
+ elseif parent.type == 'setmethod' then
+ obj = parent.method
+ end
+ if not obj then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, obj, 'ref')
+ for _, res in ipairs(newStatus.results) do
+ queue[#queue+1] = {
+ obj = res,
+ start = start,
+ force = true,
+ }
+ end
+end
+
+function m.checkSameSimpleInString(status, ref, start, queue, mode)
+ -- 特殊处理 ('xxx').xxx 的形式
+ if ref.type ~= 'string' then
+ return
+ end
+ if not status.interface.docType then
+ return
+ end
+ if status.cache.searchingBindedDoc then
+ return
+ end
+ local newStatus = m.status(status)
+ local docs = status.interface.docType('string*')
+ local mark = {}
+ for i = 1, #docs do
+ local doc = docs[i]
+ m.searchFields(newStatus, doc)
+ end
+ for _, res in ipairs(newStatus.results) do
+ if mark[res] then
+ goto CONTINUE
+ end
+ mark[res] = true
+ queue[#queue+1] = {
+ obj = res,
+ start = start + 1,
+ }
+ ::CONTINUE::
+ end
+ return true
+end
+
+function m.pushResult(status, mode, ref, simple)
+ local results = status.results
+ if mode == 'def' then
+ if ref.type == 'setglobal'
+ or ref.type == 'setlocal'
+ or ref.type == 'local' then
+ results[#results+1] = ref
+ elseif ref.type == 'setfield'
+ or ref.type == 'tablefield' then
+ results[#results+1] = ref
+ elseif ref.type == 'setmethod' then
+ results[#results+1] = ref
+ elseif ref.type == 'setindex'
+ or ref.type == 'tableindex' then
+ results[#results+1] = ref
+ elseif ref.type == 'call' then
+ if ref.node.special == 'rawset' then
+ results[#results+1] = ref
+ end
+ elseif ref.type == 'function' then
+ results[#results+1] = ref
+ elseif ref.type == 'table' then
+ results[#results+1] = ref
+ elseif ref.type == 'doc.type.function'
+ or ref.type == 'doc.class.name'
+ or ref.type == 'doc.field' then
+ results[#results+1] = ref
+ end
+ if ref.parent and ref.parent.type == 'return' then
+ if m.getParentFunction(ref) ~= m.getParentFunction(simple.node) then
+ results[#results+1] = ref
+ end
+ end
+ elseif mode == 'ref' then
+ if ref.type == 'setfield'
+ or ref.type == 'getfield'
+ or ref.type == 'tablefield' then
+ results[#results+1] = ref
+ elseif ref.type == 'setmethod'
+ or ref.type == 'getmethod' then
+ results[#results+1] = ref
+ elseif ref.type == 'setindex'
+ or ref.type == 'getindex'
+ or ref.type == 'tableindex' then
+ results[#results+1] = ref
+ elseif ref.type == 'setglobal'
+ or ref.type == 'getglobal'
+ or ref.type == 'local'
+ or ref.type == 'setlocal'
+ or ref.type == 'getlocal' then
+ results[#results+1] = ref
+ elseif ref.type == 'function' then
+ results[#results+1] = ref
+ elseif ref.type == 'table' then
+ results[#results+1] = ref
+ elseif ref.type == 'call' then
+ if ref.node.special == 'rawset'
+ or ref.node.special == 'rawget' then
+ results[#results+1] = ref
+ end
+ elseif ref.type == 'doc.type.function'
+ or ref.type == 'doc.class.name'
+ or ref.type == 'doc.field' then
+ results[#results+1] = ref
+ end
+ if ref.parent and ref.parent.type == 'return' then
+ results[#results+1] = ref
+ end
+ elseif mode == 'field' then
+ if ref.type == 'setfield'
+ or ref.type == 'getfield'
+ or ref.type == 'tablefield' then
+ results[#results+1] = ref
+ elseif ref.type == 'setmethod'
+ or ref.type == 'getmethod' then
+ results[#results+1] = ref
+ elseif ref.type == 'setindex'
+ or ref.type == 'getindex'
+ or ref.type == 'tableindex' then
+ results[#results+1] = ref
+ elseif ref.type == 'setglobal'
+ or ref.type == 'getglobal' then
+ results[#results+1] = ref
+ elseif ref.type == 'function' then
+ results[#results+1] = ref
+ elseif ref.type == 'table' then
+ results[#results+1] = ref
+ elseif ref.type == 'call' then
+ if ref.node.special == 'rawset'
+ or ref.node.special == 'rawget' then
+ results[#results+1] = ref
+ end
+ elseif ref.type == 'doc.type.function'
+ or ref.type == 'doc.class.name'
+ or ref.type == 'doc.field' then
+ results[#results+1] = ref
+ end
+ end
+end
+
+function m.checkSameSimpleName(ref, sm)
+ if sm == '*' then
+ return true
+ end
+ if m.getSimpleName(ref) == sm then
+ return true
+ end
+ if ref.type == 'doc.type'
+ and ref.array == true then
+ return true
+ end
+ return false
+end
+
+function m.checkSameSimple(status, simple, data, mode, queue)
+ local ref = data.obj
+ local start = data.start
+ local force = data.force
+ if start > #simple then
+ return
+ end
+ for i = start, #simple do
+ local sm = simple[i]
+ if not force and not m.checkSameSimpleName(ref, sm) then
+ return
+ end
+ force = false
+ local cmode = mode
+ if i < #simple then
+ cmode = 'ref'
+ end
+ -- 检查 doc
+ local skipInfer = m.checkSameSimpleByBindDocs(status, ref, i, queue, cmode)
+ or m.checkSameSimpleByDoc(status, ref, i, queue, cmode)
+ if not skipInfer then
+ -- 穿透 self:func 与 mt:func
+ m.searchSameFieldsCrossMethod(status, ref, i, queue)
+ -- 穿透赋值
+ m.searchSameFieldsInValue(status, ref, i, queue, cmode)
+ -- 检查自己是字面量表的情况
+ m.checkSameSimpleInValueOfTable(status, ref, i, queue)
+ -- 检查自己作为 setmetatable 第一个参数的情况
+ m.checkSameSimpleInArg1OfSetMetaTable(status, ref, i, queue)
+ -- 检查自己作为 setmetatable 调用的情况
+ m.checkSameSimpleInValueOfCallMetaTable(status, ref, i, queue)
+ -- 检查自己是特殊变量的分支的情况
+ m.checkSameSimpleInSpecialBranch(status, ref, i, queue)
+ -- 检查自己是字面量字符串的分支情况
+ m.checkSameSimpleInString(status, ref, i, queue, cmode)
+ if cmode == 'ref' then
+ -- 检查形如 { a = f } 的情况
+ m.checkSameSimpleAsTableField(status, ref, i, queue)
+ -- 检查形如 return m 的情况
+ m.checkSameSimpleAsReturn(status, ref, i, queue)
+ -- 检查形如 a = f 的情况
+ m.checkSameSimpleAsSetValue(status, ref, i, queue)
+ end
+ end
+ if i == #simple then
+ break
+ end
+ ref = m.getNextRef(ref)
+ if not ref then
+ return
+ end
+ end
+ m.pushResult(status, mode, ref, simple)
+ local value = m.getObjectValue(ref)
+ if value then
+ m.pushResult(status, mode, value, simple)
+ end
+end
+
+function m.searchSameFields(status, simple, mode)
+ local queue = {}
+ if simple.mode == 'global' then
+ -- 全局变量开头
+ m.checkSameSimpleInGlobal(status, simple[1], simple.node, 1, queue)
+ elseif simple.mode == 'local' then
+ -- 局部变量开头
+ queue[1] = {
+ obj = simple.node,
+ start = 1,
+ }
+ local refs = simple.node.ref
+ if refs then
+ for i = 1, #refs do
+ queue[#queue+1] = {
+ obj = refs[i],
+ start = 1,
+ }
+ end
+ end
+ else
+ queue[1] = {
+ obj = simple.node,
+ start = 1,
+ }
+ end
+ local max = 0
+ for i = 1, 1e6 do
+ local data = queue[i]
+ if not data then
+ return
+ end
+ if not status.lock[data.obj] then
+ status.lock[data.obj] = true
+ max = max + 1
+ status.cache.count = status.cache.count + 1
+ m.checkSameSimple(status, simple, data, mode, queue)
+ if max >= 10000 then
+ logWarn('Queue too large!')
+ break
+ end
+ end
+ end
+end
+
+function m.getCallerInSameFile(status, func)
+ -- 搜索所有所在函数的调用者
+ local funcRefs = m.status(status)
+ m.searchRefOfValue(funcRefs, func)
+
+ local calls = {}
+ if #funcRefs.results == 0 then
+ return calls
+ end
+ for _, res in ipairs(funcRefs.results) do
+ local call = res.parent
+ if call.type == 'call' then
+ calls[#calls+1] = call
+ end
+ end
+ return calls
+end
+
+function m.getCallerCrossFiles(status, main)
+ if status.interface.link then
+ return status.interface.link(main.uri)
+ end
+ return {}
+end
+
+function m.searchRefsAsFunctionReturn(status, obj, mode)
+ if mode == 'def' then
+ return
+ end
+ if m.checkReturnMark(status, obj, true) then
+ return
+ end
+ status.results[#status.results+1] = obj
+ -- 搜索所在函数
+ local currentFunc = m.getParentFunction(obj)
+ local rtn = obj.parent
+ if rtn.type ~= 'return' then
+ return
+ end
+ -- 看看他是第几个返回值
+ local index
+ for i = 1, #rtn do
+ if obj == rtn[i] then
+ index = i
+ break
+ end
+ end
+ if not index then
+ return
+ end
+ local calls
+ if currentFunc.type == 'main' then
+ calls = m.getCallerCrossFiles(status, currentFunc)
+ else
+ calls = m.getCallerInSameFile(status, currentFunc)
+ end
+ -- 搜索调用者的返回值
+ if #calls == 0 then
+ return
+ end
+ local selects = {}
+ for i = 1, #calls do
+ local parent = calls[i].parent
+ if parent.type == 'select' and parent.index == index then
+ selects[#selects+1] = parent.parent
+ end
+ local extParent = calls[i].extParent
+ if extParent then
+ for j = 1, #extParent do
+ local ext = extParent[j]
+ if ext.type == 'select' and ext.index == index then
+ selects[#selects+1] = ext.parent
+ end
+ end
+ end
+ end
+ -- 搜索调用者的引用
+ for i = 1, #selects do
+ m.searchRefs(status, selects[i], 'ref')
+ end
+end
+
+function m.searchRefsAsFunctionSet(status, obj, mode)
+ local parent = obj.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'local'
+ or parent.type == 'setlocal'
+ or parent.type == 'setglobal'
+ or parent.type == 'setfield'
+ or parent.type == 'setmethod'
+ or parent.type == 'tablefield' then
+ m.searchRefs(status, parent, mode)
+ elseif parent.type == 'setindex'
+ or parent.type == 'tableindex' then
+ if parent.index == obj then
+ m.searchRefs(status, parent, mode)
+ end
+ end
+end
+
+function m.searchRefsAsFunction(status, obj, mode)
+ if obj.type ~= 'function'
+ and obj.type ~= 'table' then
+ return
+ end
+ m.searchRefsAsFunctionSet(status, obj, mode)
+ -- 检查自己作为返回函数时的引用
+ m.searchRefsAsFunctionReturn(status, obj, mode)
+end
+
+function m.cleanResults(results)
+ local mark = {}
+ for i = #results, 1, -1 do
+ local res = results[i]
+ if res.tag == 'self'
+ or mark[res] then
+ results[i] = results[#results]
+ results[#results] = nil
+ else
+ mark[res] = true
+ end
+ end
+end
+
+--function m.getRefCache(status, obj, mode)
+-- local cache = status.interface.cache and status.interface.cache()
+-- if not cache then
+-- return
+-- end
+-- if m.isGlobal(obj) then
+-- obj = m.getKeyName(obj)
+-- end
+-- if not cache[mode] then
+-- cache[mode] = {}
+-- end
+-- local sourceCache = cache[mode][obj]
+-- if sourceCache then
+-- return sourceCache
+-- end
+-- sourceCache = {}
+-- cache[mode][obj] = sourceCache
+-- return nil, function (results)
+-- for i = 1, #results do
+-- sourceCache[i] = results[i]
+-- end
+-- end
+--end
+
+function m.getRefCache(status, obj, mode)
+ local cache, globalCache
+ if status.depth == 1
+ and status.deep then
+ globalCache = status.interface.cache and status.interface.cache() or {}
+ end
+ cache = status.cache.refCache or {}
+ status.cache.refCache = cache
+ if m.isGlobal(obj) then
+ obj = m.getKeyName(obj)
+ end
+ if not cache[mode] then
+ cache[mode] = {}
+ end
+ if globalCache and not globalCache[mode] then
+ globalCache[mode] = {}
+ end
+ local sourceCache = globalCache and globalCache[mode][obj] or cache[mode][obj]
+ if sourceCache then
+ return sourceCache
+ end
+ sourceCache = {}
+ cache[mode][obj] = sourceCache
+ if globalCache then
+ globalCache[mode][obj] = sourceCache
+ end
+ return nil, function (results)
+ for i = 1, #results do
+ sourceCache[i] = results[i]
+ end
+ end
+end
+
+function m.searchRefs(status, obj, mode)
+ local cache, makeCache = m.getRefCache(status, obj, mode)
+ if cache then
+ for i = 1, #cache do
+ status.results[#status.results+1] = cache[i]
+ end
+ return
+ end
+
+ -- 检查单步引用
+ local res = m.getStepRef(status, obj, mode)
+ if res then
+ for i = 1, #res do
+ status.results[#status.results+1] = res[i]
+ end
+ end
+ -- 检查simple
+ if status.depth <= 100 then
+ local simple = m.getSimple(obj)
+ if simple then
+ m.searchSameFields(status, simple, mode)
+ end
+ else
+ if m.debugMode then
+ error('status.depth overflow')
+ elseif DEVELOP then
+ --log.warn(debug.traceback('status.depth overflow'))
+ logWarn('status.depth overflow')
+ end
+ end
+
+ m.cleanResults(status.results)
+
+ if makeCache then
+ makeCache(status.results)
+ end
+end
+
+function m.searchRefOfValue(status, obj)
+ local var = obj.parent
+ if var.type == 'local'
+ or var.type == 'set' then
+ return m.searchRefs(status, var, 'ref')
+ end
+end
+
+function m.allocInfer(o)
+ if type(o.type) == 'table' then
+ local infers = {}
+ for i = 1, #o.type do
+ infers[i] = {
+ type = o.type[i],
+ value = o.value,
+ source = o.source,
+ }
+ end
+ return infers
+ else
+ return {
+ [1] = o,
+ }
+ end
+end
+
+function m.mergeTypes(types)
+ local results = {}
+ local mark = {}
+ local hasAny
+ -- 这里把 any 去掉
+ for i = 1, #types do
+ local tp = types[i]
+ if tp == 'any' then
+ hasAny = true
+ end
+ if not mark[tp] and tp ~= 'any' then
+ mark[tp] = true
+ results[#results+1] = tp
+ end
+ end
+ if #results == 0 then
+ return 'any'
+ end
+ -- 只有显性的 nil 与 any 时,取 any
+ if #results == 1 then
+ if results[1] == 'nil' and hasAny then
+ return 'any'
+ else
+ return results[1]
+ end
+ end
+ -- 同时包含 number 与 integer 时,去掉 integer
+ if mark['number'] and mark['integer'] then
+ for i = 1, #results do
+ if results[i] == 'integer' then
+ tableRemove(results, i)
+ break
+ end
+ end
+ end
+ tableSort(results, function (a, b)
+ local sa = TypeSort[a] or 100
+ local sb = TypeSort[b] or 100
+ return sa < sb
+ end)
+ return tableConcat(results, '|')
+end
+
+function m.viewInferType(infers)
+ if not infers then
+ return 'any'
+ end
+ local mark = {}
+ local types = {}
+ local hasDoc
+ for i = 1, #infers do
+ local infer = infers[i]
+ local src = infer.source
+ if src.type == 'doc.class'
+ or src.type == 'doc.class.name'
+ or src.type == 'doc.type.name'
+ or src.type == 'doc.type.array'
+ or src.type == 'doc.type.generic' then
+ if infer.type ~= 'any' then
+ hasDoc = true
+ break
+ end
+ end
+ end
+ if hasDoc then
+ for i = 1, #infers do
+ local infer = infers[i]
+ local src = infer.source
+ if src.type == 'doc.class'
+ or src.type == 'doc.class.name'
+ or src.type == 'doc.type.name'
+ or src.type == 'doc.type.array'
+ or src.type == 'doc.type.generic'
+ or src.type == 'doc.type.enum'
+ or src.type == 'doc.resume' then
+ local tp = infer.type or 'any'
+ if not mark[tp] then
+ types[#types+1] = tp
+ end
+ mark[tp] = true
+ end
+ end
+ else
+ for i = 1, #infers do
+ local tp = infers[i].type or 'any'
+ if not mark[tp] then
+ types[#types+1] = tp
+ end
+ mark[tp] = true
+ end
+ end
+ return m.mergeTypes(types)
+end
+
+function m.checkTrue(status, source)
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, source)
+ -- 当前认为的结果
+ local current
+ for _, infer in ipairs(newStatus.results) do
+ -- 新的结果
+ local new
+ if infer.type == 'nil' then
+ new = false
+ elseif infer.type == 'boolean' then
+ if infer.value == true then
+ new = true
+ elseif infer.value == false then
+ new = false
+ end
+ end
+ if new ~= nil then
+ if current == nil then
+ current = new
+ else
+ -- 如果2个结果完全相反,则返回 nil 表示不确定
+ if new ~= current then
+ return nil
+ end
+ end
+ end
+ end
+ return current
+end
+
+--- 获取特定类型的字面量值
+function m.getInferLiteral(status, source, type)
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, source)
+ for _, infer in ipairs(newStatus.results) do
+ if infer.value ~= nil then
+ if type == nil or infer.type == type then
+ return infer.value
+ end
+ end
+ end
+ return nil
+end
+
+--- 是否包含某种类型
+function m.hasType(status, source, type)
+ m.searchInfer(status, source)
+ for _, infer in ipairs(status.results) do
+ if infer.type == type then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isSameValue(status, a, b)
+ local statusA = m.status(status)
+ m.searchInfer(statusA, a)
+ local statusB = m.status(status)
+ m.searchInfer(statusB, b)
+ local infers = {}
+ for _, infer in ipairs(statusA.results) do
+ local literal = infer.value
+ if literal then
+ infers[literal] = false
+ end
+ end
+ for _, infer in ipairs(statusB.results) do
+ local literal = infer.value
+ if literal then
+ if infers[literal] == nil then
+ return false
+ end
+ infers[literal] = true
+ end
+ end
+ for k, v in pairs(infers) do
+ if v == false then
+ return false
+ end
+ end
+ return true
+end
+
+function m.inferCheckLiteralTableWithDocVararg(status, source)
+ if #source ~= 1 then
+ return
+ end
+ local vararg = source[1]
+ if vararg.type ~= 'varargs' then
+ return
+ end
+ local results = m.getVarargDocType(status, source)
+ status.results[#status.results+1] = {
+ type = m.viewInferType(results) .. '[]',
+ source = source,
+ }
+ return true
+end
+
+function m.inferCheckLiteral(status, source)
+ if source.type == 'string' then
+ status.results = m.allocInfer {
+ type = 'string',
+ value = source[1],
+ source = source,
+ }
+ return true
+ elseif source.type == 'nil' then
+ status.results = m.allocInfer {
+ type = 'nil',
+ value = NIL,
+ source = source,
+ }
+ return true
+ elseif source.type == 'boolean' then
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = source[1],
+ source = source,
+ }
+ return true
+ elseif source.type == 'number' then
+ if mathType(source[1]) == 'integer' then
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = source[1],
+ source = source,
+ }
+ return true
+ else
+ status.results = m.allocInfer {
+ type = 'number',
+ value = source[1],
+ source = source,
+ }
+ return true
+ end
+ elseif source.type == 'integer' then
+ status.results = m.allocInfer {
+ type = 'integer',
+ source = source,
+ }
+ return true
+ elseif source.type == 'table' then
+ if m.inferCheckLiteralTableWithDocVararg(status, source) then
+ return true
+ end
+ status.results = m.allocInfer {
+ type = 'table',
+ source = source,
+ }
+ return true
+ elseif source.type == 'function' then
+ status.results = m.allocInfer {
+ type = 'function',
+ source = source,
+ }
+ return true
+ elseif source.type == '...' then
+ status.results = m.allocInfer {
+ type = '...',
+ source = source,
+ }
+ return true
+ end
+end
+
+local function getDocAliasExtends(status, name)
+ if not status.interface.docType then
+ return nil
+ end
+ for _, doc in ipairs(status.interface.docType(name)) do
+ if doc.type == 'doc.alias.name' then
+ return m.viewInferType(m.getDocTypeNames(status, doc.parent.extends))
+ end
+ end
+ return nil
+end
+
+local function getDocTypeUnitName(status, unit, genericCallback)
+ local typeName
+ if unit.type == 'doc.type.name' then
+ typeName = getDocAliasExtends(status, unit[1]) or unit[1]
+ elseif unit.type == 'doc.type.function' then
+ typeName = 'function'
+ elseif unit.type == 'doc.type.array' then
+ typeName = getDocTypeUnitName(status, unit.node, genericCallback) .. '[]'
+ elseif unit.type == 'doc.type.generic' then
+ typeName = ('%s<%s, %s>'):format(
+ getDocTypeUnitName(status, unit.node, genericCallback),
+ m.viewInferType(m.getDocTypeNames(status, unit.key, genericCallback)),
+ m.viewInferType(m.getDocTypeNames(status, unit.value, genericCallback))
+ )
+ end
+ if unit.typeGeneric then
+ if genericCallback then
+ typeName = genericCallback(typeName, unit)
+ or ('<%s>'):format(typeName)
+ else
+ typeName = ('<%s>'):format(typeName)
+ end
+ end
+ return typeName
+end
+
+function m.getDocTypeNames(status, doc, genericCallback)
+ local results = {}
+ if not doc then
+ return results
+ end
+ for _, unit in ipairs(doc.types) do
+ local typeName = getDocTypeUnitName(status, unit, genericCallback)
+ results[#results+1] = {
+ type = typeName,
+ source = unit,
+ }
+ end
+ for _, enum in ipairs(doc.enums) do
+ results[#results+1] = {
+ type = enum[1],
+ source = enum,
+ }
+ end
+ for _, resume in ipairs(doc.resumes) do
+ if not resume.additional then
+ results[#results+1] = {
+ type = resume[1],
+ source = resume,
+ }
+ end
+ end
+ return results
+end
+
+function m.inferCheckDoc(status, source)
+ if source.type == 'doc.class.name' then
+ status.results[#status.results+1] = {
+ type = source[1],
+ source = source,
+ }
+ return true
+ end
+ if source.type == 'doc.class' then
+ status.results[#status.results+1] = {
+ type = source.class[1],
+ source = source,
+ }
+ return true
+ end
+ if source.type == 'doc.type' then
+ local results = m.getDocTypeNames(status, source)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ end
+ if source.type == 'doc.field' then
+ local results = m.getDocTypeNames(status, source.extends)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ end
+end
+
+function m.getVarargDocType(status, source)
+ local func = m.getParentFunction(source)
+ if not func then
+ return
+ end
+ if not func.args then
+ return
+ end
+ for _, arg in ipairs(func.args) do
+ if arg.type == '...' then
+ if arg.bindDocs then
+ for _, doc in ipairs(arg.bindDocs) do
+ if doc.type == 'doc.vararg' then
+ return m.getDocTypeNames(status, doc.vararg)
+ end
+ end
+ end
+ end
+ end
+end
+
+function m.inferCheckUpDocOfVararg(status, source)
+ if not source.vararg then
+ return
+ end
+ local results = m.getVarargDocType(status, source)
+ if not results then
+ return
+ end
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+end
+
+function m.inferCheckUpDoc(status, source)
+ if m.inferCheckUpDocOfVararg(status, source) then
+ return true
+ end
+ local parent = source.parent
+ if parent then
+ if parent.type == 'local'
+ or parent.type == 'setlocal'
+ or parent.type == 'setglobal' then
+ source = parent
+ end
+ if parent.type == 'setfield'
+ or parent.type == 'tablefield' then
+ if parent.field == source
+ or parent.value == source then
+ source = parent
+ end
+ end
+ if parent.type == 'setmethod' then
+ if parent.method == source
+ or parent.value == source then
+ source = parent
+ end
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'tableindex' then
+ if parent.index == source
+ or parent.value == source then
+ source = parent
+ end
+ end
+ end
+ local binds = source.bindDocs
+ if not binds then
+ return
+ end
+ status.results = {}
+ for _, doc in ipairs(binds) do
+ if doc.type == 'doc.class' then
+ status.results[#status.results+1] = {
+ type = doc.class[1],
+ source = doc,
+ }
+ -- ---@class Class
+ -- local x = { field = 1 }
+ -- 这种情况下,将字面量表接受为Class的定义
+ if source.value and source.value.type == 'table' then
+ status.results[#status.results+1] = {
+ type = source.value.type,
+ source = source.value,
+ }
+ end
+ return true
+ elseif doc.type == 'doc.type' then
+ local results = m.getDocTypeNames(status, doc)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ elseif doc.type == 'doc.param' then
+ -- function (x) 的情况
+ if source.type == 'local'
+ and m.getName(source) == doc.param[1] then
+ if source.parent.type == 'funcargs'
+ or source.parent.type == 'in'
+ or source.parent.type == 'loop' then
+ local results = m.getDocTypeNames(status, doc.extends)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+ end
+ end
+ end
+ end
+end
+
+function m.inferCheckFieldDoc(status, source)
+ -- 检查 string[] 的情况
+ if source.type == 'getindex' then
+ local node = source.node
+ if not node then
+ return
+ end
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, node)
+ local ok
+ for _, infer in ipairs(newStatus.results) do
+ local src = infer.source
+ if src.type == 'doc.type.array' then
+ ok = true
+ status.results[#status.results+1] = {
+ type = infer.type:gsub('%[%]$', ''),
+ source = src.node,
+ }
+ end
+ end
+ return ok
+ end
+end
+
+function m.inferCheckUnary(status, source)
+ if source.type ~= 'unary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'not' then
+ local checkTrue = m.checkTrue(status, source[1])
+ local value = nil
+ if checkTrue == true then
+ value = false
+ elseif checkTrue == false then
+ value = true
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ elseif op.type == '#' then
+ status.results = m.allocInfer {
+ type = 'integer',
+ source = source,
+ }
+ return true
+ elseif op.type == '~' then
+ local l = m.getInferLiteral(status, source[1], 'integer')
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = l and ~l or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '-' then
+ local v = m.getInferLiteral(status, source[1], 'integer')
+ if v then
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = - v,
+ source = source,
+ }
+ return true
+ end
+ v = m.getInferLiteral(status, source[1], 'number')
+ status.results = m.allocInfer {
+ type = 'number',
+ value = v and -v or nil,
+ source = source,
+ }
+ return true
+ end
+end
+
+local function mathCheck(status, a, b)
+ local v1 = m.getInferLiteral(status, a, 'integer')
+ or m.getInferLiteral(status, a, 'number')
+ local v2 = m.getInferLiteral(status, b, 'integer')
+ or m.getInferLiteral(status, a, 'number')
+ local int = m.hasType(status, a, 'integer')
+ and m.hasType(status, b, 'integer')
+ and not m.hasType(status, a, 'number')
+ and not m.hasType(status, b, 'number')
+ return int and 'integer' or 'number', v1, v2
+end
+
+function m.inferCheckBinary(status, source)
+ if source.type ~= 'binary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'and' then
+ local isTrue = m.checkTrue(status, source[1])
+ if isTrue == true then
+ m.searchInfer(status, source[2])
+ return true
+ elseif isTrue == false then
+ m.searchInfer(status, source[1])
+ return true
+ else
+ m.searchInfer(status, source[1])
+ m.searchInfer(status, source[2])
+ return true
+ end
+ elseif op.type == 'or' then
+ local isTrue = m.checkTrue(status, source[1])
+ if isTrue == true then
+ m.searchInfer(status, source[1])
+ return true
+ elseif isTrue == false then
+ m.searchInfer(status, source[2])
+ return true
+ else
+ m.searchInfer(status, source[1])
+ m.searchInfer(status, source[2])
+ return true
+ end
+ elseif op.type == '==' then
+ local value = m.isSameValue(status, source[1], source[2])
+ if value ~= nil then
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ end
+ --local isSame = m.isSameDef(status, source[1], source[2])
+ --if isSame == true then
+ -- value = true
+ --else
+ -- value = nil
+ --end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ elseif op.type == '~=' then
+ local value = m.isSameValue(status, source[1], source[2])
+ if value ~= nil then
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = not value,
+ source = source,
+ }
+ return true
+ end
+ --local isSame = m.isSameDef(status, source[1], source[2])
+ --if isSame == true then
+ -- value = false
+ --else
+ -- value = nil
+ --end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ return true
+ elseif op.type == '<=' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 <= v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '>=' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 >= v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '<' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 < v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '>' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 > v2
+ end
+ status.results = m.allocInfer {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '|' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 | v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '~' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 ~ v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '&' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 & v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '<<' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 << v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '>>' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 >> v2
+ end
+ status.results = m.allocInfer {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '..' then
+ local v1 = m.getInferLiteral(status, source[1], 'string')
+ local v2 = m.getInferLiteral(status, source[2], 'string')
+ local v
+ if v1 and v2 then
+ v = v1 .. v2
+ end
+ status.results = m.allocInfer {
+ type = 'string',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '^' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 ^ v2
+ end
+ status.results = m.allocInfer {
+ type = 'number',
+ value = v,
+ source = source,
+ }
+ return true
+ elseif op.type == '/' then
+ local v1 = m.getInferLiteral(status, source[1], 'integer')
+ or m.getInferLiteral(status, source[1], 'number')
+ local v2 = m.getInferLiteral(status, source[2], 'integer')
+ or m.getInferLiteral(status, source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 > v2
+ end
+ status.results = m.allocInfer {
+ type = 'number',
+ value = v,
+ source = source,
+ }
+ return true
+ -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数
+ elseif op.type == '+' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer{
+ type = int,
+ value = (v1 and v2) and (v1 + v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '-' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer{
+ type = int,
+ value = (v1 and v2) and (v1 - v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '*' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer {
+ type = int,
+ value = (v1 and v2) and (v1 * v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '%' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer {
+ type = int,
+ value = (v1 and v2) and (v1 % v2) or nil,
+ source = source,
+ }
+ return true
+ elseif op.type == '//' then
+ local int, v1, v2 = mathCheck(status, source[1], source[2])
+ status.results = m.allocInfer {
+ type = int,
+ value = (v1 and v2) and (v1 // v2) or nil,
+ source = source,
+ }
+ return true
+ end
+end
+
+function m.inferByDef(status, obj)
+ if not status.cache.inferedDef then
+ status.cache.inferedDef = {}
+ end
+ if status.cache.inferedDef[obj] then
+ return
+ end
+ status.cache.inferedDef[obj] = true
+ local mark = {}
+ local newStatus = m.status(status, status.interface)
+ m.searchRefs(newStatus, obj, 'def')
+ for _, src in ipairs(newStatus.results) do
+ local inferStatus = m.status(newStatus)
+ m.searchInfer(inferStatus, src)
+ for _, infer in ipairs(inferStatus.results) do
+ if not mark[infer.source] then
+ mark[infer.source] = true
+ status.results[#status.results+1] = infer
+ end
+ end
+ end
+end
+
+local function inferBySetOfLocal(status, source)
+ if status.cache[source] then
+ return
+ end
+ status.cache[source] = true
+ local newStatus = m.status(status)
+ if source.value then
+ m.searchInfer(newStatus, source.value)
+ end
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal' then
+ break
+ end
+ m.searchInfer(newStatus, ref)
+ end
+ for _, infer in ipairs(newStatus.results) do
+ status.results[#status.results+1] = infer
+ end
+ end
+end
+
+function m.inferBySet(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ if source.type == 'local' then
+ inferBySetOfLocal(status, source)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ inferBySetOfLocal(status, source.node)
+ end
+end
+
+function m.inferByCall(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ if not source.parent then
+ return
+ end
+ if source.parent.type ~= 'call' then
+ return
+ end
+ if source.parent.node == source then
+ status.results[#status.results+1] = {
+ type = 'function',
+ source = source,
+ }
+ return
+ end
+end
+
+function m.inferByGetTable(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ local next = source.next
+ if not next then
+ return
+ end
+ if next.type == 'getfield'
+ or next.type == 'getindex'
+ or next.type == 'getmethod'
+ or next.type == 'setfield'
+ or next.type == 'setindex'
+ or next.type == 'setmethod' then
+ status.results[#status.results+1] = {
+ type = 'table',
+ source = source,
+ }
+ end
+end
+
+function m.inferByUnary(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ local parent = source.parent
+ if not parent or parent.type ~= 'unary' then
+ return
+ end
+ local op = parent.op
+ if op.type == '#' then
+ status.results[#status.results+1] = {
+ type = 'string',
+ source = source
+ }
+ status.results[#status.results+1] = {
+ type = 'table',
+ source = source
+ }
+ elseif op.type == '~' then
+ status.results[#status.results+1] = {
+ type = 'integer',
+ source = source
+ }
+ elseif op.type == '-' then
+ status.results[#status.results+1] = {
+ type = 'number',
+ source = source
+ }
+ end
+end
+
+function m.inferByBinary(status, source)
+ if #status.results ~= 0 then
+ return
+ end
+ local parent = source.parent
+ if not parent or parent.type ~= 'binary' then
+ return
+ end
+ local op = parent.op
+ if op.type == '<='
+ or op.type == '>='
+ or op.type == '<'
+ or op.type == '>'
+ or op.type == '^'
+ or op.type == '/'
+ or op.type == '+'
+ or op.type == '-'
+ or op.type == '*'
+ or op.type == '%' then
+ status.results[#status.results+1] = {
+ type = 'number',
+ source = source,
+ }
+ elseif op.type == '|'
+ or op.type == '~'
+ or op.type == '&'
+ or op.type == '<<'
+ or op.type == '>>'
+ -- 整数的可能性比较高
+ or op.type == '//' then
+ status.results[#status.results+1] = {
+ type = 'integer',
+ source = source,
+ }
+ elseif op.type == '..' then
+ status.results[#status.results+1] = {
+ type = 'string',
+ source = source,
+ }
+ end
+end
+
+local function mergeFunctionReturnsByDoc(status, source, index, call)
+ if not source or source.type ~= 'function' then
+ return
+ end
+ if not source.bindDocs then
+ return
+ end
+ local returns = {}
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ returns[#returns+1] = rtn
+ end
+ end
+ end
+ local rtn = returns[index]
+ if not rtn then
+ return
+ end
+ local results = m.getDocTypeNames(status, rtn, function (typeName, typeUnit)
+ if not source.args or not call.args then
+ return
+ end
+ local name = typeUnit[1]
+ local generics = typeUnit.typeGeneric[name]
+ if not generics then
+ return
+ end
+ local first = generics[1]
+ if not first or first == typeUnit then
+ return
+ end
+ local docParam = m.getParentType(first, 'doc.param')
+ local paramName = docParam.param[1]
+ for i, arg in ipairs(source.args) do
+ if arg[1] == paramName then
+ local callArg = call.args[i]
+ if not callArg then
+ return
+ end
+ return m.viewInferType(m.searchInfer(status, callArg))
+ end
+ end
+ end)
+ if #results == 0 then
+ return
+ end
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ return true
+end
+
+local function mergeDocTypeFunctionReturns(status, source, index)
+ if not source.bindDocs then
+ return
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.type' then
+ for _, typeUnit in ipairs(doc.types) do
+ if typeUnit.type == 'doc.type.function' then
+ local rtn = typeUnit.returns[index]
+ if rtn then
+ local results = m.getDocTypeNames(status, rtn)
+ for _, res in ipairs(results) do
+ status.results[#status.results+1] = res
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function mergeFunctionReturns(status, source, index, call)
+ local returns = source.returns
+ if not returns then
+ return
+ end
+ for i = 1, #returns do
+ local rtn = returns[i]
+ if rtn[index] then
+ if rtn[index].type == 'call' then
+ if not m.checkReturnMark(status, rtn[index]) then
+ m.checkReturnMark(status, rtn[index], true)
+ m.inferByCallReturnAndIndex(status, rtn[index], index)
+ end
+ else
+ local newStatus = m.status(status)
+ m.searchInfer(newStatus, rtn[index])
+ if #newStatus.results == 0 then
+ status.results[#status.results+1] = {
+ type = 'any',
+ source = rtn[index],
+ }
+ else
+ for _, infer in ipairs(newStatus.results) do
+ status.results[#status.results+1] = infer
+ end
+ end
+ end
+ end
+ end
+end
+
+function m.inferByCallReturnAndIndex(status, call, index)
+ local node = call.node
+ local newStatus = m.status(nil, status.interface)
+ m.searchRefs(newStatus, node, 'def')
+ local hasDocReturn
+ for _, src in ipairs(newStatus.results) do
+ if mergeDocTypeFunctionReturns(status, src, index) then
+ hasDocReturn = true
+ elseif mergeFunctionReturnsByDoc(status, src.value, index, call) then
+ hasDocReturn = true
+ end
+ end
+ if not hasDocReturn then
+ for _, src in ipairs(newStatus.results) do
+ if src.value and src.value.type == 'function' then
+ if not m.checkReturnMark(status, src.value, true) then
+ mergeFunctionReturns(status, src.value, index, call)
+ end
+ end
+ end
+ end
+end
+
+function m.inferByCallReturn(status, source)
+ if source.type == 'call' then
+ m.inferByCallReturnAndIndex(status, source, 1)
+ return
+ end
+ if source.type ~= 'select' then
+ if source.value and source.value.type == 'select' then
+ source = source.value
+ else
+ return
+ end
+ end
+ if not source.vararg or source.vararg.type ~= 'call' then
+ return
+ end
+ m.inferByCallReturnAndIndex(status, source.vararg, source.index)
+end
+
+function m.inferByPCallReturn(status, source)
+ if source.type ~= 'select' then
+ if source.value and source.value.type == 'select' then
+ source = source.value
+ else
+ return
+ end
+ end
+ local call = source.vararg
+ if not call or call.type ~= 'call' then
+ return
+ end
+ local node = call.node
+ local specialName = node.special
+ local func, index
+ if specialName == 'pcall' then
+ func = call.args[1]
+ index = source.index - 1
+ elseif specialName == 'xpcall' then
+ func = call.args[1]
+ index = source.index - 2
+ else
+ return
+ end
+ local newStatus = m.status(nil, status.interface)
+ m.searchRefs(newStatus, func, 'def')
+ for _, src in ipairs(newStatus.results) do
+ if src.value and src.value.type == 'function' then
+ mergeFunctionReturns(status, src.value, index)
+ end
+ end
+end
+
+function m.cleanInfers(infers)
+ local mark = {}
+ for i = #infers, 1, -1 do
+ local infer = infers[i]
+ local key = ('%s|%p'):format(infer.type, infer.source)
+ if mark[key] then
+ infers[i] = infers[#infers]
+ infers[#infers] = nil
+ else
+ mark[key] = true
+ end
+ end
+end
+
+function m.searchInfer(status, obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return
+ end
+ end
+ while true do
+ local value = m.getObjectValue(obj)
+ if not value or value == obj then
+ break
+ end
+ obj = value
+ end
+
+ local cache, makeCache = m.getRefCache(status, obj, 'infer')
+ if cache then
+ for i = 1, #cache do
+ status.results[#status.results+1] = cache[i]
+ end
+ return
+ end
+
+ if DEVELOP then
+ status.cache.clock = status.cache.clock or osClock()
+ end
+
+ if not status.cache.lockInfer then
+ status.cache.lockInfer = {}
+ end
+ if status.cache.lockInfer[obj] then
+ return
+ end
+ status.cache.lockInfer[obj] = true
+
+ local checked = m.inferCheckDoc(status, obj)
+ or m.inferCheckUpDoc(status, obj)
+ or m.inferCheckFieldDoc(status, obj)
+ or m.inferCheckLiteral(status, obj)
+ or m.inferCheckUnary(status, obj)
+ or m.inferCheckBinary(status, obj)
+ if checked then
+ m.cleanInfers(status.results)
+ if makeCache then
+ makeCache(status.results)
+ end
+ return
+ end
+
+ if status.deep then
+ m.inferByDef(status, obj)
+ end
+ m.inferBySet(status, obj)
+ m.inferByCall(status, obj)
+ m.inferByGetTable(status, obj)
+ m.inferByUnary(status, obj)
+ m.inferByBinary(status, obj)
+ m.inferByCallReturn(status, obj)
+ m.inferByPCallReturn(status, obj)
+ m.cleanInfers(status.results)
+ if makeCache then
+ makeCache(status.results)
+ end
+end
+
+--- 请求对象的引用,包括 `a.b.c` 形式
+--- 与 `return function` 形式。
+--- 不穿透 `setmetatable` ,考虑由
+--- 业务层进行反向 def 搜索。
+function m.requestReference(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+ -- 根据 field 搜索引用
+ m.searchRefs(status, obj, 'ref')
+
+ m.searchRefsAsFunction(status, obj, 'ref')
+
+ if m.debugMode then
+ print('count:', status.cache.count)
+ end
+
+ return status.results, status.cache.count
+end
+
+--- 请求对象的定义,包括 `a.b.c` 形式
+--- 与 `return function` 形式。
+--- 穿透 `setmetatable` 。
+function m.requestDefinition(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+ -- 根据 field 搜索定义
+ m.searchRefs(status, obj, 'def')
+
+ return status.results, status.cache.count
+end
+
+--- 请求对象的域
+function m.requestFields(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+
+ m.searchFields(status, obj)
+
+ return status.results, status.cache.count
+end
+
+--- 请求对象的类型推测
+function m.requestInfer(obj, interface, deep)
+ local status = m.status(nil, interface)
+ status.deep = deep
+ m.searchInfer(status, obj)
+
+ return status.results, status.cache.count
+end
+
+return m
diff --git a/script/parser/init.lua b/script/parser/init.lua
new file mode 100644
index 00000000..ba40d145
--- /dev/null
+++ b/script/parser/init.lua
@@ -0,0 +1,12 @@
+local api = {
+ grammar = require 'parser.grammar',
+ parse = require 'parser.parse',
+ compile = require 'parser.compile',
+ split = require 'parser.split',
+ calcline = require 'parser.calcline',
+ lines = require 'parser.lines',
+ guide = require 'parser.guide',
+ luadoc = require 'parser.luadoc',
+}
+
+return api
diff --git a/script/parser/lines.lua b/script/parser/lines.lua
new file mode 100644
index 00000000..ee6b4f41
--- /dev/null
+++ b/script/parser/lines.lua
@@ -0,0 +1,45 @@
+local m = require 'lpeglabel'
+
+_ENV = nil
+
+local function Line(start, line, range, finish)
+ line.start = start
+ line.finish = finish - 1
+ line.range = range - 1
+ return line
+end
+
+local function Space(...)
+ local line = {...}
+ local sp = 0
+ local tab = 0
+ for i = 1, #line do
+ if line[i] == ' ' then
+ sp = sp + 1
+ elseif line[i] == '\t' then
+ tab = tab + 1
+ end
+ line[i] = nil
+ end
+ line.sp = sp
+ line.tab = tab
+ return line
+end
+
+local parser = m.P{
+'Lines',
+Lines = m.Ct(m.V'Line'^0 * m.V'LastLine'),
+Line = m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.V'Nl' * m.Cp() / Line,
+LastLine= m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.Cp() / Line,
+Nl = m.P'\r\n' + m.S'\r\n',
+Indent = m.C(m.S' \t')^0 / Space,
+}
+
+return function (self, text)
+ local lines, err = parser:match(text)
+ if not lines then
+ return nil, err
+ end
+
+ return lines
+end
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
new file mode 100644
index 00000000..b31c4baf
--- /dev/null
+++ b/script/parser/luadoc.lua
@@ -0,0 +1,991 @@
+local m = require 'lpeglabel'
+local re = require 'parser.relabel'
+local lines = require 'parser.lines'
+local guide = require 'parser.guide'
+
+local TokenTypes, TokenStarts, TokenFinishs, TokenContents
+local Ci, Offset, pushError, Ct, NextComment
+local parseType
+local Parser = re.compile([[
+Main <- (Token / Sp)*
+Sp <- %s+
+X16 <- [a-fA-F0-9]
+Word <- [a-zA-Z0-9_]
+Token <- Name / String / Symbol
+Name <- ({} {[a-zA-Z0-9_] [a-zA-Z0-9_.*]*} {})
+ -> Name
+String <- ({} StringDef {})
+ -> String
+StringDef <- '"'
+ {~(Esc / !'"' .)*~} -> 1
+ ('"'?)
+ / "'"
+ {~(Esc / !"'" .)*~} -> 1
+ ("'"?)
+ / ('[' {:eq: '='* :} '['
+ {(!StringClose .)*} -> 1
+ (StringClose?))
+StringClose <- ']' =eq ']'
+Esc <- '\' -> ''
+ EChar
+EChar <- 'a' -> ea
+ / 'b' -> eb
+ / 'f' -> ef
+ / 'n' -> en
+ / 'r' -> er
+ / 't' -> et
+ / 'v' -> ev
+ / '\'
+ / '"'
+ / "'"
+ / %nl
+ / ('z' (%nl / %s)*) -> ''
+ / ('x' {X16 X16}) -> Char16
+ / ([0-9] [0-9]? [0-9]?) -> Char10
+ / ('u{' {Word*} '}') -> CharUtf8
+Symbol <- ({} {
+ ':'
+ / '|'
+ / ','
+ / '[]'
+ / '<'
+ / '>'
+ / '('
+ / ')'
+ / '?'
+ / '...'
+ / '+'
+ } {})
+ -> Symbol
+]], {
+ s = m.S' \t',
+ ea = '\a',
+ eb = '\b',
+ ef = '\f',
+ en = '\n',
+ er = '\r',
+ et = '\t',
+ ev = '\v',
+ Char10 = function (char)
+ char = tonumber(char)
+ if not char or char < 0 or char > 255 then
+ return ''
+ end
+ return string.char(char)
+ end,
+ Char16 = function (char)
+ return string.char(tonumber(char, 16))
+ end,
+ CharUtf8 = function (char)
+ if #char == 0 then
+ return ''
+ end
+ local v = tonumber(char, 16)
+ if not v then
+ return ''
+ end
+ if v >= 0 and v <= 0x10FFFF then
+ return utf8.char(v)
+ end
+ return ''
+ end,
+ Name = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'name'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
+ String = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'string'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
+ Symbol = function (start, content, finish)
+ Ci = Ci + 1
+ TokenTypes[Ci] = 'symbol'
+ TokenStarts[Ci] = start
+ TokenFinishs[Ci] = finish - 1
+ TokenContents[Ci] = content
+ end,
+})
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function parseTokens(text, offset)
+ Ct = offset
+ Ci = 0
+ Offset = offset
+ TokenTypes = {}
+ TokenStarts = {}
+ TokenFinishs = {}
+ TokenContents = {}
+ Parser:match(text)
+ Ci = 0
+end
+
+local function peekToken()
+ return TokenTypes[Ci+1], TokenContents[Ci+1]
+end
+
+local function nextToken()
+ Ci = Ci + 1
+ if not TokenTypes[Ci] then
+ Ci = Ci - 1
+ return nil
+ end
+ return TokenTypes[Ci], TokenContents[Ci]
+end
+
+local function checkToken(tp, content, offset)
+ offset = offset or 0
+ return TokenTypes[Ci + offset] == tp
+ and TokenContents[Ci + offset] == content
+end
+
+local function getStart()
+ if Ci == 0 then
+ return Offset
+ end
+ return TokenStarts[Ci] + Offset
+end
+
+local function getFinish()
+ if Ci == 0 then
+ return Offset
+ end
+ return TokenFinishs[Ci] + Offset
+end
+
+local function try(callback)
+ local savePoint = Ci
+ -- rollback
+ local suc = callback()
+ if not suc then
+ Ci = savePoint
+ end
+ return suc
+end
+
+local function parseName(tp, parent)
+ local nameTp, nameText = peekToken()
+ if nameTp ~= 'name' then
+ return nil
+ end
+ nextToken()
+ local class = {
+ type = tp,
+ start = getStart(),
+ finish = getFinish(),
+ parent = parent,
+ [1] = nameText,
+ }
+ return class
+end
+
+local function parseClass(parent)
+ local result = {
+ type = 'doc.class',
+ parent = parent,
+ }
+ result.class = parseName('doc.class.name', result)
+ if not result.class then
+ pushError {
+ type = 'LUADOC_MISS_CLASS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.start = getStart()
+ result.finish = getFinish()
+ if not peekToken() then
+ return result
+ end
+ nextToken()
+ if not checkToken('symbol', ':') then
+ pushError {
+ type = 'LUADOC_MISS_EXTENDS_SYMBOL',
+ start = result.finish + 1,
+ finish = getStart() - 1,
+ }
+ return result
+ end
+ result.extends = parseName('doc.extends.name', result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_CLASS_EXTENDS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function nextSymbolOrError(symbol)
+ if checkToken('symbol', symbol, 1) then
+ nextToken()
+ return true
+ end
+ pushError {
+ type = 'LUADOC_MISS_SYMBOL',
+ start = getFinish(),
+ finish = getFinish(),
+ info = {
+ symbol = symbol,
+ }
+ }
+ return false
+end
+
+local function parseTypeUnitArray(node)
+ if not checkToken('symbol', '[]', 1) then
+ return nil
+ end
+ nextToken()
+ local result = {
+ type = 'doc.type.array',
+ start = node.start,
+ finish = getFinish(),
+ node = node,
+ }
+ return result
+end
+
+local function parseTypeUnitGeneric(node)
+ if not checkToken('symbol', '<', 1) then
+ return nil
+ end
+ if not nextSymbolOrError('<') then
+ return nil
+ end
+ local key = parseType(node)
+ if not key or not nextSymbolOrError(',') then
+ return nil
+ end
+ local value = parseType(node)
+ if not value then
+ return nil
+ end
+ nextSymbolOrError('>')
+ local result = {
+ type = 'doc.type.generic',
+ start = node.start,
+ finish = getFinish(),
+ node = node,
+ key = key,
+ value = value,
+ }
+ return result
+end
+
+local function parseTypeUnitFunction()
+ local typeUnit = {
+ type = 'doc.type.function',
+ start = getStart(),
+ args = {},
+ returns = {},
+ }
+ if not nextSymbolOrError('(') then
+ return nil
+ end
+ while true do
+ if checkToken('symbol', ')', 1) then
+ nextToken()
+ break
+ end
+ local arg = {
+ type = 'doc.type.arg',
+ parent = typeUnit,
+ }
+ arg.name = parseName('doc.type.name', arg)
+ if not arg.name then
+ pushError {
+ type = 'LUADOC_MISS_ARG_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ break
+ end
+ if not arg.start then
+ arg.start = arg.name.start
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ arg.optional = true
+ end
+ arg.finish = getFinish()
+ if not nextSymbolOrError(':') then
+ break
+ end
+ arg.extends = parseType(arg)
+ if not arg.extends then
+ break
+ end
+ arg.finish = getFinish()
+ typeUnit.args[#typeUnit.args+1] = arg
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ nextSymbolOrError(')')
+ break
+ end
+ end
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ while true do
+ local rtn = parseType(typeUnit)
+ if not rtn then
+ break
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ rtn.optional = true
+ end
+ typeUnit.returns[#typeUnit.returns+1] = rtn
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ break
+ end
+ end
+ end
+ typeUnit.finish = getFinish()
+ return typeUnit
+end
+
+local function parseTypeUnit(parent, content)
+ local result
+ if content == 'fun' then
+ result = parseTypeUnitFunction()
+ end
+ if not result then
+ result = {
+ type = 'doc.type.name',
+ start = getStart(),
+ finish = getFinish(),
+ [1] = content,
+ }
+ end
+ if not result then
+ return nil
+ end
+ result.parent = parent
+ while true do
+ local newResult = parseTypeUnitArray(result)
+ or parseTypeUnitGeneric(result)
+ if not newResult then
+ break
+ end
+ result = newResult
+ end
+ return result
+end
+
+local function parseResume()
+ local result = {
+ type = 'doc.resume'
+ }
+
+ if checkToken('symbol', '>', 1) then
+ nextToken()
+ result.default = true
+ end
+
+ if checkToken('symbol', '+', 1) then
+ nextToken()
+ result.additional = true
+ end
+
+ local tp = peekToken()
+ if tp ~= 'string' then
+ pushError {
+ type = 'LUADOC_MISS_STRING',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ local _, str = nextToken()
+ result[1] = str
+ result.start = getStart()
+ result.finish = getFinish()
+ return result
+end
+
+function parseType(parent)
+ local result = {
+ type = 'doc.type',
+ parent = parent,
+ types = {},
+ enums = {},
+ resumes = {},
+ }
+ result.start = getStart()
+ while true do
+ local tp, content = peekToken()
+ if not tp then
+ break
+ end
+ if tp == 'name' then
+ nextToken()
+ local typeUnit = parseTypeUnit(result, content)
+ if not typeUnit then
+ break
+ end
+ result.types[#result.types+1] = typeUnit
+ if not result.start then
+ result.start = typeUnit.start
+ end
+ elseif tp == 'string' then
+ nextToken()
+ local typeEnum = {
+ type = 'doc.type.enum',
+ start = getStart(),
+ finish = getFinish(),
+ parent = result,
+ [1] = content,
+ }
+ result.enums[#result.enums+1] = typeEnum
+ if not result.start then
+ result.start = typeEnum.start
+ end
+ elseif tp == 'symbol' and content == '...' then
+ nextToken()
+ local vararg = {
+ type = 'doc.type.name',
+ start = getStart(),
+ finish = getFinish(),
+ parent = result,
+ [1] = content,
+ }
+ result.types[#result.types+1] = vararg
+ if not result.start then
+ result.start = vararg.start
+ end
+ end
+ if not checkToken('symbol', '|', 1) then
+ break
+ end
+ nextToken()
+ end
+ result.finish = getFinish()
+
+ while true do
+ local nextComm = NextComment('peek')
+ if nextComm and nextComm.text:sub(1, 2) == '-|' then
+ NextComment()
+ local finishPos = nextComm.text:find('#', 3) or #nextComm.text
+ parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1)
+ local resume = parseResume()
+ if resume then
+ resume.comment = nextComm.text:match('#%s*(.+)', 3)
+ result.resumes[#result.resumes+1] = resume
+ result.finish = resume.finish
+ end
+ else
+ break
+ end
+ end
+
+ if #result.types == 0 and #result.enums == 0 and #result.resumes == 0 then
+ pushError {
+ type = 'LUADOC_MISS_TYPE_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ return result
+end
+
+local function parseAlias()
+ local result = {
+ type = 'doc.alias',
+ }
+ result.alias = parseName('doc.alias.name', result)
+ if not result.alias then
+ pushError {
+ type = 'LUADOC_MISS_ALIAS_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.start = getStart()
+ result.extends = parseType(result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_ALIAS_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseParam()
+ local result = {
+ type = 'doc.param',
+ }
+ result.param = parseName('doc.param.name', result)
+ if not result.param then
+ pushError {
+ type = 'LUADOC_MISS_PARAM_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ result.optional = true
+ end
+ result.start = result.param.start
+ result.finish = getFinish()
+ result.extends = parseType(result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_PARAM_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseReturn()
+ local result = {
+ type = 'doc.return',
+ returns = {},
+ }
+ while true do
+ local docType = parseType(result)
+ if not docType then
+ break
+ end
+ if not result.start then
+ result.start = docType.start
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ docType.optional = true
+ end
+ docType.name = parseName('doc.return.name', docType)
+ result.returns[#result.returns+1] = docType
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ if #result.returns == 0 then
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseField()
+ local result = {
+ type = 'doc.field',
+ }
+ try(function ()
+ local tp, value = nextToken()
+ if tp == 'name' then
+ if value == 'public'
+ or value == 'protected'
+ or value == 'private' then
+ result.visible = value
+ result.start = getStart()
+ return true
+ end
+ end
+ return false
+ end)
+ result.field = parseName('doc.field.name', result)
+ if not result.field then
+ pushError {
+ type = 'LUADOC_MISS_FIELD_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if not result.start then
+ result.start = result.field.start
+ end
+ result.extends = parseType(result)
+ if not result.extends then
+ pushError {
+ type = 'LUADOC_MISS_FIELD_EXTENDS',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseGeneric()
+ local result = {
+ type = 'doc.generic',
+ generics = {},
+ }
+ while true do
+ local object = {
+ type = 'doc.generic.object',
+ parent = result,
+ }
+ object.generic = parseName('doc.generic.name', object)
+ if not object.generic then
+ pushError {
+ type = 'LUADOC_MISS_GENERIC_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ object.start = object.generic.start
+ if not result.start then
+ result.start = object.start
+ end
+ if checkToken('symbol', ':', 1) then
+ nextToken()
+ object.extends = parseType(object)
+ end
+ object.finish = getFinish()
+ result.generics[#result.generics+1] = object
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function parseVararg()
+ local result = {
+ type = 'doc.vararg',
+ }
+ result.vararg = parseType(result)
+ if not result.vararg then
+ pushError {
+ type = 'LUADOC_MISS_VARARG_TYPE',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return
+ end
+ result.start = result.vararg.start
+ result.finish = result.vararg.finish
+ return result
+end
+
+local function parseOverload()
+ local tp, name = peekToken()
+ if tp ~= 'name' or name ~= 'fun' then
+ pushError {
+ type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ nextToken()
+ local result = {
+ type = 'doc.overload',
+ }
+ result.overload = parseTypeUnitFunction()
+ if not result.overload then
+ return nil
+ end
+ result.overload.parent = result
+ result.start = result.overload.start
+ result.finish = result.overload.finish
+ return result
+end
+
+local function parseDeprecated()
+ return {
+ type = 'doc.deprecated',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+end
+
+local function parseMeta()
+ return {
+ type = 'doc.meta',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+end
+
+local function parseVersion()
+ local result = {
+ type = 'doc.version',
+ versions = {},
+ }
+ while true do
+ local tp, text = nextToken()
+ if not tp then
+ pushError {
+ type = 'LUADOC_MISS_VERSION',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ break
+ end
+ if not result.start then
+ result.start = getStart()
+ end
+ local version = {
+ type = 'doc.version.unit',
+ start = getStart(),
+ }
+ if tp == 'symbol' then
+ if text == '>' then
+ version.ge = true
+ elseif text == '<' then
+ version.le = true
+ end
+ tp, text = nextToken()
+ end
+ if tp ~= 'name' then
+ pushError {
+ type = 'LUADOC_MISS_VERSION',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ break
+ end
+ version.version = tonumber(text) or text
+ version.finish = getFinish()
+ result.versions[#result.versions+1] = version
+ if not checkToken('symbol', ',', 1) then
+ break
+ end
+ nextToken()
+ end
+ if #result.versions == 0 then
+ return nil
+ end
+ result.finish = getFinish()
+ return result
+end
+
+local function convertTokens()
+ local tp, text = nextToken()
+ if not tp then
+ return
+ end
+ if tp ~= 'name' then
+ pushError {
+ type = 'LUADOC_MISS_CATE_NAME',
+ start = getStart(),
+ finish = getFinish(),
+ }
+ return nil
+ end
+ if text == 'class' then
+ return parseClass()
+ elseif text == 'type' then
+ return parseType()
+ elseif text == 'alias' then
+ return parseAlias()
+ elseif text == 'param' then
+ return parseParam()
+ elseif text == 'return' then
+ return parseReturn()
+ elseif text == 'field' then
+ return parseField()
+ elseif text == 'generic' then
+ return parseGeneric()
+ elseif text == 'vararg' then
+ return parseVararg()
+ elseif text == 'overload' then
+ return parseOverload()
+ elseif text == 'deprecated' then
+ return parseDeprecated()
+ elseif text == 'meta' then
+ return parseMeta()
+ elseif text == 'version' then
+ return parseVersion()
+ end
+end
+
+local function buildLuaDoc(comment)
+ local text = comment.text
+ if text:sub(1, 1) ~= '-' then
+ return
+ end
+ if text:sub(2, 2) ~= '@' then
+ return {
+ type = 'doc.comment',
+ start = comment.start,
+ finish = comment.finish,
+ comment = comment,
+ }
+ end
+ local finishPos = text:find('@', 3)
+ local doc, lastComment
+ if finishPos then
+ doc = text:sub(3, finishPos - 1)
+ lastComment = text:sub(finishPos)
+ else
+ doc = text:sub(3)
+ end
+
+ parseTokens(doc, comment.start + 1)
+ local result = convertTokens()
+ if result then
+ result.comment = lastComment
+ end
+
+ return result
+end
+
+local function isNextLine(lns, binded, doc)
+ if not binded then
+ return false
+ end
+ local lastDoc = binded[#binded]
+ local lastRow = guide.positionOf(lns, lastDoc.finish)
+ local newRow = guide.positionOf(lns, doc.start)
+ return newRow - lastRow == 1
+end
+
+local function bindGeneric(binded)
+ local generics = {}
+ for _, doc in ipairs(binded) do
+ if doc.type == 'doc.generic' then
+ for _, obj in ipairs(doc.generics) do
+ local name = obj.generic[1]
+ generics[name] = {}
+ end
+ elseif doc.type == 'doc.param'
+ or doc.type == 'doc.return' then
+ guide.eachSourceType(doc, 'doc.type.name', function (src)
+ local name = src[1]
+ if generics[name] then
+ generics[name][#generics[name]+1] = src
+ src.typeGeneric = generics
+ end
+ end)
+ end
+ end
+end
+
+local function bindDoc(state, lns, binded)
+ if not binded then
+ return
+ end
+ local lastDoc = binded[#binded]
+ if not lastDoc then
+ return
+ end
+ local bindSources = {}
+ for _, doc in ipairs(binded) do
+ doc.bindGroup = binded
+ doc.bindSources = bindSources
+ end
+ bindGeneric(binded)
+ local row = guide.positionOf(lns, lastDoc.finish)
+ local start, finish = guide.lineRange(lns, row + 1)
+ if start >= finish then
+ -- 空行
+ return
+ end
+ guide.eachSourceBetween(state.ast, start, finish, function (src)
+ if src.start and src.start < start then
+ return
+ end
+ if src.type == 'local'
+ or src.type == 'setlocal'
+ or src.type == 'setglobal'
+ or src.type == 'setfield'
+ or src.type == 'setmethod'
+ or src.type == 'setindex'
+ or src.type == 'tablefield'
+ or src.type == 'tableindex'
+ or src.type == 'function'
+ or src.type == '...' then
+ src.bindDocs = binded
+ bindSources[#bindSources+1] = src
+ end
+ end)
+end
+
+local function bindDocs(state)
+ local lns = lines(nil, state.lua)
+ local binded
+ for _, doc in ipairs(state.ast.docs) do
+ if not isNextLine(lns, binded, doc) then
+ bindDoc(state, lns, binded)
+ binded = {}
+ state.ast.docs.groups[#state.ast.docs.groups+1] = binded
+ end
+ binded[#binded+1] = doc
+ end
+ bindDoc(state, lns, binded)
+end
+
+return function (_, state)
+ local ast = state.ast
+ local comments = state.comms
+ table.sort(comments, function (a, b)
+ return a.start < b.start
+ end)
+ ast.docs = {
+ type = 'doc',
+ parent = ast,
+ groups = {},
+ }
+
+ pushError = state.pushError
+
+ local ci = 1
+ NextComment = function (peek)
+ local comment = comments[ci]
+ if not peek then
+ ci = ci + 1
+ end
+ return comment
+ end
+
+ while true do
+ local comment = NextComment()
+ if not comment then
+ break
+ end
+ local doc = buildLuaDoc(comment)
+ if doc then
+ ast.docs[#ast.docs+1] = doc
+ doc.parent = ast.docs
+ if ast.start > doc.start then
+ ast.start = doc.start
+ end
+ if ast.finish < doc.finish then
+ ast.finish = doc.finish
+ end
+ end
+ end
+
+ if #ast.docs == 0 then
+ return
+ end
+
+ bindDocs(state)
+end
diff --git a/script/parser/parse.lua b/script/parser/parse.lua
new file mode 100644
index 00000000..f813cc59
--- /dev/null
+++ b/script/parser/parse.lua
@@ -0,0 +1,49 @@
+local ast = require 'parser.ast'
+
+return function (self, lua, mode, version)
+ local errs = {}
+ local diags = {}
+ local comms = {}
+ local state = {
+ version = version,
+ lua = lua,
+ root = {},
+ errs = errs,
+ diags = diags,
+ comms = comms,
+ pushError = function (err)
+ if err.finish < err.start then
+ err.finish = err.start
+ end
+ local last = errs[#errs]
+ if last then
+ if last.start <= err.start and last.finish >= err.finish then
+ return
+ end
+ end
+ err.level = err.level or 'error'
+ errs[#errs+1] = err
+ return err
+ end,
+ pushDiag = function (code, info)
+ if not diags[code] then
+ diags[code] = {}
+ end
+ diags[code][#diags[code]+1] = info
+ end,
+ pushComment = function (comment)
+ comms[#comms+1] = comment
+ end
+ }
+ ast.init(state)
+ local suc, res, err = xpcall(self.grammar, debug.traceback, self, lua, mode)
+ ast.close()
+ if not suc then
+ return nil, res
+ end
+ if not res then
+ state.pushError(err)
+ end
+ state.ast = res
+ return state
+end
diff --git a/script/parser/relabel.lua b/script/parser/relabel.lua
new file mode 100644
index 00000000..ac902403
--- /dev/null
+++ b/script/parser/relabel.lua
@@ -0,0 +1,361 @@
+-- $Id: re.lua,v 1.44 2013/03/26 20:11:40 roberto Exp $
+
+-- imported functions and modules
+local tonumber, type, print, error = tonumber, type, print, error
+local pcall = pcall
+local setmetatable = setmetatable
+local tinsert, concat = table.insert, table.concat
+local rep = string.rep
+local m = require"lpeglabel"
+
+-- 'm' will be used to parse expressions, and 'mm' will be used to
+-- create expressions; that is, 're' runs on 'm', creating patterns
+-- on 'mm'
+local mm = m
+
+-- pattern's metatable
+local mt = getmetatable(mm.P(0))
+
+
+
+-- No more global accesses after this point
+_ENV = nil
+
+
+local any = m.P(1)
+local dummy = mm.P(false)
+
+
+local errinfo = {
+ NoPatt = "no pattern found",
+ ExtraChars = "unexpected characters after the pattern",
+
+ ExpPatt1 = "expected a pattern after '/'",
+
+ ExpPatt2 = "expected a pattern after '&'",
+ ExpPatt3 = "expected a pattern after '!'",
+
+ ExpPatt4 = "expected a pattern after '('",
+ ExpPatt5 = "expected a pattern after ':'",
+ ExpPatt6 = "expected a pattern after '{~'",
+ ExpPatt7 = "expected a pattern after '{|'",
+
+ ExpPatt8 = "expected a pattern after '<-'",
+
+ ExpPattOrClose = "expected a pattern or closing '}' after '{'",
+
+ ExpNumName = "expected a number, '+', '-' or a name (no space) after '^'",
+ ExpCap = "expected a string, number, '{}' or name after '->'",
+
+ ExpName1 = "expected the name of a rule after '=>'",
+ ExpName2 = "expected the name of a rule after '=' (no space)",
+ ExpName3 = "expected the name of a rule after '<' (no space)",
+
+ ExpLab1 = "expected a label after '{'",
+
+ ExpNameOrLab = "expected a name or label after '%' (no space)",
+
+ ExpItem = "expected at least one item after '[' or '^'",
+
+ MisClose1 = "missing closing ')'",
+ MisClose2 = "missing closing ':}'",
+ MisClose3 = "missing closing '~}'",
+ MisClose4 = "missing closing '|}'",
+ MisClose5 = "missing closing '}'", -- for the captures
+
+ MisClose6 = "missing closing '>'",
+ MisClose7 = "missing closing '}'", -- for the labels
+
+ MisClose8 = "missing closing ']'",
+
+ MisTerm1 = "missing terminating single quote",
+ MisTerm2 = "missing terminating double quote",
+}
+
+local function expect (pattern, label)
+ return pattern + m.T(label)
+end
+
+
+-- Pre-defined names
+local Predef = { nl = m.P"\n" }
+
+
+local mem
+local fmem
+local gmem
+
+
+local function updatelocale ()
+ mm.locale(Predef)
+ Predef.a = Predef.alpha
+ Predef.c = Predef.cntrl
+ Predef.d = Predef.digit
+ Predef.g = Predef.graph
+ Predef.l = Predef.lower
+ Predef.p = Predef.punct
+ Predef.s = Predef.space
+ Predef.u = Predef.upper
+ Predef.w = Predef.alnum
+ Predef.x = Predef.xdigit
+ Predef.A = any - Predef.a
+ Predef.C = any - Predef.c
+ Predef.D = any - Predef.d
+ Predef.G = any - Predef.g
+ Predef.L = any - Predef.l
+ Predef.P = any - Predef.p
+ Predef.S = any - Predef.s
+ Predef.U = any - Predef.u
+ Predef.W = any - Predef.w
+ Predef.X = any - Predef.x
+ mem = {} -- restart memoization
+ fmem = {}
+ gmem = {}
+ local mt = {__mode = "v"}
+ setmetatable(mem, mt)
+ setmetatable(fmem, mt)
+ setmetatable(gmem, mt)
+end
+
+
+updatelocale()
+
+
+
+local I = m.P(function (s,i) print(i, s:sub(1, i-1)); return i end)
+
+
+local function getdef (id, defs)
+ local c = defs and defs[id]
+ if not c then
+ error("undefined name: " .. id)
+ end
+ return c
+end
+
+
+local function mult (p, n)
+ local np = mm.P(true)
+ while n >= 1 do
+ if n%2 >= 1 then np = np * p end
+ p = p * p
+ n = n/2
+ end
+ return np
+end
+
+local function equalcap (s, i, c)
+ if type(c) ~= "string" then return nil end
+ local e = #c + i
+ if s:sub(i, e - 1) == c then return e else return nil end
+end
+
+
+local S = (Predef.space + "--" * (any - Predef.nl)^0)^0
+
+local name = m.C(m.R("AZ", "az", "__") * m.R("AZ", "az", "__", "09")^0)
+
+local arrow = S * "<-"
+
+-- a defined name only have meaning in a given environment
+local Def = name * m.Carg(1)
+
+local num = m.C(m.R"09"^1) * S / tonumber
+
+local String = "'" * m.C((any - "'" - m.P"\n")^0) * expect("'", "MisTerm1")
+ + '"' * m.C((any - '"' - m.P"\n")^0) * expect('"', "MisTerm2")
+
+
+local defined = "%" * Def / function (c,Defs)
+ local cat = Defs and Defs[c] or Predef[c]
+ if not cat then
+ error("name '" .. c .. "' undefined")
+ end
+ return cat
+end
+
+local Range = m.Cs(any * (m.P"-"/"") * (any - "]")) / mm.R
+
+local item = defined + Range + m.C(any - m.P"\n")
+
+local Class =
+ "["
+ * (m.C(m.P"^"^-1)) -- optional complement symbol
+ * m.Cf(expect(item, "ExpItem") * (item - "]")^0, mt.__add)
+ / function (c, p) return c == "^" and any - p or p end
+ * expect("]", "MisClose8")
+
+local function adddef (t, k, exp)
+ if t[k] then
+ -- TODO 改了一下这里的代码,重复定义不会抛错
+ --error("'"..k.."' already defined as a rule")
+ else
+ t[k] = exp
+ end
+ return t
+end
+
+local function firstdef (n, r) return adddef({n}, n, r) end
+
+
+local function NT (n, b)
+ if not b then
+ error("rule '"..n.."' used outside a grammar")
+ else return mm.V(n)
+ end
+end
+
+
+local exp = m.P{ "Exp",
+ Exp = S * ( m.V"Grammar"
+ + m.Cf(m.V"Seq" * (S * "/" * expect(S * m.V"Seq", "ExpPatt1"))^0, mt.__add) );
+ Seq = m.Cf(m.Cc(m.P"") * m.V"Prefix" * (S * m.V"Prefix")^0, mt.__mul);
+ Prefix = "&" * expect(S * m.V"Prefix", "ExpPatt2") / mt.__len
+ + "!" * expect(S * m.V"Prefix", "ExpPatt3") / mt.__unm
+ + m.V"Suffix";
+ Suffix = m.Cf(m.V"Primary" *
+ ( S * ( m.P"+" * m.Cc(1, mt.__pow)
+ + m.P"*" * m.Cc(0, mt.__pow)
+ + m.P"?" * m.Cc(-1, mt.__pow)
+ + "^" * expect( m.Cg(num * m.Cc(mult))
+ + m.Cg(m.C(m.S"+-" * m.R"09"^1) * m.Cc(mt.__pow)
+ + name * m.Cc"lab"
+ ),
+ "ExpNumName")
+ + "->" * expect(S * ( m.Cg((String + num) * m.Cc(mt.__div))
+ + m.P"{}" * m.Cc(nil, m.Ct)
+ + m.Cg(Def / getdef * m.Cc(mt.__div))
+ ),
+ "ExpCap")
+ + "=>" * expect(S * m.Cg(Def / getdef * m.Cc(m.Cmt)),
+ "ExpName1")
+ )
+ )^0, function (a,b,f) if f == "lab" then return a + mm.T(b) else return f(a,b) end end );
+ Primary = "(" * expect(m.V"Exp", "ExpPatt4") * expect(S * ")", "MisClose1")
+ + String / mm.P
+ + Class
+ + defined
+ + "%" * expect(m.P"{", "ExpNameOrLab")
+ * expect(S * m.V"Label", "ExpLab1")
+ * expect(S * "}", "MisClose7") / mm.T
+ + "{:" * (name * ":" + m.Cc(nil)) * expect(m.V"Exp", "ExpPatt5")
+ * expect(S * ":}", "MisClose2")
+ / function (n, p) return mm.Cg(p, n) end
+ + "=" * expect(name, "ExpName2")
+ / function (n) return mm.Cmt(mm.Cb(n), equalcap) end
+ + m.P"{}" / mm.Cp
+ + "{~" * expect(m.V"Exp", "ExpPatt6")
+ * expect(S * "~}", "MisClose3") / mm.Cs
+ + "{|" * expect(m.V"Exp", "ExpPatt7")
+ * expect(S * "|}", "MisClose4") / mm.Ct
+ + "{" * expect(m.V"Exp", "ExpPattOrClose")
+ * expect(S * "}", "MisClose5") / mm.C
+ + m.P"." * m.Cc(any)
+ + (name * -arrow + "<" * expect(name, "ExpName3")
+ * expect(">", "MisClose6")) * m.Cb("G") / NT;
+ Label = num + name;
+ Definition = name * arrow * expect(m.V"Exp", "ExpPatt8");
+ Grammar = m.Cg(m.Cc(true), "G")
+ * m.Cf(m.V"Definition" / firstdef * (S * m.Cg(m.V"Definition"))^0,
+ adddef) / mm.P;
+}
+
+local pattern = S * m.Cg(m.Cc(false), "G") * expect(exp, "NoPatt") / mm.P
+ * S * expect(-any, "ExtraChars")
+
+local function lineno (s, i)
+ if i == 1 then return 1, 1 end
+ local adjustment = 0
+ -- report the current line if at end of line, not the next
+ if s:sub(i,i) == '\n' then
+ i = i-1
+ adjustment = 1
+ end
+ local rest, num = s:sub(1,i):gsub("[^\n]*\n", "")
+ local r = #rest
+ return 1 + num, (r ~= 0 and r or 1) + adjustment
+end
+
+local function calcline (s, i)
+ if i == 1 then return 1, 1 end
+ local rest, line = s:sub(1,i):gsub("[^\n]*\n", "")
+ local col = #rest
+ return 1 + line, col ~= 0 and col or 1
+end
+
+
+local function splitlines(str)
+ local t = {}
+ local function helper(line) tinsert(t, line) return "" end
+ helper((str:gsub("(.-)\r?\n", helper)))
+ return t
+end
+
+local function compile (p, defs)
+ if mm.type(p) == "pattern" then return p end -- already compiled
+ p = p .. " " -- for better reporting of column numbers in errors when at EOF
+ local ok, cp, label, poserr = pcall(function() return pattern:match(p, 1, defs) end)
+ if not ok and cp then
+ if type(cp) == "string" then
+ cp = cp:gsub("^[^:]+:[^:]+: ", "")
+ end
+ error(cp, 3)
+ end
+ if not cp then
+ local lines = splitlines(p)
+ local line, col = lineno(p, poserr)
+ local err = {}
+ tinsert(err, "L" .. line .. ":C" .. col .. ": " .. errinfo[label])
+ tinsert(err, lines[line])
+ tinsert(err, rep(" ", col-1) .. "^")
+ error("syntax error(s) in pattern\n" .. concat(err, "\n"), 3)
+ end
+ return cp
+end
+
+local function match (s, p, i)
+ local cp = mem[p]
+ if not cp then
+ cp = compile(p)
+ mem[p] = cp
+ end
+ return cp:match(s, i or 1)
+end
+
+local function find (s, p, i)
+ local cp = fmem[p]
+ if not cp then
+ cp = compile(p) / 0
+ cp = mm.P{ mm.Cp() * cp * mm.Cp() + 1 * mm.V(1) }
+ fmem[p] = cp
+ end
+ local i, e = cp:match(s, i or 1)
+ if i then return i, e - 1
+ else return i
+ end
+end
+
+local function gsub (s, p, rep)
+ local g = gmem[p] or {} -- ensure gmem[p] is not collected while here
+ gmem[p] = g
+ local cp = g[rep]
+ if not cp then
+ cp = compile(p)
+ cp = mm.Cs((cp / rep + 1)^0)
+ g[rep] = cp
+ end
+ return cp:match(s)
+end
+
+
+-- exported names
+local re = {
+ compile = compile,
+ match = match,
+ find = find,
+ gsub = gsub,
+ updatelocale = updatelocale,
+ calcline = calcline
+}
+
+return re
diff --git a/script/parser/split.lua b/script/parser/split.lua
new file mode 100644
index 00000000..6ce4a4e7
--- /dev/null
+++ b/script/parser/split.lua
@@ -0,0 +1,9 @@
+local m = require 'lpeglabel'
+
+local NL = m.P'\r\n' + m.S'\r\n'
+local LINE = m.C(1 - NL)
+
+return function (str)
+ local MATCH = m.Ct((LINE * NL)^0 * LINE)
+ return MATCH:match(str)
+end