summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/hover/table.lua100
-rw-r--r--script/parser/guide.lua1
-rw-r--r--script/vm/compiler.lua12
-rw-r--r--script/vm/runner.lua3
-rw-r--r--script/vm/vm.lua1
-rw-r--r--test/hover/init.lua2
-rw-r--r--test/type_inference/init.lua15
7 files changed, 80 insertions, 54 deletions
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index 31036edd..68272745 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -16,22 +16,34 @@ local function formatKey(key)
return ('[%s]'):format(key)
end
-local function buildAsHash(keys, typeMap, literalMap, optMap, reachMax)
+---@param uri uri
+---@param keys string[]
+---@param nodeMap table<string, vm.node>
+---@param reachMax integer
+local function buildAsHash(uri, keys, nodeMap, reachMax)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local typeView = typeMap[key]
- local literalView = literalMap[key]
+ local node = nodeMap[key]
+ local isOptional = node:isOptional()
+ if isOptional then
+ node = node:copy()
+ node:removeOptional()
+ end
+ local ifr = infer.getInfer(node)
+ local typeView = ifr:view('unknown', uri)
+ local literalView = ifr:viewLiterals()
if literalView then
lines[#lines+1] = (' %s%s: %s = %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView,
- literalView)
+ literalView
+ )
else
lines[#lines+1] = (' %s%s: %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView
)
end
@@ -43,26 +55,40 @@ local function buildAsHash(keys, typeMap, literalMap, optMap, reachMax)
return table.concat(lines, '\n')
end
-local function buildAsConst(keys, typeMap, literalMap, optMap, reachMax)
+---@param uri uri
+---@param keys string[]
+---@param nodeMap table<string, vm.node>
+---@param reachMax integer
+local function buildAsConst(uri, keys, nodeMap, reachMax)
+ local literalMap = {}
+ for _, key in ipairs(keys) do
+ literalMap[key] = infer.getInfer(nodeMap[key]):viewLiterals()
+ end
table.sort(keys, function (a, b)
return tonumber(literalMap[a]) < tonumber(literalMap[b])
end)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local typeView = typeMap[key]
+ local node = nodeMap[key]
+ local isOptional = node:isOptional()
+ if isOptional then
+ node = node:copy()
+ node:removeOptional()
+ end
+ local typeView = infer.getInfer(node):view('unknown', uri)
local literalView = literalMap[key]
if literalView then
lines[#lines+1] = (' %s%s: %s = %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView,
literalView
)
else
lines[#lines+1] = (' %s%s: %s,'):format(
formatKey(key),
- optMap[key] and '?' or '',
+ isOptional and '?' or '',
typeView
)
end
@@ -110,48 +136,25 @@ local function getKeyMap(fields)
return keys, map
end
-local function getOptMap(fields, keyMap)
- local optMap = {}
- for _, field in ipairs(fields) do
- if field.type == 'doc.field' then
- if field.optional then
- local key = vm.getKeyName(field)
- if keyMap[key] then
- optMap[key] = true
- end
- end
- end
- if field.type == 'doc.type.field' then
- if field.optional then
- local key = vm.getKeyName(field)
- if keyMap[key] then
- optMap[key] = true
- end
- end
- end
- end
- return optMap
-end
-
---@async
-local function getInferMap(fields, keyMap)
- ---@type table<string, vm.infer>
- local inferMap = {}
+local function getNodeMap(fields, keyMap)
+ ---@type table<string, vm.node>
+ local nodeMap = {}
for _, field in ipairs(fields) do
local key = vm.getKeyName(field)
if not keyMap[key] then
goto CONTINUE
end
await.delay()
- local ifr = infer.getInfer(field)
- if inferMap[key] then
- inferMap[key] = inferMap[key]:merge(ifr)
+ local node = vm.compileNode(field)
+ if nodeMap[key] then
+ nodeMap[key]:merge(node)
else
- inferMap[key] = ifr
+ nodeMap[key] = node:copy()
end
::CONTINUE::
end
- return inferMap
+ return nodeMap
end
---@async
@@ -184,19 +187,14 @@ return function (source)
end
end
- local optMap = getOptMap(fields, map)
- local inferMap = getInferMap(fields, map)
+ local nodeMap = getNodeMap(fields, map)
- local typeMap = {}
- local literalMap = {}
local isConsts = true
for i = 1, #keys do
await.delay()
local key = keys[i]
-
- typeMap[key] = inferMap[key]:view('unknown', uri)
- literalMap[key] = inferMap[key]:viewLiterals()
- if not tonumber(literalMap[key]) then
+ local literal = infer.getInfer(nodeMap[key]):viewLiterals()
+ if not tonumber(literal) then
isConsts = false
end
end
@@ -204,9 +202,9 @@ return function (source)
local result
if isConsts then
- result = buildAsConst(keys, typeMap, literalMap, optMap, reachMax)
+ result = buildAsConst(uri, keys, nodeMap, reachMax)
else
- result = buildAsHash(keys, typeMap, literalMap, optMap, reachMax)
+ result = buildAsHash(uri, keys, nodeMap, reachMax)
end
--if timeUp then
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index e28a25e9..b008c69b 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -925,6 +925,7 @@ function m.getKeyNameOfLiteral(obj)
end
end
+---@return string?
function m.getKeyName(obj)
if not obj then
return nil
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 7bb9ba1b..76766734 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1253,14 +1253,22 @@ local compilerSwitch = util.switch()
if not source.extends then
return
end
- vm.setNode(source, vm.compileNode(source.extends))
+ local fieldNode = vm.compileNode(source.extends)
+ if source.optional then
+ fieldNode:addOptional()
+ end
+ vm.setNode(source, fieldNode)
end)
: case 'doc.type.field'
: call(function (source)
if not source.extends then
return
end
- vm.setNode(source, vm.compileNode(source.extends))
+ local fieldNode = vm.compileNode(source.extends)
+ if source.optional then
+ fieldNode:addOptional()
+ end
+ vm.setNode(source, fieldNode)
end)
: case 'doc.param'
: call(function (source)
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index ed881d94..54b1d839 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -26,6 +26,9 @@ mt.index = 1
---@param filter parser.object
---@param pos integer
function mt:_compileNarrowByFilter(filter, pos)
+ if not filter then
+ return
+ end
if filter.type == 'unary' then
elseif filter.type == 'binary' then
else
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index 3c1762bf..8117d311 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -23,6 +23,7 @@ function m.getSpecial(source)
return source.special
end
+---@return string?
function m.getKeyName(source)
if not source then
return nil
diff --git a/test/hover/init.lua b/test/hover/init.lua
index 1a05d78e..79af0154 100644
--- a/test/hover/init.lua
+++ b/test/hover/init.lua
@@ -772,7 +772,7 @@ local <?t?> = {
]]
[[
local t: {
- f: file*?,
+ f?: file*,
}
]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 7a4fb754..e7560a12 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1795,3 +1795,18 @@ TEST 'integer' [[
local x = 1
x = <?x?>
]]
+
+TEST 'integer?' [[
+---@class A
+---@field x? integer
+local t
+
+t.<?x?>
+]]
+
+TEST 'integer?' [[
+---@type { x?: integer }
+local t
+
+t.<?x?>
+]]