Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions Ix/Aiur.lean
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
module

-- Stage 1 (Source) IR
public import Ix.Aiur.Goldilocks
public import Ix.Aiur.Meta
public import Ix.Aiur.Stages.Source
public import Ix.Aiur.Semantics.SourceEval
public import Ix.Aiur.Interpret

-- Stage 2 (Typed) IR
public import Ix.Aiur.Stages.Typed
public import Ix.Aiur.Semantics.TypedEval

-- Stage 3 (Simple) IR
public import Ix.Aiur.Stages.Simple

-- Stage 4 (Concrete) IR
public import Ix.Aiur.Stages.Concrete
public import Ix.Aiur.Semantics.ConcreteEval

-- Stage 5 (Bytecode)
public import Ix.Aiur.Stages.Bytecode
public import Ix.Aiur.Semantics.Flatten
public import Ix.Aiur.Semantics.BytecodeFfi
public import Ix.Aiur.Semantics.SourceEval
public import Ix.Aiur.Semantics.BytecodeEval
public import Ix.Aiur.Protocol
public import Ix.Aiur.Interpret

-- Semantic relation layer
public import Ix.Aiur.Semantics.Flatten
public import Ix.Aiur.Semantics.Relation
public import Ix.Aiur.Semantics.Compatible

-- Compiler pipeline
public import Ix.Aiur.Compiler.Check
public import Ix.Aiur.Compiler.Match
public import Ix.Aiur.Compiler.Simple
Expand All @@ -22,3 +39,33 @@ public import Ix.Aiur.Compiler.Lower
public import Ix.Aiur.Compiler.Dedup
public import Ix.Aiur.Compiler
public import Ix.Aiur.Statistics

