summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md6
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua2
-rw-r--r--script/parser/guide.lua42
-rw-r--r--script/parser/luadoc.lua68
-rw-r--r--script/parser/newparser.lua9
-rw-r--r--script/vm/global-manager.lua3
-rw-r--r--script/vm/node.lua9
-rw-r--r--script/vm/runner.lua66
-rw-r--r--test/type_inference/init.lua64
9 files changed, 242 insertions, 27 deletions
diff --git a/changelog.md b/changelog.md
index e6ecca52..8e329946 100644
--- a/changelog.md
+++ b/changelog.md
@@ -23,6 +23,12 @@
local x = true
local y = x--[[@as integer]] -- y is `integer` here
```
+* `NEW` add `---@cast`
+ * `---@cast localname type`
+ * `---@cast localname +type`
+ * `---@cast localname -type`
+ * `---@cast localname +?`
+ * `---@cast localname -?`
* `NEW` generic: resolve `T[]` by `table<integer, type>` or `---@field [integer] type`
* `NEW` resolve `class[1]` by `---@field [integer] type`
* `NEW` diagnostic: `missing-parameter`
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
index d95963e4..334fd81a 100644
--- a/script/core/diagnostics/global-in-nil-env.lua
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -16,7 +16,7 @@ return function (uri, callback)
local env = guide.getENV(root)
local nilDefs = {}
- if not env.ref then
+ if not env or not env.ref then
return
end
for _, ref in ipairs(env.ref) do
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 060c5018..fffba639 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -58,6 +58,10 @@ local type = type
---@field step parser.object
---@field redundant { max: integer, passed: integer }
---@field filter parser.object
+---@field loc string
+---@field keyword integer[]
+---@field casts parser.object[]
+---@field mode? '+' | '-'
---@field hasGoTo? true
---@field hasReturn? true
---@field hasBreak? true
@@ -148,6 +152,8 @@ local childMap = {
['doc.version'] = {'#versions'},
['doc.diagnostic'] = {'#names'},
['doc.as'] = {'as'},
+ ['doc.cast'] = {'loc', '#casts'},
+ ['doc.cast.block'] = {'extends'},
}
---@type table<string, fun(obj: parser.object, list: parser.object[])>
@@ -420,7 +426,7 @@ function m.getUri(obj)
return ''
end
----@return parser.object
+---@return parser.object?
function m.getENV(source, start)
if not start then
start = 1
@@ -454,19 +460,17 @@ function m.getFunctionVarArgs(func)
end
--- 获取指定区块中可见的局部变量
----@param block table
----@param name string {comment = '变量名'}
----@param pos integer {comment = '可见位置'}
-function m.getLocal(block, name, pos)
- block = m.getBlock(block)
- for _ = 1, 10000 do
- if not block then
- return nil
- end
- local locals = block.locals
- local res
+---@param source parser.object
+---@param name string # 变量名
+---@param pos integer # 可见位置
+---@return parser.object?
+function m.getLocal(source, name, pos)
+ local root = m.getRoot(source)
+ local res
+ m.eachSourceContain(root, pos, function (src)
+ local locals = src.locals
if not locals then
- goto CONTINUE
+ return
end
for i = 1, #locals do
local loc = locals[i]
@@ -479,13 +483,8 @@ function m.getLocal(block, name, pos)
end
end
end
- if res then
- return res, res
- end
- ::CONTINUE::
- block = m.getParentBlock(block)
- end
- error('guide.getLocal overstack')
+ end)
+ return res
end
--- 获取指定区块中所有的可见局部变量名称
@@ -610,6 +609,9 @@ local function addChilds(list, obj)
end
--- 遍历所有包含position的source
+---@param ast parser.object
+---@param position integer
+---@param callback fun(src: parser.object)
function m.eachSourceContain(ast, position, callback)
local list = { ast }
local mark = {}
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 3b50db34..e10ef356 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -53,6 +53,7 @@ Symbol <- ({} {
/ '...'
/ '['
/ ']'
+ / '-' !'-'
} {})
-> Symbol
]], {
@@ -1205,6 +1206,70 @@ local docSwitch = util.switch()
result.finish = getFinish()
return result
end)
+ : case 'cast'
+ : call(function ()
+ local result = {
+ type = 'doc.cast',
+ start = getFinish(),
+ finish = getFinish(),
+ casts = {},
+ }
+
+ local loc = parseName('doc.cast.name', result)
+ if not loc then
+ pushWarning {
+ type = 'LUADOC_MISS_LOCAL_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ return result
+ end
+
+ result.loc = loc
+ result.finish = loc.finish
+
+ while true do
+ local block = {
+ type = 'doc.cast.block',
+ parent = result,
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ result.casts[#result.casts+1] = block
+ if checkToken('symbol', '+', 1) then
+ block.mode = '+'
+ nextToken()
+ block.start = getStart()
+ block.finish = getFinish()
+ elseif checkToken('symbol', '-', 1) then
+ block.mode = '-'
+ nextToken()
+ block.start = getStart()
+ block.finish = getFinish()
+ end
+
+ if checkToken('symbol', '?', 1) then
+ block.optional = true
+ nextToken()
+ block.start = block.start or getStart()
+ block.finish = block.finish
+ else
+ block.extends = parseType(block)
+ if block.extends then
+ block.start = block.start or block.extends.start
+ block.finish = block.extends.finish
+ end
+ end
+
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ break
+ end
+ end
+
+ return result
+ end)
local function convertTokens()
local tp, text = nextToken()
@@ -1313,6 +1378,9 @@ local function isNextLine(binded, doc)
return false
end
end
+ if doc.type == 'doc.cast' then
+ return false
+ end
local lastRow = guide.rowColOf(lastDoc.finish)
local newRow = guide.rowColOf(doc.start)
return newRow - lastRow == 1
diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua
index 4c58ead6..630c12c2 100644
--- a/script/parser/newparser.lua
+++ b/script/parser/newparser.lua
@@ -691,9 +691,6 @@ local function parseLocalAttrs()
end
local function createLocal(obj, attrs)
- if not obj then
- return nil
- end
obj.type = 'local'
obj.effect = obj.finish
@@ -2893,7 +2890,11 @@ local function parseLocal()
pushActionIntoCurrentChunk(loc)
skipSpace()
parseMultiVars(loc, parseName, true)
- loc.effect = lastRightPosition()
+ if loc.value then
+ loc.effect = loc.value.finish
+ else
+ loc.effect = loc.finish
+ end
return loc
end
diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua
index fbf12197..07c8950e 100644
--- a/script/vm/global-manager.lua
+++ b/script/vm/global-manager.lua
@@ -358,6 +358,9 @@ end
---@param source parser.object
function m.compileAst(source)
local env = guide.getENV(source)
+ if not env then
+ return
+ end
m.compileObject(env)
guide.eachSpecialOf(source, 'rawset', function (src)
m.compileObject(src.parent)
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 3d1a24bc..8b2bafd7 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -204,6 +204,15 @@ function mt:remove(name)
end
end
+---@param node vm.node
+function mt:removeNode(node)
+ for _, c in ipairs(node) do
+ if c.type == 'global' and c.cate == 'type' then
+ self:remove(c.name)
+ end
+ end
+end
+
---@return fun():vm.object
function mt:eachObject()
local i = 0
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 19233964..75610ee5 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -13,15 +13,16 @@ mt.__index = mt
mt.index = 1
---@class parser.object
----@field _hasSorted boolean
+---@field _casts parser.object[]
---@class vm.runner.step
----@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge'
+---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
---@field pos integer
---@field order? integer
---@field node? vm.node
---@field object? parser.object
---@field name? string
+---@field cast? parser.object
---@field tag? string
---@field copy? boolean
---@field new? boolean
@@ -250,21 +251,58 @@ function mt:_compileBlock(block)
end
end
+---@return parser.object[]
+function mt:_getCasts()
+ local root = guide.getRoot(self.loc)
+ if not root._casts then
+ root._casts = {}
+ local docs = root.docs
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.cast' and doc.loc then
+ root._casts[#root._casts+1] = doc
+ end
+ end
+ end
+ return root._casts
+end
+
function mt:_preCompile()
+ local startPos = self.loc.start
+ local finishPos = 0
+
for _, ref in ipairs(self.loc.ref) do
self.steps[#self.steps+1] = {
type = 'object',
object = ref,
pos = ref.range or ref.start,
}
+ if ref.start > finishPos then
+ finishPos = ref.start
+ end
local block = guide.getParentBlock(ref)
self:_compileBlock(block)
end
+
for i, step in ipairs(self.steps) do
if step.type ~= 'object' then
step.order = i
end
end
+
+ local casts = self:_getCasts()
+ for _, cast in ipairs(casts) do
+ if cast.loc[1] == self.loc[1]
+ and cast.start > startPos
+ and cast.finish < finishPos
+ and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
+ self.steps[#self.steps+1] = {
+ type = 'cast',
+ cast = cast,
+ pos = cast.start,
+ }
+ end
+ end
+
table.sort(self.steps, function (a, b)
if a.pos == b.pos then
return (a.order or 0) < (b.order or 0)
@@ -363,6 +401,30 @@ function mt:launch(callback)
topNode = node
elseif step.type == 'merge' then
node:merge(step.ref2.node)
+ elseif step.type == 'cast' then
+ topNode = node:copy()
+ for _, cast in ipairs(step.cast.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
+ end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
end
end
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 2f440e9d..59ef8d1a 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -2127,3 +2127,67 @@ for _ = 1, 999 do
local <?x?>
end
]]
+
+TEST 'integer' [[
+local x
+
+---@cast x integer
+
+print(<?x?>)
+]]
+
+TEST 'unknown' [[
+local x
+
+---@cast x integer
+
+local x
+print(<?x?>)
+]]
+
+TEST 'unknown' [[
+local x
+
+if true then
+ local x
+ ---@cast x integer
+ print(x)
+end
+
+print(<?x?>)
+]]
+
+TEST 'boolean|integer' [[
+local x = 1
+
+---@cast x +boolean
+
+print(<?x?>)
+]]
+
+TEST 'boolean' [[
+---@type integer|boolean
+local x
+
+---@cast x -integer
+
+print(<?x?>)
+]]
+
+TEST 'boolean?' [[
+---@type boolean
+local x
+
+---@cast x +?
+
+print(<?x?>)
+]]
+
+TEST 'boolean' [[
+---@type boolean?
+local x
+
+---@cast x -?
+
+print(<?x?>)
+]]