// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package gc

import (
	"cmd/compile/internal/types"
	"cmd/internal/src"
	"sort"
)

// typecheckswitch typechecks a switch statement.
func typecheckswitch(n *Node) {
	typecheckslice(n.Ninit.Slice(), ctxStmt)
	if n.Left != nil && n.Left.Op == OTYPESW {
		typecheckTypeSwitch(n)
	} else {
		typecheckExprSwitch(n)
	}
}

func typecheckTypeSwitch(n *Node) {
	n.Left.Right = typecheck(n.Left.Right, ctxExpr)
	t := n.Left.Right.Type
	if t != nil && !t.IsInterface() {
		yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right)
		t = nil
	}

	// We don't actually declare the type switch's guarded
	// declaration itself. So if there are no cases, we won't
	// notice that it went unused.
	if v := n.Left.Left; v != nil && !v.isBlank() && n.List.Len() == 0 {
		yyerrorl(v.Pos, "%v declared but not used", v.Sym)
	}

	var defCase, nilCase *Node
	var ts typeSet
	for _, ncase := range n.List.Slice() {
		ls := ncase.List.Slice()
		if len(ls) == 0 { // default:
			if defCase != nil {
				yyerrorl(ncase.Pos, "multiple defaults in switch (first at %v)", defCase.Line())
			} else {
				defCase = ncase
			}
		}

		for i := range ls {
			ls[i] = typecheck(ls[i], ctxExpr|ctxType)
			n1 := ls[i]
			if t == nil || n1.Type == nil {
				continue
			}

			var missing, have *types.Field
			var ptr int
			switch {
			case n1.isNil(): // case nil:
				if nilCase != nil {
					yyerrorl(ncase.Pos, "multiple nil cases in type switch (first at %v)", nilCase.Line())
				} else {
					nilCase = ncase
				}
			case n1.Op != OTYPE:
				yyerrorl(ncase.Pos, "%L is not a type", n1)
			case !n1.Type.IsInterface() && !implements(n1.Type, t, &missing, &have, &ptr) && !missing.Broke():
				if have != nil && !have.Broke() {
					yyerrorl(ncase.Pos, "impossible type switch case: %L cannot have dynamic type %v"+
						" (wrong type for %v method)\n\thave %v%S\n\twant %v%S", n.Left.Right, n1.Type, missing.Sym, have.Sym, have.Type, missing.Sym, missing.Type)
				} else if ptr != 0 {
					yyerrorl(ncase.Pos, "impossible type switch case: %L cannot have dynamic type %v"+
						" (%v method has pointer receiver)", n.Left.Right, n1.Type, missing.Sym)
				} else {
					yyerrorl(ncase.Pos, "impossible type switch case: %L cannot have dynamic type %v"+
						" (missing %v method)", n.Left.Right, n1.Type, missing.Sym)
				}
			}

			if n1.Op == OTYPE {
				ts.add(ncase.Pos, n1.Type)
			}
		}

		if ncase.Rlist.Len() != 0 {
			// Assign the clause variable's type.
			vt := t
			if len(ls) == 1 {
				if ls[0].Op == OTYPE {
					vt = ls[0].Type
				} else if ls[0].Op != OLITERAL { // TODO(mdempsky): Should be !ls[0].isNil()
					// Invalid single-type case;
					// mark variable as broken.
					vt = nil
				}
			}

			// TODO(mdempsky): It should be possible to
			// still typecheck the case body.
			if vt == nil {
				continue
			}

			nvar := ncase.Rlist.First()
			nvar.Type = vt
			nvar = typecheck(nvar, ctxExpr|ctxAssign)
			ncase.Rlist.SetFirst(nvar)
		}

		typecheckslice(ncase.Nbody.Slice(), ctxStmt)
	}
}

type typeSet struct {
	m map[string][]typeSetEntry
}

type typeSetEntry struct {
	pos src.XPos
	typ *types.Type
}

func (s *typeSet) add(pos src.XPos, typ *types.Type) {
	if s.m == nil {
		s.m = make(map[string][]typeSetEntry)
	}

	// LongString does not uniquely identify types, so we need to
	// disambiguate collisions with types.Identical.
	// TODO(mdempsky): Add a method that *is* unique.
	ls := typ.LongString()
	prevs := s.m[ls]
	for _, prev := range prevs {
		if types.Identical(typ, prev.typ) {
			yyerrorl(pos, "duplicate case %v in type switch\n\tprevious case at %s", typ, linestr(prev.pos))
			return
		}
	}
	s.m[ls] = append(prevs, typeSetEntry{pos, typ})
}