-- Proofs
public import Ix.Aiur.Proofs.ValueEqFlatten
public import Ix.Aiur.Proofs.ConcreteEvalInversion
public import Ix.Aiur.Proofs.StructCompatible
public import Ix.Aiur.Semantics.WellFormed
public import Ix.Aiur.Proofs.DedupSound
public import Ix.Aiur.Proofs.LowerShared
public import Ix.Aiur.Proofs.LowerCalleesFromLayout
public import Ix.Aiur.Proofs.LowerSoundCore
public import Ix.Aiur.Proofs.LowerSoundControl
public import Ix.Aiur.Proofs.ConcretizeSound
public import Ix.Aiur.Proofs.ConcretizeSound.FnFree
public import Ix.Aiur.Proofs.ConcretizeSound.SizeBound
public import Ix.Aiur.Proofs.ConcretizeSound.RefClosed
public import Ix.Aiur.Proofs.ConcretizeSound.Phase4
public import Ix.Aiur.Proofs.ConcretizeSound.CtorKind
public import Ix.Aiur.Proofs.ConcretizeSound.Shapes
public import Ix.Aiur.Proofs.ConcretizeSound.Layout
public import Ix.Aiur.Proofs.ConcretizeSound.StageExtract
public import Ix.Aiur.Proofs.ConcretizeSound.RefsDt
public import Ix.Aiur.Proofs.ConcretizeSound.FirstOrder
public import Ix.Aiur.Proofs.ConcretizeSound.MonoInvariants
public import Ix.Aiur.Proofs.ConcretizeSound.TypesNotFunction
public import Ix.Aiur.Proofs.ConcretizeProgress
public import Ix.Aiur.Proofs.SimplifySound
public import Ix.Aiur.Proofs.CheckSound
public import Ix.Aiur.Proofs.CompilerPreservation
public import Ix.Aiur.Proofs.CompilerProgress
public import Ix.Aiur.Proofs.CompilerCorrect
160 changes: 95 additions & 65 deletions Ix/Aiur/Compiler/Check.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,53 +149,72 @@ def expandTypeM (visited : Std.HashSet Global) (toplevelAliases : Array TypeAlia
(t : Typ) : StateT (Std.HashMap Global Typ) (Except CheckError) Typ :=
expandTypeMBound (toplevelAliases.size + 1) visited toplevelAliases t

/-- Alias-name duplicate check. Pure fold over `typeAliases` building up the
name set; throws on first collision. -/
def mkDecls_checkAliases (typeAliases : Array TypeAlias) :
Except CheckError (Std.HashSet Global) :=
typeAliases.foldlM (init := (∅ : Std.HashSet Global))
fun allNames alias => do
if allNames.contains alias.name then
throw (.duplicatedDefinition alias.name)
pure (allNames.insert alias.name)

/-- Per-function step of `mkDecls`: duplicate-check, expand input/output types,
insert the function declaration. -/
def mkDecls_functionStep
(expandTyp : Typ → Except CheckError Typ)
(acc : Std.HashSet Global × Source.Decls) (function : Function) :
Except CheckError (Std.HashSet Global × Source.Decls) := do
let (allNames, decls) := acc
if allNames.contains function.name then
throw (.duplicatedDefinition function.name)
let inputs' ← function.inputs.mapM fun (loc, typ) => do
let typ' ← expandTyp typ
pure (loc, typ')
let output' ← expandTyp function.output
let function' := { function with inputs := inputs', output := output' }
pure (allNames.insert function.name,
decls.insert function.name (.function function'))

/-- Per-datatype step of `mkDecls`: duplicate-check the datatype + each
constructor name, expand argtypes, insert the datatype + each constructor. -/
def mkDecls_dataTypeStep
(expandTyp : Typ → Except CheckError Typ)
(acc : Std.HashSet Global × Source.Decls) (dataType : DataType) :
Except CheckError (Std.HashSet Global × Source.Decls) := do
let (allNames, decls) := acc
if allNames.contains dataType.name then
throw (.duplicatedDefinition dataType.name)
let constructors ← dataType.constructors.foldlM (init := ([] : List Constructor))
fun ctors ctor => do
let argTypes' ← ctor.argTypes.mapM expandTyp
pure (ctors.concat { ctor with argTypes := argTypes' })
let dataType' := { dataType with constructors }
let allNames' := allNames.insert dataType.name
let decls' := decls.insert dataType.name (.dataType dataType')
constructors.foldlM (init := (allNames', decls'))
fun (allNames, decls) ctor => do
let ctorName := dataType.name.pushNamespace ctor.nameHead
if allNames.contains ctorName then
throw (.duplicatedDefinition ctorName)
pure (allNames.insert ctorName,
decls.insert ctorName (.constructor dataType' ctor))

/-- Constructs a map of declarations from a toplevel, expanding all type aliases. -/
def Source.Toplevel.mkDecls (toplevel : Source.Toplevel) : Except CheckError Source.Decls := do
let mut allNames : Std.HashSet Global := {}
for alias in toplevel.typeAliases do
if allNames.contains alias.name then
throw $ .duplicatedDefinition alias.name
allNames := allNames.insert alias.name

let aliasNames ← mkDecls_checkAliases toplevel.typeAliases
let initAliasMap := {}
let (_, finalAliasMap) ← (toplevel.typeAliases.mapM fun (alias : TypeAlias) => do
let expanded ← expandTypeM {} toplevel.typeAliases alias.expansion
modify fun (aliasMap : Std.HashMap Global Typ) => aliasMap.insert alias.name expanded
).run initAliasMap

let expandTyp (typ : Typ) : Except CheckError Typ :=
(expandTypeM {} toplevel.typeAliases typ).run' finalAliasMap

let mut decls : Decls := default
for function in toplevel.functions do
if allNames.contains function.name then
throw $ .duplicatedDefinition function.name
allNames := allNames.insert function.name
let inputs' ← function.inputs.mapM fun (loc, typ) => do
let typ' ← expandTyp typ
pure (loc, typ')
let output' ← expandTyp function.output
let function' := { function with inputs := inputs', output := output' }
decls := decls.insert function.name (.function function')

for dataType in toplevel.dataTypes do
if allNames.contains dataType.name then
throw $ .duplicatedDefinition dataType.name
allNames := allNames.insert dataType.name
let mut constructors : List Constructor := []
for ctor in dataType.constructors do
let argTypes' ← ctor.argTypes.mapM expandTyp
constructors := constructors.concat { ctor with argTypes := argTypes' }
let dataType' := { dataType with constructors }
decls := decls.insert dataType.name (.dataType dataType')
for ctor in constructors do
let ctorName := dataType.name.pushNamespace ctor.nameHead
if allNames.contains ctorName then
throw $ .duplicatedDefinition ctorName
allNames := allNames.insert ctorName
decls := decls.insert ctorName (.constructor dataType' ctor)

pure decls
let afterFns ← toplevel.functions.foldlM
(init := (aliasNames, (default : Source.Decls))) (mkDecls_functionStep expandTyp)
let afterDts ← toplevel.dataTypes.foldlM
(init := afterFns) (mkDecls_dataTypeStep expandTyp)
pure afterDts.2

/-! ## Inference monad and unification -/

Expand Down Expand Up @@ -926,52 +945,62 @@ def getFunctionContext (function : Function) (decls : Decls) : CheckContext :=
typeParams := function.params }

def wellFormedDecls (decls : Decls) : Except CheckError Unit := do
let mut visited := default
for (_, decl) in decls.pairs do
match EStateM.run (wellFormedDecl decl) visited with
| .error e _ => throw e
| .ok () visited' => visited := visited'
let _ ← decls.pairs.foldlM (init := (default : Std.HashSet Global))
fun visited (_, decl) => wellFormedDecl visited decl
pure ()
where
checkUniqueParams (name : Global) (params : List String) :
EStateM CheckError (Std.HashSet Global) Unit :=
let rec go : List String → Std.HashSet String → EStateM CheckError (Std.HashSet Global) Unit
| [], _ => pure ()
Except CheckError Unit :=
let rec go : List String → Std.HashSet String → Except CheckError Unit
| [], _ => .ok ()
| p :: ps, seen =>
if seen.contains p then throw $ .duplicatedTypeParam name p
if seen.contains p then .error (.duplicatedTypeParam name p)
else go ps (seen.insert p)
go params {}
wellFormedDecl : Declaration → EStateM CheckError (Std.HashSet Global) Unit
wellFormedDecl (visited : Std.HashSet Global) :
Declaration → Except CheckError (Std.HashSet Global)
| .dataType dataType => do
let map ← get
if !map.contains dataType.name then
set $ map.insert dataType.name
if !visited.contains dataType.name then
checkUniqueParams dataType.name dataType.params
dataType.constructors.flatMap (·.argTypes) |>.forM (wellFormedType dataType.params)
dataType.constructors.flatMap (·.argTypes)
|>.forM (wellFormedType dataType.params)
.ok (visited.insert dataType.name)
else
.ok visited
| .function function => do
checkUniqueParams function.name function.params
wellFormedType function.params function.output
function.inputs.forM fun (_, typ) => wellFormedType function.params typ
| .constructor .. => pure ()
wellFormedType (params : List String) : Typ → EStateM CheckError (Std.HashSet Global) Unit
.ok visited
| .constructor .. => .ok visited
wellFormedType (params : List String) : Typ → Except CheckError Unit
| .tuple typs =>
typs.attach.forM (fun ⟨t, _⟩ => wellFormedType params t)
| .pointer pointerTyp => wellFormedType params pointerTyp
| .array t _ => wellFormedType params t
| .ref ref =>
if params.any (· == ref.toName.toString) then pure ()
-- Type-param refs are produced by the parser as `Global.init p` (single-
-- component name). Compare via `Global.init p == ref` so the predicate
-- aligns with `mkParamSubst` (which keys on `Global.init p` exactly).
if params.any (fun p => Global.init p == ref) then .ok ()
else match decls.getByKey ref with
| some (.dataType dt) =>
unless dt.params.isEmpty do throw $ .wrongNumTypeArgs ref 0 dt.params.length
| some _ => throw $ .notADataType ref
| none => throw $ .unboundGlobal ref
if dt.params.isEmpty then .ok ()
else .error (.wrongNumTypeArgs ref 0 dt.params.length)
| some _ => .error (.notADataType ref)
| none => .error (.unboundGlobal ref)
| .app g args => match decls.getByKey g with
| some (.dataType dt) => do
unless args.size == dt.params.length do
throw $ .wrongNumTypeArgs g args.size dt.params.length
args.attach.forM (fun ⟨t, _⟩ => wellFormedType params t)
| some _ => throw $ .notADataType g
| none => throw $ .unboundGlobal g
| _ => pure ()
if args.size == dt.params.length then
args.attach.forM (fun ⟨t, _⟩ => wellFormedType params t)
else
.error (.wrongNumTypeArgs g args.size dt.params.length)
| some _ => .error (.notADataType g)
| none => .error (.unboundGlobal g)
| .function ins out => do
ins.attach.forM (fun ⟨t, _⟩ => wellFormedType params t)
wellFormedType params out
| _ => .ok ()
termination_by t => sizeOf t

/-- Check a function (infer + zonk). -/
Expand All @@ -981,7 +1010,8 @@ def checkFunction (function : Function) : CheckM Typed.Function := do
unless ← unifyTyp body.typ function.output do
throw $ .typeMismatch body.typ function.output
let body ← zonkTypedTerm body
pure ⟨function.name, function.params, function.inputs, function.output, body, function.entry⟩
pure ⟨function.name, function.params, function.inputs, function.output, body, function.entry,
function.notPolyEntry⟩

end Aiur

Expand Down
64 changes: 35 additions & 29 deletions Ix/Aiur/Compiler/Concretize.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module
public import Ix.Lib
public import Ix.Aiur.Proofs.Lib
public import Ix.Aiur.Compiler.Simple
public import Ix.Aiur.Stages.Concrete

Expand Down Expand Up @@ -91,7 +91,7 @@ def Typ.toFlatName : Typ → String
termination_by t => sizeOf t
decreasing_by all_goals first | decreasing_tactic | grind

def Typ.appendNameLimbs (g : Global) : Typ → Global
@[expose, reducible] def Typ.appendNameLimbs (g : Global) : Typ → Global
| .field => g.pushNamespace "G"
| .unit => g.pushNamespace "Unit"
| Typ.ref g' =>
Expand All @@ -118,7 +118,7 @@ decreasing_by
-- 1 + sizeOf name + sizeOf args, so we need sizeOf args > 0.
have := Array.two_le_sizeOf ‹Array Typ›; grind)

def concretizeName (templateName : Global) (args : Array Typ) : Global :=
@[expose] def concretizeName (templateName : Global) (args : Array Typ) : Global :=
args.foldl Typ.appendNameLimbs templateName

/-! ## Source → Concrete pattern translation — direct, non-nested subset only.
Expand Down Expand Up @@ -195,31 +195,27 @@ tuple, a list of sub-patterns (one per field), the element types, and a
body `cb`, produce the nested `.letVar`/`.letWild` + `.proj` sequence. Used
by the single-arm tuple pattern special case of `termToConcrete`'s `.match`. -/
def destructureTuple (scrutTerm : Concrete.Term) (pats : Array Pattern)
(ts : Array Concrete.Typ) (cb : Concrete.Term) : Concrete.Term := Id.run do
let mut acc := cb
for i in [:pats.size] do
(ts : Array Concrete.Typ) (cb : Concrete.Term) : Concrete.Term :=
(List.range pats.size).foldl (init := cb) fun acc i =>
let j := pats.size - 1 - i
let p := pats[j]?.getD .wildcard
let eltTyp := ts[j]?.getD .unit
let projTerm : Concrete.Term := .proj eltTyp false scrutTerm j
acc := match p with
| .var x => .letVar acc.typ acc.escapes x projTerm acc
| _ => .letWild acc.typ acc.escapes projTerm acc
acc
match p with
| .var x => .letVar acc.typ acc.escapes x projTerm acc
| _ => .letWild acc.typ acc.escapes projTerm acc

/-- Irrefutable array destructuring: analogous to `destructureTuple` but over
a homogeneous array scrutinee, using `.get` for each element. -/
def destructureArray (scrutTerm : Concrete.Term) (pats : Array Pattern)
(eltTyp : Concrete.Typ) (cb : Concrete.Term) : Concrete.Term := Id.run do
let mut acc := cb
for i in [:pats.size] do
(eltTyp : Concrete.Typ) (cb : Concrete.Term) : Concrete.Term :=
(List.range pats.size).foldl (init := cb) fun acc i =>
let j := pats.size - 1 - i
let p := pats[j]?.getD .wildcard
let getTerm : Concrete.Term := .get eltTyp false scrutTerm j
acc := match p with
| .var x => .letVar acc.typ acc.escapes x getTerm acc
| _ => .letWild acc.typ acc.escapes getTerm acc
acc
match p with
| .var x => .letVar acc.typ acc.escapes x getTerm acc
| _ => .letWild acc.typ acc.escapes getTerm acc

/-! ## The main pass

Expand Down Expand Up @@ -950,18 +946,16 @@ substitution)`. For fully-monomorphic programs, 1 suffices. Pick a generous
bound: `decls.size + 1`. Caller can raise if polymorphism hits the ceiling. -/
def concretizeDrainFuel (decls : Typed.Decls) : Nat := decls.size + 1

/-- Specialise every polymorphic template reachable from concrete decls into a
concrete monomorphic copy, then lower the whole table to `Concrete.Decls`. -/
def Typed.Decls.concretize (decls : Typed.Decls) :
Except ConcretizeError Concrete.Decls := do
let pending := concretizeSeed decls
let initState : DrainState :=
{ pending, seen := {}, mono := {}, newFunctions := #[], newDataTypes := #[] }
let drained ← concretizeDrain decls (concretizeDrainFuel decls) initState
let monoDecls := concretizeBuild decls drained.mono
drained.newFunctions drained.newDataTypes
let emptyMono : Std.HashMap (Global × Array Typ) Global := {}
monoDecls.foldlM (init := default) fun acc (name, d) => do match d with
/-- The Step-4 lowering step: lowers one `(name, Typed.Declaration)` entry to
`Concrete.Decls` with an empty mono-map (all template instantiation is baked
into the keys by `concretizeBuild`). Named so downstream proofs can manipulate
the final `foldlM` equationally instead of through an anonymous lambda. -/
def step4Lower :
Concrete.Decls → Global × Typed.Declaration →
Except ConcretizeError Concrete.Decls :=
fun acc (name, d) => do
let emptyMono : Std.HashMap (Global × Array Typ) Global := {}
match d with
| .function f =>
let inputs ← f.inputs.mapM fun (l, t) => do
let t' ← typToConcrete emptyMono t
Expand All @@ -986,6 +980,18 @@ def Typed.Decls.concretize (decls : Typed.Decls) :
let concC : Concrete.Constructor := { nameHead := c.nameHead, argTypes }
pure (acc.insert name (.constructor concDt concC))

/-- Specialise every polymorphic template reachable from concrete decls into a
concrete monomorphic copy, then lower the whole table to `Concrete.Decls`. -/
def Typed.Decls.concretize (decls : Typed.Decls) :
Except ConcretizeError Concrete.Decls := do
let pending := concretizeSeed decls
let initState : DrainState :=
{ pending, seen := {}, mono := {}, newFunctions := #[], newDataTypes := #[] }
let drained ← concretizeDrain decls (concretizeDrainFuel decls) initState
let monoDecls := concretizeBuild decls drained.mono
drained.newFunctions drained.newDataTypes
monoDecls.foldlM (init := default) step4Lower

end Aiur

end -- @[expose] section
Expand Down
Loading