summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2022-06-20 20:58:56 +0800
committer最萌小汐 <sumneko@hotmail.com>2022-06-20 20:58:56 +0800
commitc65de666e0a3706b921993648910ef531c47c607 (patch)
tree034c07b07b3a4c8109e2f3adb4c6b89b83ae0358
parenta04cffb43132645f63e5a319f6ca69e0df87dcdb (diff)
downloadlua-language-server-c65de666e0a3706b921993648910ef531c47c607.zip
update
-rw-r--r--script/core/diagnostics/cast-field-type.lua9
-rw-r--r--script/vm/compiler.lua216
-rw-r--r--test/crossfile/hover.lua2
-rw-r--r--test/diagnostics/type-check.lua34
-rw-r--r--test/type_inference/init.lua22
5 files changed, 174 insertions, 109 deletions
diff --git a/script/core/diagnostics/cast-field-type.lua b/script/core/diagnostics/cast-field-type.lua
index eaab0bbe..5cf36a5a 100644
--- a/script/core/diagnostics/cast-field-type.lua
+++ b/script/core/diagnostics/cast-field-type.lua
@@ -6,6 +6,7 @@ local await = require 'await'
---@async
return function (uri, callback)
+ do return end
if not PREVIEW and not TEST then
return
end
@@ -31,7 +32,10 @@ return function (uri, callback)
for _, class in ipairs(vm.getDefs(parent)) do
if class.type == 'doc.class' then
vm.getClassFields(uri, vm.getGlobal('type', class.class[1]), key, false, function (def)
- if def.type == 'doc.field' then
+ if def.type == 'doc.field'
+ or def.type == 'setfield'
+ or def.type == 'setmethod'
+ or def.type == 'setindex' then
fieldNode:merge(vm.compileNode(def))
end
end)
@@ -50,6 +54,9 @@ return function (uri, callback)
if not key then
return nil
end
+ if not vm.canCastType(uri, vm.compileNode(ref), vm.compileNode(ref.value)) then
+ return
+ end
local parent = ref.node
local fieldNode = getParentField(parent, key)
if not fieldNode or fieldNode:isEmpty() then
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 8a9dbfbb..51999266 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -11,6 +11,61 @@ local vm = require 'vm.vm'
---@field _node vm.node
---@field _globalBase table
+-- 该函数有副作用,会给source绑定node!
+local function bindDocs(source)
+ local isParam = source.parent.type == 'funcargs'
+ or (source.parent.type == 'in' and source.finish <= source.parent.keys.finish)
+ local docs = source.bindDocs
+ for i = #docs, 1, -1 do
+ local doc = docs[i]
+ if doc.type == 'doc.type' then
+ if not isParam then
+ vm.setNode(source, vm.compileNode(doc))
+ return true
+ end
+ end
+ if doc.type == 'doc.class' then
+ if (source.type == 'local' and not isParam)
+ or (source._globalNode and guide.isSet(source))
+ or source.type == 'tablefield'
+ or source.type == 'tableindex' then
+ vm.setNode(source, vm.compileNode(doc))
+ return true
+ end
+ end
+ if doc.type == 'doc.param' then
+ if isParam and source[1] == doc.param[1] then
+ local node = vm.compileNode(doc)
+ if doc.optional then
+ node:addOptional()
+ end
+ vm.setNode(source, node)
+ return true
+ end
+ end
+ if doc.type == 'doc.module' then
+ local name = doc.module
+ local uri = rpath.findUrisByRequirePath(guide.getUri(source), name)[1]
+ if not uri then
+ return true
+ end
+ local state = files.getState(uri)
+ local ast = state and state.ast
+ if not ast then
+ return true
+ end
+ vm.setNode(source, vm.compileNode(ast))
+ return true
+ end
+ if doc.type == 'doc.overload' then
+ if not isParam then
+ vm.setNode(source, vm.compileNode(doc))
+ end
+ end
+ end
+ return false
+end
+
local searchFieldSwitch = util.switch()
: case 'table'
: call(function (suri, source, key, ref, pushResult)
@@ -63,13 +118,23 @@ local searchFieldSwitch = util.switch()
else
fields = vm.getLocalFields(node, false)
end
- if fields then
- for _, src in ipairs(fields) do
- if ref or guide.isSet(src) then
- pushResult(src)
+ if not fields then
+ return
+ end
+ local hasMarkDoc
+ for _, src in ipairs(fields) do
+ if src.bindDocs then
+ if bindDocs(src) then
+ hasMarkDoc = true
+ pushResult(src, hasMarkDoc)
end
end
end
+ if not hasMarkDoc then
+ for _, src in ipairs(fields) do
+ pushResult(src)
+ end
+ end
end)
: case 'doc.type.array'
: call(function (suri, source, key, ref, pushResult)
@@ -206,7 +271,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
if key == nil
or fieldKey == key then
if not searchedFields[fieldKey] then
- pushResult(field)
+ pushResult(field, true)
hasFounded[fieldKey] = true
end
end
@@ -219,7 +284,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
if vm.isSubType(suri, key.name, fieldNode) then
local nkey = '|' .. key.name
if not searchedFields[nkey] then
- pushResult(field)
+ pushResult(field, true)
hasFounded[nkey] = true
end
end
@@ -250,7 +315,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult)
end
end
-- check local field and global field
- if set.bindSources then
+ if not hasFounded[key] and set.bindSources then
for _, src in ipairs(set.bindSources) do
searchFieldSwitch(src.type, suri, src, key, ref, function (field)
local fieldKey = guide.getKeyName(field)
@@ -568,97 +633,47 @@ local function bindAs(source)
return false
end
--- 该函数有副作用,会给source绑定node!
-local function bindDocs(source)
- local isParam = source.parent.type == 'funcargs'
- or (source.parent.type == 'in' and source.finish <= source.parent.keys.finish)
- local docs = source.bindDocs
- for i = #docs, 1, -1 do
- local doc = docs[i]
- if doc.type == 'doc.type' then
- if not isParam then
- vm.setNode(source, vm.compileNode(doc))
- return true
- end
- end
- if doc.type == 'doc.class' then
- if (source.type == 'local' and not isParam)
- or (source._globalNode and guide.isSet(source))
- or source.type == 'tablefield'
- or source.type == 'tableindex' then
- vm.setNode(source, vm.compileNode(doc))
- return true
- end
- end
- if doc.type == 'doc.param' then
- if isParam and source[1] == doc.param[1] then
- local node = vm.compileNode(doc)
- if doc.optional then
- node:addOptional()
- end
- vm.setNode(source, node)
- return true
- end
- end
- if doc.type == 'doc.module' then
- local name = doc.module
- local uri = rpath.findUrisByRequirePath(guide.getUri(source), name)[1]
- if not uri then
- return true
- end
- local state = files.getState(uri)
- local ast = state and state.ast
- if not ast then
- return true
- end
- vm.setNode(source, vm.compileNode(ast))
- return true
- end
- if doc.type == 'doc.overload' then
- if not isParam then
- vm.setNode(source, vm.compileNode(doc))
- end
- end
- end
- return false
-end
-
-local function compileByLocalID(source)
- local sources = vm.getLocalSourcesSets(source)
- if not sources then
- return
- end
- local hasMarkDoc
- for _, src in ipairs(sources) do
- if src.bindDocs then
- if bindDocs(src) then
- hasMarkDoc = true
- vm.setNode(source, vm.compileNode(src))
- end
- end
- end
- if not hasMarkDoc then
- for _, src in ipairs(sources) do
- if src.value then
- local valueNode = vm.compileNode(src.value)
- if valueNode:hasType 'unknown' then
- vm.setNode(source, valueNode:copy():remove 'unknown')
- else
- vm.setNode(source, valueNode)
- end
- end
- end
- end
-end
-
---@param source vm.node
---@param key? any
---@param pushResult fun(source: parser.object)
function vm.compileByParentNode(source, key, ref, pushResult)
local parentNode = vm.compileNode(source)
+ local docedResults = {}
+ local commonResults = {}
local suri = guide.getUri(source)
+ local hasClass
+ for node in parentNode:eachObject() do
+ if node.type == 'global'
+ and node.cate == 'type'
+ and not guide.isBasicType(node.name) then
+ hasClass = true
+ end
+ end
for node in parentNode:eachObject() do
- searchFieldSwitch(node.type, suri, node, key, ref, pushResult)
+ if not hasClass
+ or (
+ node.type == 'global'
+ and node.cate == 'type'
+ and not guide.isBasicType(node.name)
+ ) then
+ searchFieldSwitch(node.type, suri, node, key, ref, function (res, markDoc)
+ if markDoc then
+ docedResults[#docedResults+1] = res
+ else
+ commonResults[#commonResults+1] = res
+ end
+ end)
+ end
+ end
+ if #docedResults > 0 then
+ for _, res in ipairs(docedResults) do
+ pushResult(res)
+ end
+ end
+ if #docedResults == 0 or key == nil then
+ for _, res in ipairs(commonResults) do
+ pushResult(res)
+ end
end
end
@@ -1289,27 +1304,13 @@ local compilerSwitch = util.switch()
: case 'setfield'
: case 'setmethod'
: case 'setindex'
- : call(function (source)
- compileByLocalID(source)
- local key = guide.getKeyName(source)
- if key == nil then
- return
- end
- vm.compileByParentNode(source.node, key, false, function (src)
- if src.type == 'doc.type.field'
- or src.type == 'doc.field' then
- vm.setNode(source, vm.compileNode(src))
- end
- end)
- end)
: case 'getfield'
: case 'getmethod'
: case 'getindex'
: call(function (source)
- if bindAs(source) then
+ if guide.isGet(source) and bindAs(source) then
return
end
- compileByLocalID(source)
---@type string|vm.node
local key = guide.getKeyName(source)
if key == nil and source.index then
@@ -1327,6 +1328,9 @@ local compilerSwitch = util.switch()
else
vm.compileByParentNode(source.node, key, false, function (src)
vm.setNode(source, vm.compileNode(src))
+ if src.value then
+ vm.setNode(source, vm.compileNode(src.value))
+ end
end)
end
end)
diff --git a/test/crossfile/hover.lua b/test/crossfile/hover.lua
index 08cde574..b9ffd59b 100644
--- a/test/crossfile/hover.lua
+++ b/test/crossfile/hover.lua
@@ -811,7 +811,7 @@ food.secondField = 2
},
hover = [[
```lua
-(field) Food.firstField: number = 0
+(field) Food.firstField: number
```]]}
TEST {{ path = 'a.lua', content = '', }, {
diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua
index 30e727d3..295fb32d 100644
--- a/test/diagnostics/type-check.lua
+++ b/test/diagnostics/type-check.lua
@@ -212,6 +212,40 @@ local y
TEST [[
---@class A
+local m
+
+m.x = 1
+
+---@type A
+local t
+
+<!t.x!> = true
+]]
+
+TEST [[
+---@class A
+local m
+
+---@type integer
+m.x = 1
+
+<!m.x!> = true
+]]
+
+TEST [[
+---@class A
+local mt
+
+---@type integer
+mt.x = 1
+
+function mt:init()
+ <!self.x!> = true
+end
+]]
+
+TEST [[
+---@class A
---@field x integer
---@type A
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 3a7b1227..b448e837 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -2675,7 +2675,7 @@ local t
local <?x?> = t.x
]]
-TEST 'integer' [[
+TEST 'integer|unknown' [[
local function f()
return GG
end
@@ -2719,3 +2719,23 @@ end
local <?n?> = f()
]]
+
+TEST 'integer' [[
+---@class A
+---@field x integer
+local m
+
+m.<?x?> = true
+
+print(m.x)
+]]
+
+TEST 'integer' [[
+---@class A
+---@field x integer
+local m
+
+m.x = true
+
+print(m.<?x?>)
+]]