func typecheckExprSwitch(n *Node) {
	t := types.Types[TBOOL]
	if n.Left != nil {
		n.Left = typecheck(n.Left, ctxExpr)
		n.Left = defaultlit(n.Left, nil)
		t = n.Left.Type
	}

	var nilonly string
	if t != nil {
		switch {
		case t.IsMap():
			nilonly = "map"
		case t.Etype == TFUNC:
			nilonly = "func"
		case t.IsSlice():
			nilonly = "slice"

		case !IsComparable(t):
			if t.IsStruct() {
				yyerrorl(n.Pos, "cannot switch on %L (struct containing %v cannot be compared)", n.Left, IncomparableField(t).Type)
			} else {
				yyerrorl(n.Pos, "cannot switch on %L", n.Left)
			}
			t = nil
		}
	}

	var defCase *Node
	var cs constSet
	for _, ncase := range n.List.Slice() {
		ls := ncase.List.Slice()
		if len(ls) == 0 { // default:
			if defCase != nil {
				yyerrorl(ncase.Pos, "multiple defaults in switch (first at %v)", defCase.Line())
			} else {
				defCase = ncase
			}
		}

		for i := range ls {
			setlineno(ncase)
			ls[i] = typecheck(ls[i], ctxExpr)
			ls[i] = defaultlit(ls[i], t)
			n1 := ls[i]
			if t == nil || n1.Type == nil {
				continue
			}

			switch {
			case nilonly != "" && !n1.isNil():
				yyerrorl(ncase.Pos, "invalid case %v in switch (can only compare %s %v to nil)", n1, nilonly, n.Left)
			case t.IsInterface() && !n1.Type.IsInterface() && !IsComparable(n1.Type):
				yyerrorl(ncase.Pos, "invalid case %L in switch (incomparable type)", n1)
			case assignop(n1.Type, t, nil) == 0 && assignop(t, n1.Type, nil) == 0:
				if n.Left != nil {
					yyerrorl(ncase.Pos, "invalid case %v in switch on %v (mismatched types %v and %v)", n1, n.Left, n1.Type, t)
				} else {
					yyerrorl(ncase.Pos, "invalid case %v in switch (mismatched types %v and bool)", n1, n1.Type)
				}
			}

			// Don't check for duplicate bools. Although the spec allows it,
			// (1) the compiler hasn't checked it in the past, so compatibility mandates it, and
			// (2) it would disallow useful things like
			//       case GOARCH == "arm" && GOARM == "5":
			//       case GOARCH == "arm":
			//     which would both evaluate to false for non-ARM compiles.
			if !n1.Type.IsBoolean() {
				cs.add(ncase.Pos, n1, "case", "switch")
			}
		}

		typecheckslice(ncase.Nbody.Slice(), ctxStmt)
	}
}

// walkswitch walks a switch statement.
func walkswitch(sw *Node) {
	// Guard against double walk, see #25776.
	if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
		return // Was fatal, but eliminating every possible source of double-walking is hard
	}

	if sw.Left != nil && sw.Left.Op == OTYPESW {
		walkTypeSwitch(sw)
	} else {
		walkExprSwitch(sw)
	}
}

// walkExprSwitch generates an AST implementing sw.  sw is an
// expression switch.
func walkExprSwitch(sw *Node) {
	lno := setlineno(sw)

	cond := sw.Left
	sw.Left = nil

	// convert switch {...} to switch true {...}
	if cond == nil {
		cond = nodbool(true)
		cond = typecheck(cond, ctxExpr)
		cond = defaultlit(cond, nil)
	}

	// Given "switch string(byteslice)",
	// with all cases being side-effect free,
	// use a zero-cost alias of the byte slice.
	// Do this before calling walkexpr on cond,
	// because walkexpr will lower the string
	// conversion into a runtime call.
	// See issue 24937 for more discussion.
	if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
		cond.Op = OBYTES2STRTMP
	}

	cond = walkexpr(cond, &sw.Ninit)
	if cond.Op != OLITERAL {
		cond = copyexpr(cond, cond.Type, &sw.Nbody)
	}

	lineno = lno

	s := exprSwitch{
		exprname: cond,
	}

	var defaultGoto *Node
	var body Nodes
	for _, ncase := range sw.List.Slice() {
		label := autolabel(".s")
		jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))

		// Process case dispatch.
		if ncase.List.Len() == 0 {
			if defaultGoto != nil {
				Fatalf("duplicate default case not detected during typechecking")
			}
			defaultGoto = jmp
		}

		for _, n1 := range ncase.List.Slice() {
			s.Add(ncase.Pos, n1, jmp)
		}

		// Process body.
		body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
		body.Append(ncase.Nbody.Slice()...)
		if fall, pos := hasFall(ncase.Nbody.Slice()); !fall {
			br := nod(OBREAK, nil, nil)
			br.Pos = pos
			body.Append(br)
		}
	}
	sw.List.Set(nil)

	if defaultGoto == nil {
		br := nod(OBREAK, nil, nil)
		br.Pos = br.Pos.WithNotStmt()
		defaultGoto = br
	}

	s.Emit(&sw.Nbody)
	sw.Nbody.Append(defaultGoto)
	sw.Nbody.AppendNodes(&body)
	walkstmtlist(sw.Nbody.Slice())
}

