summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/core/semantic-tokens.lua8
-rw-r--r--script/parser/compile.lua2
-rw-r--r--script/parser/guide.lua2
-rw-r--r--script/parser/luadoc.lua4
-rw-r--r--script/provider/diagnostic.lua9
-rw-r--r--script/vm/compiler.lua67
-rw-r--r--script/vm/generic.lua15
-rw-r--r--script/vm/init.lua2
-rw-r--r--script/vm/node.lua17
9 files changed, 79 insertions, 47 deletions
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index c892aaac..8be70198 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -667,6 +667,14 @@ local Care = util.switch()
type = define.TokenTypes.keyword,
}
end)
+ : case 'doc.cast.block'
+ : call(function (source, options, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.keyword,
+ }
+ end)
: case 'doc.cast.name'
: call(function (source, options, results)
results[#results+1] = {
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index 20546e4a..cc142dfa 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -1024,7 +1024,7 @@ local function parseShortString()
fastForwardToken(currentOffset)
local right = getPosition(currentOffset - 1, 'right')
local byte = tointeger(numbers)
- if byte <= 255 then
+ if byte and byte <= 255 then
stringIndex = stringIndex + 1
stringPool[stringIndex] = schar(byte)
else
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 969eb386..56239fb1 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -240,7 +240,7 @@ local function formatNumber(n)
end
--- 是否是字面量
----@param obj parser.object
+---@param obj table
---@return boolean
function m.isLiteral(obj)
local tp = obj.type
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 847b7d37..fdf7718c 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1273,8 +1273,7 @@ local docSwitch = util.switch()
if checkToken('symbol', '?', 1) then
block.optional = true
nextToken()
- block.start = block.start or getStart()
- block.finish = block.finish
+ block.finish = getFinish()
else
block.extends = parseType(block)
if block.extends then
@@ -1286,6 +1285,7 @@ local docSwitch = util.switch()
if block.optional or block.extends then
result.casts[#result.casts+1] = block
end
+ result.finish = block.finish
if checkToken('symbol', ',', 1) then
nextToken()
diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua
index 8085713e..90421c3e 100644
--- a/script/provider/diagnostic.lua
+++ b/script/provider/diagnostic.lua
@@ -275,7 +275,12 @@ function m.doDiagnostic(uri, isScopeDiag)
local diags = {}
local lastDiag = copyDiagsWithoutSyntax(m.cache[uri])
- local function pushResult()
+ local function pushResult(isComplete)
+ -- Disable incremental diagnosis.
+ -- The current diagnosis speed is good.
+ if not isComplete then
+ return
+ end
tracy.ZoneBeginN 'mergeSyntaxAndDiags'
local _ <close> = tracy.ZoneEnd
local full = mergeDiags(syntax, lastDiag, diags)
@@ -322,7 +327,7 @@ function m.doDiagnostic(uri, isScopeDiag)
end)
lastDiag = nil
- pushResult()
+ pushResult(true)
end
---@param uri uri
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index e3c08cb9..9ed7f271 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -460,7 +460,7 @@ end
---@param func parser.object
---@param index integer
----@return vm.object?
+---@return (parser.object|vm.generic)?
function vm.getReturnOfFunction(func, index)
if func.type == 'function' then
if not func._returns then
@@ -600,6 +600,7 @@ local function getReturn(func, index, args)
for mfunc in funcNode:eachObject() do
if mfunc.type == 'function'
or mfunc.type == 'doc.type.function' then
+ ---@cast mfunc parser.object
local returnObject = vm.getReturnOfFunction(mfunc, index)
if returnObject then
local returnNode = vm.compileNode(returnObject)
@@ -680,7 +681,7 @@ local function bindAs(source)
return false
end
----@param source vm.node
+---@param source parser.object
---@param key? string|vm.global
---@param pushResult fun(source: parser.object)
function vm.compileByParentNode(source, key, ref, pushResult)
@@ -692,6 +693,7 @@ function vm.compileByParentNode(source, key, ref, pushResult)
for node in parentNode:eachObject() do
if node.type == 'global'
and node.cate == 'type'
+ ---@cast node vm.global
and not guide.isBasicType(node.name) then
hasClass = true
break
@@ -702,6 +704,7 @@ function vm.compileByParentNode(source, key, ref, pushResult)
or (
node.type == 'global'
and node.cate == 'type'
+ ---@cast node vm.global
and not guide.isBasicType(node.name)
) then
searchFieldSwitch(node.type, suri, node, key, ref, function (res, markDoc)
@@ -747,7 +750,7 @@ local function selectNode(source, list, index)
end
end
if not exp then
- vm.setNode(source, vm.getGlobal('type', 'nil'))
+ vm.setNode(source, vm.declareGlobal('type', 'nil'))
return vm.getNode(source)
end
---@type vm.node?
@@ -755,14 +758,14 @@ local function selectNode(source, list, index)
if exp.type == 'call' then
result = getReturn(exp.node, index, exp.args)
if not result then
- vm.setNode(source, vm.getGlobal('type', 'unknown'))
+ vm.setNode(source, vm.declareGlobal('type', 'unknown'))
return vm.getNode(source)
end
else
---@type vm.node
result = vm.compileNode(exp)
if result and exp.type == 'varargs' and result:isEmpty() then
- result:merge(vm.getGlobal('type', 'unknown'))
+ result:merge(vm.declareGlobal('type', 'unknown'))
end
end
if source.type == 'function.return' then
@@ -788,7 +791,7 @@ local function selectNode(source, list, index)
end
---@param source parser.object
----@param node vm.object
+---@param node vm.node.object
---@return boolean
local function isValidCallArgNode(source, node)
if source.type == 'function' then
@@ -796,7 +799,11 @@ local function isValidCallArgNode(source, node)
end
if source.type == 'table' then
return node.type == 'doc.type.table'
- or (node.type == 'global' and node.cate == 'type' and not guide.isBasicType(node.name))
+ or ( node.type == 'global'
+ and node.cate == 'type'
+ ---@cast node vm.global
+ and not guide.isBasicType(node.name)
+ )
end
if source.type == 'dummyarg' then
return true
@@ -845,12 +852,14 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for n in callNode:eachObject() do
if n.type == 'function' then
+ ---@cast n parser.object
local sign = getObjectSign(n)
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
if fn.type == 'doc.type.function' then
+ ---@cast fn parser.object
if sign then
local generic = vm.createGeneric(fn, sign)
local args = {}
@@ -866,6 +875,7 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
if n.type == 'doc.type.function' then
+ ---@cast n parser.object
local myEvent
if n.args[eventIndex] then
local argNode = vm.compileNode(n.args[eventIndex])
@@ -892,6 +902,7 @@ end
---@param arg parser.object
---@param call parser.object
---@param index? integer
+---@return vm.node?
function vm.compileCallArg(arg, call, index)
if not index then
for i, carg in ipairs(call.args) do
@@ -900,6 +911,9 @@ function vm.compileCallArg(arg, call, index)
break
end
end
+ if not index then
+ return nil
+ end
end
local callNode = vm.compileNode(call.node)
@@ -1261,6 +1275,7 @@ local compilerSwitch = util.switch()
for _, pn in ipairs(parentNode) do
if pn.type == 'global'
and pn.cate == 'type' then
+ ---@cast pn vm.global
if not guide.isBasicType(pn.name) then
vm.setNode(source, pn)
end
@@ -1349,7 +1364,10 @@ local compilerSwitch = util.switch()
for _, ref in ipairs(source.ref) do
if ref.type == 'setlocal'
and guide.getParentFunction(ref) == parentFunc then
- vm.setNode(source, vm.getNode(ref))
+ local refNode = vm.getNode(ref)
+ if refNode then
+ vm.setNode(source, refNode)
+ end
end
end
end
@@ -1490,6 +1508,7 @@ local compilerSwitch = util.switch()
end)
end
if hasGeneric then
+ ---@cast sign -?
vm.setNode(source, vm.createGeneric(rtn, sign))
else
vm.setNode(source, vm.compileNode(rtn))
@@ -1500,7 +1519,7 @@ local compilerSwitch = util.switch()
end
if lastReturn and not hasMarkDoc and lastReturn.types[1][1] == '...' then
hasMarkDoc = true
- vm.setNode(source, vm.getGlobal('type', 'unknown'))
+ vm.setNode(source, vm.declareGlobal('type', 'unknown'))
end
end
local hasReturn
@@ -1532,12 +1551,12 @@ local compilerSwitch = util.switch()
::CONTINUE::
end
if not hasKnownType and hasUnknownType then
- vm.setNode(source, vm.getGlobal('type', 'unknown'))
+ vm.setNode(source, vm.declareGlobal('type', 'unknown'))
end
end
end
if not hasMarkDoc and not hasReturn then
- vm.setNode(source, vm.getGlobal('type', 'nil'))
+ vm.setNode(source, vm.declareGlobal('type', 'nil'))
end
end)
: case 'main'
@@ -1710,7 +1729,7 @@ local compilerSwitch = util.switch()
: call(function (source)
local type = vm.getGlobal('type', source[1])
if type then
- vm.setNode(source, vm.compileNode(type))
+ vm.setNode(source, type)
end
end)
: case 'doc.type.arg'
@@ -1724,10 +1743,6 @@ local compilerSwitch = util.switch()
vm.getNode(source):addOptional()
end
end)
- : case 'generic'
- : call(function (source)
- vm.setNode(source, source)
- end)
: case 'unary'
: call(function (source)
if bindAs(source) then
@@ -1805,17 +1820,18 @@ local compilerSwitch = util.switch()
binarySwich(source.op.type, source)
end)
----@param source vm.object
+---@param source parser.object
local function compileByNode(source)
compilerSwitch(source.type, source)
end
----@param source vm.object
+---@param source parser.object
local function compileByGlobal(source)
local global = source._globalNode
if not global then
return
end
+ ---@cast source parser.object
local root = guide.getRoot(source)
local uri = guide.getUri(source)
if not root._globalBase then
@@ -1900,22 +1916,23 @@ function vm.compileNode(source)
end
end
- if source.type == 'global' then
- return source
- end
-
local cache = vm.getNode(source)
if cache ~= nil then
return cache
end
- local node = vm.createNode()
- vm.setNode(source, node, true)
+ if source.type == 'generic' then
+ vm.setNode(source, source)
+ return vm.getNode(source)
+ end
+
+ ---@cast source parser.object
+ vm.setNode(source, vm.createNode(), true)
compileByGlobal(source)
compileByNode(source)
matchCall(source)
- node = vm.getNode(source)
+ local node = vm.getNode(source)
return node
end
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index e1dcc5a4..16965fe3 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -11,13 +11,10 @@ local mt = {}
mt.__index = mt
mt.type = 'generic'
----@param source parser.object
+---@param source vm.object
---@param resolved? table<string, vm.node>
---@return parser.object
local function cloneObject(source, resolved)
- if not source then
- return nil
- end
if not resolved then
return source
end
@@ -124,8 +121,14 @@ function mt:resolve(uri, args)
local protoNode = vm.compileNode(self.proto)
local result = vm.createNode()
for nd in protoNode:eachObject() do
- local clonedNode = vm.compileNode(cloneObject(nd, resolved))
- result:merge(clonedNode)
+ if nd.type == 'global' then
+ ---@cast nd vm.global
+ result:merge(nd)
+ else
+ ---@cast nd -vm.global
+ local clonedNode = vm.compileNode(cloneObject(nd, resolved))
+ result:merge(clonedNode)
+ end
end
return result
end
diff --git a/script/vm/init.lua b/script/vm/init.lua
index 4fa65766..24e75f95 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -1,6 +1,6 @@
local vm = require 'vm.vm'
----@alias vm.object parser.object | vm.global | vm.generic
+---@alias vm.object parser.object | vm.generic
require 'vm.compiler'
require 'vm.value'
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 24f87a45..fce8c642 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -7,8 +7,10 @@ local guide = require 'parser.guide'
---@type table<vm.object, vm.node>
vm.nodeCache = {}
+---@alias vm.node.object vm.object | vm.global
+
---@class vm.node
----@field [integer] vm.object
+---@field [integer] vm.node.object
---@field [vm.object] true
local mt = {}
mt.__index = mt
@@ -17,7 +19,7 @@ mt.type = 'vm.node'
mt.optional = nil
mt.data = nil
----@param node vm.node | vm.object
+---@param node vm.node | vm.node.object
function mt:merge(node)
if not node then
return
@@ -319,7 +321,7 @@ function mt:hasName(name)
return false
end
----@return fun():vm.object
+---@return fun():vm.node.object
function mt:eachObject()
local i = 0
return function ()
@@ -334,7 +336,7 @@ function mt:copy()
end
---@param source vm.object
----@param node vm.node | vm.object
+---@param node vm.node | vm.node.object
---@param cover? boolean
---@return vm.node
function vm.setNode(source, node, cover)
@@ -345,9 +347,6 @@ function vm.setNode(source, node, cover)
log.error('Can not set nil node')
end
end
- if source.type == 'global' then
- error('Can not set node to global')
- end
if cover then
---@cast node vm.node
vm.nodeCache[source] = node
@@ -403,8 +402,8 @@ end
local ID = 0
----@param a? vm.node | vm.object
----@param b? vm.node | vm.object
+---@param a? vm.node | vm.node.object
+---@param b? vm.node | vm.node.object
---@return vm.node
function vm.createNode(a, b)
ID = ID + 1