diff options
-rw-r--r-- | changelog.md | 6 | ||||
-rw-r--r-- | script/core/diagnostics/global-in-nil-env.lua | 2 | ||||
-rw-r--r-- | script/parser/guide.lua | 42 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 68 | ||||
-rw-r--r-- | script/parser/newparser.lua | 9 | ||||
-rw-r--r-- | script/vm/global-manager.lua | 3 | ||||
-rw-r--r-- | script/vm/node.lua | 9 | ||||
-rw-r--r-- | script/vm/runner.lua | 66 | ||||
-rw-r--r-- | test/type_inference/init.lua | 64 |
9 files changed, 242 insertions, 27 deletions
diff --git a/changelog.md b/changelog.md index e6ecca52..8e329946 100644 --- a/changelog.md +++ b/changelog.md @@ -23,6 +23,12 @@ local x = true local y = x--[[@as integer]] -- y is `integer` here ``` +* `NEW` add `---@cast` + * `---@cast localname type` + * `---@cast localname +type` + * `---@cast localname -type` + * `---@cast localname +?` + * `---@cast localname -?` * `NEW` generic: resolve `T[]` by `table<integer, type>` or `---@field [integer] type` * `NEW` resolve `class[1]` by `---@field [integer] type` * `NEW` diagnostic: `missing-parameter` diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index d95963e4..334fd81a 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -16,7 +16,7 @@ return function (uri, callback) local env = guide.getENV(root) local nilDefs = {} - if not env.ref then + if not env or not env.ref then return end for _, ref in ipairs(env.ref) do diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 060c5018..fffba639 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -58,6 +58,10 @@ local type = type ---@field step parser.object ---@field redundant { max: integer, passed: integer } ---@field filter parser.object +---@field loc string +---@field keyword integer[] +---@field casts parser.object[] +---@field mode? '+' | '-' ---@field hasGoTo? true ---@field hasReturn? true ---@field hasBreak? true @@ -148,6 +152,8 @@ local childMap = { ['doc.version'] = {'#versions'}, ['doc.diagnostic'] = {'#names'}, ['doc.as'] = {'as'}, + ['doc.cast'] = {'loc', '#casts'}, + ['doc.cast.block'] = {'extends'}, } ---@type table<string, fun(obj: parser.object, list: parser.object[])> @@ -420,7 +426,7 @@ function m.getUri(obj) return '' end ----@return parser.object +---@return parser.object? function m.getENV(source, start) if not start then start = 1 @@ -454,19 +460,17 @@ function m.getFunctionVarArgs(func) end --- 获取指定区块中可见的局部变量 ----@param block table ----@param name string {comment = '变量名'} ----@param pos integer {comment = '可见位置'} -function m.getLocal(block, name, pos) - block = m.getBlock(block) - for _ = 1, 10000 do - if not block then - return nil - end - local locals = block.locals - local res +---@param source parser.object +---@param name string # 变量名 +---@param pos integer # 可见位置 +---@return parser.object? +function m.getLocal(source, name, pos) + local root = m.getRoot(source) + local res + m.eachSourceContain(root, pos, function (src) + local locals = src.locals if not locals then - goto CONTINUE + return end for i = 1, #locals do local loc = locals[i] @@ -479,13 +483,8 @@ function m.getLocal(block, name, pos) end end end - if res then - return res, res - end - ::CONTINUE:: - block = m.getParentBlock(block) - end - error('guide.getLocal overstack') + end) + return res end --- 获取指定区块中所有的可见局部变量名称 @@ -610,6 +609,9 @@ local function addChilds(list, obj) end --- 遍历所有包含position的source +---@param ast parser.object +---@param position integer +---@param callback fun(src: parser.object) function m.eachSourceContain(ast, position, callback) local list = { ast } local mark = {} diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 3b50db34..e10ef356 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -53,6 +53,7 @@ Symbol <- ({} { / '...' / '[' / ']' + / '-' !'-' } {}) -> Symbol ]], { @@ -1205,6 +1206,70 @@ local docSwitch = util.switch() result.finish = getFinish() return result end) + : case 'cast' + : call(function () + local result = { + type = 'doc.cast', + start = getFinish(), + finish = getFinish(), + casts = {}, + } + + local loc = parseName('doc.cast.name', result) + if not loc then + pushWarning { + type = 'LUADOC_MISS_LOCAL_NAME', + start = getFinish(), + finish = getFinish(), + } + return result + end + + result.loc = loc + result.finish = loc.finish + + while true do + local block = { + type = 'doc.cast.block', + parent = result, + start = getFinish(), + finish = getFinish(), + } + result.casts[#result.casts+1] = block + if checkToken('symbol', '+', 1) then + block.mode = '+' + nextToken() + block.start = getStart() + block.finish = getFinish() + elseif checkToken('symbol', '-', 1) then + block.mode = '-' + nextToken() + block.start = getStart() + block.finish = getFinish() + end + + if checkToken('symbol', '?', 1) then + block.optional = true + nextToken() + block.start = block.start or getStart() + block.finish = block.finish + else + block.extends = parseType(block) + if block.extends then + block.start = block.start or block.extends.start + block.finish = block.extends.finish + end + end + + if checkToken('symbol', ',', 1) then + nextToken() + else + break + end + end + + return result + end) local function convertTokens() local tp, text = nextToken() @@ -1313,6 +1378,9 @@ local function isNextLine(binded, doc) return false end end + if doc.type == 'doc.cast' then + return false + end local lastRow = guide.rowColOf(lastDoc.finish) local newRow = guide.rowColOf(doc.start) return newRow - lastRow == 1 diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua index 4c58ead6..630c12c2 100644 --- a/script/parser/newparser.lua +++ b/script/parser/newparser.lua @@ -691,9 +691,6 @@ local function parseLocalAttrs() end local function createLocal(obj, attrs) - if not obj then - return nil - end obj.type = 'local' obj.effect = obj.finish @@ -2893,7 +2890,11 @@ local function parseLocal() pushActionIntoCurrentChunk(loc) skipSpace() parseMultiVars(loc, parseName, true) - loc.effect = lastRightPosition() + if loc.value then + loc.effect = loc.value.finish + else + loc.effect = loc.finish + end return loc end diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua index fbf12197..07c8950e 100644 --- a/script/vm/global-manager.lua +++ b/script/vm/global-manager.lua @@ -358,6 +358,9 @@ end ---@param source parser.object function m.compileAst(source) local env = guide.getENV(source) + if not env then + return + end m.compileObject(env) guide.eachSpecialOf(source, 'rawset', function (src) m.compileObject(src.parent) diff --git a/script/vm/node.lua b/script/vm/node.lua index 3d1a24bc..8b2bafd7 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -204,6 +204,15 @@ function mt:remove(name) end end +---@param node vm.node +function mt:removeNode(node) + for _, c in ipairs(node) do + if c.type == 'global' and c.cate == 'type' then + self:remove(c.name) + end + end +end + ---@return fun():vm.object function mt:eachObject() local i = 0 diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 19233964..75610ee5 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -13,15 +13,16 @@ mt.__index = mt mt.index = 1 ---@class parser.object ----@field _hasSorted boolean +---@field _casts parser.object[] ---@class vm.runner.step ----@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' +---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast' ---@field pos integer ---@field order? integer ---@field node? vm.node ---@field object? parser.object ---@field name? string +---@field cast? parser.object ---@field tag? string ---@field copy? boolean ---@field new? boolean @@ -250,21 +251,58 @@ function mt:_compileBlock(block) end end +---@return parser.object[] +function mt:_getCasts() + local root = guide.getRoot(self.loc) + if not root._casts then + root._casts = {} + local docs = root.docs + for _, doc in ipairs(docs) do + if doc.type == 'doc.cast' and doc.loc then + root._casts[#root._casts+1] = doc + end + end + end + return root._casts +end + function mt:_preCompile() + local startPos = self.loc.start + local finishPos = 0 + for _, ref in ipairs(self.loc.ref) do self.steps[#self.steps+1] = { type = 'object', object = ref, pos = ref.range or ref.start, } + if ref.start > finishPos then + finishPos = ref.start + end local block = guide.getParentBlock(ref) self:_compileBlock(block) end + for i, step in ipairs(self.steps) do if step.type ~= 'object' then step.order = i end end + + local casts = self:_getCasts() + for _, cast in ipairs(casts) do + if cast.loc[1] == self.loc[1] + and cast.start > startPos + and cast.finish < finishPos + and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then + self.steps[#self.steps+1] = { + type = 'cast', + cast = cast, + pos = cast.start, + } + end + end + table.sort(self.steps, function (a, b) if a.pos == b.pos then return (a.order or 0) < (b.order or 0) @@ -363,6 +401,30 @@ function mt:launch(callback) topNode = node elseif step.type == 'merge' then node:merge(step.ref2.node) + elseif step.type == 'cast' then + topNode = node:copy() + for _, cast in ipairs(step.cast.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() + end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) + end + end + end end end end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 2f440e9d..59ef8d1a 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -2127,3 +2127,67 @@ for _ = 1, 999 do local <?x?> end ]] + +TEST 'integer' [[ +local x + +---@cast x integer + +print(<?x?>) +]] + +TEST 'unknown' [[ +local x + +---@cast x integer + +local x +print(<?x?>) +]] + +TEST 'unknown' [[ +local x + +if true then + local x + ---@cast x integer + print(x) +end + +print(<?x?>) +]] + +TEST 'boolean|integer' [[ +local x = 1 + +---@cast x +boolean + +print(<?x?>) +]] + +TEST 'boolean' [[ +---@type integer|boolean +local x + +---@cast x -integer + +print(<?x?>) +]] + +TEST 'boolean?' [[ +---@type boolean +local x + +---@cast x +? + +print(<?x?>) +]] + +TEST 'boolean' [[ +---@type boolean? +local x + +---@cast x -? + +print(<?x?>) +]] |