// An exprSwitch walks an expression switch.
type exprSwitch struct {
	exprname *Node // value being switched on

	done    Nodes
	clauses []exprClause
}

type exprClause struct {
	pos    src.XPos
	lo, hi *Node
	jmp    *Node
}

func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) {
	c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
	if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL {
		s.clauses = append(s.clauses, c)
		return
	}

	s.flush()
	s.clauses = append(s.clauses, c)
	s.flush()
}

func (s *exprSwitch) Emit(out *Nodes) {
	s.flush()
	out.AppendNodes(&s.done)
}

func (s *exprSwitch) flush() {
	cc := s.clauses
	s.clauses = nil
	if len(cc) == 0 {
		return
	}

	// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
	// The code below is structured to implicitly handle this case
	// (e.g., sort.Slice doesn't need to invoke the less function
	// when there's only a single slice element).

	if s.exprname.Type.IsString() && len(cc) >= 2 {
		// Sort strings by length and then by value. It is
		// much cheaper to compare lengths than values, and
		// all we need here is consistency. We respect this
		// sorting below.
		sort.Slice(cc, func(i, j int) bool {
			si := strlit(cc[i].lo)
			sj := strlit(cc[j].lo)
			if len(si) != len(sj) {
				return len(si) < len(sj)
			}
			return si < sj
		})

		// runLen returns the string length associated with a
		// particular run of exprClauses.
		runLen := func(run []exprClause) int64 { return int64(len(strlit(run[0].lo))) }

		// Collapse runs of consecutive strings with the same length.
		var runs [][]exprClause
		start := 0
		for i := 1; i < len(cc); i++ {
			if runLen(cc[start:]) != runLen(cc[i:]) {
				runs = append(runs, cc[start:i])
				start = i
			}
		}
		runs = append(runs, cc[start:])

		// Perform two-level binary search.
		nlen := nod(OLEN, s.exprname, nil)
		binarySearch(len(runs), &s.done,
			func(i int) *Node {
				return nod(OLE, nlen, nodintconst(runLen(runs[i-1])))
			},
			func(i int, nif *Node) {
				run := runs[i]
				nif.Left = nod(OEQ, nlen, nodintconst(runLen(run)))
				s.search(run, &nif.Nbody)
			},
		)
		return
	}

	sort.Slice(cc, func(i, j int) bool {
		return compareOp(cc[i].lo.Val(), OLT, cc[j].lo.Val())
	})

	// Merge consecutive integer cases.
	if s.exprname.Type.IsInteger() {
		merged := cc[:1]
		for _, c := range cc[1:] {
			last := &merged[len(merged)-1]
			if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() {
				last.hi = c.lo
			} else {
				merged = append(merged, c)
			}
		}
		cc = merged
	}

	s.search(cc, &s.done)
}

func (s *exprSwitch) search(cc []exprClause, out *Nodes) {
	binarySearch(len(cc), out,
		func(i int) *Node {
			return nod(OLE, s.exprname, cc[i-1].hi)
		},
		func(i int, nif *Node) {
			c := &cc[i]
			nif.Left = c.test(s.exprname)
			nif.Nbody.Set1(c.jmp)
		},
	)
}

