summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2024-08-15 15:20:52 +0800
committerGitHub <noreply@github.com>2024-08-15 15:20:52 +0800
commit7c9a24fc2e80952e5045a2855f37b997ab93ee80 (patch)
treeaa192cf6d6ed02d02dda8c8b7a4c240dded64efb
parentb71cb7aecd9337c9463a4dfbdb9d06cac7b825fd (diff)
parent2c798703ca854d670fb28f51adc85c2b41f08f37 (diff)
downloadlua-language-server-7c9a24fc2e80952e5045a2855f37b997ab93ee80.zip
Merge pull request #2768 from NeOzay/cast-table-to-class
check that the shape of the table corresponds to the class
-rw-r--r--changelog.md1
-rw-r--r--script/vm/type.lua93
-rw-r--r--test/diagnostics/cast-local-type.lua66
-rw-r--r--test/diagnostics/param-type-mismatch.lua84
4 files changed, 239 insertions, 5 deletions
diff --git a/changelog.md b/changelog.md
index 070f0b8f..d9c15e7b 100644
--- a/changelog.md
+++ b/changelog.md
@@ -2,6 +2,7 @@
## Unreleased
<!-- Add all new changes here. They will be moved under a version at release -->
+* `NEW` Add matching checks between the shape of tables and classes, during type checking. [#2768](https://github.com/LuaLS/lua-language-server/pull/2768
* `FIX` Error `attempt to index a nil value` when `Lua.hint.semicolon == 'All'` [#2788](https://github.com/LuaLS/lua-language-server/issues/2788)
## 3.10.3
diff --git a/script/vm/type.lua b/script/vm/type.lua
index d3ce7a92..4835065a 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -148,7 +148,7 @@ end
---@param mark table
---@param errs? typecheck.err[]
---@return boolean?
-local function checkChildEnum(childName, parent , uri, mark, errs)
+local function checkChildEnum(childName, parent, uri, mark, errs)
if mark[childName] then
return
end
@@ -168,7 +168,7 @@ local function checkChildEnum(childName, parent , uri, mark, errs)
end
mark[childName] = true
for _, enum in ipairs(enums) do
- if not vm.isSubType(uri, vm.compileNode(enum), parent, mark ,errs) then
+ if not vm.isSubType(uri, vm.compileNode(enum), parent, mark, errs) then
mark[childName] = nil
return false
end
@@ -325,10 +325,30 @@ function vm.isSubType(uri, child, parent, mark, errs)
return true
else
local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
+ local skipTable
for n in child:eachObject() do
+ if skipTable == nil and n.type == "table" and parent.type == "vm.node" then -- skip table type check if child has class
+ ---@cast parent vm.node
+ for _, c in ipairs(child) do
+ if c.type == 'global' and c.cate == 'type' then
+ for _, set in ipairs(c:getSets(uri)) do
+ if set.type == 'doc.class' then
+ skipTable = true
+ break
+ end
+ end
+ end
+ if skipTable then
+ break
+ end
+ end
+ if not skipTable then
+ skipTable = false
+ end
+ end
local nodeName = vm.getNodeName(n)
if nodeName
- and not (nodeName == 'nil' and weakNil)
+ and not (nodeName == 'nil' and weakNil) and not (skipTable and n.type == 'table')
and vm.isSubType(uri, n, parent, mark, errs) == false then
if errs then
errs[#errs+1] = 'TYPE_ERROR_UNION_DISMATCH'
@@ -463,6 +483,67 @@ function vm.isSubType(uri, child, parent, mark, errs)
return true
end
if childName == 'table' and not guide.isBasicType(parentName) then
+ local set = parent:getSets(uri)
+ local missedKeys = {}
+ local failedCheck
+ local myKeys
+ for _, def in ipairs(set) do
+ if not def.fields or #def.fields == 0 then
+ goto continue
+ end
+ if not myKeys then
+ myKeys = {}
+ for _, field in ipairs(child) do
+ local key = vm.getKeyName(field) or field.tindex
+ if key then
+ myKeys[key] = vm.compileNode(field)
+ end
+ end
+ end
+
+ for _, field in ipairs(def.fields) do
+ local key = vm.getKeyName(field)
+ if not key then
+ local fieldnode = vm.compileNode(field.field)[1]
+ if fieldnode and fieldnode.type == 'doc.type.integer' then
+ ---@cast fieldnode parser.object
+ key = vm.getKeyName(fieldnode)
+ end
+ end
+ if not key then
+ goto continue
+ end
+
+ local ok
+ local nodeField = vm.compileNode(field)
+ if myKeys[key] then
+ ok = vm.isSubType(uri, myKeys[key], nodeField, mark, errs)
+ if ok == false then
+ errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH' -- error display can be greatly improved
+ errs[#errs+1] = myKeys[key]
+ errs[#errs+1] = nodeField
+ failedCheck = true
+ end
+ elseif not nodeField:isNullable() then
+ if type(key) == "number" then
+ missedKeys[#missedKeys+1] = ('`[%s]`'):format(key)
+ else
+ missedKeys[#missedKeys+1] = ('`%s`'):format(key)
+ end
+ failedCheck = true
+ end
+ end
+ ::continue::
+ end
+ if #missedKeys > 0 then
+ errs[#errs+1] = 'DIAG_MISSING_FIELDS'
+ errs[#errs+1] = parent
+ errs[#errs+1] = table.concat(missedKeys, ', ')
+ end
+ if failedCheck then
+ return false
+ end
+
return true
end
@@ -570,11 +651,11 @@ function vm.getTableValue(uri, tnode, knode, inversion)
and field.value
and field.tindex == 1 then
if inversion then
- if vm.isSubType(uri, 'integer', knode) then
+ if vm.isSubType(uri, 'integer', knode) then
result:merge(vm.compileNode(field.value))
end
else
- if vm.isSubType(uri, knode, 'integer') then
+ if vm.isSubType(uri, knode, 'integer') then
result:merge(vm.compileNode(field.value))
end
end
@@ -692,6 +773,7 @@ function vm.canCastType(uri, defNode, refNode, errs)
return true
end
+
return false
end
@@ -713,6 +795,7 @@ local ErrorMessageMap = {
TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER = {'child'},
TYPE_ERROR_NUMBER_TYPE_TO_INTEGER = {},
TYPE_ERROR_DISMATCH = {'child', 'parent'},
+ DIAG_MISSING_FIELDS = {"1", "2"},
}
---@param uri uri
diff --git a/test/diagnostics/cast-local-type.lua b/test/diagnostics/cast-local-type.lua
index f79bf48d..93452a92 100644
--- a/test/diagnostics/cast-local-type.lua
+++ b/test/diagnostics/cast-local-type.lua
@@ -332,3 +332,69 @@ local x
- 类型 `nil` 无法匹配 `'B'`
- 类型 `nil` 无法匹配 `'A'`]])
end)
+
+TEST [[
+---@class A
+---@field x string
+---@field y number
+
+local a = {x = "", y = 0}
+
+---@type A
+local v
+v = a
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y number
+
+local a = {x = ""}
+
+---@type A
+local v
+<!v!> = a
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y number
+
+local a = {x = "", y = ""}
+
+---@type A
+local v
+<!v!> = a
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y? B
+
+---@class B
+---@field x string
+
+local a = {x = "b", y = {x = "c"}}
+
+---@type A
+local v
+v = a
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y B
+
+---@class B
+---@field x string
+
+local a = {x = "b", y = {}}
+
+---@type A
+local v
+<!v!> = a
+]]
diff --git a/test/diagnostics/param-type-mismatch.lua b/test/diagnostics/param-type-mismatch.lua
index b11068db..bb602cab 100644
--- a/test/diagnostics/param-type-mismatch.lua
+++ b/test/diagnostics/param-type-mismatch.lua
@@ -264,3 +264,87 @@ local function f(v) end
f 'x'
f 'y'
]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y number
+
+local a = {x = "", y = 0}
+
+---@param a A
+function f(a) end
+
+f(a)
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y number
+
+local a = {x = ""}
+
+---@param a A
+function f(a) end
+
+f(<!a!>)
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y number
+
+local a = {x = "", y = ""}
+
+---@param a A
+function f(a) end
+
+f(<!a!>)
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y? B
+
+---@class B
+---@field x string
+
+local a = {x = "b", y = {x = "c"}}
+
+---@param a A
+function f(a) end
+
+f(a)
+]]
+
+TEST [[
+---@class A
+---@field x string
+---@field y B
+
+---@class B
+---@field x string
+
+local a = {x = "b", y = {}}
+
+---@param a A
+function f(a) end
+
+f(<!a!>)
+]]
+
+TEST [[
+---@class A
+---@field x string
+
+---@type A
+local a = {}
+
+---@param a A
+function f(a) end
+
+f(a)
+]] \ No newline at end of file