Skip to content

Commit

Permalink
allow using function and table function
Browse files Browse the repository at this point in the history
  • Loading branch information
jennifersp committed Dec 17, 2024
1 parent 09a7e80 commit eca5bf7
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 35 deletions.
6 changes: 3 additions & 3 deletions enginetest/join_stats_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,12 @@ func (t TestProvider) Function(ctx *sql.Context, name string) (sql.Function, boo
return nil, false
}

func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) {
func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) {
if tf, ok := t.tableFunctions[strings.ToLower(name)]; ok {
return tf, nil
return tf, true
}

return nil, sql.ErrTableFunctionNotFound.New(name)
return nil, false
}

func (t TestProvider) WithTableFunctions(fns ...sql.TableFunction) (sql.TableFunctionProvider, error) {
Expand Down
6 changes: 3 additions & 3 deletions memory/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@ func (pro *DbProvider) ExternalStoredProcedures(_ *sql.Context, name string) ([]
}

// TableFunction implements sql.TableFunctionProvider
func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) {
func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) {
if tableFunction, ok := pro.tableFunctions[name]; ok {
return tableFunction, nil
return tableFunction, true
}

return nil, sql.ErrTableFunctionNotFound.New(name)
return nil, false
}
13 changes: 5 additions & 8 deletions sql/analyzer/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,14 @@ func (c *Catalog) ExternalStoredProcedures(ctx *sql.Context, name string) ([]sql
}

// TableFunction implements the TableFunctionProvider interface
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) {
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) {
if fp, ok := c.DbProvider.(sql.TableFunctionProvider); ok {
tf, err := fp.TableFunction(ctx, name)
if err != nil {
return nil, err
} else if tf != nil {
return tf, nil
tf, found := fp.TableFunction(ctx, name)
if found && tf != nil {
return tf, true
}
}

return nil, sql.ErrTableFunctionNotFound.New(name)
return nil, false
}

func (c *Catalog) RefreshTableStats(ctx *sql.Context, table sql.Table, db string) error {
Expand Down
6 changes: 3 additions & 3 deletions sql/catalog_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ func (t MapCatalog) Function(ctx *Context, name string) (Function, bool) {
return nil, false
}

func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, error) {
func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, bool) {
if f, ok := t.tabFuncs[name]; ok {
return f, nil
return f, true
}
return nil, fmt.Errorf("table func not found")
return nil, false
}

func (t MapCatalog) ExternalStoredProcedure(ctx *Context, name string, numOfParams int) (*ExternalStoredProcedureDetails, error) {
Expand Down
2 changes: 1 addition & 1 deletion sql/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type CollatedDatabaseProvider interface {
// always) implemented by a DatabaseProvider.
type TableFunctionProvider interface {
// TableFunction returns the table function with the name provided, case-insensitive
TableFunction(ctx *Context, name string) (TableFunction, error)
TableFunction(ctx *Context, name string) (TableFunction, bool)
// WithTableFunctions returns a new provider with (only) the list of table functions arguments
WithTableFunctions(fns ...TableFunction) (TableFunctionProvider, error)
}
Expand Down
125 changes: 125 additions & 0 deletions sql/expression/tablefunction/table_function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dtablefunctions

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"
)

var _ sql.TableFunction = &TableFunction{}
var _ sql.ExecSourceRel = &TableFunction{}

type TableFunction struct {
underlyingFunc sql.Function

args []sql.Expression
database sql.Database
funcExpr sql.Expression
}

func NewTableFunction(f sql.Function) sql.TableFunction {
return &TableFunction{
underlyingFunc: f,
}
}

func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
nt := *t
nt.database = db
nt.args = args
f, err := nt.underlyingFunc.NewInstance(args)
if err != nil {
return nil, err
}
nt.funcExpr = f
return &nt, nil
}

func (t *TableFunction) Children() []sql.Node {
return nil
}

func (t *TableFunction) Database() sql.Database {
return t.database
}

func (t *TableFunction) Expressions() []sql.Expression {
return t.funcExpr.Children()
}

func (t *TableFunction) IsReadOnly() bool {
return true
}

func (t *TableFunction) Name() string {
return t.underlyingFunc.FunctionName()
}

func (t *TableFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
v, err := t.funcExpr.Eval(ctx, r)
if err != nil {
return nil, err
}
return sql.RowsToRowIter(sql.Row{v}), nil
}

func (t *TableFunction) Resolved() bool {
for _, expr := range t.args {
return expr.Resolved()
}
return true
}

func (t *TableFunction) Schema() sql.Schema {
return sql.Schema{&sql.Column{Name: t.underlyingFunc.FunctionName(), Type: t.funcExpr.Type()}}
}

func (t *TableFunction) String() string {
var args []string
for _, expr := range t.args {
args = append(args, expr.String())
}
return fmt.Sprintf("%s(%s)", t.underlyingFunc.FunctionName(), strings.Join(args, ", "))
}

func (t *TableFunction) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 0 {
return nil, fmt.Errorf("unexpected children")
}
return t, nil
}

func (t *TableFunction) WithDatabase(database sql.Database) (sql.Node, error) {
nt := *t
nt.database = database
return &nt, nil
}

func (t *TableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
l := len(t.funcExpr.Children())
if len(exprs) != l {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), l)
}
nt := *t
nf, err := nt.funcExpr.WithChildren(exprs...)
if err != nil {
return nil, err
}
nt.funcExpr = nf
return &nt, nil
}
29 changes: 13 additions & 16 deletions sql/planbuilder/from.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package planbuilder

import (
"fmt"
dtablefunctions "github.com/dolthub/go-mysql-server/sql/expression/tablefunction"
"strings"

ast "github.com/dolthub/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -447,30 +448,26 @@ func (b *Builder) resolveTable(tab, db string, asOf interface{}) *plan.ResolvedT
func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope *scope) {
//TODO what are valid mysql table arguments
args := make([]sql.Expression, 0, len(t.Exprs))
for _, e := range t.Exprs {
switch e := e.(type) {
for _, expr := range t.Exprs {
switch e := expr.(type) {
case *ast.AliasedExpr:
expr := b.buildScalar(inScope, e.Expr)

if !e.As.IsEmpty() {
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
}

if selectExprNeedsAlias(e, expr) {
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
}

args = append(args, expr)
scalarExpr := b.buildScalar(inScope, e.Expr)
args = append(args, scalarExpr)
default:
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
}
}

utf := expression.NewUnresolvedTableFunction(t.Name, args)

tableFunction, err := b.cat.TableFunction(b.ctx, utf.Name())
if err != nil {
b.handleErr(err)
tableFunction, found := b.cat.TableFunction(b.ctx, utf.Name())
if !found {
// try getting regular function
f, funcFound := b.cat.Function(b.ctx, utf.Name())
if !funcFound {
b.handleErr(sql.ErrTableFunctionNotFound.New(utf.Name()))
}
tableFunction = dtablefunctions.NewTableFunction(f)
}

database := b.currentDb()
Expand Down
2 changes: 1 addition & 1 deletion test/test_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (c *Catalog) UnlockTables(ctx *sql.Context, id uint32) error {
return nil
}

func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) {
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) {
//TODO implement me
panic("implement me")
}
Expand Down

0 comments on commit eca5bf7

Please sign in to comment.