summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md2
-rw-r--r--script/vm/node.lua13
-rw-r--r--script/vm/runner.lua31
-rw-r--r--script/vm/type.lua16
-rw-r--r--test/type_inference/init.lua9
5 files changed, 45 insertions, 26 deletions
diff --git a/changelog.md b/changelog.md
index b33ece88..38ea7d92 100644
--- a/changelog.md
+++ b/changelog.md
@@ -4,10 +4,12 @@
* `FIX` incorrect type check for generic with nil
* `FIX` [#1676]
* `FIX` [#1677]
+* `FIX` [#1679]
* `FIX` [#1680]
[#1676]: https://github.com/sumneko/lua-language-server/issues/1676
[#1677]: https://github.com/sumneko/lua-language-server/issues/1677
+[#1679]: https://github.com/sumneko/lua-language-server/issues/1679
[#1680]: https://github.com/sumneko/lua-language-server/issues/1680
## 3.6.1
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 3866e56a..c07269ab 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -246,7 +246,8 @@ function mt:remove(name)
or (c.type == 'doc.type.table' and name == 'table')
or (c.type == 'doc.type.array' and name == 'table')
or (c.type == 'doc.type.sign' and name == c.node[1])
- or (c.type == 'doc.type.function' and name == 'function') then
+ or (c.type == 'doc.type.function' and name == 'function')
+ or (c.type == 'doc.type.string' and name == 'string') then
table.remove(self, index)
self[c] = nil
end
@@ -254,9 +255,10 @@ function mt:remove(name)
return self
end
+---@param uri uri
---@param name string
-function mt:narrow(name)
- if name ~= 'nil' and self.optional == true then
+function mt:narrow(uri, name)
+ if self.optional == true then
self.optional = nil
end
for index = #self, 1, -1 do
@@ -267,12 +269,13 @@ function mt:narrow(name)
or (c.type == 'doc.type.table' and name == 'table')
or (c.type == 'doc.type.array' and name == 'table')
or (c.type == 'doc.type.sign' and name == c.node[1])
- or (c.type == 'doc.type.function' and name == 'function') then
+ or (c.type == 'doc.type.function' and name == 'function')
+ or (c.type == 'doc.type.string' and name == 'string') then
goto CONTINUE
end
if c.type == 'global' and c.cate == 'type' then
if (c.name == name)
- or (c.name == 'integer' and name == 'number') then
+ or (vm.isSubType(uri, c.name, name)) then
goto CONTINUE
end
end
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 7363f0e5..250be481 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -11,6 +11,7 @@ local guide = require 'parser.guide'
---@field _mark table
---@field _has table<parser.object, true>
---@field _main parser.object
+---@field _uri uri
local mt = {}
mt.__index = mt
mt._index = 1
@@ -168,15 +169,18 @@ function mt:_lookIntoChild(action, topNode, outNode)
-- if x == y then
topNode = self:_lookIntoChild(handler, topNode, outNode)
local checkerNode = vm.compileNode(checker)
- if action.op.type == '==' then
- topNode = checkerNode
- if outNode then
- outNode:removeNode(topNode)
- end
- else
- topNode:removeNode(checkerNode)
- if outNode then
- outNode = checkerNode
+ local checkerName = vm.getNodeName(checker)
+ if checkerName then
+ if action.op.type == '==' then
+ topNode:narrow(self._uri, checkerName)
+ if outNode then
+ outNode:removeNode(checkerNode)
+ end
+ else
+ topNode:removeNode(checkerNode)
+ if outNode then
+ outNode:narrow(self._uri, checkerName)
+ end
end
end
elseif handler.type == 'call'
@@ -189,14 +193,14 @@ function mt:_lookIntoChild(action, topNode, outNode)
-- if type(x) == 'string' then
self:_lookIntoChild(handler, topNode:copy())
if action.op.type == '==' then
- topNode:narrow(checker[1])
+ topNode:narrow(self._uri, checker[1])
if outNode then
outNode:remove(checker[1])
end
else
topNode:remove(checker[1])
if outNode then
- outNode:narrow(checker[1])
+ outNode:narrow(self._uri, checker[1])
end
end
elseif handler.type == 'getlocal'
@@ -215,14 +219,14 @@ function mt:_lookIntoChild(action, topNode, outNode)
and call.args[1].node == self._loc then
-- `local tp = type(x);if tp == 'string' then`
if action.op.type == '==' then
- topNode:narrow(checker[1])
+ topNode:narrow(self._uri, checker[1])
if outNode then
outNode:remove(checker[1])
end
else
topNode:remove(checker[1])
if outNode then
- outNode:narrow(checker[1])
+ outNode:narrow(self._uri, checker[1])
end
end
end
@@ -352,6 +356,7 @@ function vm.launchRunner(loc, callback)
_mark = {},
_has = {},
_main = main,
+ _uri = guide.getUri(loc),
_callback = callback,
}, mt)
diff --git a/script/vm/type.lua b/script/vm/type.lua
index 32b8e50b..126f0a83 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -9,7 +9,7 @@ local lang = require 'language'
---@param object vm.node.object
---@return string?
-local function getNodeName(object)
+function vm.getNodeName(object)
if object.type == 'global' and object.cate == 'type' then
---@cast object vm.global
return object.name
@@ -90,14 +90,14 @@ local function checkParentEnum(parentName, child, uri, mark, errs)
return false
else
---@cast child parser.object
- local childName = getNodeName(child)
+ local childName = vm.getNodeName(child)
if childName == 'number'
or childName == 'integer'
or childName == 'boolean'
or childName == 'string' then
for _, enum in ipairs(enums) do
for nd in vm.compileNode(enum):eachObject() do
- if childName == getNodeName(nd) and nd[1] == child[1] then
+ if childName == vm.getNodeName(nd) and nd[1] == child[1] then
return true
end
end
@@ -284,7 +284,7 @@ function vm.isSubType(uri, child, parent, mark, errs)
if config.get(uri, 'Lua.type.weakUnionCheck') then
local hasKnownType = 0
for n in child:eachObject() do
- if getNodeName(n) then
+ if vm.getNodeName(n) then
hasKnownType = hasKnownType + 1
if vm.isSubType(uri, n, parent, mark, errs) == true then
return true
@@ -303,7 +303,7 @@ function vm.isSubType(uri, child, parent, mark, errs)
else
local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
for n in child:eachObject() do
- local nodeName = getNodeName(n)
+ local nodeName = vm.getNodeName(n)
if nodeName
and not (nodeName == 'nil' and weakNil)
and vm.isSubType(uri, n, parent, mark, errs) == false then
@@ -329,7 +329,7 @@ function vm.isSubType(uri, child, parent, mark, errs)
end
---@cast child vm.node.object
- local childName = getNodeName(child)
+ local childName = vm.getNodeName(child)
if childName == 'any'
or childName == 'unknown' then
return true
@@ -349,7 +349,7 @@ function vm.isSubType(uri, child, parent, mark, errs)
elseif parent.type == 'vm.node' then
local hasKnownType = 0
for n in parent:eachObject() do
- if getNodeName(n) then
+ if vm.getNodeName(n) then
hasKnownType = hasKnownType + 1
if vm.isSubType(uri, child, n, mark, errs) == true then
return true
@@ -377,7 +377,7 @@ function vm.isSubType(uri, child, parent, mark, errs)
---@cast parent vm.node.object
- local parentName = getNodeName(parent)
+ local parentName = vm.getNodeName(parent)
if parentName == 'any'
or parentName == 'unknown' then
return true
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 40f9891d..4a60766b 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -3965,3 +3965,12 @@ local t
local <?n?> = t[2]
]]
+
+TEST 'N' [[
+---@class N: number
+local x
+
+if x == 0.1 then
+ print(<?x?>)
+end
+]]