func (c *exprClause) test(exprname *Node) *Node {
	// Integer range.
	if c.hi != c.lo {
		low := nodl(c.pos, OGE, exprname, c.lo)
		high := nodl(c.pos, OLE, exprname, c.hi)
		return nodl(c.pos, OANDAND, low, high)
	}

	// Optimize "switch true { ...}" and "switch false { ... }".
	if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() {
		if exprname.Val().U.(bool) {
			return c.lo
		} else {
			return nodl(c.pos, ONOT, c.lo, nil)
		}
	}

	return nodl(c.pos, OEQ, exprname, c.lo)
}

func allCaseExprsAreSideEffectFree(sw *Node) bool {
	// In theory, we could be more aggressive, allowing any
	// side-effect-free expressions in cases, but it's a bit
	// tricky because some of that information is unavailable due
	// to the introduction of temporaries during order.
	// Restricting to constants is simple and probably powerful
	// enough.

	for _, ncase := range sw.List.Slice() {
		if ncase.Op != OCASE {
			Fatalf("switch string(byteslice) bad op: %v", ncase.Op)
		}
		for _, v := range ncase.List.Slice() {
			if v.Op != OLITERAL {
				return false
			}
		}
	}
	return true
}

// hasFall reports whether stmts ends with a "fallthrough" statement.
func hasFall(stmts []*Node) (bool, src.XPos) {
	// Search backwards for the index of the fallthrough
	// statement. Do not assume it'll be in the last
	// position, since in some cases (e.g. when the statement
	// list contains autotmp_ variables), one or more OVARKILL
	// nodes will be at the end of the list.

	i := len(stmts) - 1
	for i >= 0 && stmts[i].Op == OVARKILL {
		i--
	}
	if i < 0 {
		return false, src.NoXPos
	}
	return stmts[i].Op == OFALL, stmts[i].Pos
}

// walkTypeSwitch generates an AST that implements sw, where sw is a
// type switch.
func walkTypeSwitch(sw *Node) {
	var s typeSwitch
	s.facename = sw.Left.Right
	sw.Left = nil

	s.facename = walkexpr(s.facename, &sw.Ninit)
	s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody)
	s.okname = temp(types.Types[TBOOL])

	// Get interface descriptor word.
	// For empty interfaces this will be the type.
	// For non-empty interfaces this will be the itab.
	itab := nod(OITAB, s.facename, nil)

	// For empty interfaces, do:
	//     if e._type == nil {
	//         do nil case if it exists, otherwise default
	//     }
	//     h := e._type.hash
	// Use a similar strategy for non-empty interfaces.
	ifNil := nod(OIF, nil, nil)
	ifNil.Left = nod(OEQ, itab, nodnil())
	lineno = lineno.WithNotStmt() // disable statement marks after the first check.
	ifNil.Left = typecheck(ifNil.Left, ctxExpr)
	ifNil.Left = defaultlit(ifNil.Left, nil)
	// ifNil.Nbody assigned at end.
	sw.Nbody.Append(ifNil)

	// Load hash from type or itab.
	dotHash := nodSym(ODOTPTR, itab, nil)
	dotHash.Type = types.Types[TUINT32]
	dotHash.SetTypecheck(1)
	if s.facename.Type.IsEmptyInterface() {
		dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
	} else {
		dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
	}
	dotHash.SetBounded(true) // guaranteed not to fault
	s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody)

	br := nod(OBREAK, nil, nil)
	var defaultGoto, nilGoto *Node
	var body Nodes
	for _, ncase := range sw.List.Slice() {
		var caseVar *Node
		if ncase.Rlist.Len() != 0 {
			caseVar = ncase.Rlist.First()
		}

		// For single-type cases, we initialize the case
		// variable as part of the type assertion; but in
		// other cases, we initialize it in the body.
		singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE

		label := autolabel(".s")
		jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))

		if ncase.List.Len() == 0 { // default:
			if defaultGoto != nil {
				Fatalf("duplicate default case not detected during typechecking")
			}
			defaultGoto = jmp
		}

		for _, n1 := range ncase.List.Slice() {
			if n1.isNil() { // case nil:
				if nilGoto != nil {
					Fatalf("duplicate nil case not detected during typechecking")
				}
				nilGoto = jmp
				continue
			}

			if singleType {
				s.Add(n1.Type, caseVar, jmp)
			} else {
				s.Add(n1.Type, nil, jmp)
			}
		}

		body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
		if caseVar != nil && !singleType {
			l := []*Node{
				nodl(ncase.Pos, ODCL, caseVar, nil),
				nodl(ncase.Pos, OAS, caseVar, s.facename),
			}
			typecheckslice(l, ctxStmt)
			body.Append(l...)
		}
		body.Append(ncase.Nbody.Slice()...)
		body.Append(br)
	}
	sw.List.Set(nil)

	if defaultGoto == nil {
		defaultGoto = br
	}
	if nilGoto == nil {
		nilGoto = defaultGoto
	}
	ifNil.Nbody.Set1(nilGoto)

	s.Emit(&sw.Nbody)
	sw.Nbody.Append(defaultGoto)
	sw.Nbody.AppendNodes(&body)

	walkstmtlist(sw.Nbody.Slice())
}

// A typeSwitch walks a type switch.
type typeSwitch struct {
	// Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
	facename *Node // value being type-switched on
	hashname *Node // type hash of the value being type-switched on
	okname   *Node // boolean used for comma-ok type assertions

	done    Nodes
	clauses []typeClause
}

type typeClause struct {
	hash uint32
	body Nodes
}

func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) {
	var body Nodes
	if caseVar != nil {
		l := []*Node{
			nod(ODCL, caseVar, nil),
			nod(OAS, caseVar, nil),
		}
		typecheckslice(l, ctxStmt)
		body.Append(l...)
	} else {
		caseVar = nblank
	}

	// cv, ok = iface.(type)
	as := nod(OAS2, nil, nil)
	as.List.Set2(caseVar, s.okname) // cv, ok =
	dot := nod(ODOTTYPE, s.facename, nil)
	dot.Type = typ // iface.(type)
	as.Rlist.Set1(dot)
	as = typecheck(as, ctxStmt)
	as = walkexpr(as, &body)
	body.Append(as)

	// if ok { goto label }
	nif := nod(OIF, nil, nil)
	nif.Left = s.okname
	nif.Nbody.Set1(jmp)
	body.Append(nif)

	if !typ.IsInterface() {
		s.clauses = append(s.clauses, typeClause{
			hash: typehash(typ),
			body: body,
		})
		return
	}

	s.flush()
	s.done.AppendNodes(&body)
}

func (s *typeSwitch) Emit(out *Nodes) {
	s.flush()
	out.AppendNodes(&s.done)
}

func (s *typeSwitch) flush() {
	cc := s.clauses
	s.clauses = nil
	if len(cc) == 0 {
		return
	}

	sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })

	// Combine adjacent cases with the same hash.
	merged := cc[:1]
	for _, c := range cc[1:] {
		last := &merged[len(merged)-1]
		if last.hash == c.hash {
			last.body.AppendNodes(&c.body)
		} else {
			merged = append(merged, c)
		}
	}
	cc = merged

	binarySearch(len(cc), &s.done,
		func(i int) *Node {
			return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash)))
		},
		func(i int, nif *Node) {
			// TODO(mdempsky): Omit hash equality check if
			// there's only one type.
			c := cc[i]
			nif.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
			nif.Nbody.AppendNodes(&c.body)
		},
	)
}

// binarySearch constructs a binary search tree for handling n cases,
// and appends it to out. It's used for efficiently implementing
// switch statements.
//
// less(i) should return a boolean expression. If it evaluates true,
// then cases before i will be tested; otherwise, cases i and later.
//
// base(i, nif) should setup nif (an OIF node) to test case i. In
// particular, it should set nif.Left and nif.Nbody.
func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, nif *Node)) {
	const binarySearchMin = 4 // minimum number of cases for binary search

	var do func(lo, hi int, out *Nodes)
	do = func(lo, hi int, out *Nodes) {
		n := hi - lo
		if n < binarySearchMin {
			for i := lo; i < hi; i++ {
				nif := nod(OIF, nil, nil)
				base(i, nif)
				lineno = lineno.WithNotStmt()
				nif.Left = typecheck(nif.Left, ctxExpr)
				nif.Left = defaultlit(nif.Left, nil)
				out.Append(nif)
				out = &nif.Rlist
			}
			return
		}

		half := lo + n/2
		nif := nod(OIF, nil, nil)
		nif.Left = less(half)
		lineno = lineno.WithNotStmt()
		nif.Left = typecheck(nif.Left, ctxExpr)
		nif.Left = defaultlit(nif.Left, nil)
		do(lo, half, &nif.Nbody)
		do(half, hi, &nif.Rlist)
		out.Append(nif)
	}

	do(0, n, out)